Skip to main content

polyoxide_core/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use reqwest::StatusCode;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6use url::Url;
7
8use reqwest::header::RETRY_AFTER;
9
10use crate::error::ApiError;
11use crate::rate_limit::{RateLimiter, RetryConfig};
12
13/// Extract the `Retry-After` header value as a string, if present and valid UTF-8.
14pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
15    response
16        .headers()
17        .get(RETRY_AFTER)?
18        .to_str()
19        .ok()
20        .map(String::from)
21}
22
23/// Default request timeout in milliseconds
24pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
25/// Default connection pool size per host
26pub const DEFAULT_POOL_SIZE: usize = 10;
27
28/// Shared HTTP client with base URL, optional rate limiter, and retry config.
29///
30/// This is the common structure used by all API clients to hold
31/// the configured reqwest client, base URL, and rate-limiting state.
32#[derive(Debug, Clone)]
33pub struct HttpClient {
34    /// The underlying reqwest HTTP client
35    pub client: reqwest::Client,
36    /// Base URL for API requests
37    pub base_url: Url,
38    rate_limiter: Option<RateLimiter>,
39    retry_config: RetryConfig,
40    concurrency_limiter: Option<Arc<Semaphore>>,
41}
42
43impl HttpClient {
44    /// Await rate limiter for the given endpoint path + method.
45    pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
46        if let Some(rl) = &self.rate_limiter {
47            rl.acquire(path, method).await;
48        }
49    }
50
51    /// Acquire a concurrency permit, if a limiter is configured.
52    ///
53    /// The returned permit **must** be held until the HTTP response has been
54    /// received. Dropping the permit releases the concurrency slot.
55    /// Returns `None` when no concurrency limit is set.
56    pub async fn acquire_concurrency(&self) -> Option<OwnedSemaphorePermit> {
57        let sem = self.concurrency_limiter.as_ref()?;
58        Some(
59            sem.clone()
60                .acquire_owned()
61                .await
62                .expect("concurrency semaphore is never closed"),
63        )
64    }
65
66    /// Check if a 429 response should be retried; returns backoff duration if yes.
67    ///
68    /// When `retry_after` is `Some`, the server-provided delay is used instead of
69    /// the client-computed exponential backoff (clamped to `max_backoff_ms`).
70    pub fn should_retry(
71        &self,
72        status: StatusCode,
73        attempt: u32,
74        retry_after: Option<&str>,
75    ) -> Option<Duration> {
76        if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
77            if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
78                let ms = (delay * 1000.0) as u64;
79                Some(Duration::from_millis(
80                    ms.min(self.retry_config.max_backoff_ms),
81                ))
82            } else {
83                Some(self.retry_config.backoff(attempt))
84            }
85        } else {
86            None
87        }
88    }
89
90    /// GET a URL and return the raw response body as bytes.
91    ///
92    /// Use this for endpoints that return non-JSON payloads (e.g. `application/zip`
93    /// downloads). Applies the same rate-limiting, concurrency gating, and 429
94    /// retry behavior as the JSON-oriented [`Request`](crate::Request) helper.
95    ///
96    /// Non-2xx responses are mapped to [`ApiError`] via
97    /// [`ApiError::from_response`].
98    ///
99    /// # Errors
100    ///
101    /// Returns [`ApiError`] on URL-join failure, network errors, or non-2xx
102    /// responses.
103    pub async fn get_bytes(
104        &self,
105        path: &str,
106        query: &[(String, String)],
107    ) -> Result<Vec<u8>, ApiError> {
108        let url = self.base_url.join(path)?;
109        let mut attempt = 0u32;
110
111        loop {
112            let _permit = self.acquire_concurrency().await;
113            self.acquire_rate_limit(path, None).await;
114
115            let mut request = self.client.get(url.clone());
116            if !query.is_empty() {
117                request = request.query(query);
118            }
119
120            let response = request.send().await?;
121            let status = response.status();
122            let retry_after = retry_after_header(&response);
123
124            if let Some(backoff) = self.should_retry(status, attempt, retry_after.as_deref()) {
125                attempt += 1;
126                tracing::warn!(
127                    "Rate limited (429) on {}, retry {} after {}ms",
128                    path,
129                    attempt,
130                    backoff.as_millis()
131                );
132                drop(_permit);
133                tokio::time::sleep(backoff).await;
134                continue;
135            }
136
137            if !status.is_success() {
138                return Err(ApiError::from_response(response).await);
139            }
140
141            let bytes = response.bytes().await?;
142            return Ok(bytes.to_vec());
143        }
144    }
145}
146
147/// Builder for configuring HTTP clients.
148///
149/// Provides a consistent way to configure HTTP clients across all API crates
150/// with sensible defaults.
151///
152/// # Example
153///
154/// ```
155/// use polyoxide_core::HttpClientBuilder;
156///
157/// let client = HttpClientBuilder::new("https://api.example.com")
158///     .timeout_ms(60_000)
159///     .pool_size(20)
160///     .build()
161///     .unwrap();
162/// ```
163pub struct HttpClientBuilder {
164    base_url: String,
165    timeout_ms: u64,
166    pool_size: usize,
167    rate_limiter: Option<RateLimiter>,
168    retry_config: RetryConfig,
169    max_concurrent: Option<usize>,
170}
171
172impl HttpClientBuilder {
173    /// Create a new HTTP client builder with the given base URL.
174    pub fn new(base_url: impl Into<String>) -> Self {
175        Self {
176            base_url: base_url.into(),
177            timeout_ms: DEFAULT_TIMEOUT_MS,
178            pool_size: DEFAULT_POOL_SIZE,
179            rate_limiter: None,
180            retry_config: RetryConfig::default(),
181            max_concurrent: None,
182        }
183    }
184
185    /// Set request timeout in milliseconds.
186    ///
187    /// Default: 30,000ms (30 seconds)
188    pub fn timeout_ms(mut self, timeout: u64) -> Self {
189        self.timeout_ms = timeout;
190        self
191    }
192
193    /// Set connection pool size per host.
194    ///
195    /// Default: 10 connections
196    pub fn pool_size(mut self, size: usize) -> Self {
197        self.pool_size = size;
198        self
199    }
200
201    /// Set a rate limiter for this client.
202    pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
203        self.rate_limiter = Some(limiter);
204        self
205    }
206
207    /// Set retry configuration for 429 responses.
208    pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
209        self.retry_config = config;
210        self
211    }
212
213    /// Set the maximum number of concurrent in-flight HTTP requests.
214    ///
215    /// Prevents Cloudflare 1015 rate-limit errors caused by request bursts
216    /// when many callers share the same client concurrently.
217    pub fn with_max_concurrent(mut self, max: usize) -> Self {
218        self.max_concurrent = Some(max);
219        self
220    }
221
222    /// Build the HTTP client.
223    pub fn build(self) -> Result<HttpClient, ApiError> {
224        let client = reqwest::Client::builder()
225            .timeout(Duration::from_millis(self.timeout_ms))
226            .connect_timeout(Duration::from_secs(10))
227            .redirect(reqwest::redirect::Policy::none())
228            .pool_max_idle_per_host(self.pool_size)
229            .build()?;
230
231        let base_url = Url::parse(&self.base_url)?;
232
233        Ok(HttpClient {
234            client,
235            base_url,
236            rate_limiter: self.rate_limiter,
237            retry_config: self.retry_config,
238            concurrency_limiter: self.max_concurrent.map(|n| Arc::new(Semaphore::new(n))),
239        })
240    }
241}
242
243impl Default for HttpClientBuilder {
244    fn default() -> Self {
245        Self {
246            base_url: String::new(),
247            timeout_ms: DEFAULT_TIMEOUT_MS,
248            pool_size: DEFAULT_POOL_SIZE,
249            rate_limiter: None,
250            retry_config: RetryConfig::default(),
251            max_concurrent: None,
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    // ── should_retry() ───────────────────────────────────────────
261
262    #[test]
263    fn test_should_retry_429_under_max() {
264        let client = HttpClientBuilder::new("https://example.com")
265            .build()
266            .unwrap();
267        // Default max_retries=3, so attempts 0 and 2 should retry
268        assert!(client
269            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
270            .is_some());
271        assert!(client
272            .should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
273            .is_some());
274    }
275
276    #[test]
277    fn test_should_retry_429_at_max() {
278        let client = HttpClientBuilder::new("https://example.com")
279            .build()
280            .unwrap();
281        // attempt == max_retries → no retry
282        assert!(client
283            .should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
284            .is_none());
285    }
286
287    #[test]
288    fn test_should_retry_non_429_returns_none() {
289        let client = HttpClientBuilder::new("https://example.com")
290            .build()
291            .unwrap();
292        for status in [
293            StatusCode::OK,
294            StatusCode::INTERNAL_SERVER_ERROR,
295            StatusCode::BAD_REQUEST,
296            StatusCode::FORBIDDEN,
297        ] {
298            assert!(
299                client.should_retry(status, 0, None).is_none(),
300                "expected None for {status}"
301            );
302        }
303    }
304
305    #[test]
306    fn test_should_retry_custom_config() {
307        let client = HttpClientBuilder::new("https://example.com")
308            .with_retry_config(RetryConfig {
309                max_retries: 1,
310                ..RetryConfig::default()
311            })
312            .build()
313            .unwrap();
314        assert!(client
315            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
316            .is_some());
317        assert!(client
318            .should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
319            .is_none());
320    }
321
322    #[test]
323    fn test_should_retry_uses_retry_after_header() {
324        let client = HttpClientBuilder::new("https://example.com")
325            .build()
326            .unwrap();
327        let d = client
328            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
329            .unwrap();
330        assert_eq!(d, Duration::from_millis(2000));
331    }
332
333    #[test]
334    fn test_should_retry_retry_after_fractional_seconds() {
335        let client = HttpClientBuilder::new("https://example.com")
336            .build()
337            .unwrap();
338        let d = client
339            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
340            .unwrap();
341        assert_eq!(d, Duration::from_millis(500));
342    }
343
344    #[test]
345    fn test_should_retry_retry_after_clamped_to_max_backoff() {
346        let client = HttpClientBuilder::new("https://example.com")
347            .build()
348            .unwrap();
349        // Default max_backoff_ms = 10_000; header says 60s
350        let d = client
351            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
352            .unwrap();
353        assert_eq!(d, Duration::from_millis(10_000));
354    }
355
356    #[test]
357    fn test_should_retry_retry_after_invalid_falls_back() {
358        let client = HttpClientBuilder::new("https://example.com")
359            .build()
360            .unwrap();
361        // Non-numeric Retry-After (HTTP-date format) falls back to computed backoff
362        let d = client
363            .should_retry(
364                StatusCode::TOO_MANY_REQUESTS,
365                0,
366                Some("Wed, 21 Oct 2025 07:28:00 GMT"),
367            )
368            .unwrap();
369        // Should be in the jitter range for attempt 0: [375, 625]ms
370        let ms = d.as_millis() as u64;
371        assert!(
372            (375..=625).contains(&ms),
373            "expected fallback backoff in [375, 625], got {ms}"
374        );
375    }
376
377    // ── Builder wiring ───────────────────────────────────────────
378
379    #[tokio::test]
380    async fn test_builder_with_rate_limiter() {
381        let client = HttpClientBuilder::new("https://example.com")
382            .with_rate_limiter(RateLimiter::clob_default())
383            .build()
384            .unwrap();
385        let start = std::time::Instant::now();
386        client
387            .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
388            .await;
389        assert!(start.elapsed() < Duration::from_millis(50));
390    }
391
392    #[tokio::test]
393    async fn test_builder_without_rate_limiter() {
394        let client = HttpClientBuilder::new("https://example.com")
395            .build()
396            .unwrap();
397        let start = std::time::Instant::now();
398        client
399            .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
400            .await;
401        assert!(start.elapsed() < Duration::from_millis(10));
402    }
403
404    // ── Concurrency limiter ─────────────────────────────────────
405
406    #[tokio::test]
407    async fn test_acquire_concurrency_none_when_not_configured() {
408        let client = HttpClientBuilder::new("https://example.com")
409            .build()
410            .unwrap();
411        assert!(client.acquire_concurrency().await.is_none());
412    }
413
414    #[tokio::test]
415    async fn test_acquire_concurrency_returns_permit() {
416        let client = HttpClientBuilder::new("https://example.com")
417            .with_max_concurrent(2)
418            .build()
419            .unwrap();
420        let permit = client.acquire_concurrency().await;
421        assert!(permit.is_some());
422    }
423
424    #[tokio::test]
425    async fn test_concurrency_shared_across_clones() {
426        let client = HttpClientBuilder::new("https://example.com")
427            .with_max_concurrent(1)
428            .build()
429            .unwrap();
430        let clone = client.clone();
431
432        // Hold the only permit from the original
433        let _permit = client.acquire_concurrency().await.unwrap();
434
435        // Clone should block because concurrency=1 and permit is held
436        let result =
437            tokio::time::timeout(Duration::from_millis(50), clone.acquire_concurrency()).await;
438        assert!(result.is_err(), "clone should block when permit is held");
439    }
440
441    #[tokio::test]
442    async fn test_concurrency_limits_parallel_tasks() {
443        let client = HttpClientBuilder::new("https://example.com")
444            .with_max_concurrent(2)
445            .build()
446            .unwrap();
447
448        let start = std::time::Instant::now();
449        let mut handles = Vec::new();
450        for _ in 0..4 {
451            let c = client.clone();
452            handles.push(tokio::spawn(async move {
453                let _permit = c.acquire_concurrency().await;
454                tokio::time::sleep(Duration::from_millis(50)).await;
455            }));
456        }
457        for h in handles {
458            h.await.unwrap();
459        }
460        // 4 tasks, concurrency 2, 50ms each => ~100ms minimum
461        assert!(
462            start.elapsed() >= Duration::from_millis(90),
463            "expected ~100ms, got {:?}",
464            start.elapsed()
465        );
466    }
467
468    #[tokio::test]
469    async fn test_builder_with_max_concurrent() {
470        let client = HttpClientBuilder::new("https://example.com")
471            .with_max_concurrent(5)
472            .build()
473            .unwrap();
474        // Should be able to acquire 5 permits
475        let mut permits = Vec::new();
476        for _ in 0..5 {
477            permits.push(client.acquire_concurrency().await);
478        }
479        assert!(permits.iter().all(|p| p.is_some()));
480
481        // 6th should block
482        let result =
483            tokio::time::timeout(Duration::from_millis(50), client.acquire_concurrency()).await;
484        assert!(result.is_err());
485    }
486
487    // ── get_bytes() ──────────────────────────────────────────────
488
489    #[tokio::test]
490    async fn test_get_bytes_returns_body_verbatim() {
491        let mut server = mockito::Server::new_async().await;
492        // Intentionally non-UTF-8 bytes to prove we're not assuming text.
493        let body: Vec<u8> = vec![0x50, 0x4B, 0x03, 0x04, 0x00, 0xFF, 0xFE, 0x42];
494        let mock = server
495            .mock("GET", "/v1/accounting/snapshot")
496            .match_query(mockito::Matcher::UrlEncoded("user".into(), "0xabc".into()))
497            .with_status(200)
498            .with_header("content-type", "application/zip")
499            .with_body(body.clone())
500            .create_async()
501            .await;
502
503        let client = HttpClientBuilder::new(server.url()).build().unwrap();
504        let out = client
505            .get_bytes(
506                "/v1/accounting/snapshot",
507                &[("user".to_string(), "0xabc".to_string())],
508            )
509            .await
510            .unwrap();
511        assert_eq!(out, body);
512        mock.assert_async().await;
513    }
514
515    #[tokio::test]
516    async fn test_get_bytes_maps_non_2xx_to_api_error() {
517        let mut server = mockito::Server::new_async().await;
518        let mock = server
519            .mock("GET", "/does-not-exist")
520            .with_status(404)
521            .with_header("content-type", "application/json")
522            .with_body(r#"{"error": "not found"}"#)
523            .create_async()
524            .await;
525
526        let client = HttpClientBuilder::new(server.url()).build().unwrap();
527        let err = client.get_bytes("/does-not-exist", &[]).await.unwrap_err();
528        match err {
529            ApiError::Api { status, message } => {
530                assert_eq!(status, 404);
531                assert_eq!(message, "not found");
532            }
533            other => panic!("expected ApiError::Api, got {other:?}"),
534        }
535        mock.assert_async().await;
536    }
537
538    #[tokio::test]
539    async fn test_get_bytes_no_query_params() {
540        let mut server = mockito::Server::new_async().await;
541        let mock = server
542            .mock("GET", "/raw")
543            .with_status(200)
544            .with_body(&b"hello"[..])
545            .create_async()
546            .await;
547
548        let client = HttpClientBuilder::new(server.url()).build().unwrap();
549        let out = client.get_bytes("/raw", &[]).await.unwrap();
550        assert_eq!(out, b"hello");
551        mock.assert_async().await;
552    }
553}