Skip to main content

harness_webfetch/
engine.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7use url::Url;
8
9pub struct WebFetchEngineInput {
10    pub url: String,
11    pub method: String,
12    pub body: Option<String>,
13    pub headers: HashMap<String, String>,
14    pub timeout_ms: u64,
15    pub max_redirects: u32,
16    pub max_body_bytes: usize,
17    /// Called BEFORE each hop (including the first) with the target host.
18    /// Returning Err aborts the fetch with INVALID_URL-shaped FetchError.
19    pub check_host: HostCheckFn,
20}
21
22pub type HostCheckFn = Arc<
23    dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>>
24        + Send
25        + Sync,
26>;
27
28pub struct WebFetchEngineResult {
29    pub status: u16,
30    pub final_url: String,
31    pub redirect_chain: Vec<String>,
32    pub content_type: String,
33    pub body: Vec<u8>,
34    pub body_truncated: bool,
35}
36
37#[async_trait]
38pub trait WebFetchEngine: Send + Sync {
39    async fn fetch(
40        &self,
41        input: WebFetchEngineInput,
42    ) -> Result<WebFetchEngineResult, FetchError>;
43}
44
45#[derive(Debug, Clone)]
46pub enum FetchErrorCode {
47    InvalidUrl,
48    TlsError,
49    RedirectLoop,
50    DnsError,
51    Timeout,
52    ConnectionReset,
53    IoError,
54}
55
56#[derive(Debug, Clone)]
57pub struct FetchError {
58    pub code: FetchErrorCode,
59    pub message: String,
60    pub chain: Option<Vec<String>>,
61}
62
63impl FetchError {
64    pub fn new(code: FetchErrorCode, message: impl Into<String>) -> Self {
65        Self {
66            code,
67            message: message.into(),
68            chain: None,
69        }
70    }
71}
72
73pub struct ReqwestEngine {
74    client: reqwest::Client,
75}
76
77impl ReqwestEngine {
78    pub fn new() -> Self {
79        let client = reqwest::Client::builder()
80            .redirect(reqwest::redirect::Policy::none())
81            .build()
82            .expect("reqwest client build");
83        Self { client }
84    }
85}
86
87impl Default for ReqwestEngine {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93pub fn default_engine() -> Arc<dyn WebFetchEngine> {
94    Arc::new(ReqwestEngine::new())
95}
96
97#[async_trait]
98impl WebFetchEngine for ReqwestEngine {
99    async fn fetch(
100        &self,
101        input: WebFetchEngineInput,
102    ) -> Result<WebFetchEngineResult, FetchError> {
103        let mut current_url = input.url.clone();
104        let mut chain: Vec<String> = Vec::new();
105        let mut hops: u32 = 0;
106
107        loop {
108            let parsed = Url::parse(&current_url).map_err(|_| {
109                FetchError::new(
110                    FetchErrorCode::InvalidUrl,
111                    format!("Invalid URL: {}", current_url),
112                )
113            })?;
114            let host = parsed.host_str().unwrap_or("").to_string();
115
116            // SSRF check before every hop
117            (input.check_host)(host.clone())
118                .await
119                .map_err(|msg| FetchError::new(FetchErrorCode::InvalidUrl, msg))?;
120
121            let method = match input.method.as_str() {
122                "GET" => reqwest::Method::GET,
123                "POST" => reqwest::Method::POST,
124                other => {
125                    return Err(FetchError::new(
126                        FetchErrorCode::InvalidUrl,
127                        format!("unsupported method: {}", other),
128                    ));
129                }
130            };
131
132            let mut req = self
133                .client
134                .request(method, &current_url)
135                .timeout(Duration::from_millis(input.timeout_ms));
136            for (k, v) in &input.headers {
137                req = req.header(k, v);
138            }
139            if let Some(body) = &input.body {
140                // Only send body on the very first hop (don't replay on
141                // redirects). reqwest drops bodies on 303 anyway, but we
142                // simulate the stateless safe default: no body after first hop.
143                if hops == 0 {
144                    req = req.body(body.clone());
145                }
146            }
147
148            let res = req.send().await.map_err(classify_reqwest_error)?;
149            let status = res.status().as_u16();
150
151            // Redirect handling
152            if matches!(status, 301 | 302 | 303 | 307 | 308) {
153                let loc = res
154                    .headers()
155                    .get(reqwest::header::LOCATION)
156                    .and_then(|v| v.to_str().ok())
157                    .map(|s| s.to_string());
158                let next_url = match loc {
159                    Some(loc) => match Url::parse(&loc) {
160                        Ok(abs) => abs.to_string(),
161                        Err(_) => {
162                            // Try resolving relative to the current URL.
163                            match parsed.join(&loc) {
164                                Ok(resolved) => resolved.to_string(),
165                                Err(_) => {
166                                    // No Location — treat as terminal.
167                                    return finalize(
168                                        res,
169                                        &input.url,
170                                        &current_url,
171                                        chain,
172                                        input.max_body_bytes,
173                                    )
174                                    .await;
175                                }
176                            }
177                        }
178                    },
179                    None => {
180                        return finalize(
181                            res,
182                            &input.url,
183                            &current_url,
184                            chain,
185                            input.max_body_bytes,
186                        )
187                        .await;
188                    }
189                };
190
191                // Block https→http downgrade
192                if current_url.starts_with("https://") && next_url.starts_with("http://") {
193                    return Err(FetchError::new(
194                        FetchErrorCode::TlsError,
195                        format!(
196                            "Refusing HTTPS→HTTP downgrade redirect: {} -> {}",
197                            current_url, next_url
198                        ),
199                    ));
200                }
201
202                chain.push(current_url.clone());
203                hops += 1;
204                if hops > input.max_redirects {
205                    let mut full_chain = chain.clone();
206                    full_chain.push(next_url);
207                    return Err(FetchError {
208                        code: FetchErrorCode::RedirectLoop,
209                        message: format!(
210                            "Redirect limit ({}) exceeded",
211                            input.max_redirects
212                        ),
213                        chain: Some(full_chain),
214                    });
215                }
216                current_url = next_url;
217                continue;
218            }
219
220            // Terminal response
221            return finalize(res, &input.url, &current_url, chain, input.max_body_bytes)
222                .await;
223        }
224    }
225}
226
227async fn finalize(
228    res: reqwest::Response,
229    _original: &str,
230    final_url: &str,
231    chain: Vec<String>,
232    max_body_bytes: usize,
233) -> Result<WebFetchEngineResult, FetchError> {
234    let status = res.status().as_u16();
235    let content_type = res
236        .headers()
237        .get(reqwest::header::CONTENT_TYPE)
238        .and_then(|v| v.to_str().ok())
239        .unwrap_or("")
240        .to_string();
241
242    // Collect full body then cap. reqwest doesn't expose a streaming
243    // reader without `futures-util`. For v1 parity we accept the
244    // download-then-truncate trade-off — the 10 MB spill_hard_cap still
245    // bounds the orchestrator's downstream work.
246    let raw = res.bytes().await.map_err(classify_reqwest_error)?;
247    let truncated = raw.len() > max_body_bytes;
248    let body: Vec<u8> = if truncated {
249        raw[..max_body_bytes].to_vec()
250    } else {
251        raw.to_vec()
252    };
253    let mut final_chain = chain;
254    final_chain.push(final_url.to_string());
255    Ok(WebFetchEngineResult {
256        status,
257        final_url: final_url.to_string(),
258        redirect_chain: final_chain,
259        content_type,
260        body,
261        body_truncated: truncated,
262    })
263}
264
265fn classify_reqwest_error(e: reqwest::Error) -> FetchError {
266    let msg = e.to_string();
267    if e.is_timeout() {
268        return FetchError::new(FetchErrorCode::Timeout, msg);
269    }
270    if e.is_connect() {
271        // DNS errors surface as connect errors; try to tell them apart
272        // via the message (reqwest doesn't give us a typed signal).
273        let lower = msg.to_lowercase();
274        if lower.contains("dns")
275            || lower.contains("resolve")
276            || lower.contains("lookup")
277            || lower.contains("not known")
278            || lower.contains("no such host")
279        {
280            return FetchError::new(FetchErrorCode::DnsError, msg);
281        }
282        return FetchError::new(FetchErrorCode::ConnectionReset, msg);
283    }
284    let lower = msg.to_lowercase();
285    if lower.contains("tls") || lower.contains("certificate") || lower.contains("ssl") {
286        return FetchError::new(FetchErrorCode::TlsError, msg);
287    }
288    FetchError::new(FetchErrorCode::IoError, msg)
289}
290