Skip to main content

reliability_toolkit/
rate_limiter.rs

1//! Token-bucket rate limiter.
2//!
3//! Tokens accrue at `rate_per_second` up to `burst`. Acquiring a token blocks
4//! until one is available. Bursts are absorbed up to the bucket capacity, which
5//! is the right shape for protecting downstreams that can briefly tolerate
6//! traffic spikes but reject sustained overload.
7//!
8//! The implementation is lock-free for the read path (one `Mutex` guarding a
9//! tiny token count + last-refill timestamp). For workloads >> 10k rps consider
10//! a sharded variant.
11//!
12//! ```
13//! # use std::time::Duration;
14//! # use reliability_toolkit::RateLimiter;
15//! # async fn demo() {
16//! let limiter = RateLimiter::new(10.0, 5); // 10 rps, burst 5
17//! for _ in 0..3 {
18//!     limiter.acquire().await;
19//! }
20//! # }
21//! ```
22
23use std::sync::Arc;
24use std::time::Duration;
25
26use tokio::sync::Mutex;
27use tokio::time::{sleep, Instant};
28
29/// Token-bucket rate limiter. Cheap to `clone`; the inner state is `Arc<Mutex<_>>`.
30#[derive(Clone, Debug)]
31pub struct RateLimiter {
32    inner: Arc<Inner>,
33}
34
35#[derive(Debug)]
36struct Inner {
37    rate_per_second: f64,
38    burst: f64,
39    state: Mutex<State>,
40}
41
42#[derive(Debug)]
43struct State {
44    tokens: f64,
45    last_refill: Instant,
46}
47
48impl RateLimiter {
49    /// Build a limiter that produces `rate_per_second` tokens, capped at `burst`.
50    ///
51    /// # Panics
52    ///
53    /// Panics if `rate_per_second` is not positive or `burst` is zero.
54    pub fn new(rate_per_second: f64, burst: u32) -> Self {
55        assert!(rate_per_second > 0.0, "rate_per_second must be positive");
56        assert!(burst > 0, "burst must be non-zero");
57        let burst_f = f64::from(burst);
58        Self {
59            inner: Arc::new(Inner {
60                rate_per_second,
61                burst: burst_f,
62                state: Mutex::new(State {
63                    tokens: burst_f,
64                    last_refill: Instant::now(),
65                }),
66            }),
67        }
68    }
69
70    /// Wait until a token is available, then consume it.
71    pub async fn acquire(&self) {
72        self.acquire_n(1).await;
73    }
74
75    /// Wait until `n` tokens are available, then consume them.
76    ///
77    /// # Panics
78    ///
79    /// Panics if `n` exceeds the configured burst (it could never be granted).
80    pub async fn acquire_n(&self, n: u32) {
81        let needed = f64::from(n);
82        assert!(
83            needed <= self.inner.burst,
84            "requested {n} tokens but burst is {}",
85            self.inner.burst
86        );
87
88        loop {
89            let wait = {
90                let mut state = self.inner.state.lock().await;
91                self.refill(&mut state);
92                if state.tokens >= needed {
93                    state.tokens -= needed;
94                    return;
95                }
96                let deficit = needed - state.tokens;
97                let seconds = deficit / self.inner.rate_per_second;
98                Duration::from_secs_f64(seconds)
99            };
100            sleep(wait).await;
101        }
102    }
103
104    /// Try to consume one token without waiting. Returns `true` on success.
105    pub async fn try_acquire(&self) -> bool {
106        self.try_acquire_n(1).await
107    }
108
109    /// Try to consume `n` tokens without waiting. Returns `true` on success.
110    pub async fn try_acquire_n(&self, n: u32) -> bool {
111        let needed = f64::from(n);
112        let mut state = self.inner.state.lock().await;
113        self.refill(&mut state);
114        if state.tokens >= needed {
115            state.tokens -= needed;
116            true
117        } else {
118            false
119        }
120    }
121
122    /// Current bucket level (mostly useful for tests + telemetry).
123    pub async fn tokens(&self) -> f64 {
124        let mut state = self.inner.state.lock().await;
125        self.refill(&mut state);
126        state.tokens
127    }
128
129    fn refill(&self, state: &mut State) {
130        let now = Instant::now();
131        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
132        if elapsed > 0.0 {
133            state.tokens =
134                (state.tokens + elapsed * self.inner.rate_per_second).min(self.inner.burst);
135            state.last_refill = now;
136        }
137    }
138}