use reqwest::header::{CONTENT_TYPE, LOCATION};
use reqwest::{redirect::Policy, Client, Response};
use std::net::SocketAddr;
use std::time::Duration;
use crate::guard;
const USER_AGENT: &str = concat!("webfetch/", env!("CARGO_PKG_VERSION"));
const MAX_ATTEMPTS: u32 = 3;
const MAX_REDIRECTS: usize = 5;
const MAX_BODY_BYTES: usize = 5 * 1024 * 1024;
pub struct FetchedPage {
pub body: String,
pub final_url: String,
pub content_type: Option<String>,
}
enum Hop {
Page(FetchedPage),
Redirect(String),
}
fn build_client(
url: &reqwest::Url,
timeout_secs: u64,
pinned: &[SocketAddr],
) -> anyhow::Result<Client> {
let mut builder = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.redirect(Policy::none())
.user_agent(USER_AGENT)
.gzip(true)
.brotli(true);
if let Some(host) = url.host_str() {
if !pinned.is_empty() {
builder = builder.resolve_to_addrs(host, pinned);
}
}
Ok(builder.build()?)
}
fn push_capped(buf: &mut Vec<u8>, chunk: &[u8], max: usize) -> bool {
let remaining = max.saturating_sub(buf.len());
if chunk.len() >= remaining {
buf.extend_from_slice(&chunk[..remaining]);
true
} else {
buf.extend_from_slice(chunk);
false
}
}
async fn read_body_capped(mut resp: Response) -> Result<String, (anyhow::Error, bool)> {
let mut buf: Vec<u8> = Vec::new();
if let Some(len) = resp.content_length() {
buf.reserve(len.min(MAX_BODY_BYTES as u64) as usize);
}
loop {
match resp.chunk().await {
Ok(Some(chunk)) => {
if push_capped(&mut buf, &chunk, MAX_BODY_BYTES) {
break;
}
}
Ok(None) => break,
Err(e) => {
let transient = e.is_timeout();
return Err((e.into(), transient));
}
}
}
Ok(String::from_utf8_lossy(&buf).into_owned())
}
async fn attempt(client: &Client, url: &str) -> Result<Hop, (anyhow::Error, bool)> {
let resp = match client.get(url).send().await {
Ok(r) => r,
Err(e) => {
let transient = e.is_timeout() || e.is_connect() || e.is_request();
return Err((e.into(), transient));
}
};
let status = resp.status();
if status.is_redirection() {
return match resp.headers().get(LOCATION).and_then(|v| v.to_str().ok()) {
Some(loc) => Ok(Hop::Redirect(loc.to_string())),
None => Err((
anyhow::anyhow!("redirect ({status}) without a Location header"),
false,
)),
};
}
let resp = match resp.error_for_status() {
Ok(r) => r,
Err(e) => {
let transient = status.is_server_error() || status.as_u16() == 429;
return Err((e.into(), transient));
}
};
let final_url = resp.url().to_string();
let content_type = resp
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = read_body_capped(resp).await?;
Ok(Hop::Page(FetchedPage {
body,
final_url,
content_type,
}))
}
async fn fetch_with_retries(client: &Client, url: &str) -> anyhow::Result<Hop> {
let mut delay = Duration::from_millis(200);
for attempt_no in 1..=MAX_ATTEMPTS {
match attempt(client, url).await {
Ok(hop) => return Ok(hop),
Err((err, transient)) => {
if attempt_no == MAX_ATTEMPTS || !transient {
return Err(err);
}
tokio::time::sleep(delay).await;
delay *= 2;
}
}
}
unreachable!("loop returns on the final attempt")
}
pub async fn fetch_page(url: &str, timeout_secs: u64) -> anyhow::Result<FetchedPage> {
let mut current = reqwest::Url::parse(url)?;
let mut hops = 0usize;
loop {
let pinned = guard::validate_url(¤t).await?;
let client = build_client(¤t, timeout_secs, &pinned)?;
match fetch_with_retries(&client, current.as_str()).await? {
Hop::Page(page) => return Ok(page),
Hop::Redirect(location) => {
hops += 1;
if hops > MAX_REDIRECTS {
anyhow::bail!("too many redirects (>{MAX_REDIRECTS})");
}
current = current
.join(&location)
.map_err(|e| anyhow::anyhow!("invalid redirect target `{location}`: {e}"))?;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_capped_truncates_oversized_chunk() {
let mut buf = Vec::new();
let stopped = push_capped(&mut buf, &[b'x'; 10], 4);
assert!(stopped);
assert_eq!(buf.len(), 4);
}
#[test]
fn push_capped_accumulates_until_cap() {
let mut buf = Vec::new();
assert!(!push_capped(&mut buf, b"abc", 8));
assert!(!push_capped(&mut buf, b"de", 8));
assert_eq!(buf, b"abcde");
let stopped = push_capped(&mut buf, b"fghij", 8);
assert!(stopped);
assert_eq!(buf.len(), 8);
assert_eq!(buf, b"abcdefgh");
}
#[test]
fn push_capped_small_body_unaffected() {
let mut buf = Vec::new();
let stopped = push_capped(&mut buf, b"hello", 1024);
assert!(!stopped);
assert_eq!(buf, b"hello");
}
}