1use std::time::Duration;
2
3use reqwest::StatusCode;
4use url::Url;
5
6use crate::error::ApiError;
7use crate::rate_limit::{RateLimiter, RetryConfig};
8
9pub const DEFAULT_TIMEOUT_MS: u64 = 30_000;
11pub const DEFAULT_POOL_SIZE: usize = 10;
13
14#[derive(Debug, Clone)]
19pub struct HttpClient {
20 pub client: reqwest::Client,
22 pub base_url: Url,
24 rate_limiter: Option<RateLimiter>,
25 retry_config: RetryConfig,
26}
27
28impl HttpClient {
29 pub async fn acquire_rate_limit(&self, path: &str, method: Option<&reqwest::Method>) {
31 if let Some(rl) = &self.rate_limiter {
32 rl.acquire(path, method).await;
33 }
34 }
35
36 pub fn should_retry(&self, status: StatusCode, attempt: u32) -> Option<Duration> {
38 if status == StatusCode::TOO_MANY_REQUESTS && attempt < self.retry_config.max_retries {
39 Some(self.retry_config.backoff(attempt))
40 } else {
41 None
42 }
43 }
44}
45
46pub struct HttpClientBuilder {
63 base_url: String,
64 timeout_ms: u64,
65 pool_size: usize,
66 rate_limiter: Option<RateLimiter>,
67 retry_config: RetryConfig,
68}
69
70impl HttpClientBuilder {
71 pub fn new(base_url: impl Into<String>) -> Self {
73 Self {
74 base_url: base_url.into(),
75 timeout_ms: DEFAULT_TIMEOUT_MS,
76 pool_size: DEFAULT_POOL_SIZE,
77 rate_limiter: None,
78 retry_config: RetryConfig::default(),
79 }
80 }
81
82 pub fn timeout_ms(mut self, timeout: u64) -> Self {
86 self.timeout_ms = timeout;
87 self
88 }
89
90 pub fn pool_size(mut self, size: usize) -> Self {
94 self.pool_size = size;
95 self
96 }
97
98 pub fn with_rate_limiter(mut self, limiter: RateLimiter) -> Self {
100 self.rate_limiter = Some(limiter);
101 self
102 }
103
104 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
106 self.retry_config = config;
107 self
108 }
109
110 pub fn build(self) -> Result<HttpClient, ApiError> {
112 let client = reqwest::Client::builder()
113 .timeout(Duration::from_millis(self.timeout_ms))
114 .pool_max_idle_per_host(self.pool_size)
115 .build()?;
116
117 let base_url = Url::parse(&self.base_url)?;
118
119 Ok(HttpClient {
120 client,
121 base_url,
122 rate_limiter: self.rate_limiter,
123 retry_config: self.retry_config,
124 })
125 }
126}
127
128impl Default for HttpClientBuilder {
129 fn default() -> Self {
130 Self {
131 base_url: String::new(),
132 timeout_ms: DEFAULT_TIMEOUT_MS,
133 pool_size: DEFAULT_POOL_SIZE,
134 rate_limiter: None,
135 retry_config: RetryConfig::default(),
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
147 fn test_should_retry_429_under_max() {
148 let client = HttpClientBuilder::new("https://example.com")
149 .build()
150 .unwrap();
151 assert!(client
153 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0)
154 .is_some());
155 assert!(client
156 .should_retry(StatusCode::TOO_MANY_REQUESTS, 2)
157 .is_some());
158 }
159
160 #[test]
161 fn test_should_retry_429_at_max() {
162 let client = HttpClientBuilder::new("https://example.com")
163 .build()
164 .unwrap();
165 assert!(client
167 .should_retry(StatusCode::TOO_MANY_REQUESTS, 3)
168 .is_none());
169 }
170
171 #[test]
172 fn test_should_retry_non_429_returns_none() {
173 let client = HttpClientBuilder::new("https://example.com")
174 .build()
175 .unwrap();
176 for status in [
177 StatusCode::OK,
178 StatusCode::INTERNAL_SERVER_ERROR,
179 StatusCode::BAD_REQUEST,
180 StatusCode::FORBIDDEN,
181 ] {
182 assert!(
183 client.should_retry(status, 0).is_none(),
184 "expected None for {status}"
185 );
186 }
187 }
188
189 #[test]
190 fn test_should_retry_custom_config() {
191 let client = HttpClientBuilder::new("https://example.com")
192 .with_retry_config(RetryConfig {
193 max_retries: 1,
194 ..RetryConfig::default()
195 })
196 .build()
197 .unwrap();
198 assert!(client
199 .should_retry(StatusCode::TOO_MANY_REQUESTS, 0)
200 .is_some());
201 assert!(client
202 .should_retry(StatusCode::TOO_MANY_REQUESTS, 1)
203 .is_none());
204 }
205
206 #[tokio::test]
209 async fn test_builder_with_rate_limiter() {
210 let client = HttpClientBuilder::new("https://example.com")
211 .with_rate_limiter(RateLimiter::clob_default())
212 .build()
213 .unwrap();
214 let start = std::time::Instant::now();
215 client
216 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
217 .await;
218 assert!(start.elapsed() < Duration::from_millis(50));
219 }
220
221 #[tokio::test]
222 async fn test_builder_without_rate_limiter() {
223 let client = HttpClientBuilder::new("https://example.com")
224 .build()
225 .unwrap();
226 let start = std::time::Instant::now();
227 client
228 .acquire_rate_limit("/order", Some(&reqwest::Method::POST))
229 .await;
230 assert!(start.elapsed() < Duration::from_millis(10));
231 }
232}