amazon_spapi/client/
rate_limiter.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10/// RAII guard that automatically records response when dropped
11#[must_use = "RateLimitGuard must be held until the API response is received"]
12pub struct RateLimitGuard {
13    rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
14    identifier: String,
15    auto_record: bool,
16}
17
18impl RateLimitGuard {
19    fn new(
20        rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
21        identifier: String,
22        auto_record: bool,
23    ) -> Self {
24        Self {
25            rate_limiter,
26            identifier,
27            auto_record,
28        }
29    }
30
31    /// Manually mark that the API response was received
32    /// This will record the response time and prevent automatic recording on drop
33    pub async fn mark_response(mut self) {
34        if self.auto_record {
35            let mut buckets = self.rate_limiter.lock().await;
36            let now = SystemTime::now()
37                .duration_since(UNIX_EPOCH)
38                .unwrap_or_default()
39                .as_secs();
40
41            if let Some(bucket) = buckets.get_mut(&self.identifier) {
42                bucket.last_response_time = Some(now);
43                log::trace!(
44                    "Manually recorded response time for {}: {}",
45                    self.identifier,
46                    now
47                );
48            }
49        }
50
51        // Prevent automatic recording on drop
52        self.auto_record = false;
53    }
54}
55
56impl Drop for RateLimitGuard {
57    fn drop(&mut self) {
58        if !self.auto_record {
59            return;
60        }
61
62        let rate_limiter = self.rate_limiter.clone();
63        let identifier = self.identifier.clone();
64
65        // Spawn a task to record the response since Drop can't be async
66        tokio::spawn(async move {
67            let mut buckets = rate_limiter.lock().await;
68            let now = SystemTime::now()
69                .duration_since(UNIX_EPOCH)
70                .unwrap_or_default()
71                .as_secs();
72
73            if let Some(bucket) = buckets.get_mut(&identifier) {
74                // Only record response time after initial burst is exhausted
75                if bucket.initial_burst_used || bucket.tokens < (bucket.burst as f64 * 0.5) {
76                    bucket.last_response_time = Some(now);
77                    log::trace!("Auto-recorded response time for {}: {}", identifier, now);
78                } else {
79                    log::trace!(
80                        "Skipping response time recording for {} (still in initial burst)",
81                        identifier
82                    );
83                }
84            } else {
85                log::warn!(
86                    "Attempted to auto-record response for unknown endpoint: {}",
87                    identifier
88                );
89            }
90        });
91    }
92}
93
94/// State of a token bucket for rate limiting
95#[derive(Debug, Clone)]
96pub struct TokenBucketState {
97    pub tokens: f64,
98    pub last_refill: u64,                // Unix timestamp in seconds
99    pub last_response_time: Option<u64>, // Unix timestamp in seconds when last response was received
100    pub rate: f64,                       // requests per second
101    pub burst: u32,                      // maximum burst capacity
102    pub initial_burst_used: bool,        // Track if initial burst has been used
103}
104
105/// In-memory rate limiter that manages token buckets for different endpoints
106/// Thread-safe but not cross-process
107pub struct RateLimiter {
108    buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
109}
110
111impl RateLimiter {
112    pub fn new() -> Self {
113        Self {
114            buckets: Arc::new(Mutex::new(HashMap::new())),
115        }
116    }
117
118    /// Wait for a token to become available for the given endpoint and return a guard
119    /// When the guard is dropped, record_response will be called automatically
120    #[must_use = "The returned guard must be held until the API response is received"]
121    pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
122        self._wait_for_token(identifier, rate, burst).await?;
123
124        Ok(RateLimitGuard::new(
125            self.buckets.clone(),
126            identifier.to_string(),
127            true, // auto_record enabled
128        ))
129    }
130
131    /// Wait for a token to become available for the given endpoint
132    /// This method will block until a token is available
133    #[deprecated(since = "0.1.4", note = "Use `wait()` instead.")]
134    pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
135        self._wait_for_token(identifier, rate, burst).await
136    }
137
138    async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
139        loop {
140            let mut buckets = self.buckets.lock().await;
141            let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
142
143            // Get or create bucket for this endpoint
144            let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
145                log::debug!("Creating new token bucket for endpoint: {}", identifier);
146                TokenBucketState {
147                    tokens: burst as f64, // Start with full burst capacity
148                    last_refill: now,
149                    last_response_time: None, // No response received yet
150                    rate: rate,
151                    burst: burst,
152                    initial_burst_used: false, // Mark that we haven't used the initial burst yet
153                }
154            });
155
156            // Update bucket configuration if endpoint configuration changed
157            if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
158                log::info!(
159                    "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
160                    identifier,
161                    bucket.rate,
162                    rate,
163                    bucket.burst,
164                    burst
165                );
166                bucket.rate = rate;
167                bucket.burst = burst;
168            }
169
170            // 始终进行token refill计算
171            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
172            let tokens_to_add = time_passed * bucket.rate;
173            let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
174
175            // 如果token恢复到了接近满容量,重置initial_burst_used标志
176            if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
177                bucket.initial_burst_used = false;
178                bucket.last_response_time = None; // 清除响应时间记录
179                log::debug!(
180                    "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
181                    identifier,
182                    new_tokens
183                );
184            }
185
186            bucket.tokens = new_tokens;
187            bucket.last_refill = now;
188
189            // Only check minimum interval if initial burst has been used AND we have recent response time
190            if bucket.initial_burst_used {
191                if let Some(last_response_time) = bucket.last_response_time {
192                    let minimum_interval = 1.0 / bucket.rate;
193                    let time_since_response = now.saturating_sub(last_response_time) as f64;
194
195                    if time_since_response < minimum_interval {
196                        let wait_seconds = minimum_interval - time_since_response;
197                        log::debug!(
198                            "Enforcing minimum interval for {}: waiting {:.3}s since last response",
199                            identifier,
200                            wait_seconds
201                        );
202
203                        drop(buckets);
204                        sleep(Duration::from_secs_f64(wait_seconds)).await;
205                        continue;
206                    }
207                }
208            }
209
210            log::debug!(
211            "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
212            identifier,
213            bucket.tokens,
214            bucket.rate,
215            bucket.burst,
216            bucket.initial_burst_used
217        );
218
219            // Check if we have a token available
220            if bucket.tokens >= 1.0 {
221                bucket.tokens -= 1.0;
222
223                // 只有当token数量降到很低时才标记initial_burst已用完
224                if bucket.tokens <= 1.0 && !bucket.initial_burst_used {
225                    bucket.initial_burst_used = true;
226                    log::debug!(
227                        "Initial burst capacity exhausted for {} (tokens: {:.1})",
228                        identifier,
229                        bucket.tokens
230                    );
231                }
232
233                log::debug!(
234                    "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
235                    identifier,
236                    bucket.tokens,
237                    bucket.initial_burst_used
238                );
239
240                return Ok(());
241            }
242
243            // Calculate wait time for next token
244            let tokens_needed = 1.0 - bucket.tokens;
245            let wait_seconds = tokens_needed / bucket.rate;
246            log::debug!(
247                "Rate limit reached for {}, need to wait {:.2}s for next token",
248                identifier,
249                wait_seconds
250            );
251
252            // Mark initial burst as exhausted (if not already marked)
253            if !bucket.initial_burst_used {
254                bucket.initial_burst_used = true;
255                log::debug!(
256                    "Marking initial burst as used for {} due to token shortage",
257                    identifier
258                );
259            }
260
261            let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
262
263            drop(buckets);
264            sleep(wait_duration).await;
265        }
266    }
267
268    /// Check if a token is available without consuming it
269    pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
270        let mut buckets = self.buckets.lock().await;
271        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
272
273        if let Some(bucket) = buckets.get_mut(identifier) {
274            // Refill tokens
275            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
276            let tokens_to_add = time_passed * bucket.rate;
277            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
278            bucket.last_refill = now;
279
280            Ok(bucket.tokens >= 1.0)
281        } else {
282            // No bucket exists, so we can create one with full capacity
283            Ok(true)
284        }
285    }
286
287    /// Get current token status for all endpoints
288    /// Returns (tokens, rate, burst) for each endpoint
289    pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
290        let mut buckets = self.buckets.lock().await;
291        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
292        let mut status = HashMap::new();
293
294        for (endpoint_key, bucket) in buckets.iter_mut() {
295            // Refill tokens before reporting status
296            let time_passed = now.saturating_sub(bucket.last_refill) as f64;
297            let tokens_to_add = time_passed * bucket.rate;
298            bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
299            bucket.last_refill = now;
300
301            status.insert(
302                endpoint_key.clone(),
303                (bucket.tokens, bucket.rate, bucket.burst),
304            );
305        }
306
307        Ok(status)
308    }
309
310    /// Reset all rate limiting state (useful for testing)
311    pub async fn reset(&self) {
312        let mut buckets = self.buckets.lock().await;
313        buckets.clear();
314        log::debug!("Rate limiter state reset");
315    }
316
317    /// Get the number of active buckets
318    pub async fn active_buckets_count(&self) -> usize {
319        let buckets = self.buckets.lock().await;
320        buckets.len()
321    }
322
323    /// Record that a response was received for the given endpoint
324    /// This updates the last_response_time used for enforcing minimum intervals
325    #[deprecated(since = "0.1.4", note = "Use `wait()` instead.")]
326    pub async fn record_response(&self, identifier: &str) -> Result<()> {
327        let mut buckets = self.buckets.lock().await;
328        let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
329
330        if let Some(bucket) = buckets.get_mut(identifier) {
331            bucket.last_response_time = Some(now);
332            bucket.initial_burst_used = true; // Mark that we've moved past initial burst
333            log::trace!("Recorded response time for {}: {}", identifier, now);
334        } else {
335            log::warn!(
336                "Attempted to record response for unknown endpoint: {}",
337                identifier
338            );
339        }
340
341        Ok(())
342    }
343}
344
345impl Default for RateLimiter {
346    fn default() -> Self {
347        Self::new()
348    }
349}