use std::sync::Arc;
use std::time::Duration;
use reqwest::StatusCode;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use url::Url;
use reqwest::header::RETRY_AFTER;
use crate::error::ApiError;
use crate::rate_limit::{RateLimiter, RetryConfig};
pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
response
.headers()
.get(RETRY_AFTER)?
.to_str()
.ok()
.map(String::from)
}
pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
pub const DEFAULT_POOL_SIZE: usize = 10;
#[derive(Debug, Clone)]
pub struct HttpClient {
pub client: reqwest::Client,
pub base_url: Url,
rate_limiter: Option<RateLimiter>,
retry_config: RetryConfig,
concurrency_limiter: Option<Arc<Semaphore>>,
}
impl HttpClient {
pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
if let Some(rl) = &self.rate_limiter {
rl.acquire(path, method).await;
}
}
pub async fn acquire_concurrency(&self) -> Option<OwnedSemaphorePermit> {
let sem = self.concurrency_limiter.as_ref()?;
Some(
sem.clone()
.acquire_owned()
.await
.expect("concurrency semaphore is never closed"),
)
}
pub fn should_retry(
&self,
status: StatusCode,
attempt: u32,
retry_after: Option<&str>,
) -> Option<Duration> {
if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
let ms = (delay * 1000.0) as u64;
Some(Duration::from_millis(
ms.min(self.retry_config.max_backoff_ms),
))
} else {
Some(self.retry_config.backoff(attempt))
}
} else {
None
}
}
}
pub struct HttpClientBuilder {
base_url: String,
timeout_ms: u64,
pool_size: usize,
rate_limiter: Option<RateLimiter>,
retry_config: RetryConfig,
max_concurrent: Option<usize>,
}
impl HttpClientBuilder {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
timeout_ms: DEFAULT_TIMEOUT_MS,
pool_size: DEFAULT_POOL_SIZE,
rate_limiter: None,
retry_config: RetryConfig::default(),
max_concurrent: None,
}
}
pub fn timeout_ms(mut self, timeout: u64) -> Self {
self.timeout_ms = timeout;
self
}
pub fn pool_size(mut self, size: usize) -> Self {
self.pool_size = size;
self
}
pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
self.rate_limiter = Some(limiter);
self
}
pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
self.retry_config = config;
self
}
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = Some(max);
self
}
pub fn build(self) -> Result<HttpClient, ApiError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(self.timeout_ms))
.connect_timeout(Duration::from_secs(10))
.redirect(reqwest::redirect::Policy::none())
.pool_max_idle_per_host(self.pool_size)
.build()?;
let base_url = Url::parse(&self.base_url)?;
Ok(HttpClient {
client,
base_url,
rate_limiter: self.rate_limiter,
retry_config: self.retry_config,
concurrency_limiter: self.max_concurrent.map(|n| Arc::new(Semaphore::new(n))),
})
}
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self {
base_url: String::new(),
timeout_ms: DEFAULT_TIMEOUT_MS,
pool_size: DEFAULT_POOL_SIZE,
rate_limiter: None,
retry_config: RetryConfig::default(),
max_concurrent: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_retry_429_under_max() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
assert!(client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
.is_some());
assert!(client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
.is_some());
}
#[test]
fn test_should_retry_429_at_max() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
assert!(client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
.is_none());
}
#[test]
fn test_should_retry_non_429_returns_none() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
for status in [
StatusCode::OK,
StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::BAD_REQUEST,
StatusCode::FORBIDDEN,
] {
assert!(
client.should_retry(status, 0, None).is_none(),
"expected None for {status}"
);
}
}
#[test]
fn test_should_retry_custom_config() {
let client = HttpClientBuilder::new("https://example.com")
.with_retry_config(RetryConfig {
max_retries: 1,
..RetryConfig::default()
})
.build()
.unwrap();
assert!(client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
.is_some());
assert!(client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
.is_none());
}
#[test]
fn test_should_retry_uses_retry_after_header() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
let d = client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
.unwrap();
assert_eq!(d, Duration::from_millis(2000));
}
#[test]
fn test_should_retry_retry_after_fractional_seconds() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
let d = client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
.unwrap();
assert_eq!(d, Duration::from_millis(500));
}
#[test]
fn test_should_retry_retry_after_clamped_to_max_backoff() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
let d = client
.should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
.unwrap();
assert_eq!(d, Duration::from_millis(10_000));
}
#[test]
fn test_should_retry_retry_after_invalid_falls_back() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
let d = client
.should_retry(
StatusCode::TOO_MANY_REQUESTS,
0,
Some("Wed, 21 Oct 2025 07:28:00 GMT"),
)
.unwrap();
let ms = d.as_millis() as u64;
assert!(
(375..=625).contains(&ms),
"expected fallback backoff in [375, 625], got {ms}"
);
}
#[tokio::test]
async fn test_builder_with_rate_limiter() {
let client = HttpClientBuilder::new("https://example.com")
.with_rate_limiter(RateLimiter::clob_default())
.build()
.unwrap();
let start = std::time::Instant::now();
client
.acquire_rate_limit("/order", Some(&reqwest::Method::POST))
.await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn test_builder_without_rate_limiter() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
let start = std::time::Instant::now();
client
.acquire_rate_limit("/order", Some(&reqwest::Method::POST))
.await;
assert!(start.elapsed() < Duration::from_millis(10));
}
#[tokio::test]
async fn test_acquire_concurrency_none_when_not_configured() {
let client = HttpClientBuilder::new("https://example.com")
.build()
.unwrap();
assert!(client.acquire_concurrency().await.is_none());
}
#[tokio::test]
async fn test_acquire_concurrency_returns_permit() {
let client = HttpClientBuilder::new("https://example.com")
.with_max_concurrent(2)
.build()
.unwrap();
let permit = client.acquire_concurrency().await;
assert!(permit.is_some());
}
#[tokio::test]
async fn test_concurrency_shared_across_clones() {
let client = HttpClientBuilder::new("https://example.com")
.with_max_concurrent(1)
.build()
.unwrap();
let clone = client.clone();
let _permit = client.acquire_concurrency().await.unwrap();
let result =
tokio::time::timeout(Duration::from_millis(50), clone.acquire_concurrency()).await;
assert!(result.is_err(), "clone should block when permit is held");
}
#[tokio::test]
async fn test_concurrency_limits_parallel_tasks() {
let client = HttpClientBuilder::new("https://example.com")
.with_max_concurrent(2)
.build()
.unwrap();
let start = std::time::Instant::now();
let mut handles = Vec::new();
for _ in 0..4 {
let c = client.clone();
handles.push(tokio::spawn(async move {
let _permit = c.acquire_concurrency().await;
tokio::time::sleep(Duration::from_millis(50)).await;
}));
}
for h in handles {
h.await.unwrap();
}
assert!(
start.elapsed() >= Duration::from_millis(90),
"expected ~100ms, got {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn test_builder_with_max_concurrent() {
let client = HttpClientBuilder::new("https://example.com")
.with_max_concurrent(5)
.build()
.unwrap();
let mut permits = Vec::new();
for _ in 0..5 {
permits.push(client.acquire_concurrency().await);
}
assert!(permits.iter().all(|p| p.is_some()));
let result =
tokio::time::timeout(Duration::from_millis(50), client.acquire_concurrency()).await;
assert!(result.is_err());
}
}