metrics_lib/token_bucket.rs
1//! Strict-admission token bucket (v0.9.5).
2//!
3//! `TokenBucket` is the strict-admission counterpart to
4//! `RateMeter::tick_if_under_limit`. It exposes the classic token-bucket
5//! algorithm — capacity, refill rate, atomic acquire — using a single
6//! `compare_exchange_weak` loop on a `u64` that packs `tokens` (millitokens)
7//! and `last_refill` time together to eliminate the TOCTOU window that
8//! `RateMeter` accepts in exchange for hot-path speed.
9//!
10//! # When to choose `TokenBucket` over `RateMeter`
11//!
12//! - **`RateMeter::tick_if_under_limit`** is faster (single
13//! `fetch_add`-style hot path) but has known TOCTOU semantics: multiple
14//! threads can each observe the under-limit predicate, then `tick`,
15//! briefly overshooting the limit by up to `num_threads − 1` events.
16//! Suitable for observability use cases (dashboards, alerting).
17//! - **`TokenBucket::try_acquire`** uses an atomic CAS on the packed
18//! `(tokens, time)` state, so the limit is **never** exceeded. Suitable
19//! for billing, hard-limit admission control, downstream protection.
20//!
21//! # Algorithm
22//!
23//! State is one `AtomicU64` packing:
24//! - **upper 32 bits:** `tokens` in millitokens (10⁻³ tokens). Allows
25//! refill rates with sub-token-per-tick resolution.
26//! - **lower 32 bits:** milliseconds since `created_at` of the last
27//! refill computation.
28//!
29//! `try_acquire(n)`:
30//! 1. Load packed state.
31//! 2. Compute fresh tokens since `last_refill` (refill is monotonic and
32//! capped at `capacity_millitokens`).
33//! 3. If `fresh_tokens >= n * 1000`, subtract `n * 1000` and CAS-write the
34//! new packed state. On success return `Ok(())`; on CAS conflict retry
35//! from step 1.
36//! 4. If `fresh_tokens < n * 1000`, return `Err(WouldBlock)`.
37//!
38//! Saturates at `u32::MAX` millitokens (~4.3 M tokens). Time wraps after
39//! ~49.7 days of process uptime — a reset call before then is required
40//! for very long-running processes that depend on the bucket; in practice
41//! restart cadence covers this comfortably.
42//!
43//! # Example
44//!
45//! ```
46//! use metrics_lib::TokenBucket;
47//! use std::time::Duration;
48//!
49//! // 10 tokens/sec sustained, burst up to 50.
50//! let bucket = TokenBucket::new(50, 10.0);
51//!
52//! // Hot path: drop or downsample if no token is available.
53//! if bucket.try_acquire(1).is_ok() {
54//! // serve the request
55//! }
56//!
57//! // Burst of 5 tokens at once.
58//! let _ = bucket.try_acquire(5);
59//! ```
60
61use crate::{MetricsError, Result};
62use std::sync::atomic::{AtomicU64, Ordering};
63use std::time::Instant;
64
65/// Strict-admission token bucket with atomic-CAS acquire semantics.
66///
67/// Cache-line aligned to prevent false sharing in concurrent admission
68/// pipelines.
69#[repr(align(64))]
70pub struct TokenBucket {
71 /// Packed state. Upper 32 bits: current `tokens` × 1000 (millitokens).
72 /// Lower 32 bits: ms since `created_at` of the last refill computation.
73 state: AtomicU64,
74 /// Bucket capacity, in millitokens.
75 capacity_millitokens: u64,
76 /// Refill rate, in millitokens per millisecond. (tokens/sec × 1 / 1 ms.)
77 /// Stored as a fixed-point `u64` (millitokens × 1000 per ms = micro-tokens/ms)
78 /// so that we can avoid floating-point arithmetic on the hot path.
79 refill_micro_per_ms: u64,
80 /// Creation timestamp; the "ms since created_at" packed in `state`
81 /// is computed relative to this `Instant`.
82 created_at: Instant,
83}
84
85impl TokenBucket {
86 /// Build a new bucket.
87 ///
88 /// * `capacity` — maximum tokens the bucket can hold (burst size).
89 /// * `refill_per_second` — sustained refill rate, tokens per second.
90 ///
91 /// The bucket starts **full** (`tokens = capacity`). `refill_per_second`
92 /// of `0.0` produces a static-capacity bucket (no refill); negative or
93 /// non-finite values are coerced to `0.0`.
94 pub fn new(capacity: u32, refill_per_second: f64) -> Self {
95 let cap_mt = (capacity as u64).saturating_mul(1000);
96 let rate = if refill_per_second.is_finite() && refill_per_second > 0.0 {
97 // tokens/sec × 1000 ms/sec ⇒ millitokens/sec ⇒ × 1000 ⇒ micro-tokens/ms
98 (refill_per_second * 1_000.0).round() as u64
99 } else {
100 0
101 };
102 Self {
103 state: AtomicU64::new(pack(cap_mt, 0)),
104 capacity_millitokens: cap_mt,
105 refill_micro_per_ms: rate,
106 created_at: Instant::now(),
107 }
108 }
109
110 /// Bucket capacity in whole tokens.
111 #[must_use]
112 #[inline]
113 pub fn capacity(&self) -> u32 {
114 (self.capacity_millitokens / 1000).min(u32::MAX as u64) as u32
115 }
116
117 /// Refill rate in tokens per second (reconstructed from internal
118 /// fixed-point storage; subject to small rounding).
119 #[must_use]
120 #[inline]
121 pub fn refill_per_second(&self) -> f64 {
122 self.refill_micro_per_ms as f64 / 1_000.0
123 }
124
125 /// Current available tokens (approximate snapshot — observation has no
126 /// retry semantics; treat as advisory).
127 #[must_use]
128 pub fn available(&self) -> u32 {
129 let packed = self.state.load(Ordering::Relaxed);
130 let (tokens_mt, last_ms) = unpack(packed);
131 let now_ms = self.now_ms();
132 let mt = self.refilled(tokens_mt, last_ms, now_ms);
133 ((mt / 1000).min(u32::MAX as u64)) as u32
134 }
135
136 /// Attempt to acquire `n` tokens. Returns `Ok(())` on success, or
137 /// `Err(MetricsError::WouldBlock)` when fewer than `n` tokens are
138 /// available even after refill.
139 ///
140 /// `n == 0` always succeeds without modifying state.
141 #[inline]
142 pub fn try_acquire(&self, n: u32) -> Result<()> {
143 if n == 0 {
144 return Ok(());
145 }
146 let needed = (n as u64) * 1000;
147 // CAS loop — single atomic op on the success path under no contention.
148 loop {
149 let packed = self.state.load(Ordering::Relaxed);
150 let (tokens_mt, last_ms) = unpack(packed);
151 let now_ms = self.now_ms();
152 let mt = self.refilled(tokens_mt, last_ms, now_ms);
153 if mt < needed {
154 return Err(MetricsError::WouldBlock);
155 }
156 let new_packed = pack(mt - needed, now_ms);
157 if self
158 .state
159 .compare_exchange_weak(packed, new_packed, Ordering::Relaxed, Ordering::Relaxed)
160 .is_ok()
161 {
162 return Ok(());
163 }
164 }
165 }
166
167 /// Acquire-or-don't variant that swallows the error and returns a `bool`.
168 #[must_use]
169 #[inline]
170 pub fn acquire(&self, n: u32) -> bool {
171 self.try_acquire(n).is_ok()
172 }
173
174 /// Reset the bucket to full capacity.
175 pub fn reset(&self) {
176 let now_ms = self.now_ms();
177 self.state
178 .store(pack(self.capacity_millitokens, now_ms), Ordering::SeqCst);
179 }
180
181 /// Milliseconds elapsed since `created_at`, saturating into the lower
182 /// 32 bits of the packed state.
183 #[inline]
184 fn now_ms(&self) -> u32 {
185 (self.created_at.elapsed().as_millis() as u64).min(u32::MAX as u64) as u32
186 }
187
188 /// Compute the bucket's millitoken count after refilling for the
189 /// interval `last_ms → now_ms`, capped at capacity.
190 #[inline]
191 fn refilled(&self, tokens_mt: u64, last_ms: u32, now_ms: u32) -> u64 {
192 if self.refill_micro_per_ms == 0 {
193 return tokens_mt;
194 }
195 let elapsed_ms = now_ms.saturating_sub(last_ms) as u64;
196 // micro-tokens/ms × ms = micro-tokens. Divide by 1_000 to get millitokens.
197 let added_micro = elapsed_ms.saturating_mul(self.refill_micro_per_ms);
198 let added_mt = added_micro / 1_000;
199 (tokens_mt.saturating_add(added_mt)).min(self.capacity_millitokens)
200 }
201}
202
203#[inline]
204fn pack(tokens_mt: u64, last_ms: u32) -> u64 {
205 let tokens_mt = tokens_mt.min(u32::MAX as u64);
206 (tokens_mt << 32) | (last_ms as u64)
207}
208
209#[inline]
210fn unpack(packed: u64) -> (u64, u32) {
211 let tokens_mt = packed >> 32;
212 let last_ms = (packed & 0xFFFF_FFFF) as u32;
213 (tokens_mt, last_ms)
214}
215
216impl std::fmt::Debug for TokenBucket {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("TokenBucket")
219 .field("capacity", &self.capacity())
220 .field("available", &self.available())
221 .field("refill_per_second", &self.refill_per_second())
222 .finish()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use std::sync::Arc;
230 use std::thread;
231 use std::time::Duration;
232
233 #[test]
234 fn pack_unpack_round_trip() {
235 for &mt in &[0_u64, 1, 1000, 50_000, u32::MAX as u64] {
236 for &ms in &[0_u32, 1, 1000, u32::MAX] {
237 let (tmt, tms) = unpack(pack(mt, ms));
238 assert_eq!(tmt, mt.min(u32::MAX as u64));
239 assert_eq!(tms, ms);
240 }
241 }
242 }
243
244 #[test]
245 fn new_bucket_starts_full() {
246 let b = TokenBucket::new(10, 5.0);
247 assert_eq!(b.capacity(), 10);
248 assert_eq!(b.available(), 10);
249 }
250
251 #[test]
252 fn try_acquire_zero_is_noop() {
253 let b = TokenBucket::new(5, 1.0);
254 b.try_acquire(0).unwrap();
255 assert_eq!(b.available(), 5);
256 }
257
258 #[test]
259 fn drains_then_refuses() {
260 let b = TokenBucket::new(3, 0.0); // no refill
261 assert!(b.acquire(1));
262 assert!(b.acquire(1));
263 assert!(b.acquire(1));
264 assert!(!b.acquire(1));
265 assert!(matches!(b.try_acquire(1), Err(MetricsError::WouldBlock)));
266 }
267
268 #[test]
269 fn refills_over_time() {
270 // 200 tokens per second ⇒ ~0.2 token per ms.
271 let b = TokenBucket::new(10, 200.0);
272 // Drain.
273 assert!(b.acquire(10));
274 assert_eq!(b.available(), 0);
275 // Sleep enough to refill at least one token (>= 5 ms).
276 thread::sleep(Duration::from_millis(50));
277 assert!(
278 b.available() >= 1,
279 "expected ≥ 1 token after 50 ms, got {}",
280 b.available()
281 );
282 // And eventually we can acquire again.
283 assert!(b.acquire(1));
284 }
285
286 #[test]
287 fn refill_caps_at_capacity() {
288 let b = TokenBucket::new(5, 1000.0);
289 // Drain partially.
290 assert!(b.acquire(3));
291 thread::sleep(Duration::from_millis(50));
292 // After 50 ms at 1000/s we'd refill 50 tokens, but capacity caps at 5.
293 assert_eq!(b.available(), 5);
294 }
295
296 #[test]
297 fn reset_restores_capacity() {
298 let b = TokenBucket::new(4, 1.0);
299 assert!(b.acquire(4));
300 assert_eq!(b.available(), 0);
301 b.reset();
302 assert_eq!(b.available(), 4);
303 }
304
305 #[test]
306 fn concurrent_acquire_never_overshoots_capacity() {
307 // 100 tokens, no refill. 8 threads each try to acquire 30 tokens.
308 // Total demand 240, available 100 — exactly 100 should succeed.
309 let b = Arc::new(TokenBucket::new(100, 0.0));
310 let threads = 8;
311 let per_thread_demand = 30u32;
312
313 let handles: Vec<_> = (0..threads)
314 .map(|_| {
315 let b = Arc::clone(&b);
316 thread::spawn(move || {
317 let mut taken = 0u32;
318 for _ in 0..per_thread_demand {
319 if b.acquire(1) {
320 taken += 1;
321 }
322 }
323 taken
324 })
325 })
326 .collect();
327
328 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
329 assert_eq!(total, 100, "atomic-CAS bucket must never exceed capacity");
330 assert_eq!(b.available(), 0);
331 }
332
333 #[test]
334 fn invalid_refill_rate_treated_as_zero() {
335 let a = TokenBucket::new(5, f64::NAN);
336 assert_eq!(a.refill_per_second(), 0.0);
337 let b = TokenBucket::new(5, -1.0);
338 assert_eq!(b.refill_per_second(), 0.0);
339 let c = TokenBucket::new(5, f64::INFINITY);
340 assert_eq!(c.refill_per_second(), 0.0);
341 }
342
343 #[test]
344 fn debug_impl() {
345 let b = TokenBucket::new(7, 2.5);
346 let s = format!("{b:?}");
347 assert!(s.contains("TokenBucket"));
348 assert!(s.contains("capacity: 7"));
349 }
350}