1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10#[must_use = "RateLimitGuard must be held until the API response is received"]
12pub struct RateLimitGuard {
13 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
14 identifier: String,
15 auto_record: bool,
16}
17
18impl RateLimitGuard {
19 fn new(
20 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
21 identifier: String,
22 auto_record: bool,
23 ) -> Self {
24 Self {
25 rate_limiter,
26 identifier,
27 auto_record,
28 }
29 }
30
31 pub async fn mark_response(mut self) {
34 if self.auto_record {
35 let mut buckets = self.rate_limiter.lock().await;
36 let now = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40
41 if let Some(bucket) = buckets.get_mut(&self.identifier) {
42 bucket.last_response_time = Some(now);
43 log::trace!(
44 "Manually recorded response time for {}: {}",
45 self.identifier,
46 now
47 );
48 }
49 }
50
51 self.auto_record = false;
53 }
54}
55
56impl Drop for RateLimitGuard {
57 fn drop(&mut self) {
58 if !self.auto_record {
59 return;
60 }
61
62 let rate_limiter = self.rate_limiter.clone();
63 let identifier = self.identifier.clone();
64
65 tokio::spawn(async move {
67 let mut buckets = rate_limiter.lock().await;
68 let now = SystemTime::now()
69 .duration_since(UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_secs();
72
73 if let Some(bucket) = buckets.get_mut(&identifier) {
74 bucket.last_response_time = Some(now);
76 log::trace!("Auto-recorded response time for {}: {}", identifier, now);
77 } else {
78 log::warn!(
79 "Attempted to auto-record response for unknown endpoint: {}",
80 identifier
81 );
82 }
83 });
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct TokenBucketState {
90 pub tokens: f64,
91 pub last_refill: u64, pub last_response_time: Option<u64>, pub rate: f64, pub burst: u32, pub initial_burst_used: bool, }
97
98pub struct RateLimiter {
101 buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
102}
103
104impl RateLimiter {
105 pub fn new() -> Self {
106 Self {
107 buckets: Arc::new(Mutex::new(HashMap::new())),
108 }
109 }
110
111 #[must_use = "The returned guard must be held until the API response is received"]
114 pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
115 self._wait_for_token(identifier, rate, burst).await?;
116
117 Ok(RateLimitGuard::new(
118 self.buckets.clone(),
119 identifier.to_string(),
120 true, ))
122 }
123
124 async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
125 loop {
126 let mut buckets = self.buckets.lock().await;
127 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
128
129 let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
131 log::debug!("Creating new token bucket for endpoint: {}", identifier);
132 TokenBucketState {
133 tokens: burst as f64,
134 last_refill: now,
135 last_response_time: None,
136 rate: rate,
137 burst: burst,
138 initial_burst_used: false,
139 }
140 });
141
142 if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
144 log::info!(
145 "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
146 identifier,
147 bucket.rate,
148 rate,
149 bucket.burst,
150 burst
151 );
152 bucket.rate = rate;
153 bucket.burst = burst;
154 }
155
156 let refill_from_time = if bucket.initial_burst_used {
158 bucket.last_response_time.unwrap_or(bucket.last_refill)
160 } else {
161 bucket.last_refill
163 };
164
165 let time_passed = now.saturating_sub(refill_from_time) as f64;
166 let tokens_to_add = time_passed * bucket.rate;
167 let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
168
169 if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
171 bucket.initial_burst_used = false;
172 bucket.last_response_time = None;
173 log::debug!(
174 "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
175 identifier,
176 new_tokens
177 );
178 }
179
180 bucket.tokens = new_tokens;
181 bucket.last_refill = now;
182
183 if bucket.initial_burst_used {
185 if let Some(last_response_time) = bucket.last_response_time {
186 let minimum_interval = 1.0 / bucket.rate;
187 let time_since_response = now.saturating_sub(last_response_time) as f64;
188
189 if time_since_response < minimum_interval {
190 let wait_seconds = minimum_interval - time_since_response;
191 log::debug!(
192 "Enforcing minimum interval for {}: waiting {:.3}s since last response",
193 identifier,
194 wait_seconds
195 );
196
197 drop(buckets);
198 sleep(Duration::from_secs_f64(wait_seconds)).await;
199 continue;
200 }
201 }
202 }
203
204 log::debug!(
205 "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
206 identifier,
207 bucket.tokens,
208 bucket.rate,
209 bucket.burst,
210 bucket.initial_burst_used
211 );
212
213 if bucket.tokens >= 1.0 {
215 bucket.tokens -= 1.0;
216
217 if bucket.burst == 1 || bucket.tokens <= 0.0 {
219 bucket.initial_burst_used = true;
220 log::debug!(
221 "Initial burst capacity exhausted for {} (tokens: {:.1}, burst: {})",
222 identifier,
223 bucket.tokens,
224 bucket.burst
225 );
226 }
227
228 log::debug!(
229 "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
230 identifier,
231 bucket.tokens,
232 bucket.initial_burst_used
233 );
234
235 return Ok(());
236 }
237
238 let tokens_needed = 1.0 - bucket.tokens;
240 let wait_seconds = tokens_needed / bucket.rate;
241 log::debug!(
242 "Rate limit reached for {}, need to wait {:.2}s for next token",
243 identifier,
244 wait_seconds
245 );
246
247 if !bucket.initial_burst_used {
249 bucket.initial_burst_used = true;
250 log::debug!(
251 "Marking initial burst as used for {} due to token shortage",
252 identifier
253 );
254 }
255
256 let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
257
258 drop(buckets);
259 sleep(wait_duration).await;
260 }
261 }
262
263 pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
265 let mut buckets = self.buckets.lock().await;
266 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
267
268 if let Some(bucket) = buckets.get_mut(identifier) {
269 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
271 let tokens_to_add = time_passed * bucket.rate;
272 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
273 bucket.last_refill = now;
274
275 Ok(bucket.tokens >= 1.0)
276 } else {
277 Ok(true)
279 }
280 }
281
282 pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
285 let mut buckets = self.buckets.lock().await;
286 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
287 let mut status = HashMap::new();
288
289 for (endpoint_key, bucket) in buckets.iter_mut() {
290 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
292 let tokens_to_add = time_passed * bucket.rate;
293 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
294 bucket.last_refill = now;
295
296 status.insert(
297 endpoint_key.clone(),
298 (bucket.tokens, bucket.rate, bucket.burst),
299 );
300 }
301
302 Ok(status)
303 }
304
305 pub async fn reset(&self) {
307 let mut buckets = self.buckets.lock().await;
308 buckets.clear();
309 log::debug!("Rate limiter state reset");
310 }
311
312 pub async fn active_buckets_count(&self) -> usize {
314 let buckets = self.buckets.lock().await;
315 buckets.len()
316 }
317}
318
319impl Default for RateLimiter {
320 fn default() -> Self {
321 Self::new()
322 }
323}