amazon_spapi/client/
rate_limiter.rs1use anyhow::Result;
2use std::collections::HashMap;
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4use tokio::time::sleep;
5use std::sync::Arc;
6use tokio::sync::Mutex;
7
8use crate::client::ApiEndpoint;
9
10#[derive(Debug, Clone)]
12pub struct TokenBucketState {
13 pub tokens: f64,
14 pub last_refill: u64, pub last_response_time: Option<u64>, pub rate: f64, pub burst: u32, }
19
20pub struct RateLimiter {
23 buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
24}
25
26impl RateLimiter {
27 pub fn new() -> Self {
28 Self {
29 buckets: Arc::new(Mutex::new(HashMap::new())),
30 }
31 }
32
33 pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
36 loop {
37 {
38 let mut buckets = self.buckets.lock().await;
39 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
40
41 let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
43 log::debug!("Creating new token bucket for endpoint: {}", identifier);
44 TokenBucketState {
45 tokens: burst as f64, last_refill: now,
47 last_response_time: None, rate: rate,
49 burst: burst,
50 }
51 });
52
53 if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
55 log::info!("Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
56 identifier, bucket.rate, rate, bucket.burst, burst);
57 bucket.rate = rate;
58 bucket.burst = burst;
59 }
60
61 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
63 let tokens_to_add = time_passed * bucket.rate; bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
65 bucket.last_refill = now;
66
67 if let Some(last_response_time) = bucket.last_response_time {
69 let minimum_interval = 1.0 / bucket.rate; let time_since_response = now.saturating_sub(last_response_time) as f64;
71
72 if time_since_response < minimum_interval {
73 let wait_seconds = minimum_interval - time_since_response;
74 log::debug!("Enforcing minimum interval for {}: waiting {:.3}s since last response",
75 identifier, wait_seconds);
76
77 drop(buckets);
79 sleep(Duration::from_secs_f64(wait_seconds)).await;
80 continue; }
82 }
83
84 log::trace!("Endpoint {}: {:.2} tokens available, rate: {}/s, burst: {}",
85 identifier, bucket.tokens, bucket.rate, bucket.burst);
86
87 if bucket.tokens >= 1.0 {
89 bucket.tokens -= 1.0;
90
91 log::debug!("Token consumed for {}, {:.2} tokens remaining",
92 identifier, bucket.tokens);
93
94 return Ok(());
95 }
96
97 let wait_time = Duration::from_secs_f64(1.0 / bucket.rate);
99 log::debug!("Rate limit reached for {}, waiting {:?}",
100 identifier, wait_time);
101 } sleep(Duration::from_millis(100)).await; }
106 }
107
108 pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
110 let mut buckets = self.buckets.lock().await;
111 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
112
113 if let Some(bucket) = buckets.get_mut(identifier) {
114 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
116 let tokens_to_add = time_passed * bucket.rate;
117 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
118 bucket.last_refill = now;
119
120 Ok(bucket.tokens >= 1.0)
121 } else {
122 Ok(true)
124 }
125 }
126
127 pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
130 let mut buckets = self.buckets.lock().await;
131 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
132 let mut status = HashMap::new();
133
134 for (endpoint_key, bucket) in buckets.iter_mut() {
135 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
137 let tokens_to_add = time_passed * bucket.rate;
138 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
139 bucket.last_refill = now;
140
141 status.insert(endpoint_key.clone(), (bucket.tokens, bucket.rate, bucket.burst));
142 }
143
144 Ok(status)
145 }
146
147 pub async fn reset(&self) {
149 let mut buckets = self.buckets.lock().await;
150 buckets.clear();
151 log::debug!("Rate limiter state reset");
152 }
153
154 pub async fn active_buckets_count(&self) -> usize {
156 let buckets = self.buckets.lock().await;
157 buckets.len()
158 }
159
160 pub async fn record_response(&self, identifier: &str) -> Result<()> {
163 let mut buckets = self.buckets.lock().await;
164 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
165
166 if let Some(bucket) = buckets.get_mut(identifier) {
167 bucket.last_response_time = Some(now);
168 log::trace!("Recorded response time for {}: {}", identifier, now);
169 } else {
170 log::warn!("Attempted to record response for unknown endpoint: {}", identifier);
171 }
172
173 Ok(())
174 }
175}
176
177impl Default for RateLimiter {
178 fn default() -> Self {
179 Self::new()
180 }
181}