amazon_spapi/client/
rate_limiter.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4use tokio::time::sleep;
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8use crate::client::ApiEndpoint;
9
10/// State of a token bucket for rate limiting
11#[derive(Debug, Clone)]
12pub struct TokenBucketState {
13    pub tokens: f64,
14    pub last_refill: u64, // Unix timestamp in seconds
15    pub last_response_time: Option<u64>, // Unix timestamp in seconds when last response was received
16    pub rate: f64,        // requests per second
17    pub burst: u32,       // maximum burst capacity
18}
19
20/// In-memory rate limiter that manages token buckets for different endpoints
21/// Thread-safe but not cross-process
22pub struct RateLimiter {
23    buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
24}
25
26impl RateLimiter {
27    pub fn new() -> Self {
28        Self { 
29            buckets: Arc::new(Mutex::new(HashMap::new())),
30        }
31    }
32
33    /// Wait for a token to become available for the given endpoint
34    /// This method will block until a token is available
35    pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
36        loop {
37            {
38                let mut buckets = self.buckets.lock().await;
39                let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
40                
41                // Get or create bucket for this endpoint
42                let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
43                    log::debug!("Creating new token bucket for endpoint: {}", identifier);
44                    TokenBucketState {
45                        tokens: burst as f64, // Start with full burst capacity
46                        last_refill: now,
47                        last_response_time: None, // No response received yet
48                        rate: rate,
49                        burst: burst,
50                    }
51                });
52
53                // Update bucket configuration if endpoint configuration changed
54                if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
55                    log::info!("Updating rate limit for {}: rate {} -> {}, burst {} -> {}", 
56                        identifier, bucket.rate, rate, bucket.burst, burst);
57                    bucket.rate = rate;
58                    bucket.burst = burst;
59                }
60
61                // Refill tokens based on time passed
62                let time_passed = now.saturating_sub(bucket.last_refill) as f64;
63                let tokens_to_add = time_passed * bucket.rate; // rate tokens per second
64                bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
65                bucket.last_refill = now;
66
67                // Check if we need to wait based on the minimum interval since last response
68                if let Some(last_response_time) = bucket.last_response_time {
69                    let minimum_interval = 1.0 / bucket.rate; // minimum seconds between requests
70                    let time_since_response = now.saturating_sub(last_response_time) as f64;
71                    
72                    if time_since_response < minimum_interval {
73                        let wait_seconds = minimum_interval - time_since_response;
74                        log::debug!("Enforcing minimum interval for {}: waiting {:.3}s since last response", 
75                            identifier, wait_seconds);
76                        
77                        // Release lock and wait
78                        drop(buckets);
79                        sleep(Duration::from_secs_f64(wait_seconds)).await;
80                        continue; // Retry after waiting
81                    }
82                }
83
84                log::trace!("Endpoint {}: {:.2} tokens available, rate: {}/s, burst: {}", 
85                    identifier, bucket.tokens, bucket.rate, bucket.burst);
86
87                // Check if we have a token available
88                if bucket.tokens >= 1.0 {
89                    bucket.tokens -= 1.0;
90                    
91                    log::debug!("Token consumed for {}, {:.2} tokens remaining", 
92                        identifier, bucket.tokens);
93                    
94                    return Ok(());
95                }
96
97                // Calculate wait time for next token
98                let wait_time = Duration::from_secs_f64(1.0 / bucket.rate);
99                log::debug!("Rate limit reached for {}, waiting {:?}", 
100                    identifier, wait_time);
101            } // Release lock here
102
103            // Sleep outside the lock to allow other tasks to proceed
104            sleep(Duration::from_millis(100)).await; // Check every 100ms
105        }
106    }
107
108    /// Check if a token is available without consuming it
109    pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
110        let mut buckets = self.buckets.lock().await;
111        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
112        
113        if let Some(bucket) = buckets.get_mut(identifier) {
114            // Refill tokens
115            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
116            let tokens_to_add = time_passed * bucket.rate;
117            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
118            bucket.last_refill = now;
119            
120            Ok(bucket.tokens >= 1.0)
121        } else {
122            // No bucket exists, so we can create one with full capacity
123            Ok(true)
124        }
125    }
126
127    /// Get current token status for all endpoints
128    /// Returns (tokens, rate, burst) for each endpoint
129    pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
130        let mut buckets = self.buckets.lock().await;
131        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
132        let mut status = HashMap::new();
133        
134        for (endpoint_key, bucket) in buckets.iter_mut() {
135            // Refill tokens before reporting status
136            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
137            let tokens_to_add = time_passed * bucket.rate;
138            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
139            bucket.last_refill = now;
140
141            status.insert(endpoint_key.clone(), (bucket.tokens, bucket.rate, bucket.burst));
142        }
143        
144        Ok(status)
145    }
146
147    /// Reset all rate limiting state (useful for testing)
148    pub async fn reset(&self) {
149        let mut buckets = self.buckets.lock().await;
150        buckets.clear();
151        log::debug!("Rate limiter state reset");
152    }
153
154    /// Get the number of active buckets
155    pub async fn active_buckets_count(&self) -> usize {
156        let buckets = self.buckets.lock().await;
157        buckets.len()
158    }
159
160    /// Record that a response was received for the given endpoint
161    /// This updates the last_response_time used for enforcing minimum intervals
162    pub async fn record_response(&self, identifier: &str) -> Result<()> {
163        let mut buckets = self.buckets.lock().await;
164        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
165        
166        if let Some(bucket) = buckets.get_mut(identifier) {
167            bucket.last_response_time = Some(now);
168            log::trace!("Recorded response time for {}: {}", identifier, now);
169        } else {
170            log::warn!("Attempted to record response for unknown endpoint: {}", identifier);
171        }
172        
173        Ok(())
174    }
175}
176
177impl Default for RateLimiter {
178    fn default() -> Self {
179        Self::new()
180    }
181}