ig_client/transport/
http_client.rs

1use async_trait::async_trait;
2use once_cell::sync::Lazy;
3use reqwest::{Client, Method, RequestBuilder, Response, StatusCode};
4use serde::{Serialize, de::DeserializeOwned};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tracing::{debug, error, info, warn};
10
11use crate::constants::USER_AGENT;
12use crate::utils::rate_limiter::app_non_trading_limiter;
13use crate::{config::Config, error::AppError, session::interface::IgSession};
14
15// Global semaphore to limit concurrent API requests
16// This ensures that we don't exceed rate limits by making too many
17// concurrent requests
18static API_SEMAPHORE: Lazy<Arc<Semaphore>> = Lazy::new(|| {
19    Arc::new(Semaphore::new(3)) // Allow up to 3 concurrent requests
20});
21
22// Flag to indicate if we're in a rate-limited situation
23static RATE_LIMITED: Lazy<Arc<AtomicBool>> = Lazy::new(|| Arc::new(AtomicBool::new(false)));
24
25// Default retry configuration
26const DEFAULT_MAX_RETRIES: u32 = 10; // Increase max retries to ensure all requests are processed
27const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000; // 1 second
28const DEFAULT_MAX_BACKOFF_MS: u64 = 60000; // 60 seconds max backoff
29const DEFAULT_BACKOFF_FACTOR: f64 = 2.0; // Exponential backoff factor
30
31/// Interface for the IG HTTP client
32#[async_trait]
33pub trait IgHttpClient: Send + Sync {
34    /// Makes an HTTP request to the IG API
35    async fn request<T, R>(
36        &self,
37        method: Method,
38        path: &str,
39        session: &IgSession,
40        body: Option<&T>,
41        version: &str,
42    ) -> Result<R, AppError>
43    where
44        for<'de> R: DeserializeOwned + 'static,
45        T: Serialize + Send + Sync + 'static;
46
47    /// Makes an unauthenticated HTTP request (for login)
48    async fn request_no_auth<T, R>(
49        &self,
50        method: Method,
51        path: &str,
52        body: Option<&T>,
53        version: &str,
54    ) -> Result<R, AppError>
55    where
56        for<'de> R: DeserializeOwned + 'static,
57        T: Serialize + Send + Sync + 'static;
58}
59
60/// Implementation of the HTTP client for IG
61pub struct IgHttpClientImpl {
62    config: Arc<Config>,
63    client: Client,
64    max_retries: u32,
65    initial_backoff_ms: u64,
66    max_backoff_ms: u64,
67    backoff_factor: f64,
68}
69
70impl IgHttpClientImpl {
71    /// Creates a new instance of the HTTP client
72    pub fn new(config: Arc<Config>) -> Self {
73        let client = Client::builder()
74            .user_agent(USER_AGENT)
75            .timeout(Duration::from_secs(config.rest_api.timeout))
76            .build()
77            .expect("Failed to create HTTP client");
78
79        Self {
80            config,
81            client,
82            max_retries: DEFAULT_MAX_RETRIES,
83            initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
84            max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
85            backoff_factor: DEFAULT_BACKOFF_FACTOR,
86        }
87    }
88
89    /// Configure retry behavior
90    pub fn with_retry_config(
91        mut self,
92        max_retries: u32,
93        initial_backoff_ms: u64,
94        max_backoff_ms: u64,
95        backoff_factor: f64,
96    ) -> Self {
97        self.max_retries = max_retries;
98        self.initial_backoff_ms = initial_backoff_ms;
99        self.max_backoff_ms = max_backoff_ms;
100        self.backoff_factor = backoff_factor;
101        self
102    }
103
104    /// Calculate backoff duration for retry attempts with jitter
105    fn calculate_backoff_duration(&self, retry_count: u32) -> Duration {
106        use rand::Rng;
107        let base_backoff_ms =
108            (self.initial_backoff_ms as f64 * self.backoff_factor.powi(retry_count as i32)) as u64;
109        let capped_backoff_ms = base_backoff_ms.min(self.max_backoff_ms);
110
111        // Add jitter (±20%) to avoid thundering herd problem
112        let jitter_factor = rand::rng().random_range(0.8..1.2);
113        let jittered_backoff_ms = (capped_backoff_ms as f64 * jitter_factor) as u64;
114
115        Duration::from_millis(jittered_backoff_ms)
116    }
117
118    /// Check if an error is retryable
119    fn is_retryable_error(&self, error: &AppError) -> bool {
120        match error {
121            AppError::RateLimitExceeded => true,
122            AppError::Network(e) => {
123                // Retry on connection errors, timeouts, and server errors
124                e.is_timeout() || e.is_connect() || e.status().is_some_and(|s| s.is_server_error())
125            }
126            _ => false,
127        }
128    }
129
130    /// Builds the complete URL for a request
131    fn build_url(&self, path: &str) -> String {
132        format!(
133            "{}/{}",
134            self.config.rest_api.base_url.trim_end_matches('/'),
135            path.trim_start_matches('/')
136        )
137    }
138
139    /// Adds common headers to all requests
140    fn add_common_headers(&self, builder: RequestBuilder, version: &str) -> RequestBuilder {
141        let api_key = self.config.credentials.api_key.trim();
142        debug!("Adding X-IG-API-KEY header (length: {})", api_key.len());
143        if api_key.is_empty() {
144            error!("API key is empty!");
145        }
146        builder
147            .header("X-IG-API-KEY", api_key)
148            .header("Content-Type", "application/json; charset=UTF-8")
149            .header("Accept", "application/json; charset=UTF-8")
150            .header("Version", version)
151    }
152
153    /// Adds authentication headers to a request
154    fn add_auth_headers(&self, builder: RequestBuilder, session: &IgSession) -> RequestBuilder {
155        // Check if using OAuth (v3) or CST (v2) authentication
156        if let Some(oauth_token) = &session.oauth_token {
157            // Use OAuth Bearer token + IG-ACCOUNT-ID header
158            // Per IG API docs: OAuth requires both Authorization and IG-ACCOUNT-ID headers
159            debug!("Using OAuth authentication (Bearer token)");
160            debug!(
161                "   Access token: {}...",
162                &oauth_token.access_token[..10.min(oauth_token.access_token.len())]
163            );
164            debug!("   Account ID: {}", session.account_id);
165            builder
166                .header(
167                    "Authorization",
168                    format!("Bearer {}", oauth_token.access_token),
169                )
170                .header("IG-ACCOUNT-ID", &session.account_id)
171        } else {
172            // Use CST and X-SECURITY-TOKEN (v2)
173            debug!("Using CST authentication");
174            debug!(
175                "   CST length: {}, Token length: {}",
176                session.cst.len(),
177                session.token.len()
178            );
179            builder
180                .header("CST", &session.cst)
181                .header("X-SECURITY-TOKEN", &session.token)
182        }
183    }
184
185    /// Processes the HTTP response and handles rate limiting centrally
186    async fn process_response<R>(&self, response: Response) -> Result<R, AppError>
187    where
188        for<'de> R: DeserializeOwned + 'static,
189    {
190        let status = response.status();
191        let url = response.url().to_string();
192
193        // Handle rate limiting centrally
194        if status == StatusCode::TOO_MANY_REQUESTS {
195            self.handle_rate_limit(&url, "TOO_MANY_REQUESTS status code")
196                .await;
197            return Err(AppError::RateLimitExceeded);
198        }
199
200        match status {
201            StatusCode::OK | StatusCode::CREATED | StatusCode::ACCEPTED => {
202                let body = response.text().await?;
203                match serde_json::from_str::<R>(&body) {
204                    Ok(data) => Ok(data),
205                    Err(e) => {
206                        error!("Error deserializing response from {}: {}", url, e);
207                        error!("Response body: {}", body);
208                        Err(AppError::Json(e))
209                    }
210                }
211            }
212            StatusCode::UNAUTHORIZED => {
213                let body = response.text().await.unwrap_or_else(|_| "Unable to read response body".to_string());
214                error!("Unauthorized request to {}", url);
215                error!("Response body: {}", body);
216                Err(AppError::Unauthorized)
217            }
218            StatusCode::NOT_FOUND => {
219                error!("Resource not found at {}", url);
220                Err(AppError::NotFound)
221            }
222            StatusCode::FORBIDDEN => {
223                let body = response.text().await?;
224                if body.contains("exceeded-api-key-allowance")
225                    || body.contains("exceeded-account-allowance")
226                {
227                    self.handle_rate_limit(
228                        &url,
229                        "FORBIDDEN with exceeded-api-key-allowance or exceeded-account-allowance",
230                    )
231                    .await;
232                    Err(AppError::RateLimitExceeded)
233                } else {
234                    error!("Forbidden access to {}: {}", url, body);
235                    Err(AppError::Unauthorized)
236                }
237            }
238            _ => {
239                let body = response.text().await?;
240                error!(
241                    "Unexpected status code {} for request to {}: {}",
242                    status, url, body
243                );
244                Err(AppError::Unexpected(status))
245            }
246        }
247    }
248
249    /// Helper method to handle rate limiting
250    async fn handle_rate_limit(&self, url: &str, reason: &str) {
251        // Set the rate limited flag
252        RATE_LIMITED.store(true, Ordering::SeqCst);
253        error!("Rate limit exceeded for request to {} ({})", url, reason);
254
255        // Notify all rate limiters about the exceeded limit
256        // This will cause them to enforce a mandatory cooldown period
257        let non_trading_limiter = app_non_trading_limiter();
258        non_trading_limiter.notify_rate_limit_exceeded().await;
259
260        // Schedule a task to reset the flag after a delay
261        // Increased from 30 to 60 seconds to give more time for rate limit to reset
262        let rate_limited = RATE_LIMITED.clone();
263        tokio::spawn(async move {
264            tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
265            rate_limited.store(false, Ordering::SeqCst);
266            debug!("Rate limit flag reset after 60 second cooldown");
267        });
268    }
269}
270
271#[async_trait]
272impl IgHttpClient for IgHttpClientImpl {
273    async fn request<T, R>(
274        &self,
275        method: Method,
276        path: &str,
277        session: &IgSession,
278        body: Option<&T>,
279        version: &str,
280    ) -> Result<R, AppError>
281    where
282        for<'de> R: DeserializeOwned + 'static,
283        T: Serialize + Send + Sync + 'static,
284    {
285        let url = self.build_url(path);
286        let method_str = method.as_str().to_string(); // Store method as string for logging
287        debug!("Making {} request to {}", method_str, url);
288
289        let mut retry_count = 0;
290
291        // Retry loop
292        loop {
293            // Check if we should retry
294            if retry_count > 0 {
295                if retry_count > self.max_retries {
296                    warn!(
297                        "Max retries ({}) exceeded for {} request to {}",
298                        self.max_retries, method_str, url
299                    );
300                    break; // Exit the loop and try one last time without retrying
301                }
302
303                // Calculate backoff duration
304                let backoff = self.calculate_backoff_duration(retry_count - 1);
305                debug!(
306                    "Retry attempt {} for {} request to {}. Waiting for {:?} before retrying",
307                    retry_count, method_str, url, backoff
308                );
309                tokio::time::sleep(backoff).await;
310            }
311
312            // Check if we're currently rate limited
313            if RATE_LIMITED.load(Ordering::SeqCst) {
314                warn!("System is currently rate limited. Adding extra delay before request.");
315                // Add a longer extra delay if we're in a rate-limited situation
316                // Use retry count to increase delay for subsequent retries
317                let rate_limit_delay = 2000 + (retry_count * 1000) as u64;
318                tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
319            }
320
321            // Acquire a permit from the semaphore to limit concurrent requests
322            // This ensures we don't overwhelm the API with too many concurrent requests
323            let permit = API_SEMAPHORE.acquire().await.unwrap();
324            debug!(
325                "Acquired API semaphore permit for {} request to {}",
326                method_str, url
327            );
328
329            // Respect rate limits before making the request
330            // This will handle the actual rate limiting based on request history
331            match session.respect_rate_limit().await {
332                Ok(()) => {}
333                Err(e) => {
334                    drop(permit);
335                    if self.is_retryable_error(&e) {
336                        retry_count += 1;
337                        continue;
338                    }
339                    return Err(e);
340                }
341            }
342
343            let mut builder = self.client.request(method.clone(), &url);
344            builder = self.add_common_headers(builder, version);
345            builder = self.add_auth_headers(builder, session);
346
347            if let Some(data) = body {
348                builder = builder.json(data);
349            }
350
351            // Send the request
352            let response_result = builder.send().await;
353
354            // Check for network errors
355            let response = match response_result {
356                Ok(resp) => resp,
357                Err(e) => {
358                    error!("Network error for {} request to {}: {}", method_str, url, e);
359                    // Release the permit before continuing
360                    drop(permit);
361
362                    // Check if we should retry
363                    let app_error = AppError::Network(e);
364                    if self.is_retryable_error(&app_error) {
365                        retry_count += 1;
366                        continue;
367                    }
368                    return Err(app_error);
369                }
370            };
371
372            // Process the response - rate limiting is handled inside process_response
373            let result = self.process_response::<R>(response).await;
374
375            // If the request was successful, refresh token timer and reset rate limited flag
376            if result.is_ok() {
377                // Refresh token timer to extend token validity
378                session.refresh_token_timer();
379
380                if RATE_LIMITED.load(Ordering::SeqCst) {
381                    RATE_LIMITED.store(false, Ordering::SeqCst);
382                    debug!("Rate limit flag reset after successful request to {}", url);
383                }
384            }
385
386            // Release the permit (this happens automatically when permit goes out of scope,
387            // but we do it explicitly for clarity)
388            drop(permit);
389
390            // Handle the result
391            match &result {
392                Err(e) if self.is_retryable_error(e) => {
393                    retry_count += 1;
394                    continue;
395                }
396                _ => return result,
397            }
398        }
399
400        // Final attempt without retrying
401        info!(
402            "Making final attempt for {} request to {} after max retries",
403            method_str, url
404        );
405
406        // Acquire a permit from the semaphore
407        let permit = API_SEMAPHORE.acquire().await.unwrap();
408
409        // Respect rate limits
410        session.respect_rate_limit().await?;
411
412        let mut builder = self.client.request(method, &url);
413        builder = self.add_common_headers(builder, version);
414        builder = self.add_auth_headers(builder, session);
415
416        if let Some(data) = body {
417            builder = builder.json(data);
418        }
419
420        let response = builder.send().await?;
421        let result = self.process_response::<R>(response).await;
422
423        // If the final attempt was successful, refresh token timer
424        if result.is_ok() {
425            session.refresh_token_timer();
426        }
427
428        drop(permit);
429        result
430    }
431
432    async fn request_no_auth<T, R>(
433        &self,
434        method: Method,
435        path: &str,
436        body: Option<&T>,
437        version: &str,
438    ) -> Result<R, AppError>
439    where
440        for<'de> R: DeserializeOwned + 'static,
441        T: Serialize + Send + Sync + 'static,
442    {
443        let url = self.build_url(path);
444        let method_str = method.as_str().to_string(); // Store method as string for logging
445        info!("Making unauthenticated {} request to {}", method_str, url);
446
447        let mut retry_count = 0;
448
449        // Retry loop
450        loop {
451            // Check if we should retry
452            if retry_count > 0 {
453                if retry_count > self.max_retries {
454                    warn!(
455                        "Max retries ({}) exceeded for unauthenticated {} request to {}",
456                        self.max_retries, method_str, url
457                    );
458                    break; // Exit the loop and try one last time without retrying
459                }
460
461                // Calculate backoff duration
462                let backoff = self.calculate_backoff_duration(retry_count - 1);
463                debug!(
464                    "Retry attempt {} for unauthenticated {} request to {}. Waiting for {:?} before retrying",
465                    retry_count, method_str, url, backoff
466                );
467                tokio::time::sleep(backoff).await;
468            }
469
470            // Check if we're currently rate limited
471            if RATE_LIMITED.load(Ordering::SeqCst) {
472                warn!(
473                    "System is currently rate limited. Adding extra delay before unauthenticated request."
474                );
475                // Add a longer extra delay if we're in a rate-limited situation
476                // Use retry count to increase delay for subsequent retries
477                let rate_limit_delay = 1000 + (retry_count * 500) as u64;
478                tokio::time::sleep(tokio::time::Duration::from_millis(rate_limit_delay)).await;
479            }
480
481            // Acquire a permit from the semaphore to limit concurrent requests
482            let permit = API_SEMAPHORE.acquire().await.unwrap();
483            debug!(
484                "Acquired API semaphore permit for unauthenticated {} request to {}",
485                method_str, url
486            );
487
488            // Use the global app rate limiter for unauthenticated requests
489            // This is thread-safe and can be called from multiple threads concurrently
490            let limiter = app_non_trading_limiter();
491            limiter.wait().await;
492
493            let mut builder = self.client.request(method.clone(), &url);
494            builder = self.add_common_headers(builder, version);
495
496            if let Some(data) = body {
497                builder = builder.json(data);
498            }
499
500            // Send the request
501            let response_result = builder.send().await;
502
503            // Check for network errors
504            let response = match response_result {
505                Ok(resp) => resp,
506                Err(e) => {
507                    error!(
508                        "Network error for unauthenticated {} request to {}: {}",
509                        method_str, url, e
510                    );
511                    // Release the permit before continuing
512                    drop(permit);
513
514                    // Check if we should retry
515                    let app_error = AppError::Network(e);
516                    if self.is_retryable_error(&app_error) {
517                        retry_count += 1;
518                        continue;
519                    }
520                    return Err(app_error);
521                }
522            };
523
524            // Process the response - rate limiting is handled inside process_response
525            let result = self.process_response::<R>(response).await;
526
527            // If the request was successful, reset the rate limited flag
528            if result.is_ok() && RATE_LIMITED.load(Ordering::SeqCst) {
529                RATE_LIMITED.store(false, Ordering::SeqCst);
530                info!(
531                    "Rate limit flag reset after successful unauthenticated request to {}",
532                    url
533                );
534            }
535
536            // Release the permit
537            drop(permit);
538
539            // Handle the result
540            match &result {
541                Err(e) if self.is_retryable_error(e) => {
542                    retry_count += 1;
543                    continue;
544                }
545                _ => return result,
546            }
547        }
548
549        // Final attempt without retrying
550        info!(
551            "Making final attempt for unauthenticated {} request to {} after max retries",
552            method_str, url
553        );
554
555        // Acquire a permit from the semaphore
556        let permit = API_SEMAPHORE.acquire().await.unwrap();
557
558        // Use the global app rate limiter
559        let limiter = app_non_trading_limiter();
560        limiter.wait().await;
561
562        let mut builder = self.client.request(method, &url);
563        builder = self.add_common_headers(builder, version);
564
565        if let Some(data) = body {
566            builder = builder.json(data);
567        }
568
569        let response = builder.send().await?;
570        let result = self.process_response::<R>(response).await;
571
572        drop(permit);
573        result
574    }
575}