use super::robotstxt::RobotsTxt;
use super::{ForgeApi, error};
use crate::config::SecretHeader;
use arc_swap::ArcSwap;
use httpdate::parse_http_date;
use reqwest::Response;
use reqwest::StatusCode;
use reqwest::header::{HeaderName, HeaderValue, RETRY_AFTER};
use secrecy::ExposeSecret;
use serde::de::DeserializeOwned;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use url::Url;
pub trait ApiEndpoint<Api, Value>
where
Api: ForgeApi,
Value: DeserializeOwned,
Self: Send + Sync,
{
fn url(&self) -> Url;
fn api(&self) -> &Api;
}
pub struct Net {
restrictions: RobotsTxt,
timed_out_until: ArcSwap<Option<SystemTime>>,
}
impl Net {
pub fn new(restrictions: RobotsTxt) -> Self {
Self {
restrictions,
timed_out_until: ArcSwap::new(Arc::new(None)),
}
}
pub async fn call<A, V>(&self, endpoint: &dyn ApiEndpoint<A, V>) -> Result<V, error::Error>
where
A: ForgeApi,
V: DeserializeOwned,
{
if let Some(retry_after) = self.timed_out_until.load().as_ref() {
return Err(error::Error::Throttled(Some(*retry_after)));
}
let auth_header = endpoint
.api()
.access_token_value()
.map(|token| (endpoint.api().access_token_header_name(), token));
let response = self.get(&endpoint.url(), auth_header).await?;
let status = response.status();
if status == StatusCode::FORBIDDEN
|| status == StatusCode::UNAUTHORIZED
|| status == StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS
{
return Err(error::Error::Unauthorized);
}
if status == StatusCode::NOT_FOUND {
return Err(error::Error::ResourceNotFound);
}
if status == StatusCode::TOO_MANY_REQUESTS {
let retry_after: Option<SystemTime> = response
.headers()
.get_all(RETRY_AFTER)
.iter()
.flat_map(parse_retry_after)
.max();
self.timed_out_until.store(Arc::new(retry_after));
return Err(error::Error::Throttled(retry_after));
}
if !status.is_success() {
return Err(error::Error::UnexpectedResponse);
}
let text = response
.text()
.await
.map_err(|_| error::Error::UnexpectedResponse)?;
serde_json::from_str(&text).map_err(|_| error::Error::UnexpectedResponse)
}
const fn client(&self) -> &reqwest::Client {
self.restrictions.client()
}
async fn get(
&self,
url: &Url,
auth_header: Option<(HeaderName, SecretHeader)>,
) -> Result<Response, error::Error> {
if auth_header.is_some() {
println!("↗ HTTP GET {url} (authenticated)");
} else {
println!("↗ HTTP GET {url}");
}
if self.restrictions.is_restricted(url) {
return Err(error::Error::Restricted);
}
let mut req = self.client().get(url.clone());
if let Some((key, value)) = auth_header {
let value = value.expose_secret().as_header_value();
req = req.header(key, value);
}
match self
.client()
.execute(req.build().expect("valid request"))
.await
{
Err(err) => {
println!("Request failed due to error: {err}");
Err(error::Error::NetworkFailure)
}
Ok(response) => Ok(response),
}
}
}
fn parse_retry_after(value: &HeaderValue) -> Option<SystemTime> {
let value = value.to_str().ok()?;
if let Ok(delay_secs) = value.parse::<u64>() {
return SystemTime::now().checked_add(Duration::from_secs(delay_secs));
}
if let Ok(retry_after) = parse_http_date(value) {
match retry_after.duration_since(SystemTime::now()) {
Err(_) => return None, Ok(_) => return Some(retry_after),
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_after_seconds() {
let retry_after = HeaderValue::from_str("4200").unwrap();
let actual = parse_retry_after(&retry_after).expect("Should have parsed");
let expected = SystemTime::now()
.checked_add(Duration::from_secs(4200))
.expect("Valid time");
let tolerance = Duration::from_millis(2); assert!(expected.duration_since(actual).unwrap() < tolerance);
}
#[test]
fn test_retry_after_negative_seconds() {
let retry_after = HeaderValue::from_str("-4200").unwrap();
let actual = parse_retry_after(&retry_after);
assert!(actual.is_none());
}
#[test]
fn test_retry_after_date() {
let expected_duration = Duration::from_secs(4200);
let future = SystemTime::now().checked_add(expected_duration).unwrap();
let retry_after = HeaderValue::from_str(&httpdate::fmt_http_date(future)).unwrap();
let actual = parse_retry_after(&retry_after).expect("Should have parsed and be future");
let tolerance = Duration::from_secs(1); assert!(future.duration_since(actual).unwrap() < tolerance);
}
#[test]
fn test_retry_before_date() {
let time = Duration::from_secs(4200);
let past = SystemTime::now().checked_sub(time).unwrap();
let retry_after = HeaderValue::from_str(&httpdate::fmt_http_date(past)).unwrap();
let actual = parse_retry_after(&retry_after);
assert!(actual.is_none(), "expected None, got {:?}", actual);
}
}