heliosdb_proxy/rate_limit/
token_bucket.rs1use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9use parking_lot::Mutex;
10
11#[derive(Debug)]
16pub struct TokenBucket {
17 capacity: u32,
19
20 tokens: AtomicU64,
22
23 refill_rate: f64,
25
26 last_refill: AtomicU64,
28
29 epoch: Instant,
31
32 refill_lock: Mutex<()>,
34}
35
36impl TokenBucket {
37 pub fn new(capacity: u32, refill_rate: f64) -> Self {
43 let epoch = Instant::now();
44 Self {
45 capacity,
46 tokens: AtomicU64::new((capacity as u64) * 1000), refill_rate,
48 last_refill: AtomicU64::new(0),
49 epoch,
50 refill_lock: Mutex::new(()),
51 }
52 }
53
54 pub fn from_qps(qps: u32, burst: u32) -> Self {
56 Self::new(burst, qps as f64)
57 }
58
59 pub fn try_acquire(&self, tokens: u32) -> Result<(), TokenBucketExceeded> {
63 self.refill();
64
65 let tokens_needed = (tokens as u64) * 1000;
66 let mut current = self.tokens.load(Ordering::Acquire);
67
68 loop {
69 if current >= tokens_needed {
70 match self.tokens.compare_exchange_weak(
71 current,
72 current - tokens_needed,
73 Ordering::Release,
74 Ordering::Relaxed,
75 ) {
76 Ok(_) => return Ok(()),
77 Err(updated) => current = updated,
78 }
79 } else {
80 return Err(TokenBucketExceeded {
81 retry_after: self.time_until_available(tokens),
82 current_tokens: (current / 1000) as u32,
83 requested_tokens: tokens,
84 });
85 }
86 }
87 }
88
89 pub fn acquire_blocking(&self, tokens: u32, timeout: Duration) -> Result<(), TokenBucketExceeded> {
91 let deadline = Instant::now() + timeout;
92
93 loop {
94 match self.try_acquire(tokens) {
95 Ok(()) => return Ok(()),
96 Err(exceeded) => {
97 let now = Instant::now();
98 if now >= deadline {
99 return Err(exceeded);
100 }
101
102 let wait = exceeded.retry_after.min(deadline - now);
103 std::thread::sleep(wait);
104 }
105 }
106 }
107 }
108
109 pub fn return_tokens(&self, tokens: u32) {
111 let tokens_to_add = (tokens as u64) * 1000;
112 let max = (self.capacity as u64) * 1000;
113
114 let mut current = self.tokens.load(Ordering::Acquire);
115 loop {
116 let new_value = (current + tokens_to_add).min(max);
117 match self.tokens.compare_exchange_weak(
118 current,
119 new_value,
120 Ordering::Release,
121 Ordering::Relaxed,
122 ) {
123 Ok(_) => break,
124 Err(updated) => current = updated,
125 }
126 }
127 }
128
129 fn refill(&self) {
131 let _lock = self.refill_lock.lock();
132
133 let now_nanos = self.epoch.elapsed().as_nanos() as u64;
134 let last = self.last_refill.load(Ordering::Acquire);
135
136 if now_nanos <= last {
137 return;
138 }
139
140 let elapsed_secs = (now_nanos - last) as f64 / 1_000_000_000.0;
141 let new_tokens = (elapsed_secs * self.refill_rate * 1000.0) as u64;
142
143 if new_tokens > 0 {
144 let current = self.tokens.load(Ordering::Acquire);
145 let max = (self.capacity as u64) * 1000;
146 let updated = (current + new_tokens).min(max);
147
148 self.tokens.store(updated, Ordering::Release);
149 self.last_refill.store(now_nanos, Ordering::Release);
150 }
151 }
152
153 fn time_until_available(&self, tokens: u32) -> Duration {
155 let current = self.tokens.load(Ordering::Relaxed) / 1000;
156 let needed = (tokens as u64).saturating_sub(current);
157
158 if needed == 0 {
159 Duration::ZERO
160 } else {
161 Duration::from_secs_f64(needed as f64 / self.refill_rate)
162 }
163 }
164
165 pub fn current_tokens(&self) -> u32 {
167 self.refill();
168 (self.tokens.load(Ordering::Relaxed) / 1000) as u32
169 }
170
171 pub fn capacity(&self) -> u32 {
173 self.capacity
174 }
175
176 pub fn refill_rate(&self) -> f64 {
178 self.refill_rate
179 }
180
181 pub fn is_empty(&self) -> bool {
183 self.current_tokens() == 0
184 }
185
186 pub fn is_full(&self) -> bool {
188 self.current_tokens() >= self.capacity
189 }
190
191 pub fn fill_ratio(&self) -> f64 {
193 self.current_tokens() as f64 / self.capacity as f64
194 }
195
196 pub fn reset(&self) {
198 self.tokens.store((self.capacity as u64) * 1000, Ordering::Release);
199 self.last_refill.store(self.epoch.elapsed().as_nanos() as u64, Ordering::Release);
200 }
201
202 pub fn set_capacity(&mut self, capacity: u32) {
204 self.capacity = capacity;
205 let current = self.tokens.load(Ordering::Acquire);
207 let max = (capacity as u64) * 1000;
208 if current > max {
209 self.tokens.store(max, Ordering::Release);
210 }
211 }
212
213 pub fn set_refill_rate(&mut self, rate: f64) {
215 self.refill_rate = rate;
216 }
217}
218
219impl Clone for TokenBucket {
220 fn clone(&self) -> Self {
221 Self {
222 capacity: self.capacity,
223 tokens: AtomicU64::new(self.tokens.load(Ordering::Relaxed)),
224 refill_rate: self.refill_rate,
225 last_refill: AtomicU64::new(self.last_refill.load(Ordering::Relaxed)),
226 epoch: self.epoch,
227 refill_lock: Mutex::new(()),
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct TokenBucketExceeded {
235 pub retry_after: Duration,
237
238 pub current_tokens: u32,
240
241 pub requested_tokens: u32,
243}
244
245impl std::fmt::Display for TokenBucketExceeded {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 write!(
248 f,
249 "Token bucket exceeded: {} available, {} requested, retry after {}ms",
250 self.current_tokens,
251 self.requested_tokens,
252 self.retry_after.as_millis()
253 )
254 }
255}
256
257impl std::error::Error for TokenBucketExceeded {}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_bucket_creation() {
265 let bucket = TokenBucket::new(100, 10.0);
266 assert_eq!(bucket.capacity(), 100);
267 assert_eq!(bucket.current_tokens(), 100);
268 assert!(bucket.is_full());
269 }
270
271 #[test]
272 fn test_from_qps() {
273 let bucket = TokenBucket::from_qps(100, 200);
274 assert_eq!(bucket.capacity(), 200);
275 assert_eq!(bucket.refill_rate(), 100.0);
276 }
277
278 #[test]
279 fn test_acquire_success() {
280 let bucket = TokenBucket::new(100, 10.0);
281
282 assert!(bucket.try_acquire(50).is_ok());
283 assert_eq!(bucket.current_tokens(), 50);
284
285 assert!(bucket.try_acquire(50).is_ok());
286 assert_eq!(bucket.current_tokens(), 0);
287 }
288
289 #[test]
290 fn test_acquire_failure() {
291 let bucket = TokenBucket::new(10, 1.0);
292
293 assert!(bucket.try_acquire(10).is_ok());
295
296 let result = bucket.try_acquire(1);
298 assert!(result.is_err());
299
300 let err = result.unwrap_err();
301 assert_eq!(err.current_tokens, 0);
302 assert_eq!(err.requested_tokens, 1);
303 }
304
305 #[test]
306 fn test_refill() {
307 let bucket = TokenBucket::new(100, 100.0); assert!(bucket.try_acquire(100).is_ok());
311 assert_eq!(bucket.current_tokens(), 0);
312
313 std::thread::sleep(Duration::from_millis(50));
315
316 let tokens = bucket.current_tokens();
318 assert!(tokens > 0);
319 assert!(tokens <= 10); }
321
322 #[test]
323 fn test_return_tokens() {
324 let bucket = TokenBucket::new(100, 10.0);
325
326 assert!(bucket.try_acquire(50).is_ok());
327 assert_eq!(bucket.current_tokens(), 50);
328
329 bucket.return_tokens(30);
330 assert_eq!(bucket.current_tokens(), 80);
331
332 bucket.return_tokens(50);
334 assert_eq!(bucket.current_tokens(), 100);
335 }
336
337 #[test]
338 fn test_reset() {
339 let bucket = TokenBucket::new(100, 10.0);
340
341 assert!(bucket.try_acquire(100).is_ok());
342 assert!(bucket.is_empty());
343
344 bucket.reset();
345 assert!(bucket.is_full());
346 }
347
348 #[test]
349 fn test_fill_ratio() {
350 let bucket = TokenBucket::new(100, 10.0);
351
352 assert!((bucket.fill_ratio() - 1.0).abs() < 0.01);
353
354 assert!(bucket.try_acquire(50).is_ok());
355 assert!((bucket.fill_ratio() - 0.5).abs() < 0.01);
356
357 assert!(bucket.try_acquire(50).is_ok());
358 assert!((bucket.fill_ratio() - 0.0).abs() < 0.01);
359 }
360
361 #[test]
362 fn test_time_until_available() {
363 let bucket = TokenBucket::new(100, 10.0); assert!(bucket.try_acquire(100).is_ok());
367
368 let result = bucket.try_acquire(10);
370 assert!(result.is_err());
371
372 let err = result.unwrap_err();
373 assert!(err.retry_after.as_millis() >= 900);
375 assert!(err.retry_after.as_millis() <= 1100);
376 }
377
378 #[test]
379 fn test_acquire_blocking() {
380 let bucket = TokenBucket::new(10, 100.0); assert!(bucket.try_acquire(10).is_ok());
384
385 let result = bucket.acquire_blocking(5, Duration::from_millis(100));
387 assert!(result.is_ok());
388 }
389
390 #[test]
391 fn test_acquire_blocking_timeout() {
392 let bucket = TokenBucket::new(10, 1.0); assert!(bucket.try_acquire(10).is_ok());
396
397 let result = bucket.acquire_blocking(10, Duration::from_millis(10));
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_concurrent_access() {
404 use std::sync::Arc;
405 use std::thread;
406
407 let bucket = Arc::new(TokenBucket::new(1000, 1000.0));
408 let mut handles = vec![];
409
410 for _ in 0..10 {
412 let bucket = Arc::clone(&bucket);
413 handles.push(thread::spawn(move || {
414 for _ in 0..10 {
415 let _ = bucket.try_acquire(5);
416 }
417 }));
418 }
419
420 for handle in handles {
421 handle.join().unwrap();
422 }
423
424 assert!(bucket.current_tokens() < 1000);
426 }
427
428 #[test]
429 fn test_clone() {
430 let bucket1 = TokenBucket::new(100, 10.0);
431 assert!(bucket1.try_acquire(50).is_ok());
432
433 let bucket2 = bucket1.clone();
434 assert_eq!(bucket2.capacity(), 100);
435 assert_eq!(bucket2.current_tokens(), 50);
436 }
437}