elfo_utils/
rate_limiter.rs1use std::{
2 sync::atomic::{AtomicU64, Ordering::Relaxed},
3 time::Duration,
4};
5
6use crate::time;
7
8pub struct RateLimiter {
10 step: AtomicU64,
11 period: AtomicU64,
12 vtime: AtomicU64,
13}
14
15impl Default for RateLimiter {
17 fn default() -> Self {
18 Self::new(RateLimit::Unlimited)
19 }
20}
21
22#[derive(Clone, Copy)]
24pub enum RateLimit {
25 Unlimited,
27 Rps(u64),
29 Custom(u64, Duration),
31}
32
33impl RateLimit {
34 fn step_and_period(self) -> (u64, u64) {
35 let (limit, period) = match self {
36 Self::Unlimited => (1, 1),
37 Self::Rps(rps) => (rps, SEC),
38 Self::Custom(limit, period) => (limit, period.as_nanos() as u64),
39 };
40
41 (calculate_step(limit, period), period)
42 }
43}
44
45const SEC: u64 = 1_000_000_000;
46const UNLIMITED: u64 = 0;
47const DISABLED: u64 = u64::MAX;
48
49impl RateLimiter {
50 pub fn new(limit: RateLimit) -> Self {
52 let (step, period) = limit.step_and_period();
53
54 Self {
55 step: AtomicU64::new(step),
56 period: AtomicU64::new(period),
57 vtime: AtomicU64::new(0),
58 }
59 }
60
61 pub fn configure(&self, limit: RateLimit) {
63 let (step, period) = limit.step_and_period();
64
65 self.step.store(step, Relaxed);
66 self.period.store(period, Relaxed);
67 }
68
69 pub fn reset(&self) {
71 self.vtime.store(0, Relaxed);
72 }
73
74 #[inline]
77 pub fn acquire(&self) -> bool {
78 let step = self.step.load(Relaxed);
79
80 if step == UNLIMITED {
82 return true;
83 }
84 if step == DISABLED {
85 return false;
86 }
87
88 let period = self.period.load(Relaxed);
89 let now = time::nanos_since_unknown_epoch();
90 let deadline = now + period;
91
92 self.vtime
94 .fetch_update(Relaxed, Relaxed, |vtime| {
96 if vtime < deadline {
97 Some(vtime.max(now) + step)
98 } else {
99 None
100 }
101 })
102 .is_ok()
103 }
104}
105
106fn calculate_step(max_rate: u64, period: u64) -> u64 {
107 if max_rate == 0 {
108 return DISABLED;
109 }
110
111 if max_rate >= period {
113 return UNLIMITED;
114 }
115
116 (period - 1) / max_rate + 1
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 fn ns(ns: u64) -> Duration {
125 Duration::from_nanos(ns)
126 }
127
128 #[test]
129 fn step_calculation() {
130 for period in [1, 100, 1000] {
131 assert_eq!(calculate_step(0, period), DISABLED);
132 assert_eq!(calculate_step(period, period), UNLIMITED);
133
134 for coef in 2..50 {
135 assert_eq!(calculate_step(period, coef * period), coef);
136 assert_eq!(calculate_step(period, coef * period + 1), coef + 1);
137 }
138 }
139 }
140
141 #[test]
142 fn forbidding() {
143 time::with_instant_mock(|mock| {
144 let limiter = RateLimiter::new(RateLimit::Rps(0));
145 for _ in 0..=5 {
146 assert!(!limiter.acquire());
147 mock.advance(ns(SEC));
148 }
149 });
150 }
151
152 #[test]
153 fn unlimited() {
154 time::with_instant_mock(|_mock| {
155 let limiter = RateLimiter::new(RateLimit::Unlimited);
156 let limiter2 = RateLimiter::new(RateLimit::Rps(1_000_000_000));
157 let limiter3 = RateLimiter::new(RateLimit::Custom(2_000, Duration::from_micros(2)));
158 for _ in 0..=1_000_000 {
159 assert!(limiter.acquire());
160 assert!(limiter2.acquire());
161 assert!(limiter3.acquire());
162 }
163 });
164 }
165
166 #[test]
167 fn limited() {
168 for limit in [1, 2, 3, 4, 5, 17, 100, 1_000, 1_013] {
169 time::with_instant_mock(|mock| {
170 let limiter = RateLimiter::new(RateLimit::Rps(limit));
171
172 for _ in 0..=5 {
173 for _ in 0..limit {
174 assert!(limiter.acquire());
175 }
176 assert!(!limiter.acquire());
177 mock.advance(ns(SEC));
178 }
179 });
180 }
181 }
182
183 #[test]
184 fn keeps_rate() {
185 for limit in [1, 5, 25, 50] {
186 time::with_instant_mock(|mock| {
187 let limiter = RateLimiter::new(RateLimit::Rps(limit));
188
189 for _ in 0..limit {
191 assert!(limiter.acquire());
192 }
193 assert!(!limiter.acquire());
194
195 let parts = 10;
196 let mut counter = 0;
197
198 for _ in 0..(10 * parts) {
199 mock.advance(ns(SEC / parts));
200 while limiter.acquire() {
201 counter += 1;
202 }
203 }
204
205 assert_eq!(counter, 10 * limit, "{limit}");
206 });
207 }
208 }
209
210 #[test]
211 fn reset() {
212 time::with_instant_mock(|mock| {
213 let limit = 10;
214 let limiter = RateLimiter::new(RateLimit::Rps(limit));
215
216 for _ in 0..=5 {
217 for _ in 0..limit {
218 assert!(limiter.acquire());
219 }
220 limiter.reset();
221 for _ in 0..limit {
222 assert!(limiter.acquire());
223 }
224 assert!(!limiter.acquire());
225 mock.advance(ns(SEC));
226 }
227 });
228 }
229}