Skip to main content

dome_throttle/
token_bucket.rs

1use tokio::time::Instant;
2
3/// A token-bucket rate limiter.
4///
5/// Tokens refill continuously at `refill_rate` tokens/sec up to `max_tokens`.
6/// Each `try_acquire` consumes one token. Thread-safe when wrapped in DashMap
7/// (exterior mutability via `get_mut`).
8#[derive(Debug, Clone)]
9pub struct TokenBucket {
10    pub tokens: f64,
11    pub max_tokens: f64,
12    pub refill_rate: f64,
13    pub last_refill: Instant,
14}
15
16impl TokenBucket {
17    /// Create a new bucket starting full.
18    pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
19        Self {
20            tokens: max_tokens,
21            max_tokens,
22            refill_rate,
23            last_refill: Instant::now(),
24        }
25    }
26
27    /// Create a bucket with a specific start time (useful for testing).
28    pub fn new_at(max_tokens: f64, refill_rate: f64, now: Instant) -> Self {
29        Self {
30            tokens: max_tokens,
31            max_tokens,
32            refill_rate,
33            last_refill: now,
34        }
35    }
36
37    /// Refill tokens based on elapsed time, then try to consume one token.
38    /// Returns `true` if the token was acquired, `false` if bucket is empty.
39    pub fn try_acquire(&mut self) -> bool {
40        self.try_acquire_at(Instant::now())
41    }
42
43    /// Same as `try_acquire` but accepts a timestamp (for deterministic testing).
44    pub fn try_acquire_at(&mut self, now: Instant) -> bool {
45        self.refill(now);
46        if self.tokens >= 1.0 {
47            self.tokens -= 1.0;
48            true
49        } else {
50            false
51        }
52    }
53
54    /// Refill tokens based on elapsed time since last refill.
55    fn refill(&mut self, now: Instant) {
56        let elapsed = now.duration_since(self.last_refill);
57        let added = elapsed.as_secs_f64() * self.refill_rate;
58        if added > 0.0 {
59            self.tokens = (self.tokens + added).min(self.max_tokens);
60            self.last_refill = now;
61        }
62    }
63
64    /// Current token count (after refill).
65    pub fn available(&mut self) -> f64 {
66        self.refill(Instant::now());
67        self.tokens
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use std::time::Duration;
75
76    #[tokio::test(start_paused = true)]
77    async fn allows_requests_within_limit() {
78        let now = Instant::now();
79        let mut bucket = TokenBucket::new_at(5.0, 1.0, now);
80
81        // Should allow 5 requests (bucket starts full)
82        for i in 0..5 {
83            assert!(bucket.try_acquire_at(now), "request {i} should be allowed");
84        }
85    }
86
87    #[tokio::test(start_paused = true)]
88    async fn denies_when_exhausted() {
89        let now = Instant::now();
90        let mut bucket = TokenBucket::new_at(3.0, 1.0, now);
91
92        // Drain the bucket
93        assert!(bucket.try_acquire_at(now));
94        assert!(bucket.try_acquire_at(now));
95        assert!(bucket.try_acquire_at(now));
96
97        // Should deny
98        assert!(!bucket.try_acquire_at(now), "should deny when exhausted");
99    }
100
101    #[tokio::test(start_paused = true)]
102    async fn refills_over_time() {
103        let now = Instant::now();
104        let mut bucket = TokenBucket::new_at(3.0, 1.0, now);
105
106        // Drain completely
107        assert!(bucket.try_acquire_at(now));
108        assert!(bucket.try_acquire_at(now));
109        assert!(bucket.try_acquire_at(now));
110        assert!(!bucket.try_acquire_at(now));
111
112        // Advance 2 seconds: should refill 2 tokens (rate = 1/sec)
113        let later = now + Duration::from_secs(2);
114        assert!(bucket.try_acquire_at(later), "should allow after refill");
115        assert!(
116            bucket.try_acquire_at(later),
117            "second token should be available"
118        );
119        assert!(
120            !bucket.try_acquire_at(later),
121            "third should fail, only 2 refilled"
122        );
123    }
124
125    #[tokio::test(start_paused = true)]
126    async fn does_not_exceed_max() {
127        let now = Instant::now();
128        let mut bucket = TokenBucket::new_at(3.0, 10.0, now);
129
130        // Advance 100 seconds with high refill rate
131        let later = now + Duration::from_secs(100);
132        bucket.refill(later);
133
134        // Should cap at max_tokens
135        assert!(bucket.tokens <= bucket.max_tokens);
136        assert!((bucket.tokens - 3.0).abs() < f64::EPSILON);
137    }
138
139    #[tokio::test(start_paused = true)]
140    async fn partial_refill() {
141        let now = Instant::now();
142        let mut bucket = TokenBucket::new_at(10.0, 2.0, now);
143
144        // Drain 5 tokens
145        for _ in 0..5 {
146            assert!(bucket.try_acquire_at(now));
147        }
148        // 5 tokens left
149
150        // Advance 1 second at rate 2/sec => +2 tokens = 7
151        let later = now + Duration::from_millis(1000);
152        assert!(bucket.try_acquire_at(later)); // 6
153        assert!(bucket.try_acquire_at(later)); // 5
154        assert!(bucket.try_acquire_at(later)); // 4
155        assert!(bucket.try_acquire_at(later)); // 3
156        assert!(bucket.try_acquire_at(later)); // 2
157        assert!(bucket.try_acquire_at(later)); // 1
158        assert!(bucket.try_acquire_at(later)); // 0
159        assert!(!bucket.try_acquire_at(later)); // empty
160    }
161}