1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7use url::Url;
8
9pub struct WebFetchEngineInput {
10 pub url: String,
11 pub method: String,
12 pub body: Option<String>,
13 pub headers: HashMap<String, String>,
14 pub timeout_ms: u64,
15 pub max_redirects: u32,
16 pub max_body_bytes: usize,
17 pub check_host: HostCheckFn,
20}
21
22pub type HostCheckFn = Arc<
23 dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<(), String>> + Send>>
24 + Send
25 + Sync,
26>;
27
28pub struct WebFetchEngineResult {
29 pub status: u16,
30 pub final_url: String,
31 pub redirect_chain: Vec<String>,
32 pub content_type: String,
33 pub body: Vec<u8>,
34 pub body_truncated: bool,
35}
36
37#[async_trait]
38pub trait WebFetchEngine: Send + Sync {
39 async fn fetch(
40 &self,
41 input: WebFetchEngineInput,
42 ) -> Result<WebFetchEngineResult, FetchError>;
43}
44
45#[derive(Debug, Clone)]
46pub enum FetchErrorCode {
47 InvalidUrl,
48 TlsError,
49 RedirectLoop,
50 DnsError,
51 Timeout,
52 ConnectionReset,
53 IoError,
54}
55
56#[derive(Debug, Clone)]
57pub struct FetchError {
58 pub code: FetchErrorCode,
59 pub message: String,
60 pub chain: Option<Vec<String>>,
61}
62
63impl FetchError {
64 pub fn new(code: FetchErrorCode, message: impl Into<String>) -> Self {
65 Self {
66 code,
67 message: message.into(),
68 chain: None,
69 }
70 }
71}
72
73pub struct ReqwestEngine {
74 client: reqwest::Client,
75}
76
77impl ReqwestEngine {
78 pub fn new() -> Self {
79 let client = reqwest::Client::builder()
80 .redirect(reqwest::redirect::Policy::none())
81 .build()
82 .expect("reqwest client build");
83 Self { client }
84 }
85}
86
87impl Default for ReqwestEngine {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93pub fn default_engine() -> Arc<dyn WebFetchEngine> {
94 Arc::new(ReqwestEngine::new())
95}
96
97#[async_trait]
98impl WebFetchEngine for ReqwestEngine {
99 async fn fetch(
100 &self,
101 input: WebFetchEngineInput,
102 ) -> Result<WebFetchEngineResult, FetchError> {
103 let mut current_url = input.url.clone();
104 let mut chain: Vec<String> = Vec::new();
105 let mut hops: u32 = 0;
106
107 loop {
108 let parsed = Url::parse(¤t_url).map_err(|_| {
109 FetchError::new(
110 FetchErrorCode::InvalidUrl,
111 format!("Invalid URL: {}", current_url),
112 )
113 })?;
114 let host = parsed.host_str().unwrap_or("").to_string();
115
116 (input.check_host)(host.clone())
118 .await
119 .map_err(|msg| FetchError::new(FetchErrorCode::InvalidUrl, msg))?;
120
121 let method = match input.method.as_str() {
122 "GET" => reqwest::Method::GET,
123 "POST" => reqwest::Method::POST,
124 other => {
125 return Err(FetchError::new(
126 FetchErrorCode::InvalidUrl,
127 format!("unsupported method: {}", other),
128 ));
129 }
130 };
131
132 let mut req = self
133 .client
134 .request(method, ¤t_url)
135 .timeout(Duration::from_millis(input.timeout_ms));
136 for (k, v) in &input.headers {
137 req = req.header(k, v);
138 }
139 if let Some(body) = &input.body {
140 if hops == 0 {
144 req = req.body(body.clone());
145 }
146 }
147
148 let res = req.send().await.map_err(classify_reqwest_error)?;
149 let status = res.status().as_u16();
150
151 if matches!(status, 301 | 302 | 303 | 307 | 308) {
153 let loc = res
154 .headers()
155 .get(reqwest::header::LOCATION)
156 .and_then(|v| v.to_str().ok())
157 .map(|s| s.to_string());
158 let next_url = match loc {
159 Some(loc) => match Url::parse(&loc) {
160 Ok(abs) => abs.to_string(),
161 Err(_) => {
162 match parsed.join(&loc) {
164 Ok(resolved) => resolved.to_string(),
165 Err(_) => {
166 return finalize(
168 res,
169 &input.url,
170 ¤t_url,
171 chain,
172 input.max_body_bytes,
173 )
174 .await;
175 }
176 }
177 }
178 },
179 None => {
180 return finalize(
181 res,
182 &input.url,
183 ¤t_url,
184 chain,
185 input.max_body_bytes,
186 )
187 .await;
188 }
189 };
190
191 if current_url.starts_with("https://") && next_url.starts_with("http://") {
193 return Err(FetchError::new(
194 FetchErrorCode::TlsError,
195 format!(
196 "Refusing HTTPS→HTTP downgrade redirect: {} -> {}",
197 current_url, next_url
198 ),
199 ));
200 }
201
202 chain.push(current_url.clone());
203 hops += 1;
204 if hops > input.max_redirects {
205 let mut full_chain = chain.clone();
206 full_chain.push(next_url);
207 return Err(FetchError {
208 code: FetchErrorCode::RedirectLoop,
209 message: format!(
210 "Redirect limit ({}) exceeded",
211 input.max_redirects
212 ),
213 chain: Some(full_chain),
214 });
215 }
216 current_url = next_url;
217 continue;
218 }
219
220 return finalize(res, &input.url, ¤t_url, chain, input.max_body_bytes)
222 .await;
223 }
224 }
225}
226
227async fn finalize(
228 res: reqwest::Response,
229 _original: &str,
230 final_url: &str,
231 chain: Vec<String>,
232 max_body_bytes: usize,
233) -> Result<WebFetchEngineResult, FetchError> {
234 let status = res.status().as_u16();
235 let content_type = res
236 .headers()
237 .get(reqwest::header::CONTENT_TYPE)
238 .and_then(|v| v.to_str().ok())
239 .unwrap_or("")
240 .to_string();
241
242 let raw = res.bytes().await.map_err(classify_reqwest_error)?;
247 let truncated = raw.len() > max_body_bytes;
248 let body: Vec<u8> = if truncated {
249 raw[..max_body_bytes].to_vec()
250 } else {
251 raw.to_vec()
252 };
253 let mut final_chain = chain;
254 final_chain.push(final_url.to_string());
255 Ok(WebFetchEngineResult {
256 status,
257 final_url: final_url.to_string(),
258 redirect_chain: final_chain,
259 content_type,
260 body,
261 body_truncated: truncated,
262 })
263}
264
265fn classify_reqwest_error(e: reqwest::Error) -> FetchError {
266 let msg = e.to_string();
267 if e.is_timeout() {
268 return FetchError::new(FetchErrorCode::Timeout, msg);
269 }
270 if e.is_connect() {
271 let lower = msg.to_lowercase();
274 if lower.contains("dns")
275 || lower.contains("resolve")
276 || lower.contains("lookup")
277 || lower.contains("not known")
278 || lower.contains("no such host")
279 {
280 return FetchError::new(FetchErrorCode::DnsError, msg);
281 }
282 return FetchError::new(FetchErrorCode::ConnectionReset, msg);
283 }
284 let lower = msg.to_lowercase();
285 if lower.contains("tls") || lower.contains("certificate") || lower.contains("ssl") {
286 return FetchError::new(FetchErrorCode::TlsError, msg);
287 }
288 FetchError::new(FetchErrorCode::IoError, msg)
289}
290