use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use bytes::{Bytes, BytesMut};
use futures_util::StreamExt;
use reqwest::redirect::Policy;
use reqwest::{Client, ClientBuilder, Url};
use thiserror::Error;
use crate::{PDF_MAX_BYTES, VERSION};
const PDF_MAGIC: [u8; 5] = [0x25, 0x50, 0x44, 0x46, 0x2D];
const MAX_REDIRECTS: usize = 10;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const READ_TIMEOUT: Duration = Duration::from_secs(60);
const TOTAL_TIMEOUT: Duration = Duration::from_secs(300);
const MAX_FETCH_RETRIES: u32 = 3;
const RETRY_BASE_DELAY: Duration = Duration::from_millis(500);
const RETRY_MAX_DELAY: Duration = Duration::from_secs(30);
fn is_transient_status(code: u16) -> bool {
matches!(code, 408 | 429 | 500 | 502 | 503 | 504)
}
fn reqwest_is_transient(e: &reqwest::Error) -> bool {
(e.is_timeout() || e.is_connect() || e.is_body()) && !e.is_redirect()
}
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
let secs: u64 = headers
.get(reqwest::header::RETRY_AFTER)?
.to_str()
.ok()?
.trim()
.parse()
.ok()?;
Some(Duration::from_secs(secs).min(RETRY_MAX_DELAY))
}
fn backoff_delay(attempt: u32) -> Duration {
let factor = 1u64 << attempt.min(20);
let base_ms = RETRY_BASE_DELAY.as_millis() as u64;
let capped_ms = base_ms
.saturating_mul(factor)
.min(RETRY_MAX_DELAY.as_millis() as u64);
let jitter_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| (d.subsec_nanos() as u64) % base_ms.max(1))
.unwrap_or(0);
Duration::from_millis(capped_ms.saturating_add(jitter_ms))
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SourceAllowlist {
pub source: String,
pub redirect_hosts: Vec<String>,
}
impl SourceAllowlist {
pub fn new(source: impl Into<String>, redirect_hosts: Vec<String>) -> Self {
Self {
source: source.into(),
redirect_hosts,
}
}
pub fn matches(&self, host: &str) -> bool {
let host_lc = host.to_ascii_lowercase();
self.redirect_hosts
.iter()
.any(|pat| host_matches_pattern(&host_lc, pat))
}
}
fn host_matches_pattern(host: &str, pattern: &str) -> bool {
let pat_lc = pattern.to_ascii_lowercase();
if let Some(suffix) = pat_lc.strip_prefix("*.") {
host == suffix || host.ends_with(&format!(".{}", suffix))
} else {
host == pat_lc
}
}
pub fn tier_1_allowlist() -> Vec<SourceAllowlist> {
vec![
SourceAllowlist::new(
"crossref",
vec!["api.crossref.org".to_string(), "*.crossref.org".to_string()],
),
SourceAllowlist::new("unpaywall", vec!["api.unpaywall.org".to_string()]),
SourceAllowlist::new(
"arxiv",
vec![
"arxiv.org".to_string(),
"export.arxiv.org".to_string(),
"*.arxiv.org".to_string(),
],
),
]
}
pub fn tier_2_allowlist() -> Vec<SourceAllowlist> {
vec![
SourceAllowlist::new("openalex", vec!["api.openalex.org".to_string()]),
SourceAllowlist::new(
"semantic_scholar",
vec!["api.semanticscholar.org".to_string()],
),
SourceAllowlist::new(
"doaj",
vec!["doaj.org".to_string(), "*.doaj.org".to_string()],
),
]
}
#[cfg(feature = "tdm-springer")]
pub fn tier_3_springer_allowlist() -> Vec<SourceAllowlist> {
vec![SourceAllowlist::new(
"tdm-springer",
vec![
"api.springernature.com".to_string(),
"*.springernature.com".to_string(),
],
)]
}
#[cfg(feature = "tdm-aps")]
pub fn tier_3_aps_allowlist() -> Vec<SourceAllowlist> {
vec![SourceAllowlist::new(
"tdm-aps",
vec!["harvest.aps.org".to_string(), "*.aps.org".to_string()],
)]
}
#[cfg(feature = "tdm-elsevier")]
pub fn tier_3_elsevier_allowlist() -> Vec<SourceAllowlist> {
vec![SourceAllowlist::new(
"tdm-elsevier",
vec!["api.elsevier.com".to_string(), "*.elsevier.com".to_string()],
)]
}
pub fn oa_publisher_allowlist() -> Vec<SourceAllowlist> {
vec![SourceAllowlist::new(
"oa-publisher",
vec![
"*.springer.com".to_string(),
"*.springeropen.com".to_string(),
"*.springernature.com".to_string(),
"*.nature.com".to_string(),
"*.wiley.com".to_string(),
"*.elsevier.com".to_string(),
"*.sciencedirect.com".to_string(),
"*.frontiersin.org".to_string(),
"*.mdpi.com".to_string(),
"*.plos.org".to_string(),
"*.biorxiv.org".to_string(),
"*.medrxiv.org".to_string(),
"europepmc.org".to_string(),
"*.europepmc.org".to_string(),
"*.nih.gov".to_string(),
"*.ncbi.nlm.nih.gov".to_string(),
"arxiv.org".to_string(),
"*.arxiv.org".to_string(),
],
)]
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum HttpError {
#[error("network error: {0}")]
Network(#[from] reqwest::Error),
#[error("redirect target {host} not in allowlist for source {source_key}")]
RedirectDenied {
source_key: String,
host: String,
expected_hosts: Vec<String>,
},
#[error("redirect to non-HTTPS scheme: {scheme}")]
InsecureRedirect {
scheme: String,
},
#[error("body too large: {actual} bytes (cap = {cap})")]
OversizedBody {
actual: u64,
cap: u64,
},
#[error("PDF magic-byte mismatch: got {got:?}")]
NotAPdf {
got: [u8; 5],
},
#[error("HTTP {status} from {url}")]
HttpStatus {
status: u16,
url: String,
},
#[error("no allowlist registered for source {source_key}")]
UnknownSource {
source_key: String,
},
#[error("invalid HTTP header `{name}`: {reason}")]
InvalidHeader {
name: String,
reason: String,
},
}
impl From<&HttpError> for Option<crate::DenialContext> {
fn from(e: &HttpError) -> Self {
use crate::{DenialContext, DenialReason};
match e {
HttpError::RedirectDenied {
source_key,
host,
expected_hosts,
} => Some(DenialContext {
reason: DenialReason::RedirectNotInAllowlist,
source: Some(source_key.clone()),
attempted: Some(host.clone()),
expected: Some(expected_hosts.clone()),
hop_index: None,
cap: None,
actual: None,
}),
HttpError::OversizedBody { actual, cap } => Some(DenialContext {
reason: DenialReason::SizeCapExceeded,
source: None,
attempted: None,
expected: None,
hop_index: None,
cap: Some(*cap),
actual: Some(*actual),
}),
HttpError::NotAPdf { got } => Some(DenialContext {
reason: DenialReason::ContentTypeMismatch,
source: None,
attempted: Some(format!(
"{:02x}{:02x}{:02x}{:02x}{:02x}",
got[0], got[1], got[2], got[3], got[4]
)),
expected: Some(vec!["%PDF-".to_string()]),
hop_index: None,
cap: None,
actual: None,
}),
HttpError::InsecureRedirect { scheme } => Some(DenialContext {
reason: DenialReason::InsecureScheme,
source: None,
attempted: Some(format!("{}:...", scheme)),
expected: Some(vec!["https".to_string()]),
hop_index: None,
cap: None,
actual: None,
}),
HttpError::Network(e) => {
let mut source: Option<&(dyn std::error::Error + 'static)> =
std::error::Error::source(e);
while let Some(s) = source {
if let Some(http_err) = s.downcast_ref::<HttpError>() {
return Option::<crate::DenialContext>::from(http_err);
}
source = s.source();
}
None
}
HttpError::HttpStatus { .. }
| HttpError::UnknownSource { .. }
| HttpError::InvalidHeader { .. } => None,
}
}
}
#[derive(Clone, Debug)]
pub struct HttpClient {
clients: Arc<HashMap<String, Client>>,
allowlists: Arc<HashMap<String, SourceAllowlist>>,
}
impl HttpClient {
pub fn new(allowlists: Vec<SourceAllowlist>) -> Result<Self, reqwest::Error> {
let mut clients = HashMap::with_capacity(allowlists.len());
let mut allowlist_map = HashMap::with_capacity(allowlists.len());
for entry in allowlists {
let source = entry.source.clone();
allowlist_map.insert(source.clone(), entry.clone());
let client = build_client(entry)?;
clients.insert(source, client);
}
Ok(Self {
clients: Arc::new(clients),
allowlists: Arc::new(allowlist_map),
})
}
pub fn source_allowlist(&self, source: &str) -> Option<&SourceAllowlist> {
self.allowlists.get(source)
}
pub async fn fetch_bytes(&self, source: &str, url: Url) -> Result<(Bytes, Url), HttpError> {
self.fetch_inner(source, url, &[], false).await
}
pub async fn fetch_bytes_with_headers(
&self,
source: &str,
url: Url,
headers: &[(&str, &str)],
) -> Result<(Bytes, Url), HttpError> {
self.fetch_inner(source, url, headers, false).await
}
pub async fn fetch_pdf(&self, source: &str, url: Url) -> Result<(Bytes, Url), HttpError> {
self.fetch_inner(source, url, &[], true).await
}
async fn fetch_inner(
&self,
source: &str,
url: Url,
headers: &[(&str, &str)],
check_pdf_magic: bool,
) -> Result<(Bytes, Url), HttpError> {
let client = self
.clients
.get(source)
.ok_or_else(|| HttpError::UnknownSource {
source_key: source.to_string(),
})?;
let mut header_map = reqwest::header::HeaderMap::with_capacity(headers.len());
for (name, value) in headers {
let hn = reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
HttpError::InvalidHeader {
name: (*name).to_string(),
reason: "name".to_string(),
}
})?;
let hv = reqwest::header::HeaderValue::from_str(value).map_err(|_| {
HttpError::InvalidHeader {
name: (*name).to_string(),
reason: "value".to_string(),
}
})?;
header_map.insert(hn, hv);
}
let mut attempt: u32 = 0;
loop {
let send_result = client
.get(url.clone())
.headers(header_map.clone())
.send()
.await;
let response = match send_result {
Ok(r) => r,
Err(e) => {
if attempt < MAX_FETCH_RETRIES && reqwest_is_transient(&e) {
let d = backoff_delay(attempt);
tracing::warn!(
source,
attempt,
delay_ms = d.as_millis() as u64,
error = %e,
"transient send failure; retrying"
);
tokio::time::sleep(d).await;
attempt += 1;
continue;
}
return Err(HttpError::Network(e));
}
};
let final_url = response.url().clone();
let status = response.status();
if !status.is_success() {
let code = status.as_u16();
if attempt < MAX_FETCH_RETRIES && is_transient_status(code) {
let d = parse_retry_after(response.headers())
.unwrap_or_else(|| backoff_delay(attempt));
tracing::warn!(
source,
attempt,
status = code,
delay_ms = d.as_millis() as u64,
"transient HTTP status; retrying"
);
tokio::time::sleep(d).await;
attempt += 1;
continue;
}
return Err(HttpError::HttpStatus {
status: code,
url: redact_api_key_query(&final_url),
});
}
if let Some(len) = response.content_length() {
if len > PDF_MAX_BYTES {
return Err(HttpError::OversizedBody {
actual: len,
cap: PDF_MAX_BYTES,
});
}
}
let mut buf = BytesMut::new();
let mut stream = response.bytes_stream();
let mut oversized_at: Option<u64> = None;
let mut stream_err: Option<reqwest::Error> = None;
while let Some(chunk) = stream.next().await {
let chunk = match chunk {
Ok(c) => c,
Err(e) => {
stream_err = Some(e);
break;
}
};
let projected = (buf.len() as u64).saturating_add(chunk.len() as u64);
if projected > PDF_MAX_BYTES {
oversized_at = Some(projected);
break;
}
buf.extend_from_slice(&chunk);
}
if let Some(actual) = oversized_at {
return Err(HttpError::OversizedBody {
actual,
cap: PDF_MAX_BYTES,
});
}
if let Some(e) = stream_err {
if attempt < MAX_FETCH_RETRIES && reqwest_is_transient(&e) {
let d = backoff_delay(attempt);
tracing::warn!(
source,
attempt,
delay_ms = d.as_millis() as u64,
error = %e,
"transient mid-stream failure; retrying"
);
tokio::time::sleep(d).await;
attempt += 1;
continue;
}
return Err(HttpError::Network(e));
}
let body = buf.freeze();
if check_pdf_magic {
let mut got = [0u8; 5];
let n = body.len().min(5);
got[..n].copy_from_slice(&body[..n]);
if got != PDF_MAGIC {
return Err(HttpError::NotAPdf { got });
}
}
return Ok((body, final_url));
}
}
}
fn redact_api_key_query(url: &url::Url) -> String {
const API_KEY_PARAM: &str = "api_key";
if url.query_pairs().all(|(k, _)| k != API_KEY_PARAM) {
return url.to_string();
}
let mut redacted = url.clone();
let pairs: Vec<(String, String)> = url
.query_pairs()
.map(|(k, v)| {
if k == API_KEY_PARAM {
(k.into_owned(), "REDACTED".to_string())
} else {
(k.into_owned(), v.into_owned())
}
})
.collect();
redacted.query_pairs_mut().clear().extend_pairs(pairs);
redacted.to_string()
}
#[allow(clippy::expect_used)]
impl HttpClient {
pub fn new_for_tests_allow_http(source: &str, allowlist_host: &str) -> Self {
let allowlist = SourceAllowlist::new(source, vec![allowlist_host.to_string()]);
let client = build_client_allow_http(allowlist.clone()).expect("test client builds");
let mut map = HashMap::new();
let mut allowlist_map = HashMap::new();
allowlist_map.insert(allowlist.source.clone(), allowlist.clone());
map.insert(allowlist.source.clone(), client);
Self {
clients: Arc::new(map),
allowlists: Arc::new(allowlist_map),
}
}
pub fn new_for_tests_allow_http_multi(entries: &[(&str, &str)]) -> Self {
let mut map = HashMap::with_capacity(entries.len());
let mut allowlist_map = HashMap::with_capacity(entries.len());
for (source, host) in entries {
let allowlist = SourceAllowlist::new(*source, vec![host.to_string()]);
let client = build_client_allow_http(allowlist.clone()).expect("test client builds");
allowlist_map.insert(allowlist.source.clone(), allowlist.clone());
map.insert(allowlist.source.clone(), client);
}
Self {
clients: Arc::new(map),
allowlists: Arc::new(allowlist_map),
}
}
}
fn build_client_allow_http(allowlist: SourceAllowlist) -> Result<Client, reqwest::Error> {
let allowlist_for_closure = allowlist.clone();
let redirect_policy = Policy::custom(move |attempt| {
let scheme = attempt.url().scheme().to_string();
let host_opt = attempt.url().host_str().map(|h| h.to_ascii_lowercase());
let prev_count = attempt.previous().len();
if scheme != "https" {
return attempt.error(HttpError::InsecureRedirect { scheme });
}
if prev_count >= MAX_REDIRECTS {
return attempt.stop();
}
let host = match host_opt {
Some(h) => h,
None => {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host: String::new(),
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
};
if !allowlist_for_closure.matches(&host) {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host,
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
attempt.follow()
});
ClientBuilder::new()
.https_only(false)
.redirect(redirect_policy)
.connect_timeout(CONNECT_TIMEOUT)
.timeout(TOTAL_TIMEOUT)
.read_timeout(READ_TIMEOUT)
.user_agent(format!(
"doiget/{} (+https://github.com/sotashimozono/doiget)",
VERSION
))
.tls_backend_rustls()
.build()
}
fn build_client(allowlist: SourceAllowlist) -> Result<Client, reqwest::Error> {
let user_agent = format!(
"doiget/{} (+https://github.com/sotashimozono/doiget)",
VERSION
);
let allowlist_for_closure = allowlist.clone();
let redirect_policy = Policy::custom(move |attempt| {
let scheme = attempt.url().scheme().to_string();
let host_opt = attempt.url().host_str().map(|h| h.to_ascii_lowercase());
let prev_count = attempt.previous().len();
if scheme != "https" {
return attempt.error(HttpError::InsecureRedirect { scheme });
}
if prev_count >= MAX_REDIRECTS {
return attempt.stop();
}
let host = match host_opt {
Some(h) => h,
None => {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host: String::new(),
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
};
if !allowlist_for_closure.matches(&host) {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host,
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
attempt.follow()
});
ClientBuilder::new()
.https_only(true)
.redirect(redirect_policy)
.connect_timeout(CONNECT_TIMEOUT)
.timeout(TOTAL_TIMEOUT)
.read_timeout(READ_TIMEOUT)
.user_agent(user_agent)
.tls_backend_rustls()
.build()
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn tier_1_allowlist_includes_crossref() {
let lists = tier_1_allowlist();
let crossref = lists
.iter()
.find(|a| a.source == "crossref")
.expect("crossref entry");
assert!(
crossref
.redirect_hosts
.iter()
.any(|h| h.contains("crossref.org")),
"crossref allowlist must contain a crossref.org pattern; got {:?}",
crossref.redirect_hosts,
);
}
#[test]
fn tier_1_allowlist_includes_unpaywall_and_arxiv() {
let lists = tier_1_allowlist();
assert!(lists.iter().any(|a| a.source == "unpaywall"));
assert!(lists.iter().any(|a| a.source == "arxiv"));
}
#[test]
fn oa_publisher_allowlist_groups_under_one_synthetic_source() {
let lists = oa_publisher_allowlist();
assert_eq!(lists.len(), 1, "exactly one synthetic source entry");
assert_eq!(lists[0].source, "oa-publisher");
}
#[test]
fn oa_publisher_allowlist_matches_known_oa_hosts() {
let lists = oa_publisher_allowlist();
let oa = lists
.iter()
.find(|a| a.source == "oa-publisher")
.expect("oa-publisher entry");
assert!(oa.matches("link.springer.com"));
assert!(oa.matches("nature.com"));
assert!(oa.matches("onlinelibrary.wiley.com"));
assert!(oa.matches("www.frontiersin.org"));
assert!(oa.matches("www.mdpi.com"));
assert!(oa.matches("journals.plos.org"));
assert!(oa.matches("www.biorxiv.org"));
assert!(oa.matches("europepmc.org"));
assert!(oa.matches("www.ncbi.nlm.nih.gov"));
assert!(oa.matches("arxiv.org"));
assert!(!oa.matches("attacker.test"));
assert!(!oa.matches("notspringer.com"));
}
#[test]
fn allowlist_matches_exact_fqdn() {
let a = SourceAllowlist::new("crossref", vec!["api.crossref.org".to_string()]);
assert!(a.matches("api.crossref.org"));
assert!(!a.matches("crossref.org"));
assert!(!a.matches("xapi.crossref.org"));
}
#[test]
fn allowlist_matches_subdomain_glob() {
let a = SourceAllowlist::new("crossref", vec!["*.crossref.org".to_string()]);
assert!(a.matches("doi.crossref.org"));
assert!(a.matches("crossref.org"));
assert!(!a.matches("notcrossref.org"));
assert!(!a.matches("crossref.org.attacker.test"));
}
#[test]
fn allowlist_matches_is_case_insensitive() {
let a = SourceAllowlist::new("crossref", vec!["API.crossref.ORG".to_string()]);
assert!(a.matches("api.crossref.org"));
assert!(a.matches("API.CROSSREF.ORG"));
}
#[test]
fn allowlist_with_no_redirect_hosts_matches_nothing() {
let a = SourceAllowlist::new("ghost", Vec::<String>::new());
assert!(!a.matches("anything.test"));
assert!(!a.matches(""));
}
fn build_test_client_for_http(source: &str, allowlist_host: &str) -> HttpClient {
HttpClient::new_for_tests_allow_http(source, allowlist_host)
}
#[tokio::test]
async fn pdf_magic_byte_match_succeeds() {
let server = MockServer::start().await;
let body = b"%PDF-1.7\n...some pdf bytes...".to_vec();
Mock::given(method("GET"))
.and(path("/paper.pdf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(body.clone()))
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/paper.pdf", server.uri()).parse().unwrap();
let (got_body, _final_url) = client.fetch_pdf("crossref", url).await.expect("ok");
assert_eq!(&got_body[..], &body[..]);
}
#[tokio::test]
async fn pdf_magic_byte_mismatch_rejects() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/not_a_pdf"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"<html>not a pdf</html>".to_vec()),
)
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/not_a_pdf", server.uri()).parse().unwrap();
let err = client
.fetch_pdf("crossref", url)
.await
.expect_err("not pdf");
match err {
HttpError::NotAPdf { got } => {
assert_eq!(&got, b"<html");
}
other => panic!("expected NotAPdf, got {:?}", other),
}
}
#[tokio::test]
async fn fetch_bytes_does_not_check_pdf_magic() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/data.json"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(br#"{"hello":"world"}"#.to_vec()),
)
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/data.json", server.uri()).parse().unwrap();
let (body, _final_url) = client.fetch_bytes("crossref", url).await.expect("ok");
assert_eq!(&body[..], br#"{"hello":"world"}"#);
}
#[tokio::test]
async fn oversized_body_via_content_length_short_circuits() {
let server = MockServer::start().await;
let oversized = PDF_MAX_BYTES + 1;
Mock::given(method("GET"))
.and(path("/huge"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-length", oversized.to_string().as_str())
.set_body_bytes(b"%PDF-".to_vec()),
)
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/huge", server.uri()).parse().unwrap();
let err = client
.fetch_bytes("crossref", url)
.await
.expect_err("should reject");
match err {
HttpError::OversizedBody { actual, cap } => {
assert!(actual > cap, "actual {} should exceed cap {}", actual, cap);
assert_eq!(cap, PDF_MAX_BYTES);
}
HttpError::Network(_) => {}
other => panic!("expected OversizedBody or Network, got {:?}", other),
}
}
#[tokio::test]
async fn unknown_source_rejected() {
let client = HttpClient::new(tier_1_allowlist()).expect("client builds");
let url: Url = "https://api.crossref.org/works/10.1234/x".parse().unwrap();
let err = client
.fetch_bytes("not-a-source", url)
.await
.expect_err("unknown source");
match err {
HttpError::UnknownSource { source_key } => {
assert_eq!(source_key, "not-a-source")
}
other => panic!("expected UnknownSource, got {:?}", other),
}
}
#[tokio::test]
async fn http_status_error_surfaces() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/missing"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/missing", server.uri()).parse().unwrap();
let err = client.fetch_bytes("crossref", url).await.expect_err("404");
match err {
HttpError::HttpStatus { status, .. } => assert_eq!(status, 404),
other => panic!("expected HttpStatus, got {:?}", other),
}
}
#[tokio::test]
async fn redirect_to_http_is_rejected_by_closure() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "http://attacker.test/file"),
)
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/redir", server.uri()).parse().unwrap();
let err = client
.fetch_bytes("crossref", url)
.await
.expect_err("redirect to http rejected");
match err {
HttpError::Network(e) => {
let msg = format!("{:?}", e);
assert!(
msg.contains("InsecureRedirect") || msg.contains("non-HTTPS"),
"expected insecure-redirect signal in error chain, got {}",
msg
);
}
other => panic!("expected Network(InsecureRedirect), got {:?}", other),
}
}
#[tokio::test]
async fn redirect_outside_allowlist_is_rejected_by_closure() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "https://attacker.test/file"),
)
.mount(&server)
.await;
let host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let client = build_test_client_for_http("crossref", &host);
let url: Url = format!("{}/redir", server.uri()).parse().unwrap();
let err = client
.fetch_bytes("crossref", url)
.await
.expect_err("redirect to attacker rejected");
match err {
HttpError::Network(e) => {
let msg = format!("{:?}", e);
assert!(
msg.contains("RedirectDenied") || msg.contains("not in allowlist"),
"expected redirect-denied signal in error chain, got {}",
msg
);
}
other => panic!("expected Network(RedirectDenied), got {:?}", other),
}
}
#[tokio::test]
async fn redirect_to_allowlisted_https_host_is_followed_by_closure() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", "https://mirror.allowed.test/file"),
)
.mount(&server)
.await;
let initial_host = server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string();
let allowlist = SourceAllowlist::new(
"crossref",
vec![initial_host.clone(), "*.allowed.test".to_string()],
);
let allowlist_for_closure = allowlist.clone();
let policy = Policy::custom(move |attempt| {
let scheme = attempt.url().scheme().to_string();
let host_opt = attempt.url().host_str().map(|h| h.to_ascii_lowercase());
if scheme != "https" {
return attempt.error(HttpError::InsecureRedirect { scheme });
}
let h = match host_opt {
Some(h) => h,
None => {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host: String::new(),
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
};
if !allowlist_for_closure.matches(&h) {
return attempt.error(HttpError::RedirectDenied {
source_key: allowlist_for_closure.source.clone(),
host: h,
expected_hosts: allowlist_for_closure.redirect_hosts.clone(),
});
}
attempt.follow()
});
let raw_client = ClientBuilder::new()
.https_only(false)
.redirect(policy)
.connect_timeout(CONNECT_TIMEOUT)
.timeout(Duration::from_secs(5))
.user_agent("doiget/test")
.tls_backend_rustls()
.build()
.expect("client builds");
let url: Url = format!("{}/redir", server.uri()).parse().unwrap();
let err = raw_client.get(url).send().await.expect_err("DNS fails");
let msg = format!("{:?}", err);
assert!(
!msg.contains("RedirectDenied") && !msg.contains("InsecureRedirect"),
"closure short-circuited an allowed redirect: {}",
msg,
);
}
#[test]
fn http_client_clone_is_cheap() {
let a = HttpClient::new(tier_1_allowlist()).expect("builds");
let b = a.clone();
assert_eq!(a.clients.len(), b.clients.len());
assert!(Arc::ptr_eq(&a.clients, &b.clients));
}
#[test]
fn denial_from_redirect_denied_carries_attempted_and_expected() {
use crate::{DenialContext, DenialReason};
let e = HttpError::RedirectDenied {
source_key: "crossref".to_string(),
host: "evil.example.com".to_string(),
expected_hosts: vec!["api.crossref.org".to_string(), "*.crossref.org".to_string()],
};
let dc: Option<DenialContext> = (&e).into();
let dc = dc.expect("RedirectDenied -> Some(DenialContext)");
assert_eq!(dc.reason, DenialReason::RedirectNotInAllowlist);
assert_eq!(dc.source.as_deref(), Some("crossref"));
assert_eq!(dc.attempted.as_deref(), Some("evil.example.com"));
assert_eq!(
dc.expected.as_deref(),
Some(&["api.crossref.org".to_string(), "*.crossref.org".to_string()][..])
);
assert!(dc.cap.is_none());
assert!(dc.actual.is_none());
assert!(dc.hop_index.is_none());
}
#[test]
fn denial_from_oversized_body_carries_cap_and_actual() {
use crate::{DenialContext, DenialReason};
let e = HttpError::OversizedBody {
actual: 209_715_200,
cap: PDF_MAX_BYTES,
};
let dc: Option<DenialContext> = (&e).into();
let dc = dc.expect("OversizedBody -> Some(DenialContext)");
assert_eq!(dc.reason, DenialReason::SizeCapExceeded);
assert_eq!(dc.cap, Some(PDF_MAX_BYTES));
assert_eq!(dc.actual, Some(209_715_200));
assert!(dc.source.is_none());
assert!(dc.attempted.is_none());
assert!(dc.expected.is_none());
}
#[test]
fn denial_from_not_a_pdf_hex_encodes_got_bytes() {
use crate::{DenialContext, DenialReason};
let e = HttpError::NotAPdf {
got: [0x3c, 0x68, 0x74, 0x6d, 0x6c],
};
let dc: Option<DenialContext> = (&e).into();
let dc = dc.expect("NotAPdf -> Some(DenialContext)");
assert_eq!(dc.reason, DenialReason::ContentTypeMismatch);
assert_eq!(dc.attempted.as_deref(), Some("3c68746d6c"));
assert_eq!(dc.expected.as_deref(), Some(&["%PDF-".to_string()][..]));
}
#[test]
fn denial_from_insecure_redirect_marks_insecure_scheme() {
use crate::{DenialContext, DenialReason};
let e = HttpError::InsecureRedirect {
scheme: "http".to_string(),
};
let dc: Option<DenialContext> = (&e).into();
let dc = dc.expect("InsecureRedirect -> Some(DenialContext)");
assert_eq!(dc.reason, DenialReason::InsecureScheme);
assert_eq!(dc.attempted.as_deref(), Some("http:..."));
assert_eq!(dc.expected.as_deref(), Some(&["https".to_string()][..]));
}
#[test]
fn denial_from_non_denial_variants_returns_none() {
use crate::DenialContext;
let e = HttpError::HttpStatus {
status: 503,
url: "https://api.crossref.org/works/x".to_string(),
};
let dc: Option<DenialContext> = (&e).into();
assert!(dc.is_none(), "HttpStatus must not produce a DenialContext");
let e = HttpError::UnknownSource {
source_key: "ghost".to_string(),
};
let dc: Option<DenialContext> = (&e).into();
assert!(
dc.is_none(),
"UnknownSource must not produce a DenialContext"
);
}
fn host_of(server: &MockServer) -> String {
server
.uri()
.parse::<Url>()
.unwrap()
.host_str()
.unwrap()
.to_string()
}
#[tokio::test]
async fn transient_503_then_200_succeeds() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(200).set_body_string(r#"{"ok":1}"#))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(503))
.up_to_n_times(1)
.mount(&server)
.await;
let client = build_test_client_for_http("crossref", &host_of(&server));
let url: Url = format!("{}/p", server.uri()).parse().unwrap();
let (body, _) = client
.fetch_bytes("crossref", url)
.await
.expect("503-then-200 must succeed after one retry");
assert_eq!(&body[..], br#"{"ok":1}"#);
}
#[tokio::test]
async fn persistent_503_exhausts_and_returns_httpstatus() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(503))
.mount(&server)
.await;
let client = build_test_client_for_http("crossref", &host_of(&server));
let url: Url = format!("{}/p", server.uri()).parse().unwrap();
let err = client
.fetch_bytes("crossref", url)
.await
.expect_err("persistent 503 must exhaust retries");
match err {
HttpError::HttpStatus { status, .. } => assert_eq!(status, 503),
other => panic!("expected HttpStatus 503, got {other:?}"),
}
let reqs = server
.received_requests()
.await
.expect("wiremock records requests");
assert_eq!(reqs.len(), (MAX_FETCH_RETRIES + 1) as usize);
}
#[tokio::test]
async fn retry_after_429_then_200_succeeds() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(200).set_body_string("ok"))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "1"))
.up_to_n_times(1)
.mount(&server)
.await;
let client = build_test_client_for_http("crossref", &host_of(&server));
let url: Url = format!("{}/p", server.uri()).parse().unwrap();
let (body, _) = client
.fetch_bytes("crossref", url)
.await
.expect("429+Retry-After then 200 must succeed");
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn permanent_404_is_not_retried() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/p"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let client = build_test_client_for_http("crossref", &host_of(&server));
let url: Url = format!("{}/p", server.uri()).parse().unwrap();
let _ = client
.fetch_bytes("crossref", url)
.await
.expect_err("404 must fail");
let reqs = server
.received_requests()
.await
.expect("wiremock records requests");
assert_eq!(reqs.len(), 1, "4xx (non-408/429) must NOT be retried");
}
}