amazon_spapi/client/
rate_limiter.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10/// RAII guard that automatically records response when dropped
11#[must_use = "RateLimitGuard must be held until the API response is received"]
12pub struct RateLimitGuard {
13    rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
14    identifier: String,
15    auto_record: bool,
16}
17
18impl RateLimitGuard {
19    fn new(
20        rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
21        identifier: String,
22        auto_record: bool,
23    ) -> Self {
24        Self {
25            rate_limiter,
26            identifier,
27            auto_record,
28        }
29    }
30
31    /// Manually mark that the API response was received
32    /// This will record the response time and prevent automatic recording on drop
33    pub async fn mark_response(mut self) {
34        if self.auto_record {
35            let mut buckets = self.rate_limiter.lock().await;
36            let now = SystemTime::now()
37                .duration_since(UNIX_EPOCH)
38                .unwrap_or_default()
39                .as_secs();
40
41            if let Some(bucket) = buckets.get_mut(&self.identifier) {
42                bucket.last_response_time = Some(now);
43                log::trace!(
44                    "Manually recorded response time for {}: {}",
45                    self.identifier,
46                    now
47                );
48            }
49        }
50
51        // Prevent automatic recording on drop
52        self.auto_record = false;
53    }
54}
55
56impl Drop for RateLimitGuard {
57    fn drop(&mut self) {
58        if !self.auto_record {
59            return;
60        }
61
62        let rate_limiter = self.rate_limiter.clone();
63        let identifier = self.identifier.clone();
64
65        // Spawn a task to record the response since Drop can't be async
66        tokio::spawn(async move {
67            let mut buckets = rate_limiter.lock().await;
68            let now = SystemTime::now()
69                .duration_since(UNIX_EPOCH)
70                .unwrap_or_default()
71                .as_secs();
72
73            if let Some(bucket) = buckets.get_mut(&identifier) {
74                // Always record the response time to ensure accurate rate limiting
75                bucket.last_response_time = Some(now);
76                log::trace!("Auto-recorded response time for {}: {}", identifier, now);
77            } else {
78                log::warn!(
79                    "Attempted to auto-record response for unknown endpoint: {}",
80                    identifier
81                );
82            }
83        });
84    }
85}
86
87/// State of a token bucket for rate limiting
88#[derive(Debug, Clone)]
89pub struct TokenBucketState {
90    pub tokens: f64,
91    pub last_refill: u64,                // Unix timestamp in seconds
92    pub last_response_time: Option<u64>, // Unix timestamp in seconds when last response was received
93    pub rate: f64,                       // requests per second
94    pub burst: u32,                      // maximum burst capacity
95    pub initial_burst_used: bool,        // Track if initial burst has been used
96}
97
98/// In-memory rate limiter that manages token buckets for different endpoints
99/// Thread-safe but not cross-process
100pub struct RateLimiter {
101    buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
102}
103
104impl RateLimiter {
105    pub fn new() -> Self {
106        Self {
107            buckets: Arc::new(Mutex::new(HashMap::new())),
108        }
109    }
110
111    /// Wait for a token to become available for the given endpoint and return a guard
112    /// When the guard is dropped, record_response will be called automatically
113    #[must_use = "The returned guard must be held until the API response is received"]
114    pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
115        self._wait_for_token(identifier, rate, burst).await?;
116
117        Ok(RateLimitGuard::new(
118            self.buckets.clone(),
119            identifier.to_string(),
120            true, // auto_record enabled
121        ))
122    }
123
124    async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
125        loop {
126            let mut buckets = self.buckets.lock().await;
127            let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
128
129            // Get or create bucket for this endpoint
130            let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
131                log::debug!("Creating new token bucket for endpoint: {}", identifier);
132                TokenBucketState {
133                    tokens: burst as f64,
134                    last_refill: now,
135                    last_response_time: None,
136                    rate: rate,
137                    burst: burst,
138                    initial_burst_used: false,
139                }
140            });
141
142            // Update bucket configuration if changed
143            if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
144                log::info!(
145                    "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
146                    identifier,
147                    bucket.rate,
148                    rate,
149                    bucket.burst,
150                    burst
151                );
152                bucket.rate = rate;
153                bucket.burst = burst;
154            }
155
156            // Calculate token refill based on appropriate time reference
157            let refill_from_time = if bucket.initial_burst_used {
158                // After initial burst, use response time for consistent interval enforcement
159                bucket.last_response_time.unwrap_or(bucket.last_refill)
160            } else {
161                // During initial burst, use last_refill
162                bucket.last_refill
163            };
164
165            let time_passed = now.saturating_sub(refill_from_time) as f64;
166            let tokens_to_add = time_passed * bucket.rate;
167            let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
168
169            // Reset burst state if bucket refilled to near capacity
170            if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
171                bucket.initial_burst_used = false;
172                bucket.last_response_time = None;
173                log::debug!(
174                    "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
175                    identifier,
176                    new_tokens
177                );
178            }
179
180            bucket.tokens = new_tokens;
181            bucket.last_refill = now;
182
183            // Enforce minimum interval after initial burst
184            if bucket.initial_burst_used {
185                if let Some(last_response_time) = bucket.last_response_time {
186                    let minimum_interval = 1.0 / bucket.rate;
187                    let time_since_response = now.saturating_sub(last_response_time) as f64;
188
189                    if time_since_response < minimum_interval {
190                        let wait_seconds = minimum_interval - time_since_response;
191                        log::debug!(
192                            "Enforcing minimum interval for {}: waiting {:.3}s since last response",
193                            identifier,
194                            wait_seconds
195                        );
196
197                        drop(buckets);
198                        sleep(Duration::from_secs_f64(wait_seconds)).await;
199                        continue;
200                    }
201                }
202            }
203
204            log::debug!(
205            "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
206            identifier,
207            bucket.tokens,
208            bucket.rate,
209            bucket.burst,
210            bucket.initial_burst_used
211        );
212
213            // Check if we have a token available
214            if bucket.tokens >= 1.0 {
215                bucket.tokens -= 1.0;
216
217                // For burst=1, immediately mark initial_burst_used
218                if bucket.burst == 1 || bucket.tokens <= 0.0 {
219                    bucket.initial_burst_used = true;
220                    log::debug!(
221                        "Initial burst capacity exhausted for {} (tokens: {:.1}, burst: {})",
222                        identifier,
223                        bucket.tokens,
224                        bucket.burst
225                    );
226                }
227
228                log::debug!(
229                    "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
230                    identifier,
231                    bucket.tokens,
232                    bucket.initial_burst_used
233                );
234
235                return Ok(());
236            }
237
238            // Calculate wait time for next token
239            let tokens_needed = 1.0 - bucket.tokens;
240            let wait_seconds = tokens_needed / bucket.rate;
241            log::debug!(
242                "Rate limit reached for {}, need to wait {:.2}s for next token",
243                identifier,
244                wait_seconds
245            );
246
247            // Mark initial burst as exhausted (if not already marked)
248            if !bucket.initial_burst_used {
249                bucket.initial_burst_used = true;
250                log::debug!(
251                    "Marking initial burst as used for {} due to token shortage",
252                    identifier
253                );
254            }
255
256            let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
257
258            drop(buckets);
259            sleep(wait_duration).await;
260        }
261    }
262
263    /// Check if a token is available without consuming it
264    pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
265        let mut buckets = self.buckets.lock().await;
266        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
267
268        if let Some(bucket) = buckets.get_mut(identifier) {
269            // Refill tokens
270            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
271            let tokens_to_add = time_passed * bucket.rate;
272            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
273            bucket.last_refill = now;
274
275            Ok(bucket.tokens >= 1.0)
276        } else {
277            // No bucket exists, so we can create one with full capacity
278            Ok(true)
279        }
280    }
281
282    /// Get current token status for all endpoints
283    /// Returns (tokens, rate, burst) for each endpoint
284    pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
285        let mut buckets = self.buckets.lock().await;
286        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
287        let mut status = HashMap::new();
288
289        for (endpoint_key, bucket) in buckets.iter_mut() {
290            // Refill tokens before reporting status
291            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
292            let tokens_to_add = time_passed * bucket.rate;
293            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
294            bucket.last_refill = now;
295
296            status.insert(
297                endpoint_key.clone(),
298                (bucket.tokens, bucket.rate, bucket.burst),
299            );
300        }
301
302        Ok(status)
303    }
304
305    /// Reset all rate limiting state (useful for testing)
306    pub async fn reset(&self) {
307        let mut buckets = self.buckets.lock().await;
308        buckets.clear();
309        log::debug!("Rate limiter state reset");
310    }
311
312    /// Get the number of active buckets
313    pub async fn active_buckets_count(&self) -> usize {
314        let buckets = self.buckets.lock().await;
315        buckets.len()
316    }
317}
318
319impl Default for RateLimiter {
320    fn default() -> Self {
321        Self::new()
322    }
323}