Skip to main content

heliosdb_proxy/rate_limit/
token_bucket.rs

1//! Token Bucket Rate Limiter
2//!
3//! Implements the token bucket algorithm for rate limiting.
4//! Allows burst traffic while enforcing sustained rate limits.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9use parking_lot::Mutex;
10
11/// Token bucket rate limiter
12///
13/// The token bucket allows for burst traffic up to the bucket capacity,
14/// while enforcing a sustained rate over time.
15#[derive(Debug)]
16pub struct TokenBucket {
17    /// Maximum tokens (burst capacity)
18    capacity: u32,
19
20    /// Current token count (stored as fixed-point: tokens * 1000)
21    tokens: AtomicU64,
22
23    /// Refill rate (tokens per second)
24    refill_rate: f64,
25
26    /// Last refill timestamp (nanoseconds since epoch)
27    last_refill: AtomicU64,
28
29    /// Epoch for time calculations
30    epoch: Instant,
31
32    /// Lock for atomic operations across multiple fields
33    refill_lock: Mutex<()>,
34}
35
36impl TokenBucket {
37    /// Create a new token bucket
38    ///
39    /// # Arguments
40    /// * `capacity` - Maximum tokens (burst capacity)
41    /// * `refill_rate` - Tokens added per second
42    pub fn new(capacity: u32, refill_rate: f64) -> Self {
43        let epoch = Instant::now();
44        Self {
45            capacity,
46            tokens: AtomicU64::new((capacity as u64) * 1000), // Start full
47            refill_rate,
48            last_refill: AtomicU64::new(0),
49            epoch,
50            refill_lock: Mutex::new(()),
51        }
52    }
53
54    /// Create a token bucket from QPS configuration
55    pub fn from_qps(qps: u32, burst: u32) -> Self {
56        Self::new(burst, qps as f64)
57    }
58
59    /// Try to acquire tokens
60    ///
61    /// Returns Ok(()) if tokens were acquired, Err with retry info if not.
62    pub fn try_acquire(&self, tokens: u32) -> Result<(), TokenBucketExceeded> {
63        self.refill();
64
65        let tokens_needed = (tokens as u64) * 1000;
66        let mut current = self.tokens.load(Ordering::Acquire);
67
68        loop {
69            if current >= tokens_needed {
70                match self.tokens.compare_exchange_weak(
71                    current,
72                    current - tokens_needed,
73                    Ordering::Release,
74                    Ordering::Relaxed,
75                ) {
76                    Ok(_) => return Ok(()),
77                    Err(updated) => current = updated,
78                }
79            } else {
80                return Err(TokenBucketExceeded {
81                    retry_after: self.time_until_available(tokens),
82                    current_tokens: (current / 1000) as u32,
83                    requested_tokens: tokens,
84                });
85            }
86        }
87    }
88
89    /// Acquire tokens, blocking until available (with timeout)
90    pub fn acquire_blocking(&self, tokens: u32, timeout: Duration) -> Result<(), TokenBucketExceeded> {
91        let deadline = Instant::now() + timeout;
92
93        loop {
94            match self.try_acquire(tokens) {
95                Ok(()) => return Ok(()),
96                Err(exceeded) => {
97                    let now = Instant::now();
98                    if now >= deadline {
99                        return Err(exceeded);
100                    }
101
102                    let wait = exceeded.retry_after.min(deadline - now);
103                    std::thread::sleep(wait);
104                }
105            }
106        }
107    }
108
109    /// Return tokens to the bucket (e.g., if operation was cancelled)
110    pub fn return_tokens(&self, tokens: u32) {
111        let tokens_to_add = (tokens as u64) * 1000;
112        let max = (self.capacity as u64) * 1000;
113
114        let mut current = self.tokens.load(Ordering::Acquire);
115        loop {
116            let new_value = (current + tokens_to_add).min(max);
117            match self.tokens.compare_exchange_weak(
118                current,
119                new_value,
120                Ordering::Release,
121                Ordering::Relaxed,
122            ) {
123                Ok(_) => break,
124                Err(updated) => current = updated,
125            }
126        }
127    }
128
129    /// Refill tokens based on elapsed time
130    fn refill(&self) {
131        let _lock = self.refill_lock.lock();
132
133        let now_nanos = self.epoch.elapsed().as_nanos() as u64;
134        let last = self.last_refill.load(Ordering::Acquire);
135
136        if now_nanos <= last {
137            return;
138        }
139
140        let elapsed_secs = (now_nanos - last) as f64 / 1_000_000_000.0;
141        let new_tokens = (elapsed_secs * self.refill_rate * 1000.0) as u64;
142
143        if new_tokens > 0 {
144            let current = self.tokens.load(Ordering::Acquire);
145            let max = (self.capacity as u64) * 1000;
146            let updated = (current + new_tokens).min(max);
147
148            self.tokens.store(updated, Ordering::Release);
149            self.last_refill.store(now_nanos, Ordering::Release);
150        }
151    }
152
153    /// Calculate time until requested tokens are available
154    fn time_until_available(&self, tokens: u32) -> Duration {
155        let current = self.tokens.load(Ordering::Relaxed) / 1000;
156        let needed = (tokens as u64).saturating_sub(current);
157
158        if needed == 0 {
159            Duration::ZERO
160        } else {
161            Duration::from_secs_f64(needed as f64 / self.refill_rate)
162        }
163    }
164
165    /// Get current token count
166    pub fn current_tokens(&self) -> u32 {
167        self.refill();
168        (self.tokens.load(Ordering::Relaxed) / 1000) as u32
169    }
170
171    /// Get capacity
172    pub fn capacity(&self) -> u32 {
173        self.capacity
174    }
175
176    /// Get refill rate (tokens per second)
177    pub fn refill_rate(&self) -> f64 {
178        self.refill_rate
179    }
180
181    /// Check if bucket is empty
182    pub fn is_empty(&self) -> bool {
183        self.current_tokens() == 0
184    }
185
186    /// Check if bucket is full
187    pub fn is_full(&self) -> bool {
188        self.current_tokens() >= self.capacity
189    }
190
191    /// Get fill percentage (0.0 - 1.0)
192    pub fn fill_ratio(&self) -> f64 {
193        self.current_tokens() as f64 / self.capacity as f64
194    }
195
196    /// Reset bucket to full capacity
197    pub fn reset(&self) {
198        self.tokens.store((self.capacity as u64) * 1000, Ordering::Release);
199        self.last_refill.store(self.epoch.elapsed().as_nanos() as u64, Ordering::Release);
200    }
201
202    /// Update capacity (for dynamic limits)
203    pub fn set_capacity(&mut self, capacity: u32) {
204        self.capacity = capacity;
205        // Cap current tokens to new capacity
206        let current = self.tokens.load(Ordering::Acquire);
207        let max = (capacity as u64) * 1000;
208        if current > max {
209            self.tokens.store(max, Ordering::Release);
210        }
211    }
212
213    /// Update refill rate (for dynamic limits)
214    pub fn set_refill_rate(&mut self, rate: f64) {
215        self.refill_rate = rate;
216    }
217}
218
219impl Clone for TokenBucket {
220    fn clone(&self) -> Self {
221        Self {
222            capacity: self.capacity,
223            tokens: AtomicU64::new(self.tokens.load(Ordering::Relaxed)),
224            refill_rate: self.refill_rate,
225            last_refill: AtomicU64::new(self.last_refill.load(Ordering::Relaxed)),
226            epoch: self.epoch,
227            refill_lock: Mutex::new(()),
228        }
229    }
230}
231
232/// Error returned when token bucket is exceeded
233#[derive(Debug, Clone)]
234pub struct TokenBucketExceeded {
235    /// Time until requested tokens are available
236    pub retry_after: Duration,
237
238    /// Current tokens in bucket
239    pub current_tokens: u32,
240
241    /// Tokens that were requested
242    pub requested_tokens: u32,
243}
244
245impl std::fmt::Display for TokenBucketExceeded {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        write!(
248            f,
249            "Token bucket exceeded: {} available, {} requested, retry after {}ms",
250            self.current_tokens,
251            self.requested_tokens,
252            self.retry_after.as_millis()
253        )
254    }
255}
256
257impl std::error::Error for TokenBucketExceeded {}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_bucket_creation() {
265        let bucket = TokenBucket::new(100, 10.0);
266        assert_eq!(bucket.capacity(), 100);
267        assert_eq!(bucket.current_tokens(), 100);
268        assert!(bucket.is_full());
269    }
270
271    #[test]
272    fn test_from_qps() {
273        let bucket = TokenBucket::from_qps(100, 200);
274        assert_eq!(bucket.capacity(), 200);
275        assert_eq!(bucket.refill_rate(), 100.0);
276    }
277
278    #[test]
279    fn test_acquire_success() {
280        let bucket = TokenBucket::new(100, 10.0);
281
282        assert!(bucket.try_acquire(50).is_ok());
283        assert_eq!(bucket.current_tokens(), 50);
284
285        assert!(bucket.try_acquire(50).is_ok());
286        assert_eq!(bucket.current_tokens(), 0);
287    }
288
289    #[test]
290    fn test_acquire_failure() {
291        let bucket = TokenBucket::new(10, 1.0);
292
293        // Acquire all tokens
294        assert!(bucket.try_acquire(10).is_ok());
295
296        // Should fail now
297        let result = bucket.try_acquire(1);
298        assert!(result.is_err());
299
300        let err = result.unwrap_err();
301        assert_eq!(err.current_tokens, 0);
302        assert_eq!(err.requested_tokens, 1);
303    }
304
305    #[test]
306    fn test_refill() {
307        let bucket = TokenBucket::new(100, 100.0); // 100 tokens/sec
308
309        // Drain bucket
310        assert!(bucket.try_acquire(100).is_ok());
311        assert_eq!(bucket.current_tokens(), 0);
312
313        // Wait for refill
314        std::thread::sleep(Duration::from_millis(50));
315
316        // Should have some tokens now (approximately 5)
317        let tokens = bucket.current_tokens();
318        assert!(tokens > 0);
319        assert!(tokens <= 10); // Allow some variance
320    }
321
322    #[test]
323    fn test_return_tokens() {
324        let bucket = TokenBucket::new(100, 10.0);
325
326        assert!(bucket.try_acquire(50).is_ok());
327        assert_eq!(bucket.current_tokens(), 50);
328
329        bucket.return_tokens(30);
330        assert_eq!(bucket.current_tokens(), 80);
331
332        // Returning more than capacity should cap at capacity
333        bucket.return_tokens(50);
334        assert_eq!(bucket.current_tokens(), 100);
335    }
336
337    #[test]
338    fn test_reset() {
339        let bucket = TokenBucket::new(100, 10.0);
340
341        assert!(bucket.try_acquire(100).is_ok());
342        assert!(bucket.is_empty());
343
344        bucket.reset();
345        assert!(bucket.is_full());
346    }
347
348    #[test]
349    fn test_fill_ratio() {
350        let bucket = TokenBucket::new(100, 10.0);
351
352        assert!((bucket.fill_ratio() - 1.0).abs() < 0.01);
353
354        assert!(bucket.try_acquire(50).is_ok());
355        assert!((bucket.fill_ratio() - 0.5).abs() < 0.01);
356
357        assert!(bucket.try_acquire(50).is_ok());
358        assert!((bucket.fill_ratio() - 0.0).abs() < 0.01);
359    }
360
361    #[test]
362    fn test_time_until_available() {
363        let bucket = TokenBucket::new(100, 10.0); // 10 tokens/sec
364
365        // Drain bucket
366        assert!(bucket.try_acquire(100).is_ok());
367
368        // Try to get 10 tokens - should need ~1 second
369        let result = bucket.try_acquire(10);
370        assert!(result.is_err());
371
372        let err = result.unwrap_err();
373        // Should be approximately 1 second (within 100ms tolerance)
374        assert!(err.retry_after.as_millis() >= 900);
375        assert!(err.retry_after.as_millis() <= 1100);
376    }
377
378    #[test]
379    fn test_acquire_blocking() {
380        let bucket = TokenBucket::new(10, 100.0); // 100 tokens/sec
381
382        // Drain bucket
383        assert!(bucket.try_acquire(10).is_ok());
384
385        // Should succeed within timeout
386        let result = bucket.acquire_blocking(5, Duration::from_millis(100));
387        assert!(result.is_ok());
388    }
389
390    #[test]
391    fn test_acquire_blocking_timeout() {
392        let bucket = TokenBucket::new(10, 1.0); // 1 token/sec
393
394        // Drain bucket
395        assert!(bucket.try_acquire(10).is_ok());
396
397        // Should timeout (need 10 seconds, only wait 10ms)
398        let result = bucket.acquire_blocking(10, Duration::from_millis(10));
399        assert!(result.is_err());
400    }
401
402    #[test]
403    fn test_concurrent_access() {
404        use std::sync::Arc;
405        use std::thread;
406
407        let bucket = Arc::new(TokenBucket::new(1000, 1000.0));
408        let mut handles = vec![];
409
410        // Spawn 10 threads, each trying to acquire 50 tokens
411        for _ in 0..10 {
412            let bucket = Arc::clone(&bucket);
413            handles.push(thread::spawn(move || {
414                for _ in 0..10 {
415                    let _ = bucket.try_acquire(5);
416                }
417            }));
418        }
419
420        for handle in handles {
421            handle.join().unwrap();
422        }
423
424        // Tokens should be reduced (exact value depends on timing)
425        assert!(bucket.current_tokens() < 1000);
426    }
427
428    #[test]
429    fn test_clone() {
430        let bucket1 = TokenBucket::new(100, 10.0);
431        assert!(bucket1.try_acquire(50).is_ok());
432
433        let bucket2 = bucket1.clone();
434        assert_eq!(bucket2.capacity(), 100);
435        assert_eq!(bucket2.current_tokens(), 50);
436    }
437}