Skip to main content

polyoxide_core/
client.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use reqwest::StatusCode;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6use url::Url;
7
8use reqwest::header::RETRY_AFTER;
9
10use crate::error::ApiError;
11use crate::rate_limit::{RateLimiter, RetryConfig};
12
13/// Extract the `Retry-After` header value as a string, if present and valid UTF-8.
14pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
15    response
16        .headers()
17        .get(RETRY_AFTER)?
18        .to_str()
19        .ok()
20        .map(String::from)
21}
22
23/// Default request timeout in milliseconds
24pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
25/// Default connection pool size per host
26pub const DEFAULT_POOL_SIZE: usize = 10;
27
28/// Shared HTTP client with base URL, optional rate limiter, and retry config.
29///
30/// This is the common structure used by all API clients to hold
31/// the configured reqwest client, base URL, and rate-limiting state.
32#[derive(Debug, Clone)]
33pub struct HttpClient {
34    /// The underlying reqwest HTTP client
35    pub client: reqwest::Client,
36    /// Base URL for API requests
37    pub base_url: Url,
38    rate_limiter: Option<RateLimiter>,
39    retry_config: RetryConfig,
40    concurrency_limiter: Option<Arc<Semaphore>>,
41}
42
43impl HttpClient {
44    /// Await rate limiter for the given endpoint path + method.
45    pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
46        if let Some(rl) = &self.rate_limiter {
47            rl.acquire(path, method).await;
48        }
49    }
50
51    /// Acquire a concurrency permit, if a limiter is configured.
52    ///
53    /// The returned permit **must** be held until the HTTP response has been
54    /// received. Dropping the permit releases the concurrency slot.
55    /// Returns `None` when no concurrency limit is set.
56    pub async fn acquire_concurrency(&self) -> Option<OwnedSemaphorePermit> {
57        let sem = self.concurrency_limiter.as_ref()?;
58        Some(
59            sem.clone()
60                .acquire_owned()
61                .await
62                .expect("concurrency semaphore is never closed"),
63        )
64    }
65
66    /// Check if a 429 response should be retried; returns backoff duration if yes.
67    ///
68    /// When `retry_after` is `Some`, the server-provided delay is used instead of
69    /// the client-computed exponential backoff (clamped to `max_backoff_ms`).
70    pub fn should_retry(
71        &self,
72        status: StatusCode,
73        attempt: u32,
74        retry_after: Option<&str>,
75    ) -> Option<Duration> {
76        if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
77            if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
78                let ms = (delay * 1000.0) as u64;
79                Some(Duration::from_millis(
80                    ms.min(self.retry_config.max_backoff_ms),
81                ))
82            } else {
83                Some(self.retry_config.backoff(attempt))
84            }
85        } else {
86            None
87        }
88    }
89}
90
91/// Builder for configuring HTTP clients.
92///
93/// Provides a consistent way to configure HTTP clients across all API crates
94/// with sensible defaults.
95///
96/// # Example
97///
98/// ```
99/// use polyoxide_core::HttpClientBuilder;
100///
101/// let client = HttpClientBuilder::new("https://api.example.com")
102///     .timeout_ms(60_000)
103///     .pool_size(20)
104///     .build()
105///     .unwrap();
106/// ```
107pub struct HttpClientBuilder {
108    base_url: String,
109    timeout_ms: u64,
110    pool_size: usize,
111    rate_limiter: Option<RateLimiter>,
112    retry_config: RetryConfig,
113    max_concurrent: Option<usize>,
114}
115
116impl HttpClientBuilder {
117    /// Create a new HTTP client builder with the given base URL.
118    pub fn new(base_url: impl Into<String>) -> Self {
119        Self {
120            base_url: base_url.into(),
121            timeout_ms: DEFAULT_TIMEOUT_MS,
122            pool_size: DEFAULT_POOL_SIZE,
123            rate_limiter: None,
124            retry_config: RetryConfig::default(),
125            max_concurrent: None,
126        }
127    }
128
129    /// Set request timeout in milliseconds.
130    ///
131    /// Default: 30,000ms (30 seconds)
132    pub fn timeout_ms(mut self, timeout: u64) -> Self {
133        self.timeout_ms = timeout;
134        self
135    }
136
137    /// Set connection pool size per host.
138    ///
139    /// Default: 10 connections
140    pub fn pool_size(mut self, size: usize) -> Self {
141        self.pool_size = size;
142        self
143    }
144
145    /// Set a rate limiter for this client.
146    pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
147        self.rate_limiter = Some(limiter);
148        self
149    }
150
151    /// Set retry configuration for 429 responses.
152    pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
153        self.retry_config = config;
154        self
155    }
156
157    /// Set the maximum number of concurrent in-flight HTTP requests.
158    ///
159    /// Prevents Cloudflare 1015 rate-limit errors caused by request bursts
160    /// when many callers share the same client concurrently.
161    pub fn with_max_concurrent(mut self, max: usize) -> Self {
162        self.max_concurrent = Some(max);
163        self
164    }
165
166    /// Build the HTTP client.
167    pub fn build(self) -> Result<HttpClient, ApiError> {
168        let client = reqwest::Client::builder()
169            .timeout(Duration::from_millis(self.timeout_ms))
170            .connect_timeout(Duration::from_secs(10))
171            .redirect(reqwest::redirect::Policy::none())
172            .pool_max_idle_per_host(self.pool_size)
173            .build()?;
174
175        let base_url = Url::parse(&self.base_url)?;
176
177        Ok(HttpClient {
178            client,
179            base_url,
180            rate_limiter: self.rate_limiter,
181            retry_config: self.retry_config,
182            concurrency_limiter: self.max_concurrent.map(|n| Arc::new(Semaphore::new(n))),
183        })
184    }
185}
186
187impl Default for HttpClientBuilder {
188    fn default() -> Self {
189        Self {
190            base_url: String::new(),
191            timeout_ms: DEFAULT_TIMEOUT_MS,
192            pool_size: DEFAULT_POOL_SIZE,
193            rate_limiter: None,
194            retry_config: RetryConfig::default(),
195            max_concurrent: None,
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    // ── should_retry() ───────────────────────────────────────────
205
206    #[test]
207    fn test_should_retry_429_under_max() {
208        let client = HttpClientBuilder::new("https://example.com")
209            .build()
210            .unwrap();
211        // Default max_retries=3, so attempts 0 and 2 should retry
212        assert!(client
213            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
214            .is_some());
215        assert!(client
216            .should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
217            .is_some());
218    }
219
220    #[test]
221    fn test_should_retry_429_at_max() {
222        let client = HttpClientBuilder::new("https://example.com")
223            .build()
224            .unwrap();
225        // attempt == max_retries → no retry
226        assert!(client
227            .should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
228            .is_none());
229    }
230
231    #[test]
232    fn test_should_retry_non_429_returns_none() {
233        let client = HttpClientBuilder::new("https://example.com")
234            .build()
235            .unwrap();
236        for status in [
237            StatusCode::OK,
238            StatusCode::INTERNAL_SERVER_ERROR,
239            StatusCode::BAD_REQUEST,
240            StatusCode::FORBIDDEN,
241        ] {
242            assert!(
243                client.should_retry(status, 0, None).is_none(),
244                "expected None for {status}"
245            );
246        }
247    }
248
249    #[test]
250    fn test_should_retry_custom_config() {
251        let client = HttpClientBuilder::new("https://example.com")
252            .with_retry_config(RetryConfig {
253                max_retries: 1,
254                ..RetryConfig::default()
255            })
256            .build()
257            .unwrap();
258        assert!(client
259            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
260            .is_some());
261        assert!(client
262            .should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
263            .is_none());
264    }
265
266    #[test]
267    fn test_should_retry_uses_retry_after_header() {
268        let client = HttpClientBuilder::new("https://example.com")
269            .build()
270            .unwrap();
271        let d = client
272            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
273            .unwrap();
274        assert_eq!(d, Duration::from_millis(2000));
275    }
276
277    #[test]
278    fn test_should_retry_retry_after_fractional_seconds() {
279        let client = HttpClientBuilder::new("https://example.com")
280            .build()
281            .unwrap();
282        let d = client
283            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
284            .unwrap();
285        assert_eq!(d, Duration::from_millis(500));
286    }
287
288    #[test]
289    fn test_should_retry_retry_after_clamped_to_max_backoff() {
290        let client = HttpClientBuilder::new("https://example.com")
291            .build()
292            .unwrap();
293        // Default max_backoff_ms = 10_000; header says 60s
294        let d = client
295            .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
296            .unwrap();
297        assert_eq!(d, Duration::from_millis(10_000));
298    }
299
300    #[test]
301    fn test_should_retry_retry_after_invalid_falls_back() {
302        let client = HttpClientBuilder::new("https://example.com")
303            .build()
304            .unwrap();
305        // Non-numeric Retry-After (HTTP-date format) falls back to computed backoff
306        let d = client
307            .should_retry(
308                StatusCode::TOO_MANY_REQUESTS,
309                0,
310                Some("Wed, 21 Oct 2025 07:28:00 GMT"),
311            )
312            .unwrap();
313        // Should be in the jitter range for attempt 0: [375, 625]ms
314        let ms = d.as_millis() as u64;
315        assert!(
316            (375..=625).contains(&ms),
317            "expected fallback backoff in [375, 625], got {ms}"
318        );
319    }
320
321    // ── Builder wiring ───────────────────────────────────────────
322
323    #[tokio::test]
324    async fn test_builder_with_rate_limiter() {
325        let client = HttpClientBuilder::new("https://example.com")
326            .with_rate_limiter(RateLimiter::clob_default())
327            .build()
328            .unwrap();
329        let start = std::time::Instant::now();
330        client
331            .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
332            .await;
333        assert!(start.elapsed() < Duration::from_millis(50));
334    }
335
336    #[tokio::test]
337    async fn test_builder_without_rate_limiter() {
338        let client = HttpClientBuilder::new("https://example.com")
339            .build()
340            .unwrap();
341        let start = std::time::Instant::now();
342        client
343            .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
344            .await;
345        assert!(start.elapsed() < Duration::from_millis(10));
346    }
347
348    // ── Concurrency limiter ─────────────────────────────────────
349
350    #[tokio::test]
351    async fn test_acquire_concurrency_none_when_not_configured() {
352        let client = HttpClientBuilder::new("https://example.com")
353            .build()
354            .unwrap();
355        assert!(client.acquire_concurrency().await.is_none());
356    }
357
358    #[tokio::test]
359    async fn test_acquire_concurrency_returns_permit() {
360        let client = HttpClientBuilder::new("https://example.com")
361            .with_max_concurrent(2)
362            .build()
363            .unwrap();
364        let permit = client.acquire_concurrency().await;
365        assert!(permit.is_some());
366    }
367
368    #[tokio::test]
369    async fn test_concurrency_shared_across_clones() {
370        let client = HttpClientBuilder::new("https://example.com")
371            .with_max_concurrent(1)
372            .build()
373            .unwrap();
374        let clone = client.clone();
375
376        // Hold the only permit from the original
377        let _permit = client.acquire_concurrency().await.unwrap();
378
379        // Clone should block because concurrency=1 and permit is held
380        let result =
381            tokio::time::timeout(Duration::from_millis(50), clone.acquire_concurrency()).await;
382        assert!(result.is_err(), "clone should block when permit is held");
383    }
384
385    #[tokio::test]
386    async fn test_concurrency_limits_parallel_tasks() {
387        let client = HttpClientBuilder::new("https://example.com")
388            .with_max_concurrent(2)
389            .build()
390            .unwrap();
391
392        let start = std::time::Instant::now();
393        let mut handles = Vec::new();
394        for _ in 0..4 {
395            let c = client.clone();
396            handles.push(tokio::spawn(async move {
397                let _permit = c.acquire_concurrency().await;
398                tokio::time::sleep(Duration::from_millis(50)).await;
399            }));
400        }
401        for h in handles {
402            h.await.unwrap();
403        }
404        // 4 tasks, concurrency 2, 50ms each => ~100ms minimum
405        assert!(
406            start.elapsed() >= Duration::from_millis(90),
407            "expected ~100ms, got {:?}",
408            start.elapsed()
409        );
410    }
411
412    #[tokio::test]
413    async fn test_builder_with_max_concurrent() {
414        let client = HttpClientBuilder::new("https://example.com")
415            .with_max_concurrent(5)
416            .build()
417            .unwrap();
418        // Should be able to acquire 5 permits
419        let mut permits = Vec::new();
420        for _ in 0..5 {
421            permits.push(client.acquire_concurrency().await);
422        }
423        assert!(permits.iter().all(|p| p.is_some()));
424
425        // 6th should block
426        let result =
427            tokio::time::timeout(Duration::from_millis(50), client.acquire_concurrency()).await;
428        assert!(result.is_err());
429    }
430}