amazon_spapi/client/
rate_limiter.rs1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10pub struct RateLimitGuard {
12 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
13 identifier: String,
14}
15
16impl RateLimitGuard {
17 fn new(
18 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
19 identifier: String,
20 ) -> Self {
21 Self {
22 rate_limiter,
23 identifier,
24 }
25 }
26}
27
28impl Drop for RateLimitGuard {
29 fn drop(&mut self) {
30 let rate_limiter = self.rate_limiter.clone();
31 let identifier = self.identifier.clone();
32
33 tokio::spawn(async move {
35 let mut buckets = rate_limiter.lock().await;
36 let now = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40
41 if let Some(bucket) = buckets.get_mut(&identifier) {
42 bucket.last_response_time = Some(now);
43 log::trace!("Auto-recorded response time for {}: {}", identifier, now);
44 } else {
45 log::warn!(
46 "Attempted to auto-record response for unknown endpoint: {}",
47 identifier
48 );
49 }
50 });
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct TokenBucketState {
57 pub tokens: f64,
58 pub last_refill: u64, pub last_response_time: Option<u64>, pub rate: f64, pub burst: u32, }
63
64pub struct RateLimiter {
67 buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
68}
69
70impl RateLimiter {
71 pub fn new() -> Self {
72 Self {
73 buckets: Arc::new(Mutex::new(HashMap::new())),
74 }
75 }
76
77 #[must_use = "The returned guard must be held until the API response is received"]
80 pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
81 self.wait_for_token(identifier, rate, burst).await?;
82
83 Ok(RateLimitGuard::new(
84 self.buckets.clone(),
85 identifier.to_string(),
86 ))
87 }
88
89 pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
92 loop {
93 {
94 let mut buckets = self.buckets.lock().await;
95 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
96
97 let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
99 log::debug!("Creating new token bucket for endpoint: {}", identifier);
100 TokenBucketState {
101 tokens: burst as f64, last_refill: now,
103 last_response_time: None, rate: rate,
105 burst: burst,
106 }
107 });
108
109 if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
111 log::info!(
112 "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
113 identifier,
114 bucket.rate,
115 rate,
116 bucket.burst,
117 burst
118 );
119 bucket.rate = rate;
120 bucket.burst = burst;
121 }
122
123 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
125 let tokens_to_add = time_passed * bucket.rate; bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
127 bucket.last_refill = now;
128
129 if let Some(last_response_time) = bucket.last_response_time {
131 let minimum_interval = 1.0 / bucket.rate; let time_since_response = now.saturating_sub(last_response_time) as f64;
133
134 if time_since_response < minimum_interval {
135 let wait_seconds = minimum_interval - time_since_response;
136 log::debug!(
137 "Enforcing minimum interval for {}: waiting {:.3}s since last response",
138 identifier,
139 wait_seconds
140 );
141
142 drop(buckets);
144 sleep(Duration::from_secs_f64(wait_seconds)).await;
145 continue; }
147 }
148
149 log::trace!(
150 "Endpoint {}: {:.2} tokens available, rate: {}/s, burst: {}",
151 identifier,
152 bucket.tokens,
153 bucket.rate,
154 bucket.burst
155 );
156
157 if bucket.tokens >= 1.0 {
159 bucket.tokens -= 1.0;
160
161 log::debug!(
162 "Token consumed for {}, {:.2} tokens remaining",
163 identifier,
164 bucket.tokens
165 );
166
167 return Ok(());
168 }
169
170 let wait_time = Duration::from_secs_f64(1.0 / bucket.rate);
172 log::debug!(
173 "Rate limit reached for {}, waiting {:?}",
174 identifier,
175 wait_time
176 );
177 } sleep(Duration::from_millis(100)).await; }
182 }
183
184 pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
186 let mut buckets = self.buckets.lock().await;
187 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
188
189 if let Some(bucket) = buckets.get_mut(identifier) {
190 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
192 let tokens_to_add = time_passed * bucket.rate;
193 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
194 bucket.last_refill = now;
195
196 Ok(bucket.tokens >= 1.0)
197 } else {
198 Ok(true)
200 }
201 }
202
203 pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
206 let mut buckets = self.buckets.lock().await;
207 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
208 let mut status = HashMap::new();
209
210 for (endpoint_key, bucket) in buckets.iter_mut() {
211 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
213 let tokens_to_add = time_passed * bucket.rate;
214 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
215 bucket.last_refill = now;
216
217 status.insert(
218 endpoint_key.clone(),
219 (bucket.tokens, bucket.rate, bucket.burst),
220 );
221 }
222
223 Ok(status)
224 }
225
226 pub async fn reset(&self) {
228 let mut buckets = self.buckets.lock().await;
229 buckets.clear();
230 log::debug!("Rate limiter state reset");
231 }
232
233 pub async fn active_buckets_count(&self) -> usize {
235 let buckets = self.buckets.lock().await;
236 buckets.len()
237 }
238
239 pub async fn record_response(&self, identifier: &str) -> Result<()> {
242 let mut buckets = self.buckets.lock().await;
243 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
244
245 if let Some(bucket) = buckets.get_mut(identifier) {
246 bucket.last_response_time = Some(now);
247 log::trace!("Recorded response time for {}: {}", identifier, now);
248 } else {
249 log::warn!(
250 "Attempted to record response for unknown endpoint: {}",
251 identifier
252 );
253 }
254
255 Ok(())
256 }
257}
258
259impl Default for RateLimiter {
260 fn default() -> Self {
261 Self::new()
262 }
263}