use std::time::Duration;
use bytes::Bytes;
use reqwest::{
StatusCode,
header::{ETAG, IF_MODIFIED_SINCE, IF_NONE_MATCH, LAST_MODIFIED},
};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_SIZE: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Validators {
pub etag: Option<String>,
pub last_modified: Option<String>,
}
#[derive(Debug)]
pub enum FetchOutcome {
Modified {
body: Bytes,
validators: Validators,
},
NotModified,
}
#[derive(Debug, thiserror::Error)]
pub enum FetchError {
#[error("HTTP request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("unexpected HTTP status {0}")]
UnexpectedStatus(StatusCode),
#[error("response body exceeds the {0}-byte size limit")]
BodyTooLarge(usize),
#[error("request timed out")]
Timeout,
}
#[derive(Clone, Debug)]
pub struct Fetcher {
client: reqwest::Client,
max_size: usize,
}
impl Default for Fetcher {
fn default() -> Self {
Self::new()
}
}
impl Fetcher {
pub fn new() -> Self {
Self::build(DEFAULT_TIMEOUT, DEFAULT_MAX_SIZE)
}
pub fn with_timeout(self, timeout: Duration) -> Self {
Self::build(timeout, self.max_size)
}
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
fn build(timeout: Duration, max_size: usize) -> Self {
let _ = rustls::crypto::ring::default_provider().install_default();
let client = reqwest::Client::builder()
.timeout(timeout)
.gzip(true)
.build()
.expect("reqwest::Client build should never fail with ring installed");
Self { client, max_size }
}
pub async fn fetch(
&self,
url: &str,
validators: &Validators,
) -> Result<FetchOutcome, FetchError> {
let mut builder = self.client.get(url);
if let Some(etag) = &validators.etag {
builder = builder.header(IF_NONE_MATCH, etag);
}
if let Some(last_modified) = &validators.last_modified {
builder = builder.header(IF_MODIFIED_SINCE, last_modified);
}
let response = builder.send().await.map_err(|e| {
if e.is_timeout() {
FetchError::Timeout
} else {
FetchError::Request(e)
}
})?;
match response.status() {
StatusCode::NOT_MODIFIED => Ok(FetchOutcome::NotModified),
StatusCode::OK => {
let new_validators = Validators {
etag: response
.headers()
.get(ETAG)
.and_then(|v| v.to_str().ok())
.map(str::to_owned),
last_modified: response
.headers()
.get(LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(str::to_owned),
};
let mut accumulated = Vec::new();
let mut response = response;
while let Some(chunk) = response.chunk().await.map_err(FetchError::Request)? {
accumulated.extend_from_slice(&chunk);
if accumulated.len() > self.max_size {
return Err(FetchError::BodyTooLarge(self.max_size));
}
}
Ok(FetchOutcome::Modified {
body: Bytes::from(accumulated),
validators: new_validators,
})
}
other => Err(FetchError::UnexpectedStatus(other)),
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use bytes::Bytes;
use wiremock::matchers::{header, header_exists, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::{FetchError, FetchOutcome, Fetcher, Validators};
fn test_fetcher() -> Fetcher {
Fetcher::new().with_timeout(Duration::from_secs(5))
}
#[tokio::test]
async fn ok_200_returns_modified_with_body_and_validators() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200)
.set_body_bytes(b"0.0.0.0 ads.example.com\n".to_vec())
.insert_header("etag", r#""abc123""#)
.insert_header("last-modified", "Thu, 01 Jan 2026 00:00:00 GMT"),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let outcome = test_fetcher()
.fetch(&url, &Validators::default())
.await
.expect("fetch must succeed on 200");
let FetchOutcome::Modified { body, validators } = outcome else {
panic!("expected Modified, got NotModified");
};
assert_eq!(body, Bytes::from("0.0.0.0 ads.example.com\n"));
assert_eq!(validators.etag.as_deref(), Some(r#""abc123""#));
assert_eq!(
validators.last_modified.as_deref(),
Some("Thu, 01 Jan 2026 00:00:00 GMT")
);
}
#[tokio::test]
async fn conditional_get_304_returns_not_modified() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.and(header(
reqwest::header::IF_NONE_MATCH.as_str(),
r#""abc123""#,
))
.respond_with(ResponseTemplate::new(304))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let validators = Validators {
etag: Some(r#""abc123""#.to_owned()),
last_modified: None,
};
let outcome = test_fetcher()
.fetch(&url, &validators)
.await
.expect("fetch must succeed on 304");
assert!(
matches!(outcome, FetchOutcome::NotModified),
"expected NotModified"
);
}
#[tokio::test]
async fn if_none_match_header_is_sent_when_etag_is_set() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.and(header_exists(reqwest::header::IF_NONE_MATCH.as_str()))
.respond_with(ResponseTemplate::new(304))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let validators = Validators {
etag: Some(r#""some-etag""#.to_owned()),
last_modified: None,
};
test_fetcher()
.fetch(&url, &validators)
.await
.expect("fetch should match the 304 mock, proving If-None-Match was sent");
}
#[tokio::test]
async fn gzip_body_is_decompressed() {
use flate2::Compression;
use flate2::write::GzEncoder;
use std::io::Write as _;
let plaintext = b"0.0.0.0 tracker.example.org\n";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(plaintext).unwrap();
let compressed = encoder.finish().unwrap();
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200)
.set_body_bytes(compressed)
.insert_header("content-encoding", "gzip"),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let outcome = test_fetcher()
.fetch(&url, &Validators::default())
.await
.expect("fetch must succeed");
let FetchOutcome::Modified { body, .. } = outcome else {
panic!("expected Modified");
};
assert_eq!(
body.as_ref(),
plaintext,
"decompressed body must equal the original plaintext"
);
}
#[tokio::test]
async fn oversize_body_returns_body_too_large_error() {
let server = MockServer::start().await;
let large_body = b"0.0.0.0 ads.example.com\n0.0.0.0 x\n";
assert!(
large_body.len() > 16,
"test body must exceed the 16-byte cap"
);
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(large_body.to_vec()))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let fetcher = Fetcher::new()
.with_timeout(Duration::from_secs(5))
.with_max_size(16);
let result = fetcher.fetch(&url, &Validators::default()).await;
assert!(
matches!(result, Err(FetchError::BodyTooLarge(16))),
"expected BodyTooLarge(16), got: {result:?}"
);
}
#[tokio::test]
async fn unexpected_status_500_returns_error() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let result = test_fetcher().fetch(&url, &Validators::default()).await;
assert!(
matches!(result, Err(FetchError::UnexpectedStatus(s)) if s.as_u16() == 500),
"expected UnexpectedStatus(500), got: {result:?}"
);
}
#[test]
fn validators_default_is_all_none() {
let v = Validators::default();
assert!(v.etag.is_none());
assert!(v.last_modified.is_none());
}
#[tokio::test]
async fn ok_200_without_response_validators_yields_none_validators() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/hosts.txt"))
.respond_with(
ResponseTemplate::new(200).set_body_bytes(b"0.0.0.0 example.com\n".to_vec()),
)
.mount(&server)
.await;
let url = format!("{}/hosts.txt", server.uri());
let outcome = test_fetcher()
.fetch(&url, &Validators::default())
.await
.expect("fetch must succeed");
let FetchOutcome::Modified { validators, .. } = outcome else {
panic!("expected Modified");
};
assert!(validators.etag.is_none(), "etag must be None");
assert!(
validators.last_modified.is_none(),
"last_modified must be None"
);
}
}