1use std::sync::Arc;
17
18use bytes::Bytes;
19use futures::future::BoxFuture;
20use http::header::{CONTENT_TYPE, LOCATION, USER_AGENT};
21use http::{HeaderValue, Method, Request, Uri};
22use http_body_util::{BodyExt, Empty};
23use hyper::body::Incoming;
24use hyper_util::client::legacy::Client as HyperClient;
25use hyper_util::rt::TokioExecutor;
26
27use defect_agent::error::BoxError;
28use defect_agent::http::{HttpClient, HttpClientError, HttpRequest, HttpResponse};
29
30use super::proxy::{ProxyAwareConnector, build_proxy_connector};
31use super::user_agent::default_user_agent;
32use super::{HttpStackConfig, HttpStackError, ProxyConfig};
33
34type FetchHyperClient = HyperClient<ProxyAwareConnector, Empty<Bytes>>;
37
38pub struct FetchHttpClient {
43 inner: FetchHyperClient,
44 user_agent: HeaderValue,
45}
46
47impl std::fmt::Debug for FetchHttpClient {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("FetchHttpClient")
50 .field("user_agent", &self.user_agent)
51 .finish()
52 }
53}
54
55pub fn build_fetch_client(config: &HttpStackConfig) -> Result<FetchHttpClient, HttpStackError> {
68 let connector = build_proxy_connector(&config.proxy)?;
69 let inner = HyperClient::builder(TokioExecutor::default()).build::<_, Empty<Bytes>>(connector);
70
71 let user_agent = match &config.user_agent {
72 Some(s) => HeaderValue::from_str(s).map_err(|e| HttpStackError::Config {
73 hint: format!("invalid user_agent: {e}"),
74 })?,
75 None => default_user_agent(),
76 };
77
78 Ok(FetchHttpClient { inner, user_agent })
79}
80
81pub fn build_fetch_client_arc(
88 config: &HttpStackConfig,
89) -> Result<Arc<dyn HttpClient>, HttpStackError> {
90 Ok(Arc::new(build_fetch_client(config)?))
91}
92
93impl HttpClient for FetchHttpClient {
94 fn fetch(&self, req: HttpRequest) -> BoxFuture<'_, Result<HttpResponse, HttpClientError>> {
95 Box::pin(async move {
96 let timeout = req.timeout;
97 let fut = self.execute(req);
98 match timeout {
99 Some(d) => match tokio::time::timeout(d, fut).await {
100 Ok(res) => res,
101 Err(_) => Err(HttpClientError::Timeout),
102 },
103 None => fut.await,
104 }
105 })
106 }
107}
108
109impl FetchHttpClient {
110 async fn execute(&self, req: HttpRequest) -> Result<HttpResponse, HttpClientError> {
111 let mut current = parse_http_uri(&req.url)?;
112 let mut redirects: u32 = 0;
113
114 loop {
115 let response = self.send_one(¤t).await?;
116
117 let status = response.status().as_u16();
118 let is_redirect = (300..400).contains(&status) && status != 304;
119
120 if is_redirect && req.follow_redirects {
121 if redirects >= req.max_redirects {
122 return Err(HttpClientError::TooManyRedirects(redirects));
123 }
124 let Some(location) = response.headers().get(LOCATION) else {
125 return collect_response(response, ¤t, redirects, req.max_response_bytes)
128 .await;
129 };
130 let raw = location
131 .to_str()
132 .map_err(|e| HttpClientError::Transport(BoxError::new(e)))?
133 .to_string();
134 let next = resolve_redirect(¤t, &raw)?;
135 drop(response);
136 current = next;
137 redirects += 1;
138 continue;
139 }
140
141 return collect_response(response, ¤t, redirects, req.max_response_bytes).await;
142 }
143 }
144
145 async fn send_one(&self, uri: &Uri) -> Result<http::Response<Incoming>, HttpClientError> {
146 let request = Request::builder()
147 .method(Method::GET)
148 .uri(uri.clone())
149 .header(USER_AGENT, self.user_agent.clone())
150 .body(Empty::<Bytes>::new())
151 .map_err(|e| HttpClientError::Transport(BoxError::new(e)))?;
152
153 self.inner
154 .request(request)
155 .await
156 .map_err(|e| HttpClientError::Transport(BoxError::new(e)))
157 }
158}
159
160fn parse_http_uri(raw: &str) -> Result<Uri, HttpClientError> {
161 let uri: Uri = raw
162 .parse()
163 .map_err(|e: http::uri::InvalidUri| HttpClientError::InvalidUrl(e.to_string()))?;
164 let scheme = uri
165 .scheme_str()
166 .ok_or_else(|| HttpClientError::InvalidUrl(format!("missing scheme in `{raw}`")))?;
167 if !matches!(scheme, "http" | "https") {
168 return Err(HttpClientError::InvalidUrl(format!(
169 "unsupported scheme `{scheme}`: only http/https allowed"
170 )));
171 }
172 if uri.host().is_none() {
173 return Err(HttpClientError::InvalidUrl(format!(
174 "missing host in `{raw}`"
175 )));
176 }
177 Ok(uri)
178}
179
180fn resolve_redirect(base: &Uri, location: &str) -> Result<Uri, HttpClientError> {
188 let trimmed = location.trim();
189 if trimmed.is_empty() {
190 return Err(HttpClientError::Transport(BoxError::new(
191 std::io::Error::other("empty Location header"),
192 )));
193 }
194
195 if trimmed.contains("://") {
196 return parse_http_uri(trimmed);
198 }
199
200 let base_scheme = base.scheme_str().ok_or_else(|| {
201 HttpClientError::Transport(BoxError::new(std::io::Error::other(
202 "base URI missing scheme",
203 )))
204 })?;
205 let base_authority = base.authority().ok_or_else(|| {
206 HttpClientError::Transport(BoxError::new(std::io::Error::other(
207 "base URI missing authority",
208 )))
209 })?;
210
211 let composed = if let Some(rest) = trimmed.strip_prefix("//") {
212 format!("{base_scheme}://{rest}")
214 } else if trimmed.starts_with('/') {
215 format!("{base_scheme}://{base_authority}{trimmed}")
217 } else {
218 return Err(HttpClientError::Transport(BoxError::new(
220 std::io::Error::other(format!("relative redirect not supported: `{trimmed}`")),
221 )));
222 };
223
224 parse_http_uri(&composed)
225}
226
227async fn collect_response(
228 response: http::Response<Incoming>,
229 final_uri: &Uri,
230 redirects: u32,
231 max_response_bytes: u64,
232) -> Result<HttpResponse, HttpClientError> {
233 let status = response.status().as_u16();
234 let content_type = response
235 .headers()
236 .get(CONTENT_TYPE)
237 .and_then(|v| v.to_str().ok())
238 .map(str::to_string);
239 let final_url = final_uri.to_string();
240
241 let mut body_buf: Vec<u8> = Vec::new();
242 let mut bytes_received: u64 = 0;
243 let mut truncated = false;
244
245 let mut frames = response.into_body();
246 while let Some(frame) = frames
247 .frame()
248 .await
249 .transpose()
250 .map_err(|e| HttpClientError::Transport(BoxError::new(e)))?
251 {
252 if let Ok(data) = frame.into_data() {
253 let len = data.len() as u64;
254 bytes_received = bytes_received.saturating_add(len);
255 if !truncated {
256 let remaining = max_response_bytes.saturating_sub(body_buf.len() as u64);
257 if (data.len() as u64) <= remaining {
258 body_buf.extend_from_slice(&data);
259 } else {
260 let take = remaining as usize;
261 body_buf.extend_from_slice(&data[..take]);
262 truncated = true;
263 }
264 }
265 }
266 }
267
268 Ok(HttpResponse {
269 status,
270 content_type,
271 body: body_buf,
272 bytes_received,
273 truncated,
274 redirects,
275 final_url,
276 })
277}
278
279pub fn build_default_fetch_client_arc() -> Result<Arc<dyn HttpClient>, HttpStackError> {
288 build_fetch_client_arc(&HttpStackConfig {
289 proxy: ProxyConfig::FromEnv,
290 ..HttpStackConfig::default()
291 })
292}
293
294#[cfg(test)]
295mod tests;