1use async_trait::async_trait;
2use once_cell::sync::Lazy;
3use reqwest::{Client, Method, RequestBuilder, Response, StatusCode};
4use serde::{Serialize, de::DeserializeOwned};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, error, info, warn};
10
11use crate::constants::USER_AGENT;
12use crate::utils::rate_limiter::app_non_trading_limiter;
13use crate::{config::Config, error::AppError, session::interface::IgSession};
14
15static API_SEMAPHORE: Lazy<Arc<Semaphore>> = Lazy::new(|| {
19 Arc::new(Semaphore::new(3)) });
21
22static RATE_LIMITED: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
24
25const DEFAULT_MAX_RETRIES: u32 = 10; const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000; const DEFAULT_MAX_BACKOFF_MS: u64 = 60000; const DEFAULT_BACKOFF_FACTOR: f64 = 2.0; #[async_trait]
33pub trait IgHttpClient: Send + Sync {
34 async fn request<T, R>(
36 &self,
37 method: Method,
38 path: &str,
39 session: &IgSession,
40 body: Option<&T>,
41 version: &str,
42 ) -> Result<R, AppError>
43 where
44 for<'de> R: DeserializeOwned + 'static,
45 T: Serialize + Send + Sync + 'static;
46
47 async fn request_no_auth<T, R>(
49 &self,
50 method: Method,
51 path: &str,
52 body: Option<&T>,
53 version: &str,
54 ) -> Result<R, AppError>
55 where
56 for<'de> R: DeserializeOwned + 'static,
57 T: Serialize + Send + Sync + 'static;
58}
59
60pub struct IgHttpClientImpl {
62 config: Arc<Config>,
63 client: Client,
64 max_retries: u32,
65 initial_backoff_ms: u64,
66 max_backoff_ms: u64,
67 backoff_factor: f64,
68}
69
70impl IgHttpClientImpl {
71 pub fn new(config: Arc<Config>) -> Self {
73 let client = Client::builder()
74 .user_agent(USER_AGENT)
75 .timeout(Duration::from_secs(config.rest_api.timeout))
76 .build()
77 .expect("Failed to create HTTP client");
78
79 Self {
80 config,
81 client,
82 max_retries: DEFAULT_MAX_RETRIES,
83 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
84 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
85 backoff_factor: DEFAULT_BACKOFF_FACTOR,
86 }
87 }
88
89 pub fn with_retry_config(
91 mut self,
92 max_retries: u32,
93 initial_backoff_ms: u64,
94 max_backoff_ms: u64,
95 backoff_factor: f64,
96 ) -> Self {
97 self.max_retries = max_retries;
98 self.initial_backoff_ms = initial_backoff_ms;
99 self.max_backoff_ms = max_backoff_ms;
100 self.backoff_factor = backoff_factor;
101 self
102 }
103
104 fn calculate_backoff_duration(&self, retry_count: u32) -> Duration {
106 use rand::Rng;
107 let base_backoff_ms =
108 (self.initial_backoff_ms as f64 * self.backoff_factor.powi(retry_count as i32)) as u64;
109 let capped_backoff_ms = base_backoff_ms.min(self.max_backoff_ms);
110
111 let jitter_factor = rand::rng().random_range(0.8..1.2);
113 let jittered_backoff_ms = (capped_backoff_ms as f64 * jitter_factor) as u64;
114
115 Duration::from_millis(jittered_backoff_ms)
116 }
117
118 fn is_retryable_error(&self, error: &AppError) -> bool {
120 match error {
121 AppError::RateLimitExceeded => true,
122 AppError::Network(e) => {
123 e.is_timeout() || e.is_connect() || e.status().is_some_and(|s| s.is_server_error())
125 }
126 _ => false,
127 }
128 }
129
130 fn build_url(&self, path: &str) -> String {
132 format!(
133 "{}/{}",
134 self.config.rest_api.base_url.trim_end_matches('/'),
135 path.trim_start_matches('/')
136 )
137 }
138
139 fn add_common_headers(&self, builder: RequestBuilder, version: &str) -> RequestBuilder {
141 builder
142 .header("X-IG-API-KEY", &self.config.credentials.api_key)
143 .header("Content-Type", "application/json; charset=UTF-8")
144 .header("Accept", "application/json; charset=UTF-8")
145 .header("Version", version)
146 }
147
148 fn add_auth_headers(&self, builder: RequestBuilder, session: &IgSession) -> RequestBuilder {
150 builder
151 .header("CST", &session.cst)
152 .header("X-SECURITY-TOKEN", &session.token)
153 }
154
155 async fn process_response<R>(&self, response: Response) -> Result<R, AppError>
157 where
158 for<'de> R: DeserializeOwned + 'static,
159 {
160 let status = response.status();
161 let url = response.url().to_string();
162
163 if status == StatusCode::TOO_MANY_REQUESTS {
165 self.handle_rate_limit(&url, "TOO_MANY_REQUESTS status code")
166 .await;
167 return Err(AppError::RateLimitExceeded);
168 }
169
170 match status {
171 StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED => {
172 let body = response.text().await?;
173 match serde_json::from_str::<R>(&body) {
174 Ok(data) => Ok(data),
175 Err(e) => {
176 error!("Error deserializing response from {}: {}", url, e);
177 error!("Response body: {}", body);
178 Err(AppError::Json(e))
179 }
180 }
181 }
182 StatusCode::UNAUTHORIZED => {
183 error!("Unauthorized request to {}", url);
184 Err(AppError::Unauthorized)
185 }
186 StatusCode::NOT_FOUND => {
187 error!("Resource not found at {}", url);
188 Err(AppError::NotFound)
189 }
190 StatusCode::FORBIDDEN => {
191 let body = response.text().await?;
192 if body.contains("exceeded-api-key-allowance")
193 || body.contains("exceeded-account-allowance")
194 {
195 self.handle_rate_limit(
196 &url,
197 "FORBIDDEN with exceeded-api-key-allowance or exceeded-account-allowance",
198 )
199 .await;
200 Err(AppError::RateLimitExceeded)
201 } else {
202 error!("Forbidden access to {}: {}", url, body);
203 Err(AppError::Unauthorized)
204 }
205 }
206 _ => {
207 let body = response.text().await?;
208 error!(
209 "Unexpected status code {} for request to {}: {}",
210 status, url, body
211 );
212 Err(AppError::Unexpected(status))
213 }
214 }
215 }
216
217 async fn handle_rate_limit(&self, url: &str, reason: &str) {
219 RATE_LIMITED.store(true, Ordering::SeqCst);
221 error!("Rate limit exceeded for request to {} ({})", url, reason);
222
223 let non_trading_limiter = app_non_trading_limiter();
226 non_trading_limiter.notify_rate_limit_exceeded().await;
227
228 let rate_limited = RATE_LIMITED.clone();
231 tokio::spawn(async move {
232 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
233 rate_limited.store(false, Ordering::SeqCst);
234 info!("Rate limit flag reset after 60 second cooldown");
235 });
236 }
237}
238
239#[async_trait]
240impl IgHttpClient for IgHttpClientImpl {
241 async fn request<T, R>(
242 &self,
243 method: Method,
244 path: &str,
245 session: &IgSession,
246 body: Option<&T>,
247 version: &str,
248 ) -> Result<R, AppError>
249 where
250 for<'de> R: DeserializeOwned + 'static,
251 T: Serialize + Send + Sync + 'static,
252 {
253 let url = self.build_url(path);
254 let method_str = method.as_str().to_string(); debug!("Making {} request to {}", method_str, url);
256
257 let mut retry_count = 0;
258
259 loop {
261 if retry_count > 0 {
263 if retry_count > self.max_retries {
264 warn!(
265 "Max retries ({}) exceeded for {} request to {}",
266 self.max_retries, method_str, url
267 );
268 break; }
270
271 let backoff = self.calculate_backoff_duration(retry_count - 1);
273 debug!(
274 "Retry attempt {} for {} request to {}. Waiting for {:?} before retrying",
275 retry_count, method_str, url, backoff
276 );
277 tokio::time::sleep(backoff).await;
278 }
279
280 if RATE_LIMITED.load(Ordering::SeqCst) {
282 warn!("System is currently rate limited. Adding extra delay before request.");
283 let rate_limit_delay = 2000 + (retry_count * 1000) as u64;
286 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
287 }
288
289 let permit = API_SEMAPHORE.acquire().await.unwrap();
292 debug!(
293 "Acquired API semaphore permit for {} request to {}",
294 method_str, url
295 );
296
297 match session.respect_rate_limit().await {
300 Ok(()) => {}
301 Err(e) => {
302 drop(permit);
303 if self.is_retryable_error(&e) {
304 retry_count += 1;
305 continue;
306 }
307 return Err(e);
308 }
309 }
310
311 let mut builder = self.client.request(method.clone(), &url);
312 builder = self.add_common_headers(builder, version);
313 builder = self.add_auth_headers(builder, session);
314
315 if let Some(data) = body {
316 builder = builder.json(data);
317 }
318
319 let response_result = builder.send().await;
321
322 let response = match response_result {
324 Ok(resp) => resp,
325 Err(e) => {
326 error!("Network error for {} request to {}: {}", method_str, url, e);
327 drop(permit);
329
330 let app_error = AppError::Network(e);
332 if self.is_retryable_error(&app_error) {
333 retry_count += 1;
334 continue;
335 }
336 return Err(app_error);
337 }
338 };
339
340 let result = self.process_response::<R>(response).await;
342
343 if result.is_ok() {
345 session.refresh_token_timer();
347
348 if RATE_LIMITED.load(Ordering::SeqCst) {
349 RATE_LIMITED.store(false, Ordering::SeqCst);
350 info!("Rate limit flag reset after successful request to {}", url);
351 }
352 }
353
354 drop(permit);
357
358 match &result {
360 Err(e) if self.is_retryable_error(e) => {
361 retry_count += 1;
362 continue;
363 }
364 _ => return result,
365 }
366 }
367
368 info!(
370 "Making final attempt for {} request to {} after max retries",
371 method_str, url
372 );
373
374 let permit = API_SEMAPHORE.acquire().await.unwrap();
376
377 session.respect_rate_limit().await?;
379
380 let mut builder = self.client.request(method, &url);
381 builder = self.add_common_headers(builder, version);
382 builder = self.add_auth_headers(builder, session);
383
384 if let Some(data) = body {
385 builder = builder.json(data);
386 }
387
388 let response = builder.send().await?;
389 let result = self.process_response::<R>(response).await;
390
391 if result.is_ok() {
393 session.refresh_token_timer();
394 }
395
396 drop(permit);
397 result
398 }
399
400 async fn request_no_auth<T, R>(
401 &self,
402 method: Method,
403 path: &str,
404 body: Option<&T>,
405 version: &str,
406 ) -> Result<R, AppError>
407 where
408 for<'de> R: DeserializeOwned + 'static,
409 T: Serialize + Send + Sync + 'static,
410 {
411 let url = self.build_url(path);
412 let method_str = method.as_str().to_string(); info!("Making unauthenticated {} request to {}", method_str, url);
414
415 let mut retry_count = 0;
416
417 loop {
419 if retry_count > 0 {
421 if retry_count > self.max_retries {
422 warn!(
423 "Max retries ({}) exceeded for unauthenticated {} request to {}",
424 self.max_retries, method_str, url
425 );
426 break; }
428
429 let backoff = self.calculate_backoff_duration(retry_count - 1);
431 debug!(
432 "Retry attempt {} for unauthenticated {} request to {}. Waiting for {:?} before retrying",
433 retry_count, method_str, url, backoff
434 );
435 tokio::time::sleep(backoff).await;
436 }
437
438 if RATE_LIMITED.load(Ordering::SeqCst) {
440 warn!(
441 "System is currently rate limited. Adding extra delay before unauthenticated request."
442 );
443 let rate_limit_delay = 1000 + (retry_count * 500) as u64;
446 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
447 }
448
449 let permit = API_SEMAPHORE.acquire().await.unwrap();
451 debug!(
452 "Acquired API semaphore permit for unauthenticated {} request to {}",
453 method_str, url
454 );
455
456 let limiter = app_non_trading_limiter();
459 limiter.wait().await;
460
461 let mut builder = self.client.request(method.clone(), &url);
462 builder = self.add_common_headers(builder, version);
463
464 if let Some(data) = body {
465 builder = builder.json(data);
466 }
467
468 let response_result = builder.send().await;
470
471 let response = match response_result {
473 Ok(resp) => resp,
474 Err(e) => {
475 error!(
476 "Network error for unauthenticated {} request to {}: {}",
477 method_str, url, e
478 );
479 drop(permit);
481
482 let app_error = AppError::Network(e);
484 if self.is_retryable_error(&app_error) {
485 retry_count += 1;
486 continue;
487 }
488 return Err(app_error);
489 }
490 };
491
492 let result = self.process_response::<R>(response).await;
494
495 if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
497 RATE_LIMITED.store(false, Ordering::SeqCst);
498 info!(
499 "Rate limit flag reset after successful unauthenticated request to {}",
500 url
501 );
502 }
503
504 drop(permit);
506
507 match &result {
509 Err(e) if self.is_retryable_error(e) => {
510 retry_count += 1;
511 continue;
512 }
513 _ => return result,
514 }
515 }
516
517 info!(
519 "Making final attempt for unauthenticated {} request to {} after max retries",
520 method_str, url
521 );
522
523 let permit = API_SEMAPHORE.acquire().await.unwrap();
525
526 let limiter = app_non_trading_limiter();
528 limiter.wait().await;
529
530 let mut builder = self.client.request(method, &url);
531 builder = self.add_common_headers(builder, version);
532
533 if let Some(data) = body {
534 builder = builder.json(data);
535 }
536
537 let response = builder.send().await?;
538 let result = self.process_response::<R>(response).await;
539
540 drop(permit);
541 result
542 }
543}