1use std::time::Duration;
2
3use reqwest::StatusCode;
4use url::Url;
5
6use reqwest::header::RETRY_AFTER;
7
8use crate::error::ApiError;
9use crate::rate_limit::{RateLimiter, RetryConfig};
10
11pub fn retry_after_header(response: &reqwest::Response) -> Option<String> {
13 response
14 .headers()
15 .get(RETRY_AFTER)?
16 .to_str()
17 .ok()
18 .map(String::from)
19}
20
21pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
23pub const DEFAULT_POOL_SIZE: usize = 10;
25
26#[derive(Debug, Clone)]
31pub struct HttpClient {
32 pub client: reqwest::Client,
34 pub base_url: Url,
36 rate_limiter: Option<RateLimiter>,
37 retry_config: RetryConfig,
38}
39
40impl HttpClient {
41 pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
43 if let Some(rl) = &self.rate_limiter {
44 rl.acquire(path, method).await;
45 }
46 }
47
48 pub fn should_retry(
53 &self,
54 status: StatusCode,
55 attempt: u32,
56 retry_after: Option<&str>,
57 ) -> Option<Duration> {
58 if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
59 if let Some(delay) = retry_after.and_then(|v| v.parse::<f64>().ok()) {
60 let ms = (delay * 1000.0) as u64;
61 Some(Duration::from_millis(
62 ms.min(self.retry_config.max_backoff_ms),
63 ))
64 } else {
65 Some(self.retry_config.backoff(attempt))
66 }
67 } else {
68 None
69 }
70 }
71}
72
73pub struct HttpClientBuilder {
90 base_url: String,
91 timeout_ms: u64,
92 pool_size: usize,
93 rate_limiter: Option<RateLimiter>,
94 retry_config: RetryConfig,
95}
96
97impl HttpClientBuilder {
98 pub fn new(base_url: impl Into<String>) -> Self {
100 Self {
101 base_url: base_url.into(),
102 timeout_ms: DEFAULT_TIMEOUT_MS,
103 pool_size: DEFAULT_POOL_SIZE,
104 rate_limiter: None,
105 retry_config: RetryConfig::default(),
106 }
107 }
108
109 pub fn timeout_ms(mut self, timeout: u64) -> Self {
113 self.timeout_ms = timeout;
114 self
115 }
116
117 pub fn pool_size(mut self, size: usize) -> Self {
121 self.pool_size = size;
122 self
123 }
124
125 pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
127 self.rate_limiter = Some(limiter);
128 self
129 }
130
131 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
133 self.retry_config = config;
134 self
135 }
136
137 pub fn build(self) -> Result<HttpClient, ApiError> {
139 let client = reqwest::Client::builder()
140 .timeout(Duration::from_millis(self.timeout_ms))
141 .connect_timeout(Duration::from_secs(10))
142 .redirect(reqwest::redirect::Policy::none())
143 .pool_max_idle_per_host(self.pool_size)
144 .build()?;
145
146 let base_url = Url::parse(&self.base_url)?;
147
148 Ok(HttpClient {
149 client,
150 base_url,
151 rate_limiter: self.rate_limiter,
152 retry_config: self.retry_config,
153 })
154 }
155}
156
157impl Default for HttpClientBuilder {
158 fn default() -> Self {
159 Self {
160 base_url: String::new(),
161 timeout_ms: DEFAULT_TIMEOUT_MS,
162 pool_size: DEFAULT_POOL_SIZE,
163 rate_limiter: None,
164 retry_config: RetryConfig::default(),
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
176 fn test_should_retry_429_under_max() {
177 let client = HttpClientBuilder::new("https://example.com")
178 .build()
179 .unwrap();
180 assert!(client
182 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
183 .is_some());
184 assert!(client
185 .should_retry(StatusCode::TOO_MANY_REQUESTS, 2, None)
186 .is_some());
187 }
188
189 #[test]
190 fn test_should_retry_429_at_max() {
191 let client = HttpClientBuilder::new("https://example.com")
192 .build()
193 .unwrap();
194 assert!(client
196 .should_retry(StatusCode::TOO_MANY_REQUESTS, 3, None)
197 .is_none());
198 }
199
200 #[test]
201 fn test_should_retry_non_429_returns_none() {
202 let client = HttpClientBuilder::new("https://example.com")
203 .build()
204 .unwrap();
205 for status in [
206 StatusCode::OK,
207 StatusCode::INTERNAL_SERVER_ERROR,
208 StatusCode::BAD_REQUEST,
209 StatusCode::FORBIDDEN,
210 ] {
211 assert!(
212 client.should_retry(status, 0, None).is_none(),
213 "expected None for {status}"
214 );
215 }
216 }
217
218 #[test]
219 fn test_should_retry_custom_config() {
220 let client = HttpClientBuilder::new("https://example.com")
221 .with_retry_config(RetryConfig {
222 max_retries: 1,
223 ..RetryConfig::default()
224 })
225 .build()
226 .unwrap();
227 assert!(client
228 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, None)
229 .is_some());
230 assert!(client
231 .should_retry(StatusCode::TOO_MANY_REQUESTS, 1, None)
232 .is_none());
233 }
234
235 #[test]
236 fn test_should_retry_uses_retry_after_header() {
237 let client = HttpClientBuilder::new("https://example.com")
238 .build()
239 .unwrap();
240 let d = client
241 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("2"))
242 .unwrap();
243 assert_eq!(d, Duration::from_millis(2000));
244 }
245
246 #[test]
247 fn test_should_retry_retry_after_fractional_seconds() {
248 let client = HttpClientBuilder::new("https://example.com")
249 .build()
250 .unwrap();
251 let d = client
252 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("0.5"))
253 .unwrap();
254 assert_eq!(d, Duration::from_millis(500));
255 }
256
257 #[test]
258 fn test_should_retry_retry_after_clamped_to_max_backoff() {
259 let client = HttpClientBuilder::new("https://example.com")
260 .build()
261 .unwrap();
262 let d = client
264 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0, Some("60"))
265 .unwrap();
266 assert_eq!(d, Duration::from_millis(10_000));
267 }
268
269 #[test]
270 fn test_should_retry_retry_after_invalid_falls_back() {
271 let client = HttpClientBuilder::new("https://example.com")
272 .build()
273 .unwrap();
274 let d = client
276 .should_retry(
277 StatusCode::TOO_MANY_REQUESTS,
278 0,
279 Some("Wed, 21 Oct 2025 07:28:00 GMT"),
280 )
281 .unwrap();
282 let ms = d.as_millis() as u64;
284 assert!(
285 (375..=625).contains(&ms),
286 "expected fallback backoff in [375, 625], got {ms}"
287 );
288 }
289
290 #[tokio::test]
293 async fn test_builder_with_rate_limiter() {
294 let client = HttpClientBuilder::new("https://example.com")
295 .with_rate_limiter(RateLimiter::clob_default())
296 .build()
297 .unwrap();
298 let start = std::time::Instant::now();
299 client
300 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
301 .await;
302 assert!(start.elapsed() < Duration::from_millis(50));
303 }
304
305 #[tokio::test]
306 async fn test_builder_without_rate_limiter() {
307 let client = HttpClientBuilder::new("https://example.com")
308 .build()
309 .unwrap();
310 let start = std::time::Instant::now();
311 client
312 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
313 .await;
314 assert!(start.elapsed() < Duration::from_millis(10));
315 }
316}