1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8use crate::client::ApiEndpoint;
9
10#[must_use = "RateLimitGuard must be held until the API response is received"]
12pub struct RateLimitGuard {
13 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
14 identifier: String,
15 auto_record: bool,
16}
17
18impl RateLimitGuard {
19 fn new(
20 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
21 identifier: String,
22 auto_record: bool,
23 ) -> Self {
24 Self {
25 rate_limiter,
26 identifier,
27 auto_record,
28 }
29 }
30
31 pub async fn mark_response(mut self) {
34 if self.auto_record {
35 let mut buckets = self.rate_limiter.lock().await;
36 let now = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40
41 if let Some(bucket) = buckets.get_mut(&self.identifier) {
42 bucket.last_response_time = Some(now);
43 log::trace!(
44 "Manually recorded response time for {}: {}",
45 self.identifier,
46 now
47 );
48 }
49 }
50
51 self.auto_record = false;
53 }
54}
55
56impl Drop for RateLimitGuard {
57 fn drop(&mut self) {
58 if !self.auto_record {
59 return;
60 }
61
62 let rate_limiter = self.rate_limiter.clone();
63 let identifier = self.identifier.clone();
64
65 tokio::spawn(async move {
67 let mut buckets = rate_limiter.lock().await;
68 let now = SystemTime::now()
69 .duration_since(UNIX_EPOCH)
70 .unwrap_or_default()
71 .as_secs();
72
73 if let Some(bucket) = buckets.get_mut(&identifier) {
74 if bucket.initial_burst_used || bucket.tokens < (bucket.burst as f64 * 0.5) {
76 bucket.last_response_time = Some(now);
77 log::trace!("Auto-recorded response time for {}: {}", identifier, now);
78 } else {
79 log::trace!(
80 "Skipping response time recording for {} (still in initial burst)",
81 identifier
82 );
83 }
84 } else {
85 log::warn!(
86 "Attempted to auto-record response for unknown endpoint: {}",
87 identifier
88 );
89 }
90 });
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct TokenBucketState {
97 pub tokens: f64,
98 pub last_refill: u64, pub last_response_time: Option<u64>, pub rate: f64, pub burst: u32, pub initial_burst_used: bool, }
104
105pub struct RateLimiter {
108 buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
109}
110
111impl RateLimiter {
112 pub fn new() -> Self {
113 Self {
114 buckets: Arc::new(Mutex::new(HashMap::new())),
115 }
116 }
117
118 #[must_use = "The returned guard must be held until the API response is received"]
121 pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
122 self._wait_for_token(identifier, rate, burst).await?;
123
124 Ok(RateLimitGuard::new(
125 self.buckets.clone(),
126 identifier.to_string(),
127 true, ))
129 }
130
131 #[deprecated(since = "0.1.4", note = "Use `wait()` instead.")]
134 pub async fn wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
135 self._wait_for_token(identifier, rate, burst).await
136 }
137
138 async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
139 loop {
140 let mut buckets = self.buckets.lock().await;
141 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
142
143 let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
145 log::debug!("Creating new token bucket for endpoint: {}", identifier);
146 TokenBucketState {
147 tokens: burst as f64, last_refill: now,
149 last_response_time: None, rate: rate,
151 burst: burst,
152 initial_burst_used: false, }
154 });
155
156 if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
158 log::info!(
159 "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
160 identifier,
161 bucket.rate,
162 rate,
163 bucket.burst,
164 burst
165 );
166 bucket.rate = rate;
167 bucket.burst = burst;
168 }
169
170 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
172 let tokens_to_add = time_passed * bucket.rate;
173 let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
174
175 if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
177 bucket.initial_burst_used = false;
178 bucket.last_response_time = None; log::debug!(
180 "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
181 identifier,
182 new_tokens
183 );
184 }
185
186 bucket.tokens = new_tokens;
187 bucket.last_refill = now;
188
189 if bucket.initial_burst_used {
191 if let Some(last_response_time) = bucket.last_response_time {
192 let minimum_interval = 1.0 / bucket.rate;
193 let time_since_response = now.saturating_sub(last_response_time) as f64;
194
195 if time_since_response < minimum_interval {
196 let wait_seconds = minimum_interval - time_since_response;
197 log::debug!(
198 "Enforcing minimum interval for {}: waiting {:.3}s since last response",
199 identifier,
200 wait_seconds
201 );
202
203 drop(buckets);
204 sleep(Duration::from_secs_f64(wait_seconds)).await;
205 continue;
206 }
207 }
208 }
209
210 log::debug!(
211 "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
212 identifier,
213 bucket.tokens,
214 bucket.rate,
215 bucket.burst,
216 bucket.initial_burst_used
217 );
218
219 if bucket.tokens >= 1.0 {
221 bucket.tokens -= 1.0;
222
223 if bucket.tokens <= 1.0 && !bucket.initial_burst_used {
225 bucket.initial_burst_used = true;
226 log::debug!(
227 "Initial burst capacity exhausted for {} (tokens: {:.1})",
228 identifier,
229 bucket.tokens
230 );
231 }
232
233 log::debug!(
234 "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
235 identifier,
236 bucket.tokens,
237 bucket.initial_burst_used
238 );
239
240 return Ok(());
241 }
242
243 let tokens_needed = 1.0 - bucket.tokens;
245 let wait_seconds = tokens_needed / bucket.rate;
246 log::debug!(
247 "Rate limit reached for {}, need to wait {:.2}s for next token",
248 identifier,
249 wait_seconds
250 );
251
252 if !bucket.initial_burst_used {
254 bucket.initial_burst_used = true;
255 log::debug!(
256 "Marking initial burst as used for {} due to token shortage",
257 identifier
258 );
259 }
260
261 let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
262
263 drop(buckets);
264 sleep(wait_duration).await;
265 }
266 }
267
268 pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
270 let mut buckets = self.buckets.lock().await;
271 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
272
273 if let Some(bucket) = buckets.get_mut(identifier) {
274 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
276 let tokens_to_add = time_passed * bucket.rate;
277 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
278 bucket.last_refill = now;
279
280 Ok(bucket.tokens >= 1.0)
281 } else {
282 Ok(true)
284 }
285 }
286
287 pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
290 let mut buckets = self.buckets.lock().await;
291 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
292 let mut status = HashMap::new();
293
294 for (endpoint_key, bucket) in buckets.iter_mut() {
295 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
297 let tokens_to_add = time_passed * bucket.rate;
298 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
299 bucket.last_refill = now;
300
301 status.insert(
302 endpoint_key.clone(),
303 (bucket.tokens, bucket.rate, bucket.burst),
304 );
305 }
306
307 Ok(status)
308 }
309
310 pub async fn reset(&self) {
312 let mut buckets = self.buckets.lock().await;
313 buckets.clear();
314 log::debug!("Rate limiter state reset");
315 }
316
317 pub async fn active_buckets_count(&self) -> usize {
319 let buckets = self.buckets.lock().await;
320 buckets.len()
321 }
322
323 #[deprecated(since = "0.1.4", note = "Use `wait()` instead.")]
326 pub async fn record_response(&self, identifier: &str) -> Result<()> {
327 let mut buckets = self.buckets.lock().await;
328 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
329
330 if let Some(bucket) = buckets.get_mut(identifier) {
331 bucket.last_response_time = Some(now);
332 bucket.initial_burst_used = true; log::trace!("Recorded response time for {}: {}", identifier, now);
334 } else {
335 log::warn!(
336 "Attempted to record response for unknown endpoint: {}",
337 identifier
338 );
339 }
340
341 Ok(())
342 }
343}
344
345impl Default for RateLimiter {
346 fn default() -> Self {
347 Self::new()
348 }
349}