1use anyhow::Result;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5use tokio::sync::Mutex;
6use tokio::time::sleep;
7
8#[must_use = "RateLimitGuard must be held until the API response is received"]
10pub struct RateLimitGuard {
11 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
12 identifier: String,
13 auto_record: bool,
14}
15
16impl RateLimitGuard {
17 fn new(
18 rate_limiter: Arc<Mutex<HashMap<String, TokenBucketState>>>,
19 identifier: String,
20 auto_record: bool,
21 ) -> Self {
22 Self {
23 rate_limiter,
24 identifier,
25 auto_record,
26 }
27 }
28
29 pub async fn mark_response(mut self) {
32 if self.auto_record {
33 let mut buckets = self.rate_limiter.lock().await;
34 let now = SystemTime::now()
35 .duration_since(UNIX_EPOCH)
36 .unwrap_or_default()
37 .as_secs();
38
39 if let Some(bucket) = buckets.get_mut(&self.identifier) {
40 bucket.last_response_time = Some(now);
41 log::trace!(
42 "Manually recorded response time for {}: {}",
43 self.identifier,
44 now
45 );
46 }
47 }
48
49 self.auto_record = false;
51 }
52}
53
54impl Drop for RateLimitGuard {
55 fn drop(&mut self) {
56 if !self.auto_record {
57 return;
58 }
59
60 let rate_limiter = self.rate_limiter.clone();
61 let identifier = self.identifier.clone();
62
63 tokio::spawn(async move {
65 let mut buckets = rate_limiter.lock().await;
66 let now = SystemTime::now()
67 .duration_since(UNIX_EPOCH)
68 .unwrap_or_default()
69 .as_secs();
70
71 if let Some(bucket) = buckets.get_mut(&identifier) {
72 bucket.last_response_time = Some(now);
74 log::trace!("Auto-recorded response time for {}: {}", identifier, now);
75 } else {
76 log::warn!(
77 "Attempted to auto-record response for unknown endpoint: {}",
78 identifier
79 );
80 }
81 });
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct TokenBucketState {
88 pub tokens: f64,
89 pub last_refill: u64, pub last_response_time: Option<u64>, pub rate: f64, pub burst: u32, pub initial_burst_used: bool, }
95
96pub struct RateLimiter {
99 buckets: Arc<Mutex<HashMap<String, TokenBucketState>>>,
100 safety_factor: f64,
103}
104
105impl RateLimiter {
106 pub fn new() -> Self {
107 Self {
108 buckets: Arc::new(Mutex::new(HashMap::new())),
109 safety_factor: 1.10, }
111 }
112
113 pub fn new_with_safety_factor(safety_factor: f64) -> Self {
117 let factor = if safety_factor < 1.0 {
118 log::warn!("Safety factor {} is less than 1.0, using 1.0 instead", safety_factor);
119 1.0
120 } else {
121 safety_factor
122 };
123
124 Self {
125 buckets: Arc::new(Mutex::new(HashMap::new())),
126 safety_factor: factor,
127 }
128 }
129
130 pub fn set_safety_factor(&mut self, safety_factor: f64) {
133 if safety_factor < 1.0 {
134 log::warn!("Safety factor {} is less than 1.0, using 1.0 instead", safety_factor);
135 self.safety_factor = 1.0;
136 } else {
137 self.safety_factor = safety_factor;
138 log::info!("Rate limiter safety factor set to {}", safety_factor);
139 }
140 }
141
142 pub fn get_safety_factor(&self) -> f64 {
144 self.safety_factor
145 }
146
147 #[must_use = "The returned guard must be held until the API response is received"]
150 pub async fn wait(&self, identifier: &str, rate: f64, burst: u32) -> Result<RateLimitGuard> {
151 self._wait_for_token(identifier, rate, burst).await?;
152
153 Ok(RateLimitGuard::new(
154 self.buckets.clone(),
155 identifier.to_string(),
156 true, ))
158 }
159
160 async fn _wait_for_token(&self, identifier: &str, rate: f64, burst: u32) -> Result<()> {
161 loop {
162 let mut buckets = self.buckets.lock().await;
163 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
164
165 let bucket = buckets.entry(identifier.to_string()).or_insert_with(|| {
167 log::debug!("Creating new token bucket for endpoint: {}", identifier);
168 TokenBucketState {
169 tokens: burst as f64,
170 last_refill: now,
171 last_response_time: None,
172 rate: rate,
173 burst: burst,
174 initial_burst_used: false,
175 }
176 });
177
178 if (bucket.rate - rate).abs() > f64::EPSILON || bucket.burst != burst {
180 log::info!(
181 "Updating rate limit for {}: rate {} -> {}, burst {} -> {}",
182 identifier,
183 bucket.rate,
184 rate,
185 bucket.burst,
186 burst
187 );
188 bucket.rate = rate;
189 bucket.burst = burst;
190 }
191
192 let refill_from_time = if bucket.initial_burst_used {
194 bucket.last_response_time.unwrap_or(bucket.last_refill)
196 } else {
197 bucket.last_refill
199 };
200
201 let time_passed = now.saturating_sub(refill_from_time) as f64;
202 let tokens_to_add = time_passed * bucket.rate;
203 let new_tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
204
205 if new_tokens >= (bucket.burst as f64 - 0.5) && bucket.initial_burst_used {
207 bucket.initial_burst_used = false;
208 bucket.last_response_time = None;
209 log::debug!(
210 "Bucket {} refilled to near capacity ({:.1}), resetting burst state",
211 identifier,
212 new_tokens
213 );
214 }
215
216 bucket.tokens = new_tokens;
217 bucket.last_refill = now;
218
219 if bucket.initial_burst_used {
221 if let Some(last_response_time) = bucket.last_response_time {
222 let base_minimum_interval = 1.0 / bucket.rate;
223 let minimum_interval = base_minimum_interval * self.safety_factor;
225 let time_since_response = now.saturating_sub(last_response_time) as f64;
226
227 if time_since_response < minimum_interval {
228 let wait_seconds = minimum_interval - time_since_response;
229 log::debug!(
230 "Enforcing minimum interval for {} (safety factor: {:.2}): waiting {:.3}s since last response",
231 identifier,
232 self.safety_factor,
233 wait_seconds
234 );
235
236 drop(buckets);
237 sleep(Duration::from_secs_f64(wait_seconds)).await;
238 continue;
239 }
240 }
241 }
242
243 log::debug!(
244 "Endpoint {}: {:.1} tokens available, rate: {}/s, burst: {}, initial_burst_used: {}",
245 identifier,
246 bucket.tokens,
247 bucket.rate,
248 bucket.burst,
249 bucket.initial_burst_used
250 );
251
252 if bucket.tokens >= 1.0 {
254 bucket.tokens -= 1.0;
255
256 if bucket.burst == 1 || bucket.tokens <= 0.0 {
258 bucket.initial_burst_used = true;
259 log::debug!(
260 "Initial burst capacity exhausted for {} (tokens: {:.1}, burst: {})",
261 identifier,
262 bucket.tokens,
263 bucket.burst
264 );
265 }
266
267 log::debug!(
268 "Token consumed for {}, {:.1} tokens remaining, initial_burst_used: {}",
269 identifier,
270 bucket.tokens,
271 bucket.initial_burst_used
272 );
273
274 return Ok(());
275 }
276
277 let tokens_needed = 1.0 - bucket.tokens;
279 let base_wait_seconds = tokens_needed / bucket.rate;
280 let wait_seconds = base_wait_seconds * self.safety_factor;
282 log::debug!(
283 "Rate limit reached for {}, need to wait {:.2}s (base: {:.2}s, safety factor: {:.2}) for next token",
284 identifier,
285 wait_seconds,
286 base_wait_seconds,
287 self.safety_factor
288 );
289
290 if !bucket.initial_burst_used {
292 bucket.initial_burst_used = true;
293 log::debug!(
294 "Marking initial burst as used for {} due to token shortage",
295 identifier
296 );
297 }
298
299 let wait_duration = Duration::from_secs_f64(wait_seconds.max(0.1));
300
301 drop(buckets);
302 sleep(wait_duration).await;
303 }
304 }
305
306 pub async fn check_token_availability(&self, identifier: &str) -> Result<bool> {
308 let mut buckets = self.buckets.lock().await;
309 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
310
311 if let Some(bucket) = buckets.get_mut(identifier) {
312 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
314 let tokens_to_add = time_passed * bucket.rate;
315 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
316 bucket.last_refill = now;
317
318 Ok(bucket.tokens >= 1.0)
319 } else {
320 Ok(true)
322 }
323 }
324
325 pub async fn get_token_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
328 let mut buckets = self.buckets.lock().await;
329 let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
330 let mut status = HashMap::new();
331
332 for (endpoint_key, bucket) in buckets.iter_mut() {
333 let time_passed = now.saturating_sub(bucket.last_refill) as f64;
335 let tokens_to_add = time_passed * bucket.rate;
336 bucket.tokens = (bucket.tokens + tokens_to_add).min(bucket.burst as f64);
337 bucket.last_refill = now;
338
339 status.insert(
340 endpoint_key.clone(),
341 (bucket.tokens, bucket.rate, bucket.burst),
342 );
343 }
344
345 Ok(status)
346 }
347
348 pub async fn reset(&self) {
350 let mut buckets = self.buckets.lock().await;
351 buckets.clear();
352 log::debug!("Rate limiter state reset");
353 }
354
355 pub async fn active_buckets_count(&self) -> usize {
357 let buckets = self.buckets.lock().await;
358 buckets.len()
359 }
360}
361
362impl Default for RateLimiter {
363 fn default() -> Self {
364 Self::new()
365 }
366}