#![deny(
clippy::all,
clippy::pedantic,
missing_debug_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unused_extern_crates,
unused_import_braces,
unused_qualifications,
unused_results
)]
#![allow(clippy::missing_errors_doc)]
use log::debug;
use once_cell::sync::Lazy;
use regex::Regex;
use reqwest::{
blocking::{Client, ClientBuilder},
header::{self, HeaderMap},
};
use semver::Version;
use serde::Deserialize;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LookupError {
#[error("HTTP client error")]
HttpClient(#[from] reqwest::Error),
#[error("invalid header value")]
HeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[error("could not get header value")]
HeaderToString(#[from] reqwest::header::ToStrError),
#[error("no release found")]
NoReleases,
#[error("repository not found")]
RepositoryNotFound,
#[error("authentication error")]
AuthenticationError(u16),
#[error("received error HTTP response code")]
ErrorHttpResponse(u16),
}
type Result<T> = std::result::Result<T, LookupError>;
const DEFAULT_USER_AGENT: &str = "github.com/celeo/github_version_check";
const DEFAULT_ACCEPT_HEADER: &str = "application/vnd.github.v3+json";
const PAGINATION_REQUEST_AMOUNT: usize = 100;
static PAGE_EXTRACT_REGEX: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(\w*)page=(\d+)").expect("Could not compile regex"));
pub const DEFAULT_API_ROOT: &str = "https://api.github.com/";
fn generate_headers(token: Option<&str>) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let _prev = headers.insert(
header::USER_AGENT,
header::HeaderValue::from_str(DEFAULT_USER_AGENT)?,
);
let _prev = headers.insert(
header::ACCEPT,
header::HeaderValue::from_str(DEFAULT_ACCEPT_HEADER)?,
);
if let Some(t) = token {
let _prev = headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {}", t))?,
);
}
Ok(headers)
}
#[derive(Debug, Deserialize)]
struct GitHubReleaseItem {
tag_name: String,
}
#[derive(Debug)]
pub struct GitHub {
client: Client,
api_root: String,
}
impl GitHub {
pub fn new() -> Result<Self> {
let client = ClientBuilder::new()
.default_headers(generate_headers(None)?)
.build()?;
Ok(Self {
client,
api_root: DEFAULT_API_ROOT.to_owned(),
})
}
pub fn from_custom(api_endpoint: &str, access_token: &str) -> Result<Self> {
let client = ClientBuilder::new()
.default_headers(generate_headers(Some(access_token))?)
.build()?;
Ok(Self {
client,
api_root: api_endpoint.to_owned(),
})
}
pub fn get_all_versions(&self, repository: &str) -> Result<Vec<String>> {
let mut page = 1usize;
let mut pages = Vec::<Vec<GitHubReleaseItem>>::new();
let mut last_page: Option<usize> = None;
loop {
let query = vec![("per_page", PAGINATION_REQUEST_AMOUNT), ("page", page)];
let url = format!("{}repos/{}/releases", self.api_root, repository);
debug!(
"Querying GitHub at {}, page {} of {}",
url,
page,
last_page.map_or_else(|| String::from("?"), |p| p.to_string())
);
let request = self
.client
.request(reqwest::Method::GET, &url)
.query(&query)
.build()?;
let response = self.client.execute(request)?;
if !response.status().is_success() {
debug!(
"Got status \"{}\" from GitHub release check",
response.status()
);
let stat = response.status().as_u16();
if stat == 404 {
return Err(LookupError::RepositoryNotFound);
}
if stat == 401 || stat == 403 {
return Err(LookupError::AuthenticationError(stat));
}
return Err(LookupError::ErrorHttpResponse(stat));
}
if last_page.is_none() {
debug!("Determining last page from response headers");
last_page = get_last_page(response.headers())?;
}
pages.push(response.json()?);
page += 1;
match last_page {
Some(last) => {
if page >= last {
break;
}
}
None => {
debug!("No pagination header found (fewer than 100 releases)");
break;
}
}
}
Ok(pages
.iter()
.flat_map(|page| page.iter().map(|item| item.tag_name.clone()))
.collect())
}
pub fn get_latest_version(&self, repository: &str) -> Result<Version> {
let versions = self.get_all_versions(repository)?;
let latest = versions
.iter()
.map(|s| {
let mut s = s.clone();
if s.starts_with('v') {
s = s.chars().skip(1).collect();
}
Version::parse(&s)
})
.filter_map(std::result::Result::ok)
.max()
.ok_or(LookupError::NoReleases)?;
Ok(latest)
}
}
fn get_last_page(headers: &HeaderMap) -> Result<Option<usize>> {
let links = match headers.get("link") {
Some(l) => l.to_str()?,
None => return Ok(None),
};
for page_ref in links.split(',') {
if !page_ref.contains("rel=\"last\"") {
continue;
}
for cap_part in PAGE_EXTRACT_REGEX.captures_iter(page_ref) {
if cap_part[1].is_empty() {
let page = cap_part[2]
.parse::<usize>()
.expect("Could not get page version from regex");
return Ok(Some(page));
}
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::{get_last_page, GitHub};
use mockito::mock;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
#[test]
fn test_get_last_page_none() {
let map = HeaderMap::new();
let last = get_last_page(&map).unwrap();
assert!(last.is_none());
}
#[test]
fn test_get_last_page_some() {
let mut map = HeaderMap::new();
let _ = map.insert(
HeaderName::from_static("link"),
HeaderValue::from_static(r#"<https://api.github.com/repositories/275449421/releases?per_page=1&page=2>; rel="next", <https://api.github.com/repositories/275449421/releases?per_page=1&page=10>; rel="last""#)
);
let last = get_last_page(&map).unwrap();
assert_eq!(last, Some(10));
}
#[test]
fn test_get_all_versions_none() {
let _m = mock("GET", "/repos/foo/bar/releases")
.match_query(mockito::Matcher::Any)
.with_body("[]")
.create();
let github = GitHub::from_custom(&format!("{}/", mockito::server_url()), "").unwrap();
let versions = github.get_all_versions("foo/bar").unwrap();
assert!(versions.is_empty());
}
#[test]
fn test_get_all_versions_valid() {
let _m = mock("GET", "/repos/foo/bar/releases")
.match_query(mockito::Matcher::Any)
.with_body(
r#"[
{ "tag_name": "v1.0.0" },
{ "tag_name": "v1.9.10" },
{ "tag_name": "v0.3.0" }
]"#,
)
.create();
let github = GitHub::from_custom(&format!("{}/", mockito::server_url()), "").unwrap();
let versions = github.get_all_versions("foo/bar").unwrap();
assert_eq!(versions.len(), 3);
}
#[test]
fn test_get_latest_version_none() {
let _m = mock("GET", "/repos/foo/bar/releases")
.match_query(mockito::Matcher::Any)
.with_body("[]")
.create();
let github = GitHub::from_custom(&format!("{}/", mockito::server_url()), "").unwrap();
let version_res = github.get_latest_version("foo/bar");
assert!(version_res.is_err());
}
#[test]
fn test_get_latest_version_bad_semvers() {
let _m = mock("GET", "/repos/foo/bar/releases")
.match_query(mockito::Matcher::Any)
.with_body(
r#"[
{ "tag_name": "uhhhh" },
{ "tag_name": "v3.0.0-alpha" },
{ "tag_name": "v1.9.10" }
]"#,
)
.create();
let github = GitHub::from_custom(&format!("{}/", mockito::server_url()), "").unwrap();
let version = github.get_latest_version("foo/bar").unwrap();
assert_eq!(version, semver::Version::parse("3.0.0-alpha").unwrap());
}
}