1use async_trait::async_trait;
2use once_cell::sync::Lazy;
3use reqwest::{Client, Method, RequestBuilder, Response, StatusCode};
4use serde::{Serialize, de::DeserializeOwned};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, error, info, warn};
10
11use crate::constants::USER_AGENT;
12use crate::utils::rate_limiter::app_non_trading_limiter;
13use crate::{config::Config, error::AppError, session::interface::IgSession};
14
15static API_SEMAPHORE: Lazy<Arc<Semaphore>> = Lazy::new(|| {
19 Arc::new(Semaphore::new(3)) });
21
22static RATE_LIMITED: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
24
25const DEFAULT_MAX_RETRIES: u32 = 10; const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000; const DEFAULT_MAX_BACKOFF_MS: u64 = 60000; const DEFAULT_BACKOFF_FACTOR: f64 = 2.0; #[async_trait]
33pub trait IgHttpClient: Send + Sync {
34 async fn request<T, R>(
36 &self,
37 method: Method,
38 path: &str,
39 session: &IgSession,
40 body: Option<&T>,
41 version: &str,
42 ) -> Result<R, AppError>
43 where
44 for<'de> R: DeserializeOwned + 'static,
45 T: Serialize + Send + Sync + 'static;
46
47 async fn request_no_auth<T, R>(
49 &self,
50 method: Method,
51 path: &str,
52 body: Option<&T>,
53 version: &str,
54 ) -> Result<R, AppError>
55 where
56 for<'de> R: DeserializeOwned + 'static,
57 T: Serialize + Send + Sync + 'static;
58}
59
60pub struct IgHttpClientImpl {
62 config: Arc<Config>,
63 client: Client,
64 max_retries: u32,
65 initial_backoff_ms: u64,
66 max_backoff_ms: u64,
67 backoff_factor: f64,
68}
69
70impl IgHttpClientImpl {
71 pub fn new(config: Arc<Config>) -> Self {
73 let client = Client::builder()
74 .user_agent(USER_AGENT)
75 .timeout(Duration::from_secs(config.rest_api.timeout))
76 .build()
77 .expect("Failed to create HTTP client");
78
79 Self {
80 config,
81 client,
82 max_retries: DEFAULT_MAX_RETRIES,
83 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
84 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
85 backoff_factor: DEFAULT_BACKOFF_FACTOR,
86 }
87 }
88
89 pub fn with_retry_config(
91 mut self,
92 max_retries: u32,
93 initial_backoff_ms: u64,
94 max_backoff_ms: u64,
95 backoff_factor: f64,
96 ) -> Self {
97 self.max_retries = max_retries;
98 self.initial_backoff_ms = initial_backoff_ms;
99 self.max_backoff_ms = max_backoff_ms;
100 self.backoff_factor = backoff_factor;
101 self
102 }
103
104 fn calculate_backoff_duration(&self, retry_count: u32) -> Duration {
106 use rand::Rng;
107 let base_backoff_ms =
108 (self.initial_backoff_ms as f64 * self.backoff_factor.powi(retry_count as i32)) as u64;
109 let capped_backoff_ms = base_backoff_ms.min(self.max_backoff_ms);
110
111 let jitter_factor = rand::rng().random_range(0.8..1.2);
113 let jittered_backoff_ms = (capped_backoff_ms as f64 * jitter_factor) as u64;
114
115 Duration::from_millis(jittered_backoff_ms)
116 }
117
118 fn is_retryable_error(&self, error: &AppError) -> bool {
120 match error {
121 AppError::RateLimitExceeded => true,
122 AppError::Network(e) => {
123 e.is_timeout() || e.is_connect() || e.status().is_some_and(|s| s.is_server_error())
125 }
126 _ => false,
127 }
128 }
129
130 fn build_url(&self, path: &str) -> String {
132 format!(
133 "{}/{}",
134 self.config.rest_api.base_url.trim_end_matches('/'),
135 path.trim_start_matches('/')
136 )
137 }
138
139 fn add_common_headers(&self, builder: RequestBuilder, version: &str) -> RequestBuilder {
141 let api_key = self.config.credentials.api_key.trim();
142 debug!("Adding X-IG-API-KEY header (length: {})", api_key.len());
143 if api_key.is_empty() {
144 error!("API key is empty!");
145 }
146 builder
147 .header("X-IG-API-KEY", api_key)
148 .header("Content-Type", "application/json; charset=UTF-8")
149 .header("Accept", "application/json; charset=UTF-8")
150 .header("Version", version)
151 }
152
153 fn add_auth_headers(&self, builder: RequestBuilder, session: &IgSession) -> RequestBuilder {
155 if let Some(oauth_token) = &session.oauth_token {
157 debug!("Using OAuth authentication (Bearer token)");
160 debug!(
161 " Access token: {}...",
162 &oauth_token.access_token[..10.min(oauth_token.access_token.len())]
163 );
164 debug!(" Account ID: {}", session.account_id);
165 builder
166 .header(
167 "Authorization",
168 format!("Bearer {}", oauth_token.access_token),
169 )
170 .header("IG-ACCOUNT-ID", &session.account_id)
171 } else {
172 debug!("Using CST authentication");
174 debug!(
175 " CST length: {}, Token length: {}",
176 session.cst.len(),
177 session.token.len()
178 );
179 builder
180 .header("CST", &session.cst)
181 .header("X-SECURITY-TOKEN", &session.token)
182 }
183 }
184
185 async fn process_response<R>(&self, response: Response) -> Result<R, AppError>
187 where
188 for<'de> R: DeserializeOwned + 'static,
189 {
190 let status = response.status();
191 let url = response.url().to_string();
192
193 if status == StatusCode::TOO_MANY_REQUESTS {
195 self.handle_rate_limit(&url, "TOO_MANY_REQUESTS status code")
196 .await;
197 return Err(AppError::RateLimitExceeded);
198 }
199
200 match status {
201 StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED => {
202 let body = response.text().await?;
203 match serde_json::from_str::<R>(&body) {
204 Ok(data) => Ok(data),
205 Err(e) => {
206 error!("Error deserializing response from {}: {}", url, e);
207 error!("Response body: {}", body);
208 Err(AppError::Json(e))
209 }
210 }
211 }
212 StatusCode::UNAUTHORIZED => {
213 let body = response
214 .text()
215 .await
216 .unwrap_or_else(|_| "Unable to read response body".to_string());
217 error!("Unauthorized request to {}", url);
218 error!("Response body: {}", body);
219 Err(AppError::Unauthorized)
220 }
221 StatusCode::NOT_FOUND => {
222 error!("Resource not found at {}", url);
223 Err(AppError::NotFound)
224 }
225 StatusCode::FORBIDDEN => {
226 let body = response.text().await?;
227 if body.contains("exceeded-api-key-allowance")
228 || body.contains("exceeded-account-allowance")
229 {
230 self.handle_rate_limit(
231 &url,
232 "FORBIDDEN with exceeded-api-key-allowance or exceeded-account-allowance",
233 )
234 .await;
235 Err(AppError::RateLimitExceeded)
236 } else {
237 error!("Forbidden access to {}: {}", url, body);
238 Err(AppError::Unauthorized)
239 }
240 }
241 _ => {
242 let body = response.text().await?;
243 error!(
244 "Unexpected status code {} for request to {}: {}",
245 status, url, body
246 );
247 Err(AppError::Unexpected(status))
248 }
249 }
250 }
251
252 async fn handle_rate_limit(&self, url: &str, reason: &str) {
254 RATE_LIMITED.store(true, Ordering::SeqCst);
256 error!("Rate limit exceeded for request to {} ({})", url, reason);
257
258 let non_trading_limiter = app_non_trading_limiter();
261 non_trading_limiter.notify_rate_limit_exceeded().await;
262
263 let rate_limited = RATE_LIMITED.clone();
266 tokio::spawn(async move {
267 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
268 rate_limited.store(false, Ordering::SeqCst);
269 debug!("Rate limit flag reset after 60 second cooldown");
270 });
271 }
272}
273
274#[async_trait]
275impl IgHttpClient for IgHttpClientImpl {
276 async fn request<T, R>(
277 &self,
278 method: Method,
279 path: &str,
280 session: &IgSession,
281 body: Option<&T>,
282 version: &str,
283 ) -> Result<R, AppError>
284 where
285 for<'de> R: DeserializeOwned + 'static,
286 T: Serialize + Send + Sync + 'static,
287 {
288 let url = self.build_url(path);
289 let method_str = method.as_str().to_string(); debug!("Making {} request to {}", method_str, url);
291
292 let mut retry_count = 0;
293
294 loop {
296 if retry_count > 0 {
298 if retry_count > self.max_retries {
299 warn!(
300 "Max retries ({}) exceeded for {} request to {}",
301 self.max_retries, method_str, url
302 );
303 break; }
305
306 let backoff = self.calculate_backoff_duration(retry_count - 1);
308 debug!(
309 "Retry attempt {} for {} request to {}. Waiting for {:?} before retrying",
310 retry_count, method_str, url, backoff
311 );
312 tokio::time::sleep(backoff).await;
313 }
314
315 if RATE_LIMITED.load(Ordering::SeqCst) {
317 warn!("System is currently rate limited. Adding extra delay before request.");
318 let rate_limit_delay = 2000 + (retry_count * 1000) as u64;
321 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
322 }
323
324 let permit = API_SEMAPHORE.acquire().await.unwrap();
327 debug!(
328 "Acquired API semaphore permit for {} request to {}",
329 method_str, url
330 );
331
332 match session.respect_rate_limit().await {
335 Ok(()) => {}
336 Err(e) => {
337 drop(permit);
338 if self.is_retryable_error(&e) {
339 retry_count += 1;
340 continue;
341 }
342 return Err(e);
343 }
344 }
345
346 let mut builder = self.client.request(method.clone(), &url);
347 builder = self.add_common_headers(builder, version);
348 builder = self.add_auth_headers(builder, session);
349
350 if let Some(data) = body {
351 builder = builder.json(data);
352 }
353
354 let response_result = builder.send().await;
356
357 let response = match response_result {
359 Ok(resp) => resp,
360 Err(e) => {
361 error!("Network error for {} request to {}: {}", method_str, url, e);
362 drop(permit);
364
365 let app_error = AppError::Network(e);
367 if self.is_retryable_error(&app_error) {
368 retry_count += 1;
369 continue;
370 }
371 return Err(app_error);
372 }
373 };
374
375 let result = self.process_response::<R>(response).await;
377
378 if result.is_ok() {
380 session.refresh_token_timer();
382
383 if RATE_LIMITED.load(Ordering::SeqCst) {
384 RATE_LIMITED.store(false, Ordering::SeqCst);
385 debug!("Rate limit flag reset after successful request to {}", url);
386 }
387 }
388
389 drop(permit);
392
393 match &result {
395 Err(e) if self.is_retryable_error(e) => {
396 retry_count += 1;
397 continue;
398 }
399 _ => return result,
400 }
401 }
402
403 info!(
405 "Making final attempt for {} request to {} after max retries",
406 method_str, url
407 );
408
409 let permit = API_SEMAPHORE.acquire().await.unwrap();
411
412 session.respect_rate_limit().await?;
414
415 let mut builder = self.client.request(method, &url);
416 builder = self.add_common_headers(builder, version);
417 builder = self.add_auth_headers(builder, session);
418
419 if let Some(data) = body {
420 builder = builder.json(data);
421 }
422
423 let response = builder.send().await?;
424 let result = self.process_response::<R>(response).await;
425
426 if result.is_ok() {
428 session.refresh_token_timer();
429 }
430
431 drop(permit);
432 result
433 }
434
435 async fn request_no_auth<T, R>(
436 &self,
437 method: Method,
438 path: &str,
439 body: Option<&T>,
440 version: &str,
441 ) -> Result<R, AppError>
442 where
443 for<'de> R: DeserializeOwned + 'static,
444 T: Serialize + Send + Sync + 'static,
445 {
446 let url = self.build_url(path);
447 let method_str = method.as_str().to_string(); info!("Making unauthenticated {} request to {}", method_str, url);
449
450 let mut retry_count = 0;
451
452 loop {
454 if retry_count > 0 {
456 if retry_count > self.max_retries {
457 warn!(
458 "Max retries ({}) exceeded for unauthenticated {} request to {}",
459 self.max_retries, method_str, url
460 );
461 break; }
463
464 let backoff = self.calculate_backoff_duration(retry_count - 1);
466 debug!(
467 "Retry attempt {} for unauthenticated {} request to {}. Waiting for {:?} before retrying",
468 retry_count, method_str, url, backoff
469 );
470 tokio::time::sleep(backoff).await;
471 }
472
473 if RATE_LIMITED.load(Ordering::SeqCst) {
475 warn!(
476 "System is currently rate limited. Adding extra delay before unauthenticated request."
477 );
478 let rate_limit_delay = 1000 + (retry_count * 500) as u64;
481 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
482 }
483
484 let permit = API_SEMAPHORE.acquire().await.unwrap();
486 debug!(
487 "Acquired API semaphore permit for unauthenticated {} request to {}",
488 method_str, url
489 );
490
491 let limiter = app_non_trading_limiter();
494 limiter.wait().await;
495
496 let mut builder = self.client.request(method.clone(), &url);
497 builder = self.add_common_headers(builder, version);
498
499 if let Some(data) = body {
500 builder = builder.json(data);
501 }
502
503 let response_result = builder.send().await;
505
506 let response = match response_result {
508 Ok(resp) => resp,
509 Err(e) => {
510 error!(
511 "Network error for unauthenticated {} request to {}: {}",
512 method_str, url, e
513 );
514 drop(permit);
516
517 let app_error = AppError::Network(e);
519 if self.is_retryable_error(&app_error) {
520 retry_count += 1;
521 continue;
522 }
523 return Err(app_error);
524 }
525 };
526
527 let result = self.process_response::<R>(response).await;
529
530 if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
532 RATE_LIMITED.store(false, Ordering::SeqCst);
533 info!(
534 "Rate limit flag reset after successful unauthenticated request to {}",
535 url
536 );
537 }
538
539 drop(permit);
541
542 match &result {
544 Err(e) if self.is_retryable_error(e) => {
545 retry_count += 1;
546 continue;
547 }
548 _ => return result,
549 }
550 }
551
552 info!(
554 "Making final attempt for unauthenticated {} request to {} after max retries",
555 method_str, url
556 );
557
558 let permit = API_SEMAPHORE.acquire().await.unwrap();
560
561 let limiter = app_non_trading_limiter();
563 limiter.wait().await;
564
565 let mut builder = self.client.request(method, &url);
566 builder = self.add_common_headers(builder, version);
567
568 if let Some(data) = body {
569 builder = builder.json(data);
570 }
571
572 let response = builder.send().await?;
573 let result = self.process_response::<R>(response).await;
574
575 drop(permit);
576 result
577 }
578}