1use async_trait::async_trait;
2use once_cell::sync::Lazy;
3use reqwest::{Client, Method, RequestBuilder, Response, StatusCode};
4use serde::{Serialize, de::DeserializeOwned};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, error, info, warn};
10
11use crate::utils::rate_limiter::{app_non_trading_limiter, one_per_second_limiter};
12use crate::{config::Config, error::AppError, session::interface::IgSession};
13
14static API_SEMAPHORE: Lazy<Arc<Semaphore>> = Lazy::new(|| {
18 Arc::new(Semaphore::new(3)) });
20
21static RATE_LIMITED: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
23
24const DEFAULT_MAX_RETRIES: u32 = 10; const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000; const DEFAULT_MAX_BACKOFF_MS: u64 = 60000; const DEFAULT_BACKOFF_FACTOR: f64 = 2.0; #[async_trait]
32pub trait IgHttpClient: Send + Sync {
33 async fn request<T, R>(
35 &self,
36 method: Method,
37 path: &str,
38 session: &IgSession,
39 body: Option<&T>,
40 version: &str,
41 ) -> Result<R, AppError>
42 where
43 for<'de> R: DeserializeOwned + 'static,
44 T: Serialize + Send + Sync + 'static;
45
46 async fn request_no_auth<T, R>(
48 &self,
49 method: Method,
50 path: &str,
51 body: Option<&T>,
52 version: &str,
53 ) -> Result<R, AppError>
54 where
55 for<'de> R: DeserializeOwned + 'static,
56 T: Serialize + Send + Sync + 'static;
57}
58
59pub struct IgHttpClientImpl {
61 config: Arc<Config>,
62 client: Client,
63 max_retries: u32,
64 initial_backoff_ms: u64,
65 max_backoff_ms: u64,
66 backoff_factor: f64,
67}
68
69impl IgHttpClientImpl {
70 pub fn new(config: Arc<Config>) -> Self {
72 let client = Client::builder()
73 .user_agent("ig-client/0.1.0")
74 .timeout(Duration::from_secs(config.rest_api.timeout))
75 .build()
76 .expect("Failed to create HTTP client");
77
78 Self {
79 config,
80 client,
81 max_retries: DEFAULT_MAX_RETRIES,
82 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
83 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
84 backoff_factor: DEFAULT_BACKOFF_FACTOR,
85 }
86 }
87
88 pub fn with_retry_config(
90 mut self,
91 max_retries: u32,
92 initial_backoff_ms: u64,
93 max_backoff_ms: u64,
94 backoff_factor: f64,
95 ) -> Self {
96 self.max_retries = max_retries;
97 self.initial_backoff_ms = initial_backoff_ms;
98 self.max_backoff_ms = max_backoff_ms;
99 self.backoff_factor = backoff_factor;
100 self
101 }
102
103 fn calculate_backoff_duration(&self, retry_count: u32) -> Duration {
105 use rand::Rng;
106 let base_backoff_ms =
107 (self.initial_backoff_ms as f64 * self.backoff_factor.powi(retry_count as i32)) as u64;
108 let capped_backoff_ms = base_backoff_ms.min(self.max_backoff_ms);
109
110 let jitter_factor = rand::rng().random_range(0.8..1.2);
112 let jittered_backoff_ms = (capped_backoff_ms as f64 * jitter_factor) as u64;
113
114 Duration::from_millis(jittered_backoff_ms)
115 }
116
117 fn is_retryable_error(&self, error: &AppError) -> bool {
119 match error {
120 AppError::RateLimitExceeded => true,
121 AppError::Network(e) => {
122 e.is_timeout() || e.is_connect() || e.status().is_some_and(|s| s.is_server_error())
124 }
125 _ => false,
126 }
127 }
128
129 fn build_url(&self, path: &str) -> String {
131 format!(
132 "{}/{}",
133 self.config.rest_api.base_url.trim_end_matches('/'),
134 path.trim_start_matches('/')
135 )
136 }
137
138 fn add_common_headers(&self, builder: RequestBuilder, version: &str) -> RequestBuilder {
140 builder
141 .header("X-IG-API-KEY", &self.config.credentials.api_key)
142 .header("Content-Type", "application/json; charset=UTF-8")
143 .header("Accept", "application/json; charset=UTF-8")
144 .header("Version", version)
145 }
146
147 fn add_auth_headers(&self, builder: RequestBuilder, session: &IgSession) -> RequestBuilder {
149 builder
150 .header("CST", &session.cst)
151 .header("X-SECURITY-TOKEN", &session.token)
152 }
153
154 async fn process_response<R>(&self, response: Response) -> Result<R, AppError>
156 where
157 for<'de> R: DeserializeOwned + 'static,
158 {
159 let status = response.status();
160 let url = response.url().to_string();
161
162 if status == StatusCode::TOO_MANY_REQUESTS {
164 self.handle_rate_limit(&url, "TOO_MANY_REQUESTS status code")
165 .await;
166 return Err(AppError::RateLimitExceeded);
167 }
168
169 match status {
170 StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED => {
171 let body = response.text().await?;
172 match serde_json::from_str::<R>(&body) {
173 Ok(data) => Ok(data),
174 Err(e) => {
175 error!("Error deserializing response from {}: {}", url, e);
176 error!("Response body: {}", body);
177 Err(AppError::Json(e))
178 }
179 }
180 }
181 StatusCode::UNAUTHORIZED => {
182 error!("Unauthorized request to {}", url);
183 Err(AppError::Unauthorized)
184 }
185 StatusCode::NOT_FOUND => {
186 error!("Resource not found at {}", url);
187 Err(AppError::NotFound)
188 }
189 StatusCode::FORBIDDEN => {
190 let body = response.text().await?;
191 if body.contains("exceeded-api-key-allowance")
192 || body.contains("exceeded-account-allowance")
193 {
194 self.handle_rate_limit(
195 &url,
196 "FORBIDDEN with exceeded-api-key-allowance or exceeded-account-allowance",
197 )
198 .await;
199 Err(AppError::RateLimitExceeded)
200 } else {
201 error!("Forbidden access to {}: {}", url, body);
202 Err(AppError::Unauthorized)
203 }
204 }
205 _ => {
206 let body = response.text().await?;
207 error!(
208 "Unexpected status code {} for request to {}: {}",
209 status, url, body
210 );
211 Err(AppError::Unexpected(status))
212 }
213 }
214 }
215
216 async fn handle_rate_limit(&self, url: &str, reason: &str) {
218 RATE_LIMITED.store(true, Ordering::SeqCst);
220 error!("Rate limit exceeded for request to {} ({})", url, reason);
221
222 let non_trading_limiter = app_non_trading_limiter();
225 non_trading_limiter.notify_rate_limit_exceeded().await;
226
227 let rate_limited = RATE_LIMITED.clone();
230 tokio::spawn(async move {
231 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
232 rate_limited.store(false, Ordering::SeqCst);
233 info!("Rate limit flag reset after 60 second cooldown");
234 });
235 }
236}
237
238#[async_trait]
239impl IgHttpClient for IgHttpClientImpl {
240 async fn request<T, R>(
241 &self,
242 method: Method,
243 path: &str,
244 session: &IgSession,
245 body: Option<&T>,
246 version: &str,
247 ) -> Result<R, AppError>
248 where
249 for<'de> R: DeserializeOwned + 'static,
250 T: Serialize + Send + Sync + 'static,
251 {
252 let url = self.build_url(path);
253 let method_str = method.as_str().to_string(); debug!("Making {} request to {}", method_str, url);
255
256 let mut retry_count = 0;
257
258 loop {
260 if retry_count > 0 {
262 if retry_count > self.max_retries {
263 warn!(
264 "Max retries ({}) exceeded for {} request to {}",
265 self.max_retries, method_str, url
266 );
267 break; }
269
270 let backoff = self.calculate_backoff_duration(retry_count - 1);
272 info!(
273 "Retry attempt {} for {} request to {}. Waiting for {:?} before retrying",
274 retry_count, method_str, url, backoff
275 );
276 tokio::time::sleep(backoff).await;
277 }
278
279 if RATE_LIMITED.load(Ordering::SeqCst) {
281 warn!("System is currently rate limited. Adding extra delay before request.");
282 let rate_limit_delay = 2000 + (retry_count * 1000) as u64;
285 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
286 }
287
288 let permit = API_SEMAPHORE.acquire().await.unwrap();
291 debug!(
292 "Acquired API semaphore permit for {} request to {}",
293 method_str, url
294 );
295
296 match session.respect_rate_limit().await {
299 Ok(()) => {}
300 Err(e) => {
301 drop(permit);
302 if self.is_retryable_error(&e) {
303 retry_count += 1;
304 continue;
305 }
306 return Err(e);
307 }
308 }
309
310 one_per_second_limiter().wait().await;
313
314 let mut builder = self.client.request(method.clone(), &url);
315 builder = self.add_common_headers(builder, version);
316 builder = self.add_auth_headers(builder, session);
317
318 if let Some(data) = body {
319 builder = builder.json(data);
320 }
321
322 let response_result = builder.send().await;
324
325 let response = match response_result {
327 Ok(resp) => resp,
328 Err(e) => {
329 error!("Network error for {} request to {}: {}", method_str, url, e);
330 drop(permit);
332
333 let app_error = AppError::Network(e);
335 if self.is_retryable_error(&app_error) {
336 retry_count += 1;
337 continue;
338 }
339 return Err(app_error);
340 }
341 };
342
343 let result = self.process_response::<R>(response).await;
345
346 if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
348 RATE_LIMITED.store(false, Ordering::SeqCst);
349 info!("Rate limit flag reset after successful request to {}", url);
350 }
351
352 drop(permit);
355
356 match &result {
358 Err(e) if self.is_retryable_error(e) => {
359 retry_count += 1;
360 continue;
361 }
362 _ => return result,
363 }
364 }
365
366 info!(
368 "Making final attempt for {} request to {} after max retries",
369 method_str, url
370 );
371
372 let permit = API_SEMAPHORE.acquire().await.unwrap();
374
375 session.respect_rate_limit().await?;
377
378 one_per_second_limiter().wait().await;
380
381 let mut builder = self.client.request(method, &url);
382 builder = self.add_common_headers(builder, version);
383 builder = self.add_auth_headers(builder, session);
384
385 if let Some(data) = body {
386 builder = builder.json(data);
387 }
388
389 let response = builder.send().await?;
390 let result = self.process_response::<R>(response).await;
391
392 drop(permit);
393 result
394 }
395
396 async fn request_no_auth<T, R>(
397 &self,
398 method: Method,
399 path: &str,
400 body: Option<&T>,
401 version: &str,
402 ) -> Result<R, AppError>
403 where
404 for<'de> R: DeserializeOwned + 'static,
405 T: Serialize + Send + Sync + 'static,
406 {
407 let url = self.build_url(path);
408 let method_str = method.as_str().to_string(); info!("Making unauthenticated {} request to {}", method_str, url);
410
411 let mut retry_count = 0;
412
413 loop {
415 if retry_count > 0 {
417 if retry_count > self.max_retries {
418 warn!(
419 "Max retries ({}) exceeded for unauthenticated {} request to {}",
420 self.max_retries, method_str, url
421 );
422 break; }
424
425 let backoff = self.calculate_backoff_duration(retry_count - 1);
427 info!(
428 "Retry attempt {} for unauthenticated {} request to {}. Waiting for {:?} before retrying",
429 retry_count, method_str, url, backoff
430 );
431 tokio::time::sleep(backoff).await;
432 }
433
434 if RATE_LIMITED.load(Ordering::SeqCst) {
436 warn!(
437 "System is currently rate limited. Adding extra delay before unauthenticated request."
438 );
439 let rate_limit_delay = 1000 + (retry_count * 500) as u64;
442 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
443 }
444
445 let permit = API_SEMAPHORE.acquire().await.unwrap();
447 debug!(
448 "Acquired API semaphore permit for unauthenticated {} request to {}",
449 method_str, url
450 );
451
452 let limiter = app_non_trading_limiter();
455 limiter.wait().await;
456
457 one_per_second_limiter().wait().await;
460
461 let mut builder = self.client.request(method.clone(), &url);
462 builder = self.add_common_headers(builder, version);
463
464 if let Some(data) = body {
465 builder = builder.json(data);
466 }
467
468 let response_result = builder.send().await;
470
471 let response = match response_result {
473 Ok(resp) => resp,
474 Err(e) => {
475 error!(
476 "Network error for unauthenticated {} request to {}: {}",
477 method_str, url, e
478 );
479 drop(permit);
481
482 let app_error = AppError::Network(e);
484 if self.is_retryable_error(&app_error) {
485 retry_count += 1;
486 continue;
487 }
488 return Err(app_error);
489 }
490 };
491
492 let result = self.process_response::<R>(response).await;
494
495 if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
497 RATE_LIMITED.store(false, Ordering::SeqCst);
498 info!(
499 "Rate limit flag reset after successful unauthenticated request to {}",
500 url
501 );
502 }
503
504 drop(permit);
506
507 match &result {
509 Err(e) if self.is_retryable_error(e) => {
510 retry_count += 1;
511 continue;
512 }
513 _ => return result,
514 }
515 }
516
517 info!(
519 "Making final attempt for unauthenticated {} request to {} after max retries",
520 method_str, url
521 );
522
523 let permit = API_SEMAPHORE.acquire().await.unwrap();
525
526 let limiter = app_non_trading_limiter();
528 limiter.wait().await;
529
530 one_per_second_limiter().wait().await;
532
533 let mut builder = self.client.request(method, &url);
534 builder = self.add_common_headers(builder, version);
535
536 if let Some(data) = body {
537 builder = builder.json(data);
538 }
539
540 let response = builder.send().await?;
541 let result = self.process_response::<R>(response).await;
542
543 drop(permit);
544 result
545 }
546}