Skip to main content

chio_guards/external/
token_bucket.rs

1//! Simple token bucket rate limiter for the external guard adapter.
2//!
3//! Unlike [`crate::velocity::VelocityGuard`] which tracks per-grant buckets
4//! with milli-token precision, this bucket is a single-instance rate limiter
5//! intended to cap the QPS of one [`crate::external::AsyncGuardAdapter`].
6//!
7//! The bucket uses a [`Clock`] abstraction so tests drive time via Tokio's
8//! pausable timer.
9
10use std::sync::Arc;
11use std::sync::Mutex;
12
13use tokio::time::Instant;
14
15use super::cache::{Clock, TokioClock};
16
17/// Single-instance token bucket rate limiter.
18///
19/// Tokens refill continuously at `rate_per_second` up to the `burst`
20/// ceiling. A call to [`TokenBucket::try_acquire`] consumes one token;
21/// returns `true` on success and `false` when the bucket is empty.
22pub struct TokenBucket {
23    inner: Mutex<BucketInner>,
24    rate_per_second: f64,
25    burst: f64,
26    clock: Arc<dyn Clock>,
27}
28
29#[derive(Debug)]
30struct BucketInner {
31    tokens: f64,
32    last_refill: Instant,
33}
34
35impl TokenBucket {
36    /// Create a bucket with `rate_per_second` refill rate and `burst`
37    /// capacity. Starts full. `rate_per_second` is clamped to `>= 0.0`; a
38    /// zero rate means the bucket never refills (only the initial burst
39    /// is available).
40    pub fn new(rate_per_second: f64, burst: u32) -> Self {
41        Self::with_clock(rate_per_second, burst, Arc::new(TokioClock))
42    }
43
44    /// Create a bucket with a custom clock.
45    pub fn with_clock(rate_per_second: f64, burst: u32, clock: Arc<dyn Clock>) -> Self {
46        let rate = rate_per_second.max(0.0);
47        let burst_f = f64::from(burst.max(1));
48        let now = clock.now();
49        Self {
50            inner: Mutex::new(BucketInner {
51                tokens: burst_f,
52                last_refill: now,
53            }),
54            rate_per_second: rate,
55            burst: burst_f,
56            clock,
57        }
58    }
59
60    /// Configured refill rate (tokens per second).
61    pub fn rate_per_second(&self) -> f64 {
62        self.rate_per_second
63    }
64
65    /// Configured burst ceiling.
66    pub fn burst(&self) -> u32 {
67        self.burst as u32
68    }
69
70    /// Attempt to consume one token. Returns `true` if a token was available
71    /// (and consumed); `false` otherwise.
72    pub fn try_acquire(&self) -> bool {
73        self.try_acquire_n(1.0)
74    }
75
76    /// Attempt to consume `n` tokens.
77    pub fn try_acquire_n(&self, n: f64) -> bool {
78        if n <= 0.0 {
79            return true;
80        }
81        let now = self.clock.now();
82        let Ok(mut inner) = self.inner.lock() else {
83            return false;
84        };
85        self.refill(&mut inner, now);
86        if inner.tokens + f64::EPSILON >= n {
87            inner.tokens -= n;
88            if inner.tokens < 0.0 {
89                inner.tokens = 0.0;
90            }
91            true
92        } else {
93            false
94        }
95    }
96
97    /// Current token count. Mainly useful for tests and diagnostics.
98    pub fn available(&self) -> f64 {
99        let now = self.clock.now();
100        let Ok(mut inner) = self.inner.lock() else {
101            return 0.0;
102        };
103        self.refill(&mut inner, now);
104        inner.tokens
105    }
106
107    fn refill(&self, inner: &mut BucketInner, now: Instant) {
108        if self.rate_per_second == 0.0 {
109            inner.last_refill = now;
110            return;
111        }
112        let elapsed = now
113            .saturating_duration_since(inner.last_refill)
114            .as_secs_f64();
115        if elapsed <= 0.0 {
116            return;
117        }
118        let added = elapsed * self.rate_per_second;
119        inner.tokens = (inner.tokens + added).min(self.burst);
120        inner.last_refill = now;
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::time::Duration;
128
129    #[tokio::test(flavor = "current_thread", start_paused = true)]
130    async fn starts_full_with_burst_capacity() {
131        let bucket = TokenBucket::new(1.0, 3);
132        assert!(bucket.try_acquire());
133        assert!(bucket.try_acquire());
134        assert!(bucket.try_acquire());
135        assert!(!bucket.try_acquire());
136    }
137
138    #[tokio::test(flavor = "current_thread", start_paused = true)]
139    async fn refills_at_configured_rate() {
140        let bucket = TokenBucket::new(2.0, 2);
141        assert!(bucket.try_acquire());
142        assert!(bucket.try_acquire());
143        assert!(!bucket.try_acquire());
144        tokio::time::advance(Duration::from_millis(1100)).await;
145        // After 1.1s at 2 tok/s we should have ~2 tokens (capped).
146        assert!(bucket.try_acquire());
147        assert!(bucket.try_acquire());
148        assert!(!bucket.try_acquire());
149    }
150
151    #[tokio::test(flavor = "current_thread", start_paused = true)]
152    async fn zero_rate_only_uses_burst() {
153        let bucket = TokenBucket::new(0.0, 1);
154        assert!(bucket.try_acquire());
155        tokio::time::advance(Duration::from_secs(60)).await;
156        assert!(!bucket.try_acquire());
157    }
158}