1use std::sync::Arc;
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures::TryStreamExt;
6use reqwest::Client;
7use tokio_util::sync::CancellationToken;
8use url::Url;
9
10use crate::{
11 error::{NetError, NetResult},
12 retry::{DefaultRetryPolicy, RetryNet},
13 traits::{Net, NetExt},
14 types::{Compression, Headers, NetOptions, RangeSpec},
15};
16
17const HTTP_PARTIAL_CONTENT: u16 = 206;
19
20fn truncate_error_body(mut body: String) -> String {
25 const MAX_CHARS: usize = 200;
28
29 let total = body.chars().count();
30 if total <= MAX_CHARS {
31 return body;
32 }
33 let cut_at = body
34 .char_indices()
35 .nth(MAX_CHARS)
36 .map_or(body.len(), |(i, _)| i);
37 body.truncate(cut_at);
38 body.push_str(&format!("…(truncated, {total} chars total)"));
39 body
40}
41
42#[cfg(not(target_arch = "wasm32"))]
46type ClientBuilderMod = fn(reqwest::ClientBuilder) -> reqwest::ClientBuilder;
47
48#[cfg(not(target_arch = "wasm32"))]
49impl From<Compression> for Vec<ClientBuilderMod> {
50 fn from(c: Compression) -> Self {
51 [
52 (
53 Compression::GZIP,
54 reqwest::ClientBuilder::no_gzip as ClientBuilderMod,
55 ),
56 (Compression::DEFLATE, reqwest::ClientBuilder::no_deflate),
57 (Compression::BROTLI, reqwest::ClientBuilder::no_brotli),
58 (Compression::ZSTD, reqwest::ClientBuilder::no_zstd),
59 ]
60 .into_iter()
61 .filter(|(flag, _)| !c.contains(*flag))
62 .map(|(_, disable)| disable)
63 .collect()
64 }
65}
66
67#[cfg(not(target_arch = "wasm32"))]
68fn build_client(options: &NetOptions) -> reqwest::Result<Client> {
69 let base = Client::builder()
70 .cookie_store(true)
71 .pool_max_idle_per_host(options.pool_max_idle_per_host)
72 .pool_idle_timeout(Some(std::time::Duration::from_secs(5)))
73 .danger_accept_invalid_certs(options.is_insecure)
74 .read_timeout(options.inactivity_timeout);
75 Vec::<ClientBuilderMod>::from(options.compression)
76 .into_iter()
77 .fold(base, |b, disable| disable(b))
78 .build()
79}
80
81#[cfg(target_arch = "wasm32")]
82fn build_client(_options: &NetOptions) -> reqwest::Result<Client> {
83 Client::builder().build()
84}
85
86fn extract_headers(resp: &reqwest::Response) -> Headers {
88 let mut headers = Headers::new();
89 let str_pairs = resp
90 .headers()
91 .iter()
92 .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str(), v)));
93 for (name, value) in str_pairs {
94 headers.insert(name, value);
95 }
96 headers
97}
98
99#[derive(Clone)]
103struct RawHttp {
104 inner: Client,
105 options: NetOptions,
106}
107
108impl RawHttp {
109 fn apply_headers(
110 mut req: reqwest::RequestBuilder,
111 headers: Option<Headers>,
112 ) -> reqwest::RequestBuilder {
113 if let Some(headers) = headers {
114 for (k, v) in headers.iter() {
115 req = req.header(k, v);
116 }
117 }
118 req
119 }
120
121 #[cfg(not(target_arch = "wasm32"))]
122 fn head_request(&self, url: Url) -> reqwest::RequestBuilder {
123 self.inner.head(url)
124 }
125
126 #[cfg(target_arch = "wasm32")]
127 fn head_request(&self, url: Url) -> reqwest::RequestBuilder {
128 self.inner.get(url).header("Range", "bytes=0-0")
129 }
130
131 fn response_to_stream(resp: reqwest::Response) -> crate::ByteStream {
132 let headers = extract_headers(&resp);
133 let stream = resp.bytes_stream().map_err(NetError::from);
134 crate::ByteStream::new(headers, Box::pin(stream))
135 }
136
137 async fn send_checked(
138 &self,
139 req: reqwest::RequestBuilder,
140 headers: Option<Headers>,
141 url: Url,
142 accept_partial: bool,
143 ) -> Result<reqwest::Response, NetError> {
144 let req = Self::apply_headers(req, headers);
145 let req = if let Some(total) = self.options.total_timeout {
146 req.timeout(total)
147 } else {
148 req
149 };
150 let resp = req.send().await.map_err(NetError::from)?;
151 let status = resp.status();
152
153 let ok = status.is_success() || (accept_partial && status.as_u16() == HTTP_PARTIAL_CONTENT);
154 if !ok {
155 let body = truncate_error_body(resp.text().await.unwrap_or_default());
156 return Err(NetError::HttpError {
157 url,
158 status: status.as_u16(),
159 body: Some(body),
160 });
161 }
162
163 Ok(resp)
164 }
165}
166
167#[derive(Clone)]
174pub struct HttpClient {
175 net: Arc<RetryNet<RawHttp, DefaultRetryPolicy>>,
176 options: NetOptions,
177}
178
179impl HttpClient {
180 #[must_use]
192 pub fn new(options: NetOptions, cancel: CancellationToken) -> Self {
193 let inner = build_client(&options)
194 .expect("BUG: reqwest::Client::builder().build() with our defaults cannot fail");
195 let raw = RawHttp {
196 inner,
197 options: options.clone(),
198 };
199 let net = Arc::new(raw.with_retry(options.retry_policy.clone(), cancel));
200 Self { net, options }
201 }
202
203 pub async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> NetResult<Bytes> {
207 self.net.get_bytes(url, headers).await
208 }
209
210 pub async fn get_range(
214 &self,
215 url: Url,
216 range: RangeSpec,
217 headers: Option<Headers>,
218 ) -> NetResult<crate::ByteStream> {
219 self.net.get_range(url, range, headers).await
220 }
221
222 pub async fn head(&self, url: Url, headers: Option<Headers>) -> NetResult<Headers> {
226 self.net.head(url, headers).await
227 }
228
229 #[must_use]
230 pub fn options(&self) -> &NetOptions {
231 &self.options
232 }
233
234 pub async fn stream(&self, url: Url, headers: Option<Headers>) -> NetResult<crate::ByteStream> {
238 self.net.stream(url, headers).await
239 }
240}
241
242impl std::fmt::Debug for HttpClient {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 f.debug_struct("HttpClient")
245 .field("options", &self.options)
246 .finish_non_exhaustive()
247 }
248}
249
250#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
251#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
252impl Net for HttpClient {
253 async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> Result<Bytes, NetError> {
254 self.net.get_bytes(url, headers).await
255 }
256
257 async fn get_range(
258 &self,
259 url: Url,
260 range: RangeSpec,
261 headers: Option<Headers>,
262 ) -> Result<crate::ByteStream, NetError> {
263 self.net.get_range(url, range, headers).await
264 }
265
266 async fn head(&self, url: Url, headers: Option<Headers>) -> Result<Headers, NetError> {
267 self.net.head(url, headers).await
268 }
269
270 async fn stream(
271 &self,
272 url: Url,
273 headers: Option<Headers>,
274 ) -> Result<crate::ByteStream, NetError> {
275 self.net.stream(url, headers).await
276 }
277}
278
279#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
280#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
281impl Net for RawHttp {
282 #[cfg_attr(feature = "perf", hotpath::measure)]
283 async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> Result<Bytes, NetError> {
284 let req = self.inner.get(url.clone());
285 let resp = self.send_checked(req, headers, url, false).await?;
286 resp.bytes().await.map_err(NetError::from)
287 }
288
289 #[cfg_attr(feature = "perf", hotpath::measure)]
290 async fn get_range(
291 &self,
292 url: Url,
293 range: RangeSpec,
294 headers: Option<Headers>,
295 ) -> Result<crate::ByteStream, NetError> {
296 let req = self
297 .inner
298 .get(url.clone())
299 .header("Range", range.to_string());
300 let resp = self.send_checked(req, headers, url, true).await?;
301 Ok(Self::response_to_stream(resp))
302 }
303
304 #[cfg_attr(feature = "perf", hotpath::measure)]
305 async fn head(&self, url: Url, headers: Option<Headers>) -> Result<Headers, NetError> {
306 let req = self.head_request(url.clone());
307 let req = Self::apply_headers(req, headers);
308 let req = if let Some(total) = self.options.total_timeout {
309 req.timeout(total)
310 } else {
311 req
312 };
313 let resp = req.send().await.map_err(NetError::from)?;
314
315 let status = resp.status();
316
317 if !status.is_success() && status.as_u16() != HTTP_PARTIAL_CONTENT {
318 let body = truncate_error_body(resp.text().await.unwrap_or_default());
319 return Err(NetError::HttpError {
320 url,
321 status: status.as_u16(),
322 body: Some(body),
323 });
324 }
325
326 let mut out = Headers::new();
327 let str_pairs = resp
328 .headers()
329 .iter()
330 .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str(), v)));
331 for (name, v) in str_pairs {
332 out.insert(name, v);
333 }
334
335 if out.get("content-length").is_none() {
336 let total_from_range = out
337 .get("content-range")
338 .and_then(|h| h.split('/').nth(1))
339 .filter(|s| *s != "*")
340 .map(str::to_owned);
341 if let Some(total) = total_from_range {
342 out.insert("content-length", total);
343 }
344 }
345
346 Ok(out)
347 }
348
349 #[cfg_attr(feature = "perf", hotpath::measure)]
350 async fn stream(
351 &self,
352 url: Url,
353 headers: Option<Headers>,
354 ) -> Result<crate::ByteStream, NetError> {
355 let req = self.inner.get(url.clone());
356 let resp = self.send_checked(req, headers, url, false).await?;
357 Ok(Self::response_to_stream(resp))
358 }
359}
360
361#[cfg(test)]
362#[cfg(not(target_arch = "wasm32"))]
363mod tests {
364 mod kithara {
365 pub(crate) use kithara_test_macros::test;
366 }
367
368 use std::{
369 net::SocketAddr,
370 sync::{
371 Arc,
372 atomic::{AtomicU32, Ordering},
373 },
374 time::Duration,
375 };
376
377 use axum::{Router, http::StatusCode, routing::get};
378 use tokio::net::TcpListener;
379
380 use super::*;
381 use crate::types::RetryPolicy;
382
383 async fn server_failing_first_n(fail_count: u32) -> (Url, Arc<AtomicU32>) {
388 let counter = Arc::new(AtomicU32::new(0));
389 let counter_c = Arc::clone(&counter);
390 let app = Router::new().route(
391 "/probe",
392 get(move || {
393 let counter = Arc::clone(&counter_c);
394 async move {
395 let seen = counter.fetch_add(1, Ordering::SeqCst);
396 if seen < fail_count {
397 (StatusCode::SERVICE_UNAVAILABLE, "busy")
398 } else {
399 (StatusCode::OK, "ok")
400 }
401 }
402 }),
403 );
404 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
405 let addr: SocketAddr = listener.local_addr().expect("local_addr");
406 tokio::spawn(async move {
407 axum::serve(listener, app.into_make_service())
408 .await
409 .expect("serve");
410 });
411 let url = Url::parse(&format!("http://{addr}/probe")).expect("url");
412 (url, counter)
413 }
414
415 fn fast_options(max_retries: u32) -> NetOptions {
416 NetOptions::builder()
417 .retry_policy(RetryPolicy {
418 max_retries,
419 base_delay: Duration::from_millis(1),
420 max_delay: Duration::from_millis(10),
421 })
422 .build()
423 }
424
425 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
426 async fn http_client_retries_503_until_ok() {
427 let (url, counter) = server_failing_first_n(2).await;
428 let client = HttpClient::new(fast_options(3), CancellationToken::new());
429 let bytes = client
430 .get_bytes(url, None)
431 .await
432 .expect("get_bytes must succeed after retries");
433 assert_eq!(&bytes[..], b"ok");
434 assert_eq!(
435 counter.load(Ordering::SeqCst),
436 3,
437 "exactly 3 attempts: 2 failed (503) + 1 ok"
438 );
439 }
440
441 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
442 async fn http_client_no_retry_propagates_5xx() {
443 let (url, counter) = server_failing_first_n(2).await;
444 let client = HttpClient::new(fast_options(0), CancellationToken::new());
445 let err = client
446 .get_bytes(url, None)
447 .await
448 .expect_err("max_retries=0 must propagate the 503");
449 assert!(
450 matches!(err, NetError::HttpError { status: 503, .. }),
451 "expected HttpError(503), got {err:?}"
452 );
453 assert_eq!(
454 counter.load(Ordering::SeqCst),
455 1,
456 "max_retries=0 issues exactly one attempt"
457 );
458 }
459
460 #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
461 async fn http_client_head_retries_503_until_ok() {
462 let (url, counter) = server_failing_first_n(1).await;
463 let client = HttpClient::new(fast_options(2), CancellationToken::new());
464 client.head(url, None).await.expect("HEAD must retry");
465 assert_eq!(counter.load(Ordering::SeqCst), 2);
466 }
467}