1use std::sync::Arc;
2use std::time::Duration;
3
4use reqwest::StatusCode;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6use url::Url;
7
8use reqwest::header::RETRY_AFTER;
9
10use crate::error::ApiError;
11use crate::rate_limit::{RateLimiter, RetryConfig};
12
13pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
15 response
16 .headers()
17 .get(RETRY_AFTER)?
18 .to_str()
19 .ok()
20 .map(String::from)
21}
22
23pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
25pub const DEFAULT_POOL_SIZE: usize = 10;
27
28#[derive(Debug, Clone)]
33pub struct HttpClient {
34 pub client: reqwest::Client,
36 pub base_url: Url,
38 rate_limiter: Option<RateLimiter>,
39 retry_config: RetryConfig,
40 concurrency_limiter: Option<Arc<Semaphore>>,
41}
42
43impl HttpClient {
44 pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
46 if let Some(rl) = &self.rate_limiter {
47 rl.acquire(path, method).await;
48 }
49 }
50
51 pub async fn acquire_concurrency(&self) -> Option<OwnedSemaphorePermit> {
57 let sem = self.concurrency_limiter.as_ref()?;
58 Some(
59 sem.clone()
60 .acquire_owned()
61 .await
62 .expect("concurrency semaphore is never closed"),
63 )
64 }
65
66 pub fn should_retry(
71 &self,
72 status: StatusCode,
73 attempt: u32,
74 retry_after: Option<&str>,
75 ) -> Option<Duration> {
76 if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
77 if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
78 let ms = (delay * 1000.0) as u64;
79 Some(Duration::from_millis(
80 ms.min(self.retry_config.max_backoff_ms),
81 ))
82 } else {
83 Some(self.retry_config.backoff(attempt))
84 }
85 } else {
86 None
87 }
88 }
89}
90
91pub struct HttpClientBuilder {
108 base_url: String,
109 timeout_ms: u64,
110 pool_size: usize,
111 rate_limiter: Option<RateLimiter>,
112 retry_config: RetryConfig,
113 max_concurrent: Option<usize>,
114}
115
116impl HttpClientBuilder {
117 pub fn new(base_url: impl Into<String>) -> Self {
119 Self {
120 base_url: base_url.into(),
121 timeout_ms: DEFAULT_TIMEOUT_MS,
122 pool_size: DEFAULT_POOL_SIZE,
123 rate_limiter: None,
124 retry_config: RetryConfig::default(),
125 max_concurrent: None,
126 }
127 }
128
129 pub fn timeout_ms(mut self, timeout: u64) -> Self {
133 self.timeout_ms = timeout;
134 self
135 }
136
137 pub fn pool_size(mut self, size: usize) -> Self {
141 self.pool_size = size;
142 self
143 }
144
145 pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
147 self.rate_limiter = Some(limiter);
148 self
149 }
150
151 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
153 self.retry_config = config;
154 self
155 }
156
157 pub fn with_max_concurrent(mut self, max: usize) -> Self {
162 self.max_concurrent = Some(max);
163 self
164 }
165
166 pub fn build(self) -> Result<HttpClient, ApiError> {
168 let client = reqwest::Client::builder()
169 .timeout(Duration::from_millis(self.timeout_ms))
170 .connect_timeout(Duration::from_secs(10))
171 .redirect(reqwest::redirect::Policy::none())
172 .pool_max_idle_per_host(self.pool_size)
173 .build()?;
174
175 let base_url = Url::parse(&self.base_url)?;
176
177 Ok(HttpClient {
178 client,
179 base_url,
180 rate_limiter: self.rate_limiter,
181 retry_config: self.retry_config,
182 concurrency_limiter: self.max_concurrent.map(|n| Arc::new(Semaphore::new(n))),
183 })
184 }
185}
186
187impl Default for HttpClientBuilder {
188 fn default() -> Self {
189 Self {
190 base_url: String::new(),
191 timeout_ms: DEFAULT_TIMEOUT_MS,
192 pool_size: DEFAULT_POOL_SIZE,
193 rate_limiter: None,
194 retry_config: RetryConfig::default(),
195 max_concurrent: None,
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
207 fn test_should_retry_429_under_max() {
208 let client = HttpClientBuilder::new("https://example.com")
209 .build()
210 .unwrap();
211 assert!(client
213 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
214 .is_some());
215 assert!(client
216 .should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
217 .is_some());
218 }
219
220 #[test]
221 fn test_should_retry_429_at_max() {
222 let client = HttpClientBuilder::new("https://example.com")
223 .build()
224 .unwrap();
225 assert!(client
227 .should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
228 .is_none());
229 }
230
231 #[test]
232 fn test_should_retry_non_429_returns_none() {
233 let client = HttpClientBuilder::new("https://example.com")
234 .build()
235 .unwrap();
236 for status in [
237 StatusCode::OK,
238 StatusCode::INTERNAL_SERVER_ERROR,
239 StatusCode::BAD_REQUEST,
240 StatusCode::FORBIDDEN,
241 ] {
242 assert!(
243 client.should_retry(status, 0, None).is_none(),
244 "expected None for {status}"
245 );
246 }
247 }
248
249 #[test]
250 fn test_should_retry_custom_config() {
251 let client = HttpClientBuilder::new("https://example.com")
252 .with_retry_config(RetryConfig {
253 max_retries: 1,
254 ..RetryConfig::default()
255 })
256 .build()
257 .unwrap();
258 assert!(client
259 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
260 .is_some());
261 assert!(client
262 .should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
263 .is_none());
264 }
265
266 #[test]
267 fn test_should_retry_uses_retry_after_header() {
268 let client = HttpClientBuilder::new("https://example.com")
269 .build()
270 .unwrap();
271 let d = client
272 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
273 .unwrap();
274 assert_eq!(d, Duration::from_millis(2000));
275 }
276
277 #[test]
278 fn test_should_retry_retry_after_fractional_seconds() {
279 let client = HttpClientBuilder::new("https://example.com")
280 .build()
281 .unwrap();
282 let d = client
283 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
284 .unwrap();
285 assert_eq!(d, Duration::from_millis(500));
286 }
287
288 #[test]
289 fn test_should_retry_retry_after_clamped_to_max_backoff() {
290 let client = HttpClientBuilder::new("https://example.com")
291 .build()
292 .unwrap();
293 let d = client
295 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
296 .unwrap();
297 assert_eq!(d, Duration::from_millis(10_000));
298 }
299
300 #[test]
301 fn test_should_retry_retry_after_invalid_falls_back() {
302 let client = HttpClientBuilder::new("https://example.com")
303 .build()
304 .unwrap();
305 let d = client
307 .should_retry(
308 StatusCode::TOO_MANY_REQUESTS,
309 0,
310 Some("Wed, 21 Oct 2025 07:28:00 GMT"),
311 )
312 .unwrap();
313 let ms = d.as_millis() as u64;
315 assert!(
316 (375..=625).contains(&ms),
317 "expected fallback backoff in [375, 625], got {ms}"
318 );
319 }
320
321 #[tokio::test]
324 async fn test_builder_with_rate_limiter() {
325 let client = HttpClientBuilder::new("https://example.com")
326 .with_rate_limiter(RateLimiter::clob_default())
327 .build()
328 .unwrap();
329 let start = std::time::Instant::now();
330 client
331 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
332 .await;
333 assert!(start.elapsed() < Duration::from_millis(50));
334 }
335
336 #[tokio::test]
337 async fn test_builder_without_rate_limiter() {
338 let client = HttpClientBuilder::new("https://example.com")
339 .build()
340 .unwrap();
341 let start = std::time::Instant::now();
342 client
343 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
344 .await;
345 assert!(start.elapsed() < Duration::from_millis(10));
346 }
347
348 #[tokio::test]
351 async fn test_acquire_concurrency_none_when_not_configured() {
352 let client = HttpClientBuilder::new("https://example.com")
353 .build()
354 .unwrap();
355 assert!(client.acquire_concurrency().await.is_none());
356 }
357
358 #[tokio::test]
359 async fn test_acquire_concurrency_returns_permit() {
360 let client = HttpClientBuilder::new("https://example.com")
361 .with_max_concurrent(2)
362 .build()
363 .unwrap();
364 let permit = client.acquire_concurrency().await;
365 assert!(permit.is_some());
366 }
367
368 #[tokio::test]
369 async fn test_concurrency_shared_across_clones() {
370 let client = HttpClientBuilder::new("https://example.com")
371 .with_max_concurrent(1)
372 .build()
373 .unwrap();
374 let clone = client.clone();
375
376 let _permit = client.acquire_concurrency().await.unwrap();
378
379 let result =
381 tokio::time::timeout(Duration::from_millis(50), clone.acquire_concurrency()).await;
382 assert!(result.is_err(), "clone should block when permit is held");
383 }
384
385 #[tokio::test]
386 async fn test_concurrency_limits_parallel_tasks() {
387 let client = HttpClientBuilder::new("https://example.com")
388 .with_max_concurrent(2)
389 .build()
390 .unwrap();
391
392 let start = std::time::Instant::now();
393 let mut handles = Vec::new();
394 for _ in 0..4 {
395 let c = client.clone();
396 handles.push(tokio::spawn(async move {
397 let _permit = c.acquire_concurrency().await;
398 tokio::time::sleep(Duration::from_millis(50)).await;
399 }));
400 }
401 for h in handles {
402 h.await.unwrap();
403 }
404 assert!(
406 start.elapsed() >= Duration::from_millis(90),
407 "expected ~100ms, got {:?}",
408 start.elapsed()
409 );
410 }
411
412 #[tokio::test]
413 async fn test_builder_with_max_concurrent() {
414 let client = HttpClientBuilder::new("https://example.com")
415 .with_max_concurrent(5)
416 .build()
417 .unwrap();
418 let mut permits = Vec::new();
420 for _ in 0..5 {
421 permits.push(client.acquire_concurrency().await);
422 }
423 assert!(permits.iter().all(|p| p.is_some()));
424
425 let result =
427 tokio::time::timeout(Duration::from_millis(50), client.acquire_concurrency()).await;
428 assert!(result.is_err());
429 }
430}