Skip to main content

goldrush_sdk/
rate_limit.rs

1use crate::{Error, Result};
2use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
3use std::time::Duration;
4use tracing::{debug, warn, instrument};
5
6/// Rate limiting configuration for API requests.
7#[derive(Debug, Clone)]
8pub struct RateLimitConfig {
9    /// Maximum requests per second.
10    pub max_requests_per_second: f64,
11    /// Burst capacity for short-term spikes.
12    pub burst_capacity: u32,
13    /// Enable exponential backoff for failed requests.
14    pub enable_backoff: bool,
15    /// Maximum retry attempts.
16    pub max_retries: u32,
17}
18
19impl Default for RateLimitConfig {
20    fn default() -> Self {
21        Self {
22            max_requests_per_second: 10.0, // Conservative default
23            burst_capacity: 20,
24            enable_backoff: true,
25            max_retries: 3,
26        }
27    }
28}
29
30/// Token bucket rate limiter.
31#[derive(Debug)]
32pub struct RateLimiter {
33    config: RateLimitConfig,
34    tokens: std::sync::Arc<tokio::sync::Mutex<f64>>,
35    last_refill: std::sync::Arc<tokio::sync::Mutex<std::time::Instant>>,
36}
37
38impl RateLimiter {
39    pub fn new(config: RateLimitConfig) -> Self {
40        Self {
41            tokens: std::sync::Arc::new(tokio::sync::Mutex::new(config.burst_capacity as f64)),
42            last_refill: std::sync::Arc::new(tokio::sync::Mutex::new(std::time::Instant::now())),
43            config,
44        }
45    }
46
47    /// Check if a request can proceed, applying rate limiting if necessary.
48    pub async fn acquire(&self) -> Result<()> {
49        self.acquire_internal().await
50    }
51    
52    #[instrument(skip(self), fields(max_rps = %self.config.max_requests_per_second))]
53    async fn acquire_internal(&self) -> Result<()> {
54        let mut tokens = self.tokens.lock().await;
55        let mut last_refill = self.last_refill.lock().await;
56        
57        let now = std::time::Instant::now();
58        let time_elapsed = now.duration_since(*last_refill);
59        
60        // Refill tokens based on elapsed time
61        let tokens_to_add = time_elapsed.as_secs_f64() * self.config.max_requests_per_second;
62        *tokens = (*tokens + tokens_to_add).min(self.config.burst_capacity as f64);
63        *last_refill = now;
64        
65        if *tokens >= 1.0 {
66            *tokens -= 1.0;
67            debug!(tokens_remaining = %*tokens, "Rate limit check passed");
68            Ok(())
69        } else {
70            let wait_time = Duration::from_millis(
71                ((1.0 - *tokens) / self.config.max_requests_per_second * 1000.0) as u64
72            );
73            
74            warn!(
75                wait_time_ms = %wait_time.as_millis(),
76                "Rate limit exceeded, waiting"
77            );
78            
79            // Release locks before waiting
80            drop(tokens);
81            drop(last_refill);
82            
83            tokio::time::sleep(wait_time).await;
84            Box::pin(self.acquire_internal()).await
85        }
86    }
87}
88
89/// Create exponential backoff strategy for retries.
90pub fn create_backoff_strategy(config: &RateLimitConfig) -> ExponentialBackoff {
91    ExponentialBackoffBuilder::new()
92        .with_initial_interval(Duration::from_millis(100))
93        .with_max_interval(Duration::from_secs(30))
94        .with_multiplier(2.0)
95        .with_max_elapsed_time(Some(Duration::from_secs(300))) // 5 minutes max
96        .build()
97}
98
99/// Retry a request with exponential backoff.
100#[instrument(skip(operation), fields(max_retries = %max_retries))]
101pub async fn retry_with_backoff<F, T>(
102    operation: F,
103    max_retries: u32,
104) -> Result<T>
105where
106    F: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + 'static>>,
107    T: Send + 'static,
108{
109    let backoff = ExponentialBackoffBuilder::new()
110        .with_initial_interval(Duration::from_millis(100))
111        .with_max_interval(Duration::from_secs(30))
112        .with_multiplier(2.0)
113        .build();
114    
115    let mut attempts = 0;
116    let mut current_wait = Duration::from_millis(100);
117    
118    loop {
119        attempts += 1;
120        
121        match operation().await {
122            Ok(result) => {
123                debug!(attempts = %attempts, "Request succeeded");
124                return Ok(result);
125            }
126            Err(err) => {
127                if attempts >= max_retries {
128                    warn!(attempts = %attempts, "Max retries exceeded");
129                    return Err(err);
130                }
131                
132                warn!(
133                    attempt = %attempts,
134                    wait_time_ms = %current_wait.as_millis(),
135                    error = %err,
136                    "Request failed, retrying with backoff"
137                );
138                
139                tokio::time::sleep(current_wait).await;
140                current_wait = std::cmp::min(current_wait * 2, Duration::from_secs(30));
141            }
142        }
143    }
144}