amazon_spapi/client/
rate_limiter.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8/// RAII guard that automatically records response when dropped
9#[must_use = "RateLimitGuard must be held until the API response is received"]
10pub struct RateLimitGuard {
11    rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
12    identifier: String,
13    auto_record: bool,
14}
15
16impl RateLimitGuard {
17    fn new(
18        rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
19        identifier: String,
20        auto_record: bool,
21    ) -> Self {
22        Self {
23            rate_limiter,
24            identifier,
25            auto_record,
26        }
27    }
28
29    /// Manually mark that the API response was received
30    /// This will record the response time and prevent automatic recording on drop
31    pub async fn mark_response(mut self) {
32        if self.auto_record {
33            let mut buckets = self.rate_limiter.lock().await;
34            let now = SystemTime::now()
35                .duration_since(UNIX_EPOCH)
36                .unwrap_or_default()
37                .as_secs();
38
39            if let Some(bucket) = buckets.get_mut(&self.identifier) {
40                bucket.last_response_time = Some(now);
41                log::trace!(
42                    "Manually recorded response time for {}: {}",
43                    self.identifier,
44                    now
45                );
46            }
47        }
48
49        // Prevent automatic recording on drop
50        self.auto_record = false;
51    }
52}
53
54impl Drop for RateLimitGuard {
55    fn drop(&mut self) {
56        if !self.auto_record {
57            return;
58        }
59
60        let rate_limiter = self.rate_limiter.clone();
61        let identifier = self.identifier.clone();
62
63        // Spawn a task to record the response since Drop can't be async
64        tokio::spawn(async move {
65            let mut buckets = rate_limiter.lock().await;
66            let now = SystemTime::now()
67                .duration_since(UNIX_EPOCH)
68                .unwrap_or_default()
69                .as_secs();
70
71            if let Some(bucket) = buckets.get_mut(&identifier) {
72                // Always record the response time to ensure accurate rate limiting
73                bucket.last_response_time = Some(now);
74                log::trace!("Auto-recorded response time for {}: {}", identifier, now);
75            } else {
76                log::warn!(
77                    "Attempted to auto-record response for unknown endpoint: {}",
78                    identifier
79                );
80            }
81        });
82    }
83}
84
85/// State of a token bucket for rate limiting
86#[derive(Debug, Clone)]
87pub struct TokenBucketState {
88    pub tokens: f64,
89    pub last_refill: u64,                // Unix timestamp in seconds
90    pub last_response_time: Option<u64>, // Unix timestamp in seconds when last response was received
91    pub rate: f64,                       // requests per second
92    pub burst: u32,                      // maximum burst capacity
93    pub initial_burst_used: bool,        // Track if initial burst has been used
94}
95
96/// In-memory rate limiter that manages token buckets for different endpoints
97/// Thread-safe but not cross-process
98pub struct RateLimiter {
99    buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
100    /// Safety factor to add buffer time between requests (1.0 = no buffer, 1.1 = 10% buffer)
101    /// Default is 1.05 (5% buffer) to avoid 429 errors due to timing inconsistencies
102    safety_factor: f64,
103}
104
105impl RateLimiter {
106    pub fn new() -> Self {
107        Self {
108            buckets: Arc::new(Mutex::new(HashMap::new())),
109            safety_factor: 1.10, // Default 10% safety buffer
110        }
111    }
112
113    /// Create a new rate limiter with a custom safety factor
114    /// Safety factor > 1.0 adds buffer time between requests
115    /// For example, 1.1 means 10% longer wait times
116    pub fn new_with_safety_factor(safety_factor: f64) -> Self {
117        let factor = if safety_factor < 1.0 {
118            log::warn!("Safety factor {} is less than 1.0, using 1.0 instead", safety_factor);
119            1.0
120        } else {
121            safety_factor
122        };
123        
124        Self {
125            buckets: Arc::new(Mutex::new(HashMap::new())),
126            safety_factor: factor,
127        }
128    }
129
130    /// Set the safety factor for rate limiting
131    /// This adds a buffer to prevent 429 errors due to timing inconsistencies
132    pub fn set_safety_factor(&mut self, safety_factor: f64) {
133        if safety_factor < 1.0 {
134            log::warn!("Safety factor {} is less than 1.0, using 1.0 instead", safety_factor);
135            self.safety_factor = 1.0;
136        } else {
137            self.safety_factor = safety_factor;
138            log::info!("Rate limiter safety factor set to {}", safety_factor);
139        }
140    }
141
142    /// Get the current safety factor
143    pub fn get_safety_factor(&self) -> f64 {
144        self.safety_factor
145    }
146
147    /// Wait for a token to become available for the given endpoint and return a guard
148    /// When the guard is dropped, record_response will be called automatically
149    #[must_use = "The returned guard must be held until the API response is received"]
150    pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
151        self._wait_for_token(identifier, rate, burst).await?;
152
153        Ok(RateLimitGuard::new(
154            self.buckets.clone(),
155            identifier.to_string(),
156            true, // auto_record enabled
157        ))
158    }
159
160    async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
161        loop {
162            let mut buckets = self.buckets.lock().await;
163            let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
164
165            // Get or create bucket for this endpoint
166            let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
167                log::debug!("Creating new token bucket for endpoint: {}", identifier);
168                TokenBucketState {
169                    tokens: burst as f64,
170                    last_refill: now,
171                    last_response_time: None,
172                    rate: rate,
173                    burst: burst,
174                    initial_burst_used: false,
175                }
176            });
177
178            // Update bucket configuration if changed
179            if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
180                log::info!(
181                    "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
182                    identifier,
183                    bucket.rate,
184                    rate,
185                    bucket.burst,
186                    burst
187                );
188                bucket.rate = rate;
189                bucket.burst = burst;
190            }
191
192            // Calculate token refill based on appropriate time reference
193            let refill_from_time = if bucket.initial_burst_used {
194                // After initial burst, use response time for consistent interval enforcement
195                bucket.last_response_time.unwrap_or(bucket.last_refill)
196            } else {
197                // During initial burst, use last_refill
198                bucket.last_refill
199            };
200
201            let time_passed = now.saturating_sub(refill_from_time) as f64;
202            let tokens_to_add = time_passed * bucket.rate;
203            let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
204
205            // Reset burst state if bucket refilled to near capacity
206            if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
207                bucket.initial_burst_used = false;
208                bucket.last_response_time = None;
209                log::debug!(
210                    "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
211                    identifier,
212                    new_tokens
213                );
214            }
215
216            bucket.tokens = new_tokens;
217            bucket.last_refill = now;
218
219            // Enforce minimum interval after initial burst
220            if bucket.initial_burst_used {
221                if let Some(last_response_time) = bucket.last_response_time {
222                    let base_minimum_interval = 1.0 / bucket.rate;
223                    // Apply safety factor to add buffer time
224                    let minimum_interval = base_minimum_interval * self.safety_factor;
225                    let time_since_response = now.saturating_sub(last_response_time) as f64;
226
227                    if time_since_response < minimum_interval {
228                        let wait_seconds = minimum_interval - time_since_response;
229                        log::debug!(
230                            "Enforcing minimum interval for {} (safety factor: {:.2}): waiting {:.3}s since last response",
231                            identifier,
232                            self.safety_factor,
233                            wait_seconds
234                        );
235
236                        drop(buckets);
237                        sleep(Duration::from_secs_f64(wait_seconds)).await;
238                        continue;
239                    }
240                }
241            }
242
243            log::debug!(
244            "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
245            identifier,
246            bucket.tokens,
247            bucket.rate,
248            bucket.burst,
249            bucket.initial_burst_used
250        );
251
252            // Check if we have a token available
253            if bucket.tokens >= 1.0 {
254                bucket.tokens -= 1.0;
255
256                // For burst=1, immediately mark initial_burst_used
257                if bucket.burst == 1 || bucket.tokens <= 0.0 {
258                    bucket.initial_burst_used = true;
259                    log::debug!(
260                        "Initial burst capacity exhausted for {} (tokens: {:.1}, burst: {})",
261                        identifier,
262                        bucket.tokens,
263                        bucket.burst
264                    );
265                }
266
267                log::debug!(
268                    "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
269                    identifier,
270                    bucket.tokens,
271                    bucket.initial_burst_used
272                );
273
274                return Ok(());
275            }
276
277            // Calculate wait time for next token
278            let tokens_needed = 1.0 - bucket.tokens;
279            let base_wait_seconds = tokens_needed / bucket.rate;
280            // Apply safety factor to wait time
281            let wait_seconds = base_wait_seconds * self.safety_factor;
282            log::debug!(
283                "Rate limit reached for {}, need to wait {:.2}s (base: {:.2}s, safety factor: {:.2}) for next token",
284                identifier,
285                wait_seconds,
286                base_wait_seconds,
287                self.safety_factor
288            );
289
290            // Mark initial burst as exhausted (if not already marked)
291            if !bucket.initial_burst_used {
292                bucket.initial_burst_used = true;
293                log::debug!(
294                    "Marking initial burst as used for {} due to token shortage",
295                    identifier
296                );
297            }
298
299            let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
300
301            drop(buckets);
302            sleep(wait_duration).await;
303        }
304    }
305
306    /// Check if a token is available without consuming it
307    pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
308        let mut buckets = self.buckets.lock().await;
309        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
310
311        if let Some(bucket) = buckets.get_mut(identifier) {
312            // Refill tokens
313            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
314            let tokens_to_add = time_passed * bucket.rate;
315            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
316            bucket.last_refill = now;
317
318            Ok(bucket.tokens >= 1.0)
319        } else {
320            // No bucket exists, so we can create one with full capacity
321            Ok(true)
322        }
323    }
324
325    /// Get current token status for all endpoints
326    /// Returns (tokens, rate, burst) for each endpoint
327    pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
328        let mut buckets = self.buckets.lock().await;
329        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
330        let mut status = HashMap::new();
331
332        for (endpoint_key, bucket) in buckets.iter_mut() {
333            // Refill tokens before reporting status
334            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
335            let tokens_to_add = time_passed * bucket.rate;
336            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
337            bucket.last_refill = now;
338
339            status.insert(
340                endpoint_key.clone(),
341                (bucket.tokens, bucket.rate, bucket.burst),
342            );
343        }
344
345        Ok(status)
346    }
347
348    /// Reset all rate limiting state (useful for testing)
349    pub async fn reset(&self) {
350        let mut buckets = self.buckets.lock().await;
351        buckets.clear();
352        log::debug!("Rate limiter state reset");
353    }
354
355    /// Get the number of active buckets
356    pub async fn active_buckets_count(&self) -> usize {
357        let buckets = self.buckets.lock().await;
358        buckets.len()
359    }
360}
361
362impl Default for RateLimiter {
363    fn default() -> Self {
364        Self::new()
365    }
366}