Skip to main content

elfo_utils/
rate_limiter.rs

1use std::{
2    sync::atomic::{AtomicU64, Ordering::Relaxed},
3    time::Duration,
4};
5
6use crate::time;
7
8/// A rate limiter implementing [GCRA](https://en.wikipedia.org/wiki/Generic_cell_rate_algorithm).
9pub struct RateLimiter {
10    step: AtomicU64,
11    period: AtomicU64,
12    vtime: AtomicU64,
13}
14
15/// Unlimited by default.
16impl Default for RateLimiter {
17    fn default() -> Self {
18        Self::new(RateLimit::Unlimited)
19    }
20}
21
22/// A rate limit configuration.
23#[derive(Clone, Copy)]
24pub enum RateLimit {
25    /// Unlimited rate.
26    Unlimited,
27    /// Requests per second.
28    Rps(u64),
29    /// Requests per a custom period.
30    Custom(u64, Duration),
31}
32
33impl RateLimit {
34    fn step_and_period(self) -> (u64, u64) {
35        let (limit, period) = match self {
36            Self::Unlimited => (1, 1),
37            Self::Rps(rps) => (rps, SEC),
38            Self::Custom(limit, period) => (limit, period.as_nanos() as u64),
39        };
40
41        (calculate_step(limit, period), period)
42    }
43}
44
45const SEC: u64 = 1_000_000_000;
46const UNLIMITED: u64 = 0;
47const DISABLED: u64 = u64::MAX;
48
49impl RateLimiter {
50    /// Creates a new limiter.
51    pub fn new(limit: RateLimit) -> Self {
52        let (step, period) = limit.step_and_period();
53
54        Self {
55            step: AtomicU64::new(step),
56            period: AtomicU64::new(period),
57            vtime: AtomicU64::new(0),
58        }
59    }
60
61    /// Reconfigures a limiter.
62    pub fn configure(&self, limit: RateLimit) {
63        let (step, period) = limit.step_and_period();
64
65        self.step.store(step, Relaxed);
66        self.period.store(period, Relaxed);
67    }
68
69    /// Resets a limiter.
70    pub fn reset(&self) {
71        self.vtime.store(0, Relaxed);
72    }
73
74    /// Acquires one permit.
75    /// Returns `true` if an operation is allowed.
76    #[inline]
77    pub fn acquire(&self) -> bool {
78        let step = self.step.load(Relaxed);
79
80        // Handle special cases.
81        if step == UNLIMITED {
82            return true;
83        }
84        if step == DISABLED {
85            return false;
86        }
87
88        let period = self.period.load(Relaxed);
89        let now = time::nanos_since_unknown_epoch();
90        let deadline = now + period;
91
92        // GCRA logic.
93        self.vtime
94            // It seems to be enough to use `Relaxed` here.
95            .fetch_update(Relaxed, Relaxed, |vtime| {
96                if vtime < deadline {
97                    Some(vtime.max(now) + step)
98                } else {
99                    None
100                }
101            })
102            .is_ok()
103    }
104}
105
106fn calculate_step(max_rate: u64, period: u64) -> u64 {
107    if max_rate == 0 {
108        return DISABLED;
109    }
110
111    // Practically unlimited.
112    if max_rate >= period {
113        return UNLIMITED;
114    }
115
116    // round_up(period / max_rate)
117    (period - 1) / max_rate + 1
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    fn ns(ns: u64) -> Duration {
125        Duration::from_nanos(ns)
126    }
127
128    #[test]
129    fn step_calculation() {
130        for period in [1, 100, 1000] {
131            assert_eq!(calculate_step(0, period), DISABLED);
132            assert_eq!(calculate_step(period, period), UNLIMITED);
133
134            for coef in 2..50 {
135                assert_eq!(calculate_step(period, coef * period), coef);
136                assert_eq!(calculate_step(period, coef * period + 1), coef + 1);
137            }
138        }
139    }
140
141    #[test]
142    fn forbidding() {
143        time::with_instant_mock(|mock| {
144            let limiter = RateLimiter::new(RateLimit::Rps(0));
145            for _ in 0..=5 {
146                assert!(!limiter.acquire());
147                mock.advance(ns(SEC));
148            }
149        });
150    }
151
152    #[test]
153    fn unlimited() {
154        time::with_instant_mock(|_mock| {
155            let limiter = RateLimiter::new(RateLimit::Unlimited);
156            let limiter2 = RateLimiter::new(RateLimit::Rps(1_000_000_000));
157            let limiter3 = RateLimiter::new(RateLimit::Custom(2_000, Duration::from_micros(2)));
158            for _ in 0..=1_000_000 {
159                assert!(limiter.acquire());
160                assert!(limiter2.acquire());
161                assert!(limiter3.acquire());
162            }
163        });
164    }
165
166    #[test]
167    fn limited() {
168        for limit in [1, 2, 3, 4, 5, 17, 100, 1_000, 1_013] {
169            time::with_instant_mock(|mock| {
170                let limiter = RateLimiter::new(RateLimit::Rps(limit));
171
172                for _ in 0..=5 {
173                    for _ in 0..limit {
174                        assert!(limiter.acquire());
175                    }
176                    assert!(!limiter.acquire());
177                    mock.advance(ns(SEC));
178                }
179            });
180        }
181    }
182
183    #[test]
184    fn keeps_rate() {
185        for limit in [1, 5, 25, 50] {
186            time::with_instant_mock(|mock| {
187                let limiter = RateLimiter::new(RateLimit::Rps(limit));
188
189                // Skip the first second.
190                for _ in 0..limit {
191                    assert!(limiter.acquire());
192                }
193                assert!(!limiter.acquire());
194
195                let parts = 10;
196                let mut counter = 0;
197
198                for _ in 0..(10 * parts) {
199                    mock.advance(ns(SEC / parts));
200                    while limiter.acquire() {
201                        counter += 1;
202                    }
203                }
204
205                assert_eq!(counter, 10 * limit, "{limit}");
206            });
207        }
208    }
209
210    #[test]
211    fn reset() {
212        time::with_instant_mock(|mock| {
213            let limit = 10;
214            let limiter = RateLimiter::new(RateLimit::Rps(limit));
215
216            for _ in 0..=5 {
217                for _ in 0..limit {
218                    assert!(limiter.acquire());
219                }
220                limiter.reset();
221                for _ in 0..limit {
222                    assert!(limiter.acquire());
223                }
224                assert!(!limiter.acquire());
225                mock.advance(ns(SEC));
226            }
227        });
228    }
229}