request_rate_limiter/
limiter.rs1use std::{
4 fmt::Debug,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::timeout;
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18
19#[derive(Debug)]
21pub struct Token {
22 start_time: Instant,
23}
24
25#[async_trait]
31pub trait RateLimiter: Debug + Sync {
32 async fn acquire(&self) -> Token;
34
35 async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
37
38 async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
41}
42
43#[derive(Debug, Clone)]
46pub struct DefaultRateLimiter<T> {
47 algorithm: T,
48 tokens: Arc<AtomicU64>,
49 last_refill_nanos: Arc<AtomicU64>,
50 requests_per_second: Arc<AtomicU64>,
51 refill_interval_nanos: Arc<AtomicU64>,
52}
53
54#[derive(Debug, Clone, Copy)]
58pub struct RateLimiterState {
59 requests_per_second: RequestCount,
61 available_tokens: RequestCount,
63 bucket_capacity: RequestCount,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum RequestOutcome {
72 Success,
74 Overload,
76 ClientError,
78}
79
80impl<T> DefaultRateLimiter<T>
81where
82 T: RateLimitAlgorithm,
83{
84 pub fn new(algorithm: T) -> Self {
86 let initial_rps = algorithm.requests_per_second();
87 assert!(initial_rps >= 1);
88 let now_nanos = std::time::SystemTime::now()
89 .duration_since(std::time::UNIX_EPOCH)
90 .unwrap()
91 .as_nanos() as u64;
92
93 Self {
94 algorithm,
95 tokens: Arc::new(AtomicU64::new(initial_rps)),
96 last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
97 requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
98 refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
99 }
100 }
101
102 fn refill_tokens(&self) {
103 let bucket_capacity = self.requests_per_second.load(Ordering::Relaxed);
105 let current_tokens = self.tokens.load(Ordering::Relaxed);
106
107 if current_tokens > bucket_capacity {
109 self.tokens.store(bucket_capacity, Ordering::Relaxed);
110 }
111
112 let current_tokens = self.tokens.load(Ordering::Relaxed);
113 if current_tokens >= bucket_capacity {
114 return;
115 }
116
117 let now_nanos = std::time::SystemTime::now()
118 .duration_since(std::time::UNIX_EPOCH)
119 .unwrap()
120 .as_nanos() as u64;
121
122 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
123 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
124 let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
125
126 if elapsed_nanos >= refill_interval {
127 let tokens_to_add = elapsed_nanos / refill_interval;
128
129 if tokens_to_add > 0 {
130 let actual_elapsed = tokens_to_add * refill_interval;
132
133 if self
134 .last_refill_nanos
135 .compare_exchange_weak(
136 last_refill,
137 last_refill + actual_elapsed,
138 Ordering::Release,
139 Ordering::Relaxed,
140 )
141 .is_ok()
142 {
143 self.tokens
144 .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
145 let new_tokens = (current + tokens_to_add).min(bucket_capacity);
146 if new_tokens > current {
147 Some(new_tokens)
148 } else {
149 None
150 }
151 })
152 .ok();
153 }
154 }
155 }
156 }
157
158 pub fn state(&self) -> RateLimiterState {
160 self.refill_tokens();
161 RateLimiterState {
162 requests_per_second: self.algorithm.requests_per_second(),
163 available_tokens: self.tokens.load(Ordering::Acquire),
164 bucket_capacity: self.requests_per_second.load(Ordering::Relaxed),
165 }
166 }
167}
168
169#[async_trait]
170impl<T> RateLimiter for DefaultRateLimiter<T>
171where
172 T: RateLimitAlgorithm + Sync + Debug,
173{
174 async fn acquire(&self) -> Token {
175 loop {
176 if self
177 .tokens
178 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
179 if current > 0 {
180 Some(current - 1)
181 } else {
182 None
183 }
184 })
185 .is_ok()
186 {
187 return Token {
188 start_time: Instant::now(),
189 };
190 }
191
192 self.refill_tokens();
193
194 if self
195 .tokens
196 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
197 if current > 0 {
198 Some(current - 1)
199 } else {
200 None
201 }
202 })
203 .is_ok()
204 {
205 return Token {
206 start_time: Instant::now(),
207 };
208 }
209
210 let now_nanos = std::time::SystemTime::now()
211 .duration_since(std::time::UNIX_EPOCH)
212 .unwrap()
213 .as_nanos() as u64;
214
215 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
216 let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
217 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
218
219 if elapsed_nanos < refill_interval {
220 let wait_nanos = refill_interval - elapsed_nanos;
221 tokio::time::sleep(Duration::from_nanos(wait_nanos)).await;
222 } else {
223 tokio::task::yield_now().await;
224 }
225 }
226 }
227
228 async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
229 timeout(duration, self.acquire()).await.ok()
230 }
231
232 async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
233 let response_time = token.start_time.elapsed();
234
235 if let Some(outcome) = outcome {
236 let current_rps = self.requests_per_second.load(Ordering::Relaxed);
237 let sample = RequestSample::new(response_time, current_rps, outcome);
238
239 let new_rps = self.algorithm.update(sample).await;
240 self.requests_per_second.store(new_rps, Ordering::Relaxed);
241
242 if new_rps != current_rps && new_rps > 0 {
243 self.refill_interval_nanos
244 .store(1_000_000_000 / new_rps, Ordering::Relaxed);
245 }
246 }
247 }
248}
249
250impl RateLimiterState {
251 pub fn requests_per_second(&self) -> RequestCount {
253 self.requests_per_second
254 }
255 pub fn available_tokens(&self) -> RequestCount {
257 self.available_tokens
258 }
259 pub fn bucket_capacity(&self) -> RequestCount {
261 self.bucket_capacity
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use crate::{
268 algorithms::Fixed,
269 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
270 };
271 use std::time::Duration;
272
273 #[tokio::test]
274 async fn rate_limiter_allows_requests_within_limit() {
275 let limiter = DefaultRateLimiter::new(Fixed::new(10));
276 let token = limiter.acquire().await;
277 limiter.release(token, Some(RequestOutcome::Success)).await;
278 }
279
280 #[tokio::test]
281 async fn rate_limiter_waits_for_tokens() {
282 use std::sync::Arc;
283
284 let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
285 let token1 = limiter.acquire().await;
286
287 let limiter_clone = Arc::clone(&limiter);
288 let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
289
290 tokio::time::sleep(Duration::from_millis(10)).await;
291 limiter.release(token1, Some(RequestOutcome::Success)).await;
292
293 let token2 = acquire_task.await.unwrap();
294 limiter.release(token2, Some(RequestOutcome::Success)).await;
295 }
296}