Skip to main content

kithara_net/
client.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures::TryStreamExt;
6use reqwest::Client;
7use tokio_util::sync::CancellationToken;
8use url::Url;
9
10use crate::{
11    error::{NetError, NetResult},
12    retry::{DefaultRetryPolicy, RetryNet},
13    traits::{Net, NetExt},
14    types::{Compression, Headers, NetOptions, RangeSpec},
15};
16
17/// HTTP 206 Partial Content status code.
18const HTTP_PARTIAL_CONTENT: u16 = 206;
19
20/// Truncate an HTTP error body so it stays useful in logs without dumping
21/// kilobytes of HTML (rate-limit stubs, anti-bot challenges). Preserves
22/// the first 200 characters (char-aligned to not split a UTF-8 codepoint)
23/// and appends a `…(truncated, N chars total)` suffix for anything longer.
24fn truncate_error_body(mut body: String) -> String {
25    /// Maximum characters of an HTTP error body kept in
26    /// [`NetError::HttpError`].
27    const MAX_CHARS: usize = 200;
28
29    let total = body.chars().count();
30    if total <= MAX_CHARS {
31        return body;
32    }
33    let cut_at = body
34        .char_indices()
35        .nth(MAX_CHARS)
36        .map_or(body.len(), |(i, _)| i);
37    body.truncate(cut_at);
38    body.push_str(&format!("…(truncated, {total} chars total)"));
39    body
40}
41
42/// Build a `reqwest::Client` with our default configuration. Native
43/// build applies pool / TLS / read-timeout knobs; wasm32 takes the
44/// builder defaults because most options aren't supported there.
45#[cfg(not(target_arch = "wasm32"))]
46type ClientBuilderMod = fn(reqwest::ClientBuilder) -> reqwest::ClientBuilder;
47
48#[cfg(not(target_arch = "wasm32"))]
49impl From<Compression> for Vec<ClientBuilderMod> {
50    fn from(c: Compression) -> Self {
51        [
52            (
53                Compression::GZIP,
54                reqwest::ClientBuilder::no_gzip as ClientBuilderMod,
55            ),
56            (Compression::DEFLATE, reqwest::ClientBuilder::no_deflate),
57            (Compression::BROTLI, reqwest::ClientBuilder::no_brotli),
58            (Compression::ZSTD, reqwest::ClientBuilder::no_zstd),
59        ]
60        .into_iter()
61        .filter(|(flag, _)| !c.contains(*flag))
62        .map(|(_, disable)| disable)
63        .collect()
64    }
65}
66
67#[cfg(not(target_arch = "wasm32"))]
68fn build_client(options: &NetOptions) -> reqwest::Result<Client> {
69    let base = Client::builder()
70        .cookie_store(true)
71        .pool_max_idle_per_host(options.pool_max_idle_per_host)
72        .pool_idle_timeout(Some(std::time::Duration::from_secs(5)))
73        .danger_accept_invalid_certs(options.is_insecure)
74        .read_timeout(options.inactivity_timeout);
75    Vec::<ClientBuilderMod>::from(options.compression)
76        .into_iter()
77        .fold(base, |b, disable| disable(b))
78        .build()
79}
80
81#[cfg(target_arch = "wasm32")]
82fn build_client(_options: &NetOptions) -> reqwest::Result<Client> {
83    Client::builder().build()
84}
85
86/// Extract response headers into our [`Headers`] type.
87fn extract_headers(resp: &reqwest::Response) -> Headers {
88    let mut headers = Headers::new();
89    let str_pairs = resp
90        .headers()
91        .iter()
92        .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str(), v)));
93    for (name, value) in str_pairs {
94        headers.insert(name, value);
95    }
96    headers
97}
98
99/// Raw HTTP client (one `reqwest::Client`, no retry layer). Lives
100/// behind [`HttpClient`]'s [`RetryNet`] decorator — exposed only via
101/// the [`Net`] trait, never constructed by callers directly.
102#[derive(Clone)]
103struct RawHttp {
104    inner: Client,
105    options: NetOptions,
106}
107
108impl RawHttp {
109    fn apply_headers(
110        mut req: reqwest::RequestBuilder,
111        headers: Option<Headers>,
112    ) -> reqwest::RequestBuilder {
113        if let Some(headers) = headers {
114            for (k, v) in headers.iter() {
115                req = req.header(k, v);
116            }
117        }
118        req
119    }
120
121    #[cfg(not(target_arch = "wasm32"))]
122    fn head_request(&self, url: Url) -> reqwest::RequestBuilder {
123        self.inner.head(url)
124    }
125
126    #[cfg(target_arch = "wasm32")]
127    fn head_request(&self, url: Url) -> reqwest::RequestBuilder {
128        self.inner.get(url).header("Range", "bytes=0-0")
129    }
130
131    fn response_to_stream(resp: reqwest::Response) -> crate::ByteStream {
132        let headers = extract_headers(&resp);
133        let stream = resp.bytes_stream().map_err(NetError::from);
134        crate::ByteStream::new(headers, Box::pin(stream))
135    }
136
137    async fn send_checked(
138        &self,
139        req: reqwest::RequestBuilder,
140        headers: Option<Headers>,
141        url: Url,
142        accept_partial: bool,
143    ) -> Result<reqwest::Response, NetError> {
144        let req = Self::apply_headers(req, headers);
145        let req = if let Some(total) = self.options.total_timeout {
146            req.timeout(total)
147        } else {
148            req
149        };
150        let resp = req.send().await.map_err(NetError::from)?;
151        let status = resp.status();
152
153        let ok = status.is_success() || (accept_partial && status.as_u16() == HTTP_PARTIAL_CONTENT);
154        if !ok {
155            let body = truncate_error_body(resp.text().await.unwrap_or_default());
156            return Err(NetError::HttpError {
157                url,
158                status: status.as_u16(),
159                body: Some(body),
160            });
161        }
162
163        Ok(resp)
164    }
165}
166
167/// Production HTTP client used across the workspace. Wraps a raw
168/// `reqwest::Client` with the workspace's [`RetryNet`] decorator so
169/// every [`Net`] method (`head`/`get_bytes`/`get_range`/`stream`) honours
170/// `options.retry_policy` — retryable errors (TLS-close, timeout,
171/// 5xx, IO) are re-issued with exponential backoff; non-retryable
172/// errors (HTTP 4xx, cancellation) propagate immediately.
173#[derive(Clone)]
174pub struct HttpClient {
175    net: Arc<RetryNet<RawHttp, DefaultRetryPolicy>>,
176    options: NetOptions,
177}
178
179impl HttpClient {
180    /// Build a retry-decorated HTTP client rooted on `cancel`. The
181    /// `RetryNet` layer aborts pending retries when that token is
182    /// cancelled. Callers MUST pass a token that lives in the
183    /// consumer-crate's cancel tree — typically
184    /// `master_cancel.child_token()` derived at the consumer-crate top
185    /// (`App`, `Queue`, FFI player). The workspace cancel hierarchy
186    /// forbids orphan tokens in production code.
187    ///
188    /// # Panics
189    ///
190    /// Panics if the `reqwest::Client` builder fails to build.
191    #[must_use]
192    pub fn new(options: NetOptions, cancel: CancellationToken) -> Self {
193        let inner = build_client(&options)
194            .expect("BUG: reqwest::Client::builder().build() with our defaults cannot fail");
195        let raw = RawHttp {
196            inner,
197            options: options.clone(),
198        };
199        let net = Arc::new(raw.with_retry(options.retry_policy.clone(), cancel));
200        Self { net, options }
201    }
202
203    /// # Errors
204    ///
205    /// Returns [`NetError`] on HTTP failure, timeout, or network error.
206    pub async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> NetResult<Bytes> {
207        self.net.get_bytes(url, headers).await
208    }
209
210    /// # Errors
211    ///
212    /// Returns [`NetError`] on HTTP failure or network error.
213    pub async fn get_range(
214        &self,
215        url: Url,
216        range: RangeSpec,
217        headers: Option<Headers>,
218    ) -> NetResult<crate::ByteStream> {
219        self.net.get_range(url, range, headers).await
220    }
221
222    /// # Errors
223    ///
224    /// Returns [`NetError`] on HTTP failure or network error.
225    pub async fn head(&self, url: Url, headers: Option<Headers>) -> NetResult<Headers> {
226        self.net.head(url, headers).await
227    }
228
229    #[must_use]
230    pub fn options(&self) -> &NetOptions {
231        &self.options
232    }
233
234    /// # Errors
235    ///
236    /// Returns [`NetError`] on HTTP failure or network error.
237    pub async fn stream(&self, url: Url, headers: Option<Headers>) -> NetResult<crate::ByteStream> {
238        self.net.stream(url, headers).await
239    }
240}
241
242impl std::fmt::Debug for HttpClient {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        f.debug_struct("HttpClient")
245            .field("options", &self.options)
246            .finish_non_exhaustive()
247    }
248}
249
250#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
251#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
252impl Net for HttpClient {
253    async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> Result<Bytes, NetError> {
254        self.net.get_bytes(url, headers).await
255    }
256
257    async fn get_range(
258        &self,
259        url: Url,
260        range: RangeSpec,
261        headers: Option<Headers>,
262    ) -> Result<crate::ByteStream, NetError> {
263        self.net.get_range(url, range, headers).await
264    }
265
266    async fn head(&self, url: Url, headers: Option<Headers>) -> Result<Headers, NetError> {
267        self.net.head(url, headers).await
268    }
269
270    async fn stream(
271        &self,
272        url: Url,
273        headers: Option<Headers>,
274    ) -> Result<crate::ByteStream, NetError> {
275        self.net.stream(url, headers).await
276    }
277}
278
279#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
280#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
281impl Net for RawHttp {
282    #[cfg_attr(feature = "perf", hotpath::measure)]
283    async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> Result<Bytes, NetError> {
284        let req = self.inner.get(url.clone());
285        let resp = self.send_checked(req, headers, url, false).await?;
286        resp.bytes().await.map_err(NetError::from)
287    }
288
289    #[cfg_attr(feature = "perf", hotpath::measure)]
290    async fn get_range(
291        &self,
292        url: Url,
293        range: RangeSpec,
294        headers: Option<Headers>,
295    ) -> Result<crate::ByteStream, NetError> {
296        let req = self
297            .inner
298            .get(url.clone())
299            .header("Range", range.to_string());
300        let resp = self.send_checked(req, headers, url, true).await?;
301        Ok(Self::response_to_stream(resp))
302    }
303
304    #[cfg_attr(feature = "perf", hotpath::measure)]
305    async fn head(&self, url: Url, headers: Option<Headers>) -> Result<Headers, NetError> {
306        let req = self.head_request(url.clone());
307        let req = Self::apply_headers(req, headers);
308        let req = if let Some(total) = self.options.total_timeout {
309            req.timeout(total)
310        } else {
311            req
312        };
313        let resp = req.send().await.map_err(NetError::from)?;
314
315        let status = resp.status();
316
317        if !status.is_success() && status.as_u16() != HTTP_PARTIAL_CONTENT {
318            let body = truncate_error_body(resp.text().await.unwrap_or_default());
319            return Err(NetError::HttpError {
320                url,
321                status: status.as_u16(),
322                body: Some(body),
323            });
324        }
325
326        let mut out = Headers::new();
327        let str_pairs = resp
328            .headers()
329            .iter()
330            .filter_map(|(name, value)| value.to_str().ok().map(|v| (name.as_str(), v)));
331        for (name, v) in str_pairs {
332            out.insert(name, v);
333        }
334
335        if out.get("content-length").is_none() {
336            let total_from_range = out
337                .get("content-range")
338                .and_then(|h| h.split('/').nth(1))
339                .filter(|s| *s != "*")
340                .map(str::to_owned);
341            if let Some(total) = total_from_range {
342                out.insert("content-length", total);
343            }
344        }
345
346        Ok(out)
347    }
348
349    #[cfg_attr(feature = "perf", hotpath::measure)]
350    async fn stream(
351        &self,
352        url: Url,
353        headers: Option<Headers>,
354    ) -> Result<crate::ByteStream, NetError> {
355        let req = self.inner.get(url.clone());
356        let resp = self.send_checked(req, headers, url, false).await?;
357        Ok(Self::response_to_stream(resp))
358    }
359}
360
361#[cfg(test)]
362#[cfg(not(target_arch = "wasm32"))]
363mod tests {
364    mod kithara {
365        pub(crate) use kithara_test_macros::test;
366    }
367
368    use std::{
369        net::SocketAddr,
370        sync::{
371            Arc,
372            atomic::{AtomicU32, Ordering},
373        },
374        time::Duration,
375    };
376
377    use axum::{Router, http::StatusCode, routing::get};
378    use tokio::net::TcpListener;
379
380    use super::*;
381    use crate::types::RetryPolicy;
382
383    /// Spawn an axum server that returns 503 for the first
384    /// `fail_count` requests against `/probe`, then 200 `"ok"` for
385    /// every subsequent request. Returns the bound URL and a counter
386    /// shared with the handler.
387    async fn server_failing_first_n(fail_count: u32) -> (Url, Arc<AtomicU32>) {
388        let counter = Arc::new(AtomicU32::new(0));
389        let counter_c = Arc::clone(&counter);
390        let app = Router::new().route(
391            "/probe",
392            get(move || {
393                let counter = Arc::clone(&counter_c);
394                async move {
395                    let seen = counter.fetch_add(1, Ordering::SeqCst);
396                    if seen < fail_count {
397                        (StatusCode::SERVICE_UNAVAILABLE, "busy")
398                    } else {
399                        (StatusCode::OK, "ok")
400                    }
401                }
402            }),
403        );
404        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
405        let addr: SocketAddr = listener.local_addr().expect("local_addr");
406        tokio::spawn(async move {
407            axum::serve(listener, app.into_make_service())
408                .await
409                .expect("serve");
410        });
411        let url = Url::parse(&format!("http://{addr}/probe")).expect("url");
412        (url, counter)
413    }
414
415    fn fast_options(max_retries: u32) -> NetOptions {
416        NetOptions::builder()
417            .retry_policy(RetryPolicy {
418                max_retries,
419                base_delay: Duration::from_millis(1),
420                max_delay: Duration::from_millis(10),
421            })
422            .build()
423    }
424
425    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
426    async fn http_client_retries_503_until_ok() {
427        let (url, counter) = server_failing_first_n(2).await;
428        let client = HttpClient::new(fast_options(3), CancellationToken::new());
429        let bytes = client
430            .get_bytes(url, None)
431            .await
432            .expect("get_bytes must succeed after retries");
433        assert_eq!(&bytes[..], b"ok");
434        assert_eq!(
435            counter.load(Ordering::SeqCst),
436            3,
437            "exactly 3 attempts: 2 failed (503) + 1 ok"
438        );
439    }
440
441    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
442    async fn http_client_no_retry_propagates_5xx() {
443        let (url, counter) = server_failing_first_n(2).await;
444        let client = HttpClient::new(fast_options(0), CancellationToken::new());
445        let err = client
446            .get_bytes(url, None)
447            .await
448            .expect_err("max_retries=0 must propagate the 503");
449        assert!(
450            matches!(err, NetError::HttpError { status: 503, .. }),
451            "expected HttpError(503), got {err:?}"
452        );
453        assert_eq!(
454            counter.load(Ordering::SeqCst),
455            1,
456            "max_retries=0 issues exactly one attempt"
457        );
458    }
459
460    #[kithara::test(tokio, timeout(Duration::from_secs(5)))]
461    async fn http_client_head_retries_503_until_ok() {
462        let (url, counter) = server_failing_first_n(1).await;
463        let client = HttpClient::new(fast_options(2), CancellationToken::new());
464        client.head(url, None).await.expect("HEAD must retry");
465        assert_eq!(counter.load(Ordering::SeqCst), 2);
466    }
467}