1use dashmap::DashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10use tracing::{debug, trace};
11
12use grapsus_config::TokenRateLimit;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TokenRateLimitResult {
17 Allowed,
19 TokensExceeded {
21 retry_after_ms: u64,
23 },
24 RequestsExceeded {
26 retry_after_ms: u64,
28 },
29}
30
31impl TokenRateLimitResult {
32 pub fn is_allowed(&self) -> bool {
34 matches!(self, Self::Allowed)
35 }
36
37 pub fn retry_after_ms(&self) -> u64 {
39 match self {
40 Self::Allowed => 0,
41 Self::TokensExceeded { retry_after_ms } => *retry_after_ms,
42 Self::RequestsExceeded { retry_after_ms } => *retry_after_ms,
43 }
44 }
45}
46
47struct TokenBucket {
49 tokens: AtomicU64,
51 max_tokens: u64,
53 refill_rate: f64,
55 last_refill: std::sync::Mutex<Instant>,
57}
58
59impl TokenBucket {
60 fn new(tokens_per_minute: u64, burst_tokens: u64) -> Self {
61 let refill_rate = tokens_per_minute as f64 / 60_000.0;
63
64 Self {
65 tokens: AtomicU64::new(burst_tokens),
66 max_tokens: burst_tokens,
67 refill_rate,
68 last_refill: std::sync::Mutex::new(Instant::now()),
69 }
70 }
71
72 fn try_consume(&self, amount: u64) -> Result<(), u64> {
74 self.refill();
76
77 loop {
79 let current = self.tokens.load(Ordering::Acquire);
80 if current < amount {
81 let needed = amount - current;
83 let wait_ms = (needed as f64 / self.refill_rate).ceil() as u64;
84 return Err(wait_ms);
85 }
86
87 if self
89 .tokens
90 .compare_exchange(
91 current,
92 current - amount,
93 Ordering::AcqRel,
94 Ordering::Relaxed,
95 )
96 .is_ok()
97 {
98 return Ok(());
99 }
100 }
102 }
103
104 fn refill(&self) {
106 let mut last = self.last_refill.lock().unwrap();
107 let now = Instant::now();
108 let elapsed = now.duration_since(*last);
109
110 if elapsed.as_millis() > 0 {
111 let refill_amount = (elapsed.as_millis() as f64 * self.refill_rate) as u64;
112 if refill_amount > 0 {
113 let current = self.tokens.load(Ordering::Acquire);
114 let new_tokens = (current + refill_amount).min(self.max_tokens);
115 self.tokens.store(new_tokens, Ordering::Release);
116 *last = now;
117 }
118 }
119 }
120
121 fn current_tokens(&self) -> u64 {
123 self.refill();
124 self.tokens.load(Ordering::Acquire)
125 }
126}
127
128pub struct TokenRateLimiter {
132 token_buckets: DashMap<String, TokenBucket>,
134 request_buckets: Option<DashMap<String, TokenBucket>>,
136 config: TokenRateLimit,
138}
139
140impl TokenRateLimiter {
141 pub fn new(config: TokenRateLimit) -> Self {
143 let request_buckets = config.requests_per_minute.map(|rpm| DashMap::new());
144
145 Self {
146 token_buckets: DashMap::new(),
147 request_buckets,
148 config,
149 }
150 }
151
152 pub fn check(&self, key: &str, estimated_tokens: u64) -> TokenRateLimitResult {
156 let token_bucket = self
158 .token_buckets
159 .entry(key.to_string())
160 .or_insert_with(|| {
161 TokenBucket::new(self.config.tokens_per_minute, self.config.burst_tokens)
162 });
163
164 if let Err(retry_ms) = token_bucket.try_consume(estimated_tokens) {
165 trace!(
166 key = key,
167 estimated_tokens = estimated_tokens,
168 retry_after_ms = retry_ms,
169 "Token rate limit exceeded"
170 );
171 return TokenRateLimitResult::TokensExceeded {
172 retry_after_ms: retry_ms,
173 };
174 }
175
176 if let (Some(rpm), Some(ref request_buckets)) =
178 (self.config.requests_per_minute, &self.request_buckets)
179 {
180 let request_bucket = request_buckets.entry(key.to_string()).or_insert_with(|| {
181 let burst = rpm.max(1) / 6;
183 TokenBucket::new(rpm, burst.max(1))
184 });
185
186 if let Err(retry_ms) = request_bucket.try_consume(1) {
187 trace!(
188 key = key,
189 retry_after_ms = retry_ms,
190 "Request rate limit exceeded"
191 );
192 return TokenRateLimitResult::RequestsExceeded {
193 retry_after_ms: retry_ms,
194 };
195 }
196 }
197
198 trace!(
199 key = key,
200 estimated_tokens = estimated_tokens,
201 "Rate limit check passed"
202 );
203 TokenRateLimitResult::Allowed
204 }
205
206 pub fn record_actual(&self, key: &str, actual_tokens: u64, estimated_tokens: u64) {
212 if let Some(bucket) = self.token_buckets.get(key) {
213 if actual_tokens < estimated_tokens {
214 let refund = estimated_tokens - actual_tokens;
216 let current = bucket.tokens.load(Ordering::Acquire);
217 let new_tokens = (current + refund).min(bucket.max_tokens);
218 bucket.tokens.store(new_tokens, Ordering::Release);
219
220 debug!(
221 key = key,
222 actual = actual_tokens,
223 estimated = estimated_tokens,
224 refund = refund,
225 "Refunded over-estimated tokens"
226 );
227 } else if actual_tokens > estimated_tokens {
228 let extra = actual_tokens - estimated_tokens;
230 let current = bucket.tokens.load(Ordering::Acquire);
231 let to_consume = extra.min(current);
232 if to_consume > 0 {
233 bucket.tokens.fetch_sub(to_consume, Ordering::AcqRel);
234 }
235
236 debug!(
237 key = key,
238 actual = actual_tokens,
239 estimated = estimated_tokens,
240 consumed_extra = to_consume,
241 "Consumed under-estimated tokens"
242 );
243 }
244 }
245 }
246
247 pub fn current_tokens(&self, key: &str) -> Option<u64> {
249 self.token_buckets.get(key).map(|b| b.current_tokens())
250 }
251
252 pub fn stats(&self) -> TokenRateLimiterStats {
254 TokenRateLimiterStats {
255 active_keys: self.token_buckets.len(),
256 tokens_per_minute: self.config.tokens_per_minute,
257 requests_per_minute: self.config.requests_per_minute,
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct TokenRateLimiterStats {
265 pub active_keys: usize,
267 pub tokens_per_minute: u64,
269 pub requests_per_minute: Option<u64>,
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use grapsus_config::TokenEstimation;
277
278 fn test_config() -> TokenRateLimit {
279 TokenRateLimit {
280 tokens_per_minute: 1000,
281 requests_per_minute: Some(10),
282 burst_tokens: 200,
283 estimation_method: TokenEstimation::Chars,
284 }
285 }
286
287 #[test]
288 fn test_basic_rate_limiting() {
289 let limiter = TokenRateLimiter::new(test_config());
290
291 let result = limiter.check("test-key", 50);
293 assert!(result.is_allowed());
294
295 let current = limiter.current_tokens("test-key").unwrap();
297 assert!(current > 0);
298 }
299
300 #[test]
301 fn test_token_exhaustion() {
302 let limiter = TokenRateLimiter::new(test_config());
303
304 for _ in 0..4 {
306 let _ = limiter.check("test-key", 50);
307 }
308
309 let result = limiter.check("test-key", 50);
311 assert!(!result.is_allowed());
312 assert!(matches!(
313 result,
314 TokenRateLimitResult::TokensExceeded { .. }
315 ));
316 }
317
318 #[test]
319 fn test_actual_token_refund() {
320 let limiter = TokenRateLimiter::new(test_config());
321
322 let _ = limiter.check("test-key", 100);
324 let before = limiter.current_tokens("test-key").unwrap();
325
326 limiter.record_actual("test-key", 50, 100);
328 let after = limiter.current_tokens("test-key").unwrap();
329
330 assert!(after > before);
332 }
333}