1use std::sync::Arc;
2use std::time::Duration;
3
4use reqwest::StatusCode;
5use tokio::sync::{OwnedSemaphorePermit, Semaphore};
6use url::Url;
7
8use reqwest::header::RETRY_AFTER;
9
10use crate::error::ApiError;
11use crate::rate_limit::{RateLimiter, RetryConfig};
12
13pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
15 response
16 .headers()
17 .get(RETRY_AFTER)?
18 .to_str()
19 .ok()
20 .map(String::from)
21}
22
23pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
25pub const DEFAULT_POOL_SIZE: usize = 10;
27
28#[derive(Debug, Clone)]
33pub struct HttpClient {
34 pub client: reqwest::Client,
36 pub base_url: Url,
38 rate_limiter: Option<RateLimiter>,
39 retry_config: RetryConfig,
40 concurrency_limiter: Option<Arc<Semaphore>>,
41}
42
43impl HttpClient {
44 pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
46 if let Some(rl) = &self.rate_limiter {
47 rl.acquire(path, method).await;
48 }
49 }
50
51 pub async fn acquire_concurrency(&self) -> Option<OwnedSemaphorePermit> {
57 let sem = self.concurrency_limiter.as_ref()?;
58 Some(
59 sem.clone()
60 .acquire_owned()
61 .await
62 .expect("concurrency semaphore is never closed"),
63 )
64 }
65
66 pub fn should_retry(
71 &self,
72 status: StatusCode,
73 attempt: u32,
74 retry_after: Option<&str>,
75 ) -> Option<Duration> {
76 if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
77 if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
78 let ms = (delay * 1000.0) as u64;
79 Some(Duration::from_millis(
80 ms.min(self.retry_config.max_backoff_ms),
81 ))
82 } else {
83 Some(self.retry_config.backoff(attempt))
84 }
85 } else {
86 None
87 }
88 }
89
90 pub async fn get_bytes(
104 &self,
105 path: &str,
106 query: &[(String, String)],
107 ) -> Result<Vec<u8>, ApiError> {
108 let url = self.base_url.join(path)?;
109 let mut attempt = 0u32;
110
111 loop {
112 let _permit = self.acquire_concurrency().await;
113 self.acquire_rate_limit(path, None).await;
114
115 let mut request = self.client.get(url.clone());
116 if !query.is_empty() {
117 request = request.query(query);
118 }
119
120 let response = request.send().await?;
121 let status = response.status();
122 let retry_after = retry_after_header(&response);
123
124 if let Some(backoff) = self.should_retry(status, attempt, retry_after.as_deref()) {
125 attempt += 1;
126 tracing::warn!(
127 "Rate limited (429) on {}, retry {} after {}ms",
128 path,
129 attempt,
130 backoff.as_millis()
131 );
132 drop(_permit);
133 tokio::time::sleep(backoff).await;
134 continue;
135 }
136
137 if !status.is_success() {
138 return Err(ApiError::from_response(response).await);
139 }
140
141 let bytes = response.bytes().await?;
142 return Ok(bytes.to_vec());
143 }
144 }
145}
146
147pub struct HttpClientBuilder {
164 base_url: String,
165 timeout_ms: u64,
166 pool_size: usize,
167 rate_limiter: Option<RateLimiter>,
168 retry_config: RetryConfig,
169 max_concurrent: Option<usize>,
170}
171
172impl HttpClientBuilder {
173 pub fn new(base_url: impl Into<String>) -> Self {
175 Self {
176 base_url: base_url.into(),
177 timeout_ms: DEFAULT_TIMEOUT_MS,
178 pool_size: DEFAULT_POOL_SIZE,
179 rate_limiter: None,
180 retry_config: RetryConfig::default(),
181 max_concurrent: None,
182 }
183 }
184
185 pub fn timeout_ms(mut self, timeout: u64) -> Self {
189 self.timeout_ms = timeout;
190 self
191 }
192
193 pub fn pool_size(mut self, size: usize) -> Self {
197 self.pool_size = size;
198 self
199 }
200
201 pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
203 self.rate_limiter = Some(limiter);
204 self
205 }
206
207 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
209 self.retry_config = config;
210 self
211 }
212
213 pub fn with_max_concurrent(mut self, max: usize) -> Self {
218 self.max_concurrent = Some(max);
219 self
220 }
221
222 pub fn build(self) -> Result<HttpClient, ApiError> {
224 let client = reqwest::Client::builder()
225 .timeout(Duration::from_millis(self.timeout_ms))
226 .connect_timeout(Duration::from_secs(10))
227 .redirect(reqwest::redirect::Policy::none())
228 .pool_max_idle_per_host(self.pool_size)
229 .build()?;
230
231 let base_url = Url::parse(&self.base_url)?;
232
233 Ok(HttpClient {
234 client,
235 base_url,
236 rate_limiter: self.rate_limiter,
237 retry_config: self.retry_config,
238 concurrency_limiter: self.max_concurrent.map(|n| Arc::new(Semaphore::new(n))),
239 })
240 }
241}
242
243impl Default for HttpClientBuilder {
244 fn default() -> Self {
245 Self {
246 base_url: String::new(),
247 timeout_ms: DEFAULT_TIMEOUT_MS,
248 pool_size: DEFAULT_POOL_SIZE,
249 rate_limiter: None,
250 retry_config: RetryConfig::default(),
251 max_concurrent: None,
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
263 fn test_should_retry_429_under_max() {
264 let client = HttpClientBuilder::new("https://example.com")
265 .build()
266 .unwrap();
267 assert!(client
269 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
270 .is_some());
271 assert!(client
272 .should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
273 .is_some());
274 }
275
276 #[test]
277 fn test_should_retry_429_at_max() {
278 let client = HttpClientBuilder::new("https://example.com")
279 .build()
280 .unwrap();
281 assert!(client
283 .should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
284 .is_none());
285 }
286
287 #[test]
288 fn test_should_retry_non_429_returns_none() {
289 let client = HttpClientBuilder::new("https://example.com")
290 .build()
291 .unwrap();
292 for status in [
293 StatusCode::OK,
294 StatusCode::INTERNAL_SERVER_ERROR,
295 StatusCode::BAD_REQUEST,
296 StatusCode::FORBIDDEN,
297 ] {
298 assert!(
299 client.should_retry(status, 0, None).is_none(),
300 "expected None for {status}"
301 );
302 }
303 }
304
305 #[test]
306 fn test_should_retry_custom_config() {
307 let client = HttpClientBuilder::new("https://example.com")
308 .with_retry_config(RetryConfig {
309 max_retries: 1,
310 ..RetryConfig::default()
311 })
312 .build()
313 .unwrap();
314 assert!(client
315 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
316 .is_some());
317 assert!(client
318 .should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
319 .is_none());
320 }
321
322 #[test]
323 fn test_should_retry_uses_retry_after_header() {
324 let client = HttpClientBuilder::new("https://example.com")
325 .build()
326 .unwrap();
327 let d = client
328 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
329 .unwrap();
330 assert_eq!(d, Duration::from_millis(2000));
331 }
332
333 #[test]
334 fn test_should_retry_retry_after_fractional_seconds() {
335 let client = HttpClientBuilder::new("https://example.com")
336 .build()
337 .unwrap();
338 let d = client
339 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
340 .unwrap();
341 assert_eq!(d, Duration::from_millis(500));
342 }
343
344 #[test]
345 fn test_should_retry_retry_after_clamped_to_max_backoff() {
346 let client = HttpClientBuilder::new("https://example.com")
347 .build()
348 .unwrap();
349 let d = client
351 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
352 .unwrap();
353 assert_eq!(d, Duration::from_millis(10_000));
354 }
355
356 #[test]
357 fn test_should_retry_retry_after_invalid_falls_back() {
358 let client = HttpClientBuilder::new("https://example.com")
359 .build()
360 .unwrap();
361 let d = client
363 .should_retry(
364 StatusCode::TOO_MANY_REQUESTS,
365 0,
366 Some("Wed, 21 Oct 2025 07:28:00 GMT"),
367 )
368 .unwrap();
369 let ms = d.as_millis() as u64;
371 assert!(
372 (375..=625).contains(&ms),
373 "expected fallback backoff in [375, 625], got {ms}"
374 );
375 }
376
377 #[tokio::test]
380 async fn test_builder_with_rate_limiter() {
381 let client = HttpClientBuilder::new("https://example.com")
382 .with_rate_limiter(RateLimiter::clob_default())
383 .build()
384 .unwrap();
385 let start = std::time::Instant::now();
386 client
387 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
388 .await;
389 assert!(start.elapsed() < Duration::from_millis(50));
390 }
391
392 #[tokio::test]
393 async fn test_builder_without_rate_limiter() {
394 let client = HttpClientBuilder::new("https://example.com")
395 .build()
396 .unwrap();
397 let start = std::time::Instant::now();
398 client
399 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
400 .await;
401 assert!(start.elapsed() < Duration::from_millis(10));
402 }
403
404 #[tokio::test]
407 async fn test_acquire_concurrency_none_when_not_configured() {
408 let client = HttpClientBuilder::new("https://example.com")
409 .build()
410 .unwrap();
411 assert!(client.acquire_concurrency().await.is_none());
412 }
413
414 #[tokio::test]
415 async fn test_acquire_concurrency_returns_permit() {
416 let client = HttpClientBuilder::new("https://example.com")
417 .with_max_concurrent(2)
418 .build()
419 .unwrap();
420 let permit = client.acquire_concurrency().await;
421 assert!(permit.is_some());
422 }
423
424 #[tokio::test]
425 async fn test_concurrency_shared_across_clones() {
426 let client = HttpClientBuilder::new("https://example.com")
427 .with_max_concurrent(1)
428 .build()
429 .unwrap();
430 let clone = client.clone();
431
432 let _permit = client.acquire_concurrency().await.unwrap();
434
435 let result =
437 tokio::time::timeout(Duration::from_millis(50), clone.acquire_concurrency()).await;
438 assert!(result.is_err(), "clone should block when permit is held");
439 }
440
441 #[tokio::test]
442 async fn test_concurrency_limits_parallel_tasks() {
443 let client = HttpClientBuilder::new("https://example.com")
444 .with_max_concurrent(2)
445 .build()
446 .unwrap();
447
448 let start = std::time::Instant::now();
449 let mut handles = Vec::new();
450 for _ in 0..4 {
451 let c = client.clone();
452 handles.push(tokio::spawn(async move {
453 let _permit = c.acquire_concurrency().await;
454 tokio::time::sleep(Duration::from_millis(50)).await;
455 }));
456 }
457 for h in handles {
458 h.await.unwrap();
459 }
460 assert!(
462 start.elapsed() >= Duration::from_millis(90),
463 "expected ~100ms, got {:?}",
464 start.elapsed()
465 );
466 }
467
468 #[tokio::test]
469 async fn test_builder_with_max_concurrent() {
470 let client = HttpClientBuilder::new("https://example.com")
471 .with_max_concurrent(5)
472 .build()
473 .unwrap();
474 let mut permits = Vec::new();
476 for _ in 0..5 {
477 permits.push(client.acquire_concurrency().await);
478 }
479 assert!(permits.iter().all(|p| p.is_some()));
480
481 let result =
483 tokio::time::timeout(Duration::from_millis(50), client.acquire_concurrency()).await;
484 assert!(result.is_err());
485 }
486
487 #[tokio::test]
490 async fn test_get_bytes_returns_body_verbatim() {
491 let mut server = mockito::Server::new_async().await;
492 let body: Vec<u8> = vec![0x50, 0x4B, 0x03, 0x04, 0x00, 0xFF, 0xFE, 0x42];
494 let mock = server
495 .mock("GET", "/v1/accounting/snapshot")
496 .match_query(mockito::Matcher::UrlEncoded("user".into(), "0xabc".into()))
497 .with_status(200)
498 .with_header("content-type", "application/zip")
499 .with_body(body.clone())
500 .create_async()
501 .await;
502
503 let client = HttpClientBuilder::new(server.url()).build().unwrap();
504 let out = client
505 .get_bytes(
506 "/v1/accounting/snapshot",
507 &[("user".to_string(), "0xabc".to_string())],
508 )
509 .await
510 .unwrap();
511 assert_eq!(out, body);
512 mock.assert_async().await;
513 }
514
515 #[tokio::test]
516 async fn test_get_bytes_maps_non_2xx_to_api_error() {
517 let mut server = mockito::Server::new_async().await;
518 let mock = server
519 .mock("GET", "/does-not-exist")
520 .with_status(404)
521 .with_header("content-type", "application/json")
522 .with_body(r#"{"error": "not found"}"#)
523 .create_async()
524 .await;
525
526 let client = HttpClientBuilder::new(server.url()).build().unwrap();
527 let err = client.get_bytes("/does-not-exist", &[]).await.unwrap_err();
528 match err {
529 ApiError::Api { status, message } => {
530 assert_eq!(status, 404);
531 assert_eq!(message, "not found");
532 }
533 other => panic!("expected ApiError::Api, got {other:?}"),
534 }
535 mock.assert_async().await;
536 }
537
538 #[tokio::test]
539 async fn test_get_bytes_no_query_params() {
540 let mut server = mockito::Server::new_async().await;
541 let mock = server
542 .mock("GET", "/raw")
543 .with_status(200)
544 .with_body(&b"hello"[..])
545 .create_async()
546 .await;
547
548 let client = HttpClientBuilder::new(server.url()).build().unwrap();
549 let out = client.get_bytes("/raw", &[]).await.unwrap();
550 assert_eq!(out, b"hello");
551 mock.assert_async().await;
552 }
553}