use futures::Stream;
pub use futures::StreamExt;
pub use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::header::{ACCEPT, AUTHORIZATION};
pub use reqwest::{StatusCode, Url};
use serde::de::DeserializeOwned;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
pub struct Response {
res: reqwest::Response,
client: Arc<ClientInner>,
config: ConfigInner,
}
impl Response {
pub async fn obj<T: DeserializeOwned>(self) -> Result<T, GHError> {
Ok(self.res.json().await?)
}
pub fn array<T: DeserializeOwned + Unpin + 'static>(self) -> impl Stream<Item = Result<T, GHError>> {
let mut res = self.res;
let client = self.client;
let config = self.config;
Box::pin(async_stream::try_stream! {
#[allow(clippy::large_futures)]
loop {
let next_link = res.headers().get("link")
.and_then(|h| h.to_str().ok())
.and_then(parse_next_link);
let items = res.json::<Vec<T>>().await?;
for item in items {
yield item;
}
match next_link {
Some(url) => res = client.raw_get(&url, &config).await?,
None => break,
}
}
})
}
pub fn headers(&self) -> &HeaderMap {
self.res.headers()
}
pub fn status(&self) -> StatusCode {
self.res.status()
}
pub fn url(&self) -> &Url {
self.res.url()
}
}
pub struct Builder {
client: Arc<ClientInner>,
config: ConfigInner,
url: String,
query_string_started: bool,
}
impl Builder {
#[must_use]
pub fn path(mut self, url_part: &'static str) -> Self {
debug_assert_eq!(url_part, url_part.trim_matches('/'));
assert!(!self.query_string_started);
self.url.push('/');
self.url.push_str(url_part);
self
}
#[must_use]
pub fn arg(mut self, arg: &str) -> Self {
if !self.query_string_started {
self.url.push('/');
}
use std::fmt::Write;
write!(&mut self.url, "{}", urlencoding::Encoded(arg)).unwrap();
self
}
#[must_use]
pub fn query(mut self, query_string: &str) -> Self {
debug_assert!(!query_string.starts_with('?'));
debug_assert!(!query_string.starts_with('&'));
self.url.push(if self.query_string_started { '&' } else { '?' });
self.url.push_str(query_string);
self.query_string_started = true;
self
}
pub async fn send(self) -> Result<Response, GHError> {
let res = Box::pin(self.client.raw_get(&self.url, &self.config)).await?;
Ok(Response {
client: self.client,
config: self.config,
res,
})
}
}
struct ClientInner {
client: reqwest::Client,
wait_until_timestamp_ms: AtomicU64,
}
#[derive(Clone)]
struct ConfigInner {
authorization_header: Option<HeaderValue>,
}
impl UnwindSafe for Client {}
impl RefUnwindSafe for Client {}
#[derive(Clone)]
pub struct Client {
inner: Arc<ClientInner>,
config: ConfigInner,
}
impl Client {
#[must_use] pub fn new_from_env() -> Self {
Self::new(std::env::var("GITHUB_TOKEN").ok().as_deref())
}
#[must_use]
pub fn new(token: Option<&str>) -> Self {
let mut default_headers = HeaderMap::with_capacity(2);
default_headers.insert(ACCEPT, HeaderValue::from_static("application/vnd.github.v3+json"));
default_headers.insert("X-GitHub-Api-Version", HeaderValue::from_static("2022-11-28"));
Self {
config: ConfigInner {
authorization_header: token.and_then(|token| HeaderValue::from_str(&format!("token {token}")).ok()),
},
inner: Arc::new(ClientInner {
client: reqwest::Client::builder()
.user_agent(concat!("rust-github-v3/{}", env!("CARGO_PKG_VERSION")))
.default_headers(default_headers)
.connect_timeout(Duration::from_secs(4))
.timeout(Duration::from_secs(20))
.build()
.unwrap(),
wait_until_timestamp_ms: AtomicU64::new(0),
}),
}
}
#[must_use]
pub fn with_authorization(&self, header: Option<&str>) -> Self {
Self {
config: ConfigInner {
authorization_header: header.and_then(|header| HeaderValue::from_str(header).ok()),
},
inner: self.inner.clone(),
}
}
#[must_use]
pub fn get(&self) -> Builder {
let mut url = String::with_capacity(100);
url.push_str("https://api.github.com");
Builder {
client: self.inner.clone(),
config: self.config.clone(),
url,
query_string_started: false,
}
}
pub fn wait_time(&self) -> Duration {
self.inner.wait_time()
}
}
impl ClientInner {
fn wait_time(&self) -> Duration {
let ts_ms = self.wait_until_timestamp_ms.load(SeqCst);
let until = SystemTime::UNIX_EPOCH + Duration::from_millis(ts_ms);
until.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO)
}
async fn raw_get(&self, url: &str, config: &ConfigInner) -> Result<reqwest::Response, GHError> {
debug_assert!(url.starts_with("https://api.github.com/"));
let mut retries = 5u8;
let mut retry_delay = 1;
loop {
let wait_duration = self.wait_time();
if wait_duration > Duration::ZERO {
tokio::time::sleep(wait_duration).await;
}
let mut req = self.client.get(url);
if let Some(auth) = &config.authorization_header {
req = req.header(AUTHORIZATION, auth);
}
let res = req.send().await?;
let headers = res.headers();
let status = res.status();
let now = SystemTime::now();
let wait_duration = match (Self::rate_limit_remaining(headers), Self::rate_limit_reset(headers)) {
(Some(rl), Some(rs)) => {
rs.duration_since(now).ok()
.and_then(|d| d.checked_div(rl + 2))
.unwrap_or(Duration::ZERO)
.min(Duration::from_secs(30)) }
_ => Duration::from_secs(if status == StatusCode::TOO_MANY_REQUESTS {3} else {0}),
};
let wait_until = now + wait_duration;
let wait_until_timestamp_ms = wait_until.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_millis() as u64;
self.wait_until_timestamp_ms.store(wait_until_timestamp_ms, SeqCst);
let should_wait_for_content = status == StatusCode::ACCEPTED;
if should_wait_for_content && retries > 0 {
tokio::time::sleep(Duration::from_secs(retry_delay)).await;
retry_delay *= 2;
retries -= 1;
continue;
}
return if status.is_success() && !should_wait_for_content {
Ok(res)
} else {
Err(error_for_response(res).await)
};
}
}
pub fn rate_limit_remaining(headers: &HeaderMap) -> Option<u32> {
headers.get("x-ratelimit-remaining")
.and_then(|s| s.to_str().ok())
.and_then(|s| s.parse().ok())
}
pub fn rate_limit_reset(headers: &HeaderMap) -> Option<SystemTime> {
headers
.get("x-ratelimit-reset")
.and_then(|s| s.to_str().ok())
.and_then(|s| s.parse().ok())
.map(|s| SystemTime::UNIX_EPOCH + Duration::from_secs(s))
}
}
async fn error_for_response(res: reqwest::Response) -> GHError {
let status = res.status();
let mime = res.headers().get("content-type").and_then(|h| h.to_str().ok()).unwrap_or("");
GHError::Response {
status,
message: if mime.starts_with("application/json") {
res.json::<GitHubErrorResponse>().await.ok().map(|res| res.message)
} else {
None
},
}
}
fn parse_next_link(link: &str) -> Option<String> {
for part in link.split(',') {
if part.contains(r#"; rel="next""#) {
if let Some(start) = part.find('<') {
let next_link = &part[start + 1..];
if let Some(end) = next_link.find('>') {
return Some(next_link[..end].to_owned());
}
}
}
}
None
}
#[derive(serde::Deserialize)]
struct GitHubErrorResponse {
message: String,
}
use thiserror::Error;
#[derive(Error, Debug)]
pub enum GHError {
#[error("Request timed out")]
Timeout,
#[error("Request error: {}", _0)]
Request(String),
#[error("{} ({})", message.as_deref().unwrap_or("HTTP error"), status)]
Response { status: StatusCode, message: Option<String> },
#[error("Internal error")]
Internal,
}
impl From<reqwest::Error> for GHError {
fn from(e: reqwest::Error) -> Self {
if e.is_timeout() {
return Self::Timeout;
}
if let Some(status) = e.status() {
Self::Response {
status,
message: Some(e.to_string()),
}
} else {
Self::Request(e.to_string())
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn req_test() {
let gh = Client::new_from_env();
gh.get().path("users/octocat/orgs").send().await.unwrap();
}
#[test]
fn parse_next_link_test() {
let example = "\"<https://api.github.com/organizations/fakeid/repos?page=1>; rel=\"prev\", <https://api.github.com/organizations/fakeid/repos?page=3>; rel=\"next\", <https://api.github.com/organizations/fakeid/repos?page=38>; rel=\"last\", <https://api.github.com/organizations/fakeid/repos?page=1>; rel=\"first\"";
let expected = Some(String::from("https://api.github.com/organizations/fakeid/repos?page=3"));
assert_eq!(parse_next_link(example), expected);
}
}