Skip to main content

metrics_lib/
token_bucket.rs

1//! Strict-admission token bucket (v0.9.5).
2//!
3//! `TokenBucket` is the strict-admission counterpart to
4//! `RateMeter::tick_if_under_limit`. It exposes the classic token-bucket
5//! algorithm — capacity, refill rate, atomic acquire — using a single
6//! `compare_exchange_weak` loop on a `u64` that packs `tokens` (millitokens)
7//! and `last_refill` time together to eliminate the TOCTOU window that
8//! `RateMeter` accepts in exchange for hot-path speed.
9//!
10//! # When to choose `TokenBucket` over `RateMeter`
11//!
12//! - **`RateMeter::tick_if_under_limit`** is faster (single
13//!   `fetch_add`-style hot path) but has known TOCTOU semantics: multiple
14//!   threads can each observe the under-limit predicate, then `tick`,
15//!   briefly overshooting the limit by up to `num_threads − 1` events.
16//!   Suitable for observability use cases (dashboards, alerting).
17//! - **`TokenBucket::try_acquire`** uses an atomic CAS on the packed
18//!   `(tokens, time)` state, so the limit is **never** exceeded. Suitable
19//!   for billing, hard-limit admission control, downstream protection.
20//!
21//! # Algorithm
22//!
23//! State is one `AtomicU64` packing:
24//! - **upper 32 bits:** `tokens` in millitokens (10⁻³ tokens). Allows
25//!   refill rates with sub-token-per-tick resolution.
26//! - **lower 32 bits:** milliseconds since `created_at` of the last
27//!   refill computation.
28//!
29//! `try_acquire(n)`:
30//! 1. Load packed state.
31//! 2. Compute fresh tokens since `last_refill` (refill is monotonic and
32//!    capped at `capacity_millitokens`).
33//! 3. If `fresh_tokens >= n * 1000`, subtract `n * 1000` and CAS-write the
34//!    new packed state. On success return `Ok(())`; on CAS conflict retry
35//!    from step 1.
36//! 4. If `fresh_tokens < n * 1000`, return `Err(WouldBlock)`.
37//!
38//! Saturates at `u32::MAX` millitokens (~4.3 M tokens). Time wraps after
39//! ~49.7 days of process uptime — a reset call before then is required
40//! for very long-running processes that depend on the bucket; in practice
41//! restart cadence covers this comfortably.
42//!
43//! # Example
44//!
45//! ```
46//! use metrics_lib::TokenBucket;
47//! use std::time::Duration;
48//!
49//! // 10 tokens/sec sustained, burst up to 50.
50//! let bucket = TokenBucket::new(50, 10.0);
51//!
52//! // Hot path: drop or downsample if no token is available.
53//! if bucket.try_acquire(1).is_ok() {
54//!     // serve the request
55//! }
56//!
57//! // Burst of 5 tokens at once.
58//! let _ = bucket.try_acquire(5);
59//! ```
60
61use crate::{MetricsError, Result};
62use std::sync::atomic::{AtomicU64, Ordering};
63use std::time::Instant;
64
65/// Strict-admission token bucket with atomic-CAS acquire semantics.
66///
67/// Cache-line aligned to prevent false sharing in concurrent admission
68/// pipelines.
69#[repr(align(64))]
70pub struct TokenBucket {
71    /// Packed state. Upper 32 bits: current `tokens` × 1000 (millitokens).
72    /// Lower 32 bits: ms since `created_at` of the last refill computation.
73    state: AtomicU64,
74    /// Bucket capacity, in millitokens.
75    capacity_millitokens: u64,
76    /// Refill rate, in millitokens per millisecond. (tokens/sec × 1 / 1 ms.)
77    /// Stored as a fixed-point `u64` (millitokens × 1000 per ms = micro-tokens/ms)
78    /// so that we can avoid floating-point arithmetic on the hot path.
79    refill_micro_per_ms: u64,
80    /// Creation timestamp; the "ms since created_at" packed in `state`
81    /// is computed relative to this `Instant`.
82    created_at: Instant,
83}
84
85impl TokenBucket {
86    /// Build a new bucket.
87    ///
88    /// * `capacity` — maximum tokens the bucket can hold (burst size).
89    /// * `refill_per_second` — sustained refill rate, tokens per second.
90    ///
91    /// The bucket starts **full** (`tokens = capacity`). `refill_per_second`
92    /// of `0.0` produces a static-capacity bucket (no refill); negative or
93    /// non-finite values are coerced to `0.0`.
94    pub fn new(capacity: u32, refill_per_second: f64) -> Self {
95        let cap_mt = (capacity as u64).saturating_mul(1000);
96        let rate = if refill_per_second.is_finite() && refill_per_second > 0.0 {
97            // tokens/sec × 1000 ms/sec ⇒ millitokens/sec ⇒ × 1000 ⇒ micro-tokens/ms
98            (refill_per_second * 1_000.0).round() as u64
99        } else {
100            0
101        };
102        Self {
103            state: AtomicU64::new(pack(cap_mt, 0)),
104            capacity_millitokens: cap_mt,
105            refill_micro_per_ms: rate,
106            created_at: Instant::now(),
107        }
108    }
109
110    /// Bucket capacity in whole tokens.
111    #[must_use]
112    #[inline]
113    pub fn capacity(&self) -> u32 {
114        (self.capacity_millitokens / 1000).min(u32::MAX as u64) as u32
115    }
116
117    /// Refill rate in tokens per second (reconstructed from internal
118    /// fixed-point storage; subject to small rounding).
119    #[must_use]
120    #[inline]
121    pub fn refill_per_second(&self) -> f64 {
122        self.refill_micro_per_ms as f64 / 1_000.0
123    }
124
125    /// Current available tokens (approximate snapshot — observation has no
126    /// retry semantics; treat as advisory).
127    #[must_use]
128    pub fn available(&self) -> u32 {
129        let packed = self.state.load(Ordering::Relaxed);
130        let (tokens_mt, last_ms) = unpack(packed);
131        let now_ms = self.now_ms();
132        let mt = self.refilled(tokens_mt, last_ms, now_ms);
133        ((mt / 1000).min(u32::MAX as u64)) as u32
134    }
135
136    /// Attempt to acquire `n` tokens. Returns `Ok(())` on success, or
137    /// `Err(MetricsError::WouldBlock)` when fewer than `n` tokens are
138    /// available even after refill.
139    ///
140    /// `n == 0` always succeeds without modifying state.
141    #[inline]
142    pub fn try_acquire(&self, n: u32) -> Result<()> {
143        if n == 0 {
144            return Ok(());
145        }
146        let needed = (n as u64) * 1000;
147        // CAS loop — single atomic op on the success path under no contention.
148        loop {
149            let packed = self.state.load(Ordering::Relaxed);
150            let (tokens_mt, last_ms) = unpack(packed);
151            let now_ms = self.now_ms();
152            let mt = self.refilled(tokens_mt, last_ms, now_ms);
153            if mt < needed {
154                return Err(MetricsError::WouldBlock);
155            }
156            let new_packed = pack(mt - needed, now_ms);
157            if self
158                .state
159                .compare_exchange_weak(packed, new_packed, Ordering::Relaxed, Ordering::Relaxed)
160                .is_ok()
161            {
162                return Ok(());
163            }
164        }
165    }
166
167    /// Acquire-or-don't variant that swallows the error and returns a `bool`.
168    #[must_use]
169    #[inline]
170    pub fn acquire(&self, n: u32) -> bool {
171        self.try_acquire(n).is_ok()
172    }
173
174    /// Reset the bucket to full capacity.
175    pub fn reset(&self) {
176        let now_ms = self.now_ms();
177        self.state
178            .store(pack(self.capacity_millitokens, now_ms), Ordering::SeqCst);
179    }
180
181    /// Milliseconds elapsed since `created_at`, saturating into the lower
182    /// 32 bits of the packed state.
183    #[inline]
184    fn now_ms(&self) -> u32 {
185        (self.created_at.elapsed().as_millis() as u64).min(u32::MAX as u64) as u32
186    }
187
188    /// Compute the bucket's millitoken count after refilling for the
189    /// interval `last_ms → now_ms`, capped at capacity.
190    #[inline]
191    fn refilled(&self, tokens_mt: u64, last_ms: u32, now_ms: u32) -> u64 {
192        if self.refill_micro_per_ms == 0 {
193            return tokens_mt;
194        }
195        let elapsed_ms = now_ms.saturating_sub(last_ms) as u64;
196        // micro-tokens/ms × ms = micro-tokens. Divide by 1_000 to get millitokens.
197        let added_micro = elapsed_ms.saturating_mul(self.refill_micro_per_ms);
198        let added_mt = added_micro / 1_000;
199        (tokens_mt.saturating_add(added_mt)).min(self.capacity_millitokens)
200    }
201}
202
203#[inline]
204fn pack(tokens_mt: u64, last_ms: u32) -> u64 {
205    let tokens_mt = tokens_mt.min(u32::MAX as u64);
206    (tokens_mt << 32) | (last_ms as u64)
207}
208
209#[inline]
210fn unpack(packed: u64) -> (u64, u32) {
211    let tokens_mt = packed >> 32;
212    let last_ms = (packed & 0xFFFF_FFFF) as u32;
213    (tokens_mt, last_ms)
214}
215
216impl std::fmt::Debug for TokenBucket {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        f.debug_struct("TokenBucket")
219            .field("capacity", &self.capacity())
220            .field("available", &self.available())
221            .field("refill_per_second", &self.refill_per_second())
222            .finish()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::sync::Arc;
230    use std::thread;
231    use std::time::Duration;
232
233    #[test]
234    fn pack_unpack_round_trip() {
235        for &mt in &[0_u64, 1, 1000, 50_000, u32::MAX as u64] {
236            for &ms in &[0_u32, 1, 1000, u32::MAX] {
237                let (tmt, tms) = unpack(pack(mt, ms));
238                assert_eq!(tmt, mt.min(u32::MAX as u64));
239                assert_eq!(tms, ms);
240            }
241        }
242    }
243
244    #[test]
245    fn new_bucket_starts_full() {
246        let b = TokenBucket::new(10, 5.0);
247        assert_eq!(b.capacity(), 10);
248        assert_eq!(b.available(), 10);
249    }
250
251    #[test]
252    fn try_acquire_zero_is_noop() {
253        let b = TokenBucket::new(5, 1.0);
254        b.try_acquire(0).unwrap();
255        assert_eq!(b.available(), 5);
256    }
257
258    #[test]
259    fn drains_then_refuses() {
260        let b = TokenBucket::new(3, 0.0); // no refill
261        assert!(b.acquire(1));
262        assert!(b.acquire(1));
263        assert!(b.acquire(1));
264        assert!(!b.acquire(1));
265        assert!(matches!(b.try_acquire(1), Err(MetricsError::WouldBlock)));
266    }
267
268    #[test]
269    fn refills_over_time() {
270        // 200 tokens per second ⇒ ~0.2 token per ms.
271        let b = TokenBucket::new(10, 200.0);
272        // Drain.
273        assert!(b.acquire(10));
274        assert_eq!(b.available(), 0);
275        // Sleep enough to refill at least one token (>= 5 ms).
276        thread::sleep(Duration::from_millis(50));
277        assert!(
278            b.available() >= 1,
279            "expected ≥ 1 token after 50 ms, got {}",
280            b.available()
281        );
282        // And eventually we can acquire again.
283        assert!(b.acquire(1));
284    }
285
286    #[test]
287    fn refill_caps_at_capacity() {
288        let b = TokenBucket::new(5, 1000.0);
289        // Drain partially.
290        assert!(b.acquire(3));
291        thread::sleep(Duration::from_millis(50));
292        // After 50 ms at 1000/s we'd refill 50 tokens, but capacity caps at 5.
293        assert_eq!(b.available(), 5);
294    }
295
296    #[test]
297    fn reset_restores_capacity() {
298        let b = TokenBucket::new(4, 1.0);
299        assert!(b.acquire(4));
300        assert_eq!(b.available(), 0);
301        b.reset();
302        assert_eq!(b.available(), 4);
303    }
304
305    #[test]
306    fn concurrent_acquire_never_overshoots_capacity() {
307        // 100 tokens, no refill. 8 threads each try to acquire 30 tokens.
308        // Total demand 240, available 100 — exactly 100 should succeed.
309        let b = Arc::new(TokenBucket::new(100, 0.0));
310        let threads = 8;
311        let per_thread_demand = 30u32;
312
313        let handles: Vec<_> = (0..threads)
314            .map(|_| {
315                let b = Arc::clone(&b);
316                thread::spawn(move || {
317                    let mut taken = 0u32;
318                    for _ in 0..per_thread_demand {
319                        if b.acquire(1) {
320                            taken += 1;
321                        }
322                    }
323                    taken
324                })
325            })
326            .collect();
327
328        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
329        assert_eq!(total, 100, "atomic-CAS bucket must never exceed capacity");
330        assert_eq!(b.available(), 0);
331    }
332
333    #[test]
334    fn invalid_refill_rate_treated_as_zero() {
335        let a = TokenBucket::new(5, f64::NAN);
336        assert_eq!(a.refill_per_second(), 0.0);
337        let b = TokenBucket::new(5, -1.0);
338        assert_eq!(b.refill_per_second(), 0.0);
339        let c = TokenBucket::new(5, f64::INFINITY);
340        assert_eq!(c.refill_per_second(), 0.0);
341    }
342
343    #[test]
344    fn debug_impl() {
345        let b = TokenBucket::new(7, 2.5);
346        let s = format!("{b:?}");
347        assert!(s.contains("TokenBucket"));
348        assert!(s.contains("capacity: 7"));
349    }
350}