amazon_spapi/client/
rate_limiter.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10/// RAII guard that automatically records response when dropped
11pub struct RateLimitGuard {
12    rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
13    identifier: String,
14}
15
16impl RateLimitGuard {
17    fn new(
18        rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
19        identifier: String,
20    ) -> Self {
21        Self {
22            rate_limiter,
23            identifier,
24        }
25    }
26}
27
28impl Drop for RateLimitGuard {
29    fn drop(&mut self) {
30        let rate_limiter = self.rate_limiter.clone();
31        let identifier = self.identifier.clone();
32
33        // Spawn a task to record the response since Drop can't be async
34        tokio::spawn(async move {
35            let mut buckets = rate_limiter.lock().await;
36            let now = SystemTime::now()
37                .duration_since(UNIX_EPOCH)
38                .unwrap_or_default()
39                .as_secs();
40
41            if let Some(bucket) = buckets.get_mut(&identifier) {
42                bucket.last_response_time = Some(now);
43                log::trace!("Auto-recorded response time for {}: {}", identifier, now);
44            } else {
45                log::warn!(
46                    "Attempted to auto-record response for unknown endpoint: {}",
47                    identifier
48                );
49            }
50        });
51    }
52}
53
54/// State of a token bucket for rate limiting
55#[derive(Debug, Clone)]
56pub struct TokenBucketState {
57    pub tokens: f64,
58    pub last_refill: u64,                // Unix timestamp in seconds
59    pub last_response_time: Option<u64>, // Unix timestamp in seconds when last response was received
60    pub rate: f64,                       // requests per second
61    pub burst: u32,                      // maximum burst capacity
62}
63
64/// In-memory rate limiter that manages token buckets for different endpoints
65/// Thread-safe but not cross-process
66pub struct RateLimiter {
67    buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
68}
69
70impl RateLimiter {
71    pub fn new() -> Self {
72        Self {
73            buckets: Arc::new(Mutex::new(HashMap::new())),
74        }
75    }
76
77    /// Wait for a token to become available for the given endpoint and return a guard
78    /// When the guard is dropped, record_response will be called automatically
79    #[must_use = "The returned guard must be held until the API response is received"]
80    pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
81        self.wait_for_token(identifier, rate, burst).await?;
82
83        Ok(RateLimitGuard::new(
84            self.buckets.clone(),
85            identifier.to_string(),
86        ))
87    }
88
89    /// Wait for a token to become available for the given endpoint
90    /// This method will block until a token is available
91    pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
92        loop {
93            {
94                let mut buckets = self.buckets.lock().await;
95                let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
96
97                // Get or create bucket for this endpoint
98                let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
99                    log::debug!("Creating new token bucket for endpoint: {}", identifier);
100                    TokenBucketState {
101                        tokens: burst as f64, // Start with full burst capacity
102                        last_refill: now,
103                        last_response_time: None, // No response received yet
104                        rate: rate,
105                        burst: burst,
106                    }
107                });
108
109                // Update bucket configuration if endpoint configuration changed
110                if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
111                    log::info!(
112                        "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
113                        identifier,
114                        bucket.rate,
115                        rate,
116                        bucket.burst,
117                        burst
118                    );
119                    bucket.rate = rate;
120                    bucket.burst = burst;
121                }
122
123                // Refill tokens based on time passed
124                let time_passed = now.saturating_sub(bucket.last_refill) as f64;
125                let tokens_to_add = time_passed * bucket.rate; // rate tokens per second
126                bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
127                bucket.last_refill = now;
128
129                // Check if we need to wait based on the minimum interval since last response
130                if let Some(last_response_time) = bucket.last_response_time {
131                    let minimum_interval = 1.0 / bucket.rate; // minimum seconds between requests
132                    let time_since_response = now.saturating_sub(last_response_time) as f64;
133
134                    if time_since_response < minimum_interval {
135                        let wait_seconds = minimum_interval - time_since_response;
136                        log::debug!(
137                            "Enforcing minimum interval for {}: waiting {:.3}s since last response",
138                            identifier,
139                            wait_seconds
140                        );
141
142                        // Release lock and wait
143                        drop(buckets);
144                        sleep(Duration::from_secs_f64(wait_seconds)).await;
145                        continue; // Retry after waiting
146                    }
147                }
148
149                log::trace!(
150                    "Endpoint {}: {:.2} tokens available, rate: {}/s, burst: {}",
151                    identifier,
152                    bucket.tokens,
153                    bucket.rate,
154                    bucket.burst
155                );
156
157                // Check if we have a token available
158                if bucket.tokens >= 1.0 {
159                    bucket.tokens -= 1.0;
160
161                    log::debug!(
162                        "Token consumed for {}, {:.2} tokens remaining",
163                        identifier,
164                        bucket.tokens
165                    );
166
167                    return Ok(());
168                }
169
170                // Calculate wait time for next token
171                let wait_time = Duration::from_secs_f64(1.0 / bucket.rate);
172                log::debug!(
173                    "Rate limit reached for {}, waiting {:?}",
174                    identifier,
175                    wait_time
176                );
177            } // Release lock here
178
179            // Sleep outside the lock to allow other tasks to proceed
180            sleep(Duration::from_millis(100)).await; // Check every 100ms
181        }
182    }
183
184    /// Check if a token is available without consuming it
185    pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
186        let mut buckets = self.buckets.lock().await;
187        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
188
189        if let Some(bucket) = buckets.get_mut(identifier) {
190            // Refill tokens
191            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
192            let tokens_to_add = time_passed * bucket.rate;
193            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
194            bucket.last_refill = now;
195
196            Ok(bucket.tokens >= 1.0)
197        } else {
198            // No bucket exists, so we can create one with full capacity
199            Ok(true)
200        }
201    }
202
203    /// Get current token status for all endpoints
204    /// Returns (tokens, rate, burst) for each endpoint
205    pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
206        let mut buckets = self.buckets.lock().await;
207        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
208        let mut status = HashMap::new();
209
210        for (endpoint_key, bucket) in buckets.iter_mut() {
211            // Refill tokens before reporting status
212            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
213            let tokens_to_add = time_passed * bucket.rate;
214            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
215            bucket.last_refill = now;
216
217            status.insert(
218                endpoint_key.clone(),
219                (bucket.tokens, bucket.rate, bucket.burst),
220            );
221        }
222
223        Ok(status)
224    }
225
226    /// Reset all rate limiting state (useful for testing)
227    pub async fn reset(&self) {
228        let mut buckets = self.buckets.lock().await;
229        buckets.clear();
230        log::debug!("Rate limiter state reset");
231    }
232
233    /// Get the number of active buckets
234    pub async fn active_buckets_count(&self) -> usize {
235        let buckets = self.buckets.lock().await;
236        buckets.len()
237    }
238
239    /// Record that a response was received for the given endpoint
240    /// This updates the last_response_time used for enforcing minimum intervals
241    pub async fn record_response(&self, identifier: &str) -> Result<()> {
242        let mut buckets = self.buckets.lock().await;
243        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
244
245        if let Some(bucket) = buckets.get_mut(identifier) {
246            bucket.last_response_time = Some(now);
247            log::trace!("Recorded response time for {}: {}", identifier, now);
248        } else {
249            log::warn!(
250                "Attempted to record response for unknown endpoint: {}",
251                identifier
252            );
253        }
254
255        Ok(())
256    }
257}
258
259impl Default for RateLimiter {
260    fn default() -> Self {
261        Self::new()
262    }
263}