1use async_trait::async_trait;
2use once_cell::sync::Lazy;
3use reqwest::{Client, Method, RequestBuilder, Response, StatusCode};
4use serde::{Serialize, de::DeserializeOwned};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, error, info, warn};
10
11use crate::constants::USER_AGENT;
12use crate::utils::rate_limiter::app_non_trading_limiter;
13use crate::{config::Config, error::AppError, session::interface::IgSession};
14
15static API_SEMAPHORE: Lazy<Arc<Semaphore>> = Lazy::new(|| {
19 Arc::new(Semaphore::new(3)) });
21
22static RATE_LIMITED: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
24
25const DEFAULT_MAX_RETRIES: u32 = 10; const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000; const DEFAULT_MAX_BACKOFF_MS: u64 = 60000; const DEFAULT_BACKOFF_FACTOR: f64 = 2.0; #[async_trait]
33pub trait IgHttpClient: Send + Sync {
34 async fn request<T, R>(
36 &self,
37 method: Method,
38 path: &str,
39 session: &IgSession,
40 body: Option<&T>,
41 version: &str,
42 ) -> Result<R, AppError>
43 where
44 for<'de> R: DeserializeOwned + 'static,
45 T: Serialize + Send + Sync + 'static;
46
47 async fn request_no_auth<T, R>(
49 &self,
50 method: Method,
51 path: &str,
52 body: Option<&T>,
53 version: &str,
54 ) -> Result<R, AppError>
55 where
56 for<'de> R: DeserializeOwned + 'static,
57 T: Serialize + Send + Sync + 'static;
58}
59
60pub struct IgHttpClientImpl {
62 config: Arc<Config>,
63 client: Client,
64 max_retries: u32,
65 initial_backoff_ms: u64,
66 max_backoff_ms: u64,
67 backoff_factor: f64,
68}
69
70impl IgHttpClientImpl {
71 pub fn new(config: Arc<Config>) -> Self {
73 let client = Client::builder()
74 .user_agent(USER_AGENT)
75 .timeout(Duration::from_secs(config.rest_api.timeout))
76 .build()
77 .expect("Failed to create HTTP client");
78
79 Self {
80 config,
81 client,
82 max_retries: DEFAULT_MAX_RETRIES,
83 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
84 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
85 backoff_factor: DEFAULT_BACKOFF_FACTOR,
86 }
87 }
88
89 pub fn with_retry_config(
91 mut self,
92 max_retries: u32,
93 initial_backoff_ms: u64,
94 max_backoff_ms: u64,
95 backoff_factor: f64,
96 ) -> Self {
97 self.max_retries = max_retries;
98 self.initial_backoff_ms = initial_backoff_ms;
99 self.max_backoff_ms = max_backoff_ms;
100 self.backoff_factor = backoff_factor;
101 self
102 }
103
104 fn calculate_backoff_duration(&self, retry_count: u32) -> Duration {
106 use rand::Rng;
107 let base_backoff_ms =
108 (self.initial_backoff_ms as f64 * self.backoff_factor.powi(retry_count as i32)) as u64;
109 let capped_backoff_ms = base_backoff_ms.min(self.max_backoff_ms);
110
111 let jitter_factor = rand::rng().random_range(0.8..1.2);
113 let jittered_backoff_ms = (capped_backoff_ms as f64 * jitter_factor) as u64;
114
115 Duration::from_millis(jittered_backoff_ms)
116 }
117
118 fn is_retryable_error(&self, error: &AppError) -> bool {
120 match error {
121 AppError::RateLimitExceeded => true,
122 AppError::OAuthTokenExpired => true,
123 AppError::Network(e) => {
124 e.is_timeout() || e.is_connect() || e.status().is_some_and(|s| s.is_server_error())
126 }
127 _ => false,
128 }
129 }
130
131 fn build_url(&self, path: &str) -> String {
133 format!(
134 "{}/{}",
135 self.config.rest_api.base_url.trim_end_matches('/'),
136 path.trim_start_matches('/')
137 )
138 }
139
140 fn add_common_headers(&self, builder: RequestBuilder, version: &str) -> RequestBuilder {
142 let api_key = self.config.credentials.api_key.trim();
143 debug!("Adding X-IG-API-KEY header (length: {})", api_key.len());
144 if api_key.is_empty() {
145 error!("API key is empty!");
146 }
147 builder
148 .header("X-IG-API-KEY", api_key)
149 .header("Content-Type", "application/json; charset=UTF-8")
150 .header("Accept", "application/json; charset=UTF-8")
151 .header("Version", version)
152 }
153
154 fn add_auth_headers(&self, builder: RequestBuilder, session: &IgSession) -> RequestBuilder {
156 if let Some(oauth_token) = &session.oauth_token {
158 debug!("Using OAuth authentication (Bearer token)");
161 debug!(
162 " Access token: {}...",
163 &oauth_token.access_token[..10.min(oauth_token.access_token.len())]
164 );
165 debug!(" Account ID: {}", session.account_id);
166 builder
167 .header(
168 "Authorization",
169 format!("Bearer {}", oauth_token.access_token),
170 )
171 .header("IG-ACCOUNT-ID", &session.account_id)
172 } else {
173 debug!("Using CST authentication");
175 debug!(
176 " CST length: {}, Token length: {}",
177 session.cst.len(),
178 session.token.len()
179 );
180 builder
181 .header("CST", &session.cst)
182 .header("X-SECURITY-TOKEN", &session.token)
183 }
184 }
185
186 async fn process_response<R>(&self, response: Response) -> Result<R, AppError>
188 where
189 for<'de> R: DeserializeOwned + 'static,
190 {
191 let status = response.status();
192 let url = response.url().to_string();
193
194 if status == StatusCode::TOO_MANY_REQUESTS {
196 self.handle_rate_limit(&url, "TOO_MANY_REQUESTS status code")
197 .await;
198 return Err(AppError::RateLimitExceeded);
199 }
200
201 match status {
202 StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED => {
203 let body = response.text().await?;
204 match serde_json::from_str::<R>(&body) {
205 Ok(data) => Ok(data),
206 Err(e) => {
207 error!("Error deserializing response from {}: {}", url, e);
208 error!("Response body: {}", body);
209 Err(AppError::Json(e))
210 }
211 }
212 }
213 StatusCode::UNAUTHORIZED => {
214 let body = response
215 .text()
216 .await
217 .unwrap_or_else(|_| "Unable to read response body".to_string());
218 error!("Unauthorized request to {}", url);
219 error!("Response body: {}", body);
220
221 if body.contains("error.security.oauth-token-invalid") {
223 debug!("Detected expired OAuth token");
224 Err(AppError::OAuthTokenExpired)
225 } else {
226 Err(AppError::Unauthorized)
227 }
228 }
229 StatusCode::NOT_FOUND => {
230 error!("Resource not found at {}", url);
231 Err(AppError::NotFound)
232 }
233 StatusCode::FORBIDDEN => {
234 let body = response.text().await?;
235 if body.contains("exceeded-api-key-allowance")
236 || body.contains("exceeded-account-allowance")
237 {
238 self.handle_rate_limit(
239 &url,
240 "FORBIDDEN with exceeded-api-key-allowance or exceeded-account-allowance",
241 )
242 .await;
243 Err(AppError::RateLimitExceeded)
244 } else {
245 error!("Forbidden access to {}: {}", url, body);
246 Err(AppError::Unauthorized)
247 }
248 }
249 _ => {
250 let body = response.text().await?;
251 error!(
252 "Unexpected status code {} for request to {}: {}",
253 status, url, body
254 );
255 Err(AppError::Unexpected(status))
256 }
257 }
258 }
259
260 async fn handle_rate_limit(&self, url: &str, reason: &str) {
262 RATE_LIMITED.store(true, Ordering::SeqCst);
264 error!("Rate limit exceeded for request to {} ({})", url, reason);
265
266 let non_trading_limiter = app_non_trading_limiter();
269 non_trading_limiter.notify_rate_limit_exceeded().await;
270
271 let rate_limited = RATE_LIMITED.clone();
274 tokio::spawn(async move {
275 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
276 rate_limited.store(false, Ordering::SeqCst);
277 debug!("Rate limit flag reset after 60 second cooldown");
278 });
279 }
280}
281
282#[async_trait]
283impl IgHttpClient for IgHttpClientImpl {
284 async fn request<T, R>(
285 &self,
286 method: Method,
287 path: &str,
288 session: &IgSession,
289 body: Option<&T>,
290 version: &str,
291 ) -> Result<R, AppError>
292 where
293 for<'de> R: DeserializeOwned + 'static,
294 T: Serialize + Send + Sync + 'static,
295 {
296 let url = self.build_url(path);
297 let method_str = method.as_str().to_string(); debug!("Making {} request to {}", method_str, url);
299
300 let mut retry_count = 0;
301
302 loop {
304 if retry_count > 0 {
306 if retry_count > self.max_retries {
307 warn!(
308 "Max retries ({}) exceeded for {} request to {}",
309 self.max_retries, method_str, url
310 );
311 break; }
313
314 let backoff = self.calculate_backoff_duration(retry_count - 1);
316 debug!(
317 "Retry attempt {} for {} request to {}. Waiting for {:?} before retrying",
318 retry_count, method_str, url, backoff
319 );
320 tokio::time::sleep(backoff).await;
321 }
322
323 if RATE_LIMITED.load(Ordering::SeqCst) {
325 warn!("System is currently rate limited. Adding extra delay before request.");
326 let rate_limit_delay = 2000 + (retry_count * 1000) as u64;
329 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
330 }
331
332 let permit = API_SEMAPHORE.acquire().await.unwrap();
335 debug!(
336 "Acquired API semaphore permit for {} request to {}",
337 method_str, url
338 );
339
340 match session.respect_rate_limit().await {
343 Ok(()) => {}
344 Err(e) => {
345 drop(permit);
346 if self.is_retryable_error(&e) {
347 retry_count += 1;
348 continue;
349 }
350 return Err(e);
351 }
352 }
353
354 let mut builder = self.client.request(method.clone(), &url);
355 builder = self.add_common_headers(builder, version);
356 builder = self.add_auth_headers(builder, session);
357
358 if let Some(data) = body {
359 builder = builder.json(data);
360 }
361
362 let response_result = builder.send().await;
364
365 let response = match response_result {
367 Ok(resp) => resp,
368 Err(e) => {
369 error!("Network error for {} request to {}: {}", method_str, url, e);
370 drop(permit);
372
373 let app_error = AppError::Network(e);
375 if self.is_retryable_error(&app_error) {
376 retry_count += 1;
377 continue;
378 }
379 return Err(app_error);
380 }
381 };
382
383 let result = self.process_response::<R>(response).await;
385
386 if result.is_ok() {
388 session.refresh_token_timer();
390
391 if RATE_LIMITED.load(Ordering::SeqCst) {
392 RATE_LIMITED.store(false, Ordering::SeqCst);
393 debug!("Rate limit flag reset after successful request to {}", url);
394 }
395 }
396
397 drop(permit);
400
401 match &result {
403 Err(e) if self.is_retryable_error(e) => {
404 retry_count += 1;
405 continue;
406 }
407 _ => return result,
408 }
409 }
410
411 info!(
413 "Making final attempt for {} request to {} after max retries",
414 method_str, url
415 );
416
417 let permit = API_SEMAPHORE.acquire().await.unwrap();
419
420 session.respect_rate_limit().await?;
422
423 let mut builder = self.client.request(method, &url);
424 builder = self.add_common_headers(builder, version);
425 builder = self.add_auth_headers(builder, session);
426
427 if let Some(data) = body {
428 builder = builder.json(data);
429 }
430
431 let response = builder.send().await?;
432 let result = self.process_response::<R>(response).await;
433
434 if result.is_ok() {
436 session.refresh_token_timer();
437 }
438
439 drop(permit);
440 result
441 }
442
443 async fn request_no_auth<T, R>(
444 &self,
445 method: Method,
446 path: &str,
447 body: Option<&T>,
448 version: &str,
449 ) -> Result<R, AppError>
450 where
451 for<'de> R: DeserializeOwned + 'static,
452 T: Serialize + Send + Sync + 'static,
453 {
454 let url = self.build_url(path);
455 let method_str = method.as_str().to_string(); info!("Making unauthenticated {} request to {}", method_str, url);
457
458 let mut retry_count = 0;
459
460 loop {
462 if retry_count > 0 {
464 if retry_count > self.max_retries {
465 warn!(
466 "Max retries ({}) exceeded for unauthenticated {} request to {}",
467 self.max_retries, method_str, url
468 );
469 break; }
471
472 let backoff = self.calculate_backoff_duration(retry_count - 1);
474 debug!(
475 "Retry attempt {} for unauthenticated {} request to {}. Waiting for {:?} before retrying",
476 retry_count, method_str, url, backoff
477 );
478 tokio::time::sleep(backoff).await;
479 }
480
481 if RATE_LIMITED.load(Ordering::SeqCst) {
483 warn!(
484 "System is currently rate limited. Adding extra delay before unauthenticated request."
485 );
486 let rate_limit_delay = 1000 + (retry_count * 500) as u64;
489 tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
490 }
491
492 let permit = API_SEMAPHORE.acquire().await.unwrap();
494 debug!(
495 "Acquired API semaphore permit for unauthenticated {} request to {}",
496 method_str, url
497 );
498
499 let limiter = app_non_trading_limiter();
502 limiter.wait().await;
503
504 let mut builder = self.client.request(method.clone(), &url);
505 builder = self.add_common_headers(builder, version);
506
507 if let Some(data) = body {
508 builder = builder.json(data);
509 }
510
511 let response_result = builder.send().await;
513
514 let response = match response_result {
516 Ok(resp) => resp,
517 Err(e) => {
518 error!(
519 "Network error for unauthenticated {} request to {}: {}",
520 method_str, url, e
521 );
522 drop(permit);
524
525 let app_error = AppError::Network(e);
527 if self.is_retryable_error(&app_error) {
528 retry_count += 1;
529 continue;
530 }
531 return Err(app_error);
532 }
533 };
534
535 let result = self.process_response::<R>(response).await;
537
538 if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
540 RATE_LIMITED.store(false, Ordering::SeqCst);
541 info!(
542 "Rate limit flag reset after successful unauthenticated request to {}",
543 url
544 );
545 }
546
547 drop(permit);
549
550 match &result {
552 Err(e) if self.is_retryable_error(e) => {
553 retry_count += 1;
554 continue;
555 }
556 _ => return result,
557 }
558 }
559
560 info!(
562 "Making final attempt for unauthenticated {} request to {} after max retries",
563 method_str, url
564 );
565
566 let permit = API_SEMAPHORE.acquire().await.unwrap();
568
569 let limiter = app_non_trading_limiter();
571 limiter.wait().await;
572
573 let mut builder = self.client.request(method, &url);
574 builder = self.add_common_headers(builder, version);
575
576 if let Some(data) = body {
577 builder = builder.json(data);
578 }
579
580 let response = builder.send().await?;
581 let result = self.process_response::<R>(response).await;
582
583 drop(permit);
584 result
585 }
586}