request_rate_limiter/
limiter.rs1use std::{
4 fmt::Debug,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::timeout;
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18
19#[derive(Debug)]
22pub struct Token {
23 start_time: Instant,
24}
25
26#[async_trait]
32pub trait RateLimiter: Debug + Sync {
33 async fn acquire(&self) -> Token;
35
36 async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
38
39 async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
42}
43
44#[derive(Debug)]
48pub struct DefaultRateLimiter<T> {
49 algorithm: T,
50 tokens: Arc<AtomicU64>,
51 last_refill_nanos: Arc<AtomicU64>,
52 requests_per_second: Arc<AtomicU64>,
53 bucket_capacity: RequestCount,
54 refill_interval_nanos: Arc<AtomicU64>,
55}
56
57#[derive(Debug, Clone, Copy)]
61pub struct RateLimiterState {
62 requests_per_second: RequestCount,
64 available_tokens: RequestCount,
66 bucket_capacity: RequestCount,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum RequestOutcome {
75 Success,
77 Overload,
79 ClientError,
81}
82
83impl<T> DefaultRateLimiter<T>
84where
85 T: RateLimitAlgorithm,
86{
87 pub fn new(algorithm: T) -> Self {
89 let initial_rps = algorithm.requests_per_second();
90 let bucket_capacity = initial_rps; assert!(initial_rps >= 1);
93 let now_nanos = std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .unwrap()
96 .as_nanos() as u64;
97
98 Self {
99 algorithm,
100 tokens: Arc::new(AtomicU64::new(bucket_capacity)),
101 last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
102 requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
103 bucket_capacity,
104 refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
105 }
106 }
107
108 #[inline]
109 fn refill_tokens(&self) {
110 let current_tokens = self.tokens.load(Ordering::Relaxed);
111 if current_tokens >= self.bucket_capacity {
112 return; }
114
115 let now_nanos = std::time::SystemTime::now()
116 .duration_since(std::time::UNIX_EPOCH)
117 .unwrap()
118 .as_nanos() as u64;
119
120 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
121 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
122 let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
123
124 if elapsed_nanos >= refill_interval {
125 let tokens_to_add = elapsed_nanos / refill_interval;
126
127 if tokens_to_add > 0 {
128 let _ = self.last_refill_nanos.compare_exchange_weak(
130 last_refill,
131 now_nanos,
132 Ordering::Release,
133 Ordering::Relaxed,
134 );
135
136 self.tokens
137 .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
138 let new_tokens = (current + tokens_to_add).min(self.bucket_capacity);
139 if new_tokens > current {
140 Some(new_tokens)
141 } else {
142 None
143 }
144 })
145 .ok();
146 }
147 }
148 }
149
150 pub fn state(&self) -> RateLimiterState {
152 self.refill_tokens();
153 RateLimiterState {
154 requests_per_second: self.algorithm.requests_per_second(),
155 available_tokens: self.tokens.load(Ordering::Acquire),
156 bucket_capacity: self.bucket_capacity,
157 }
158 }
159}
160
161#[async_trait]
162impl<T> RateLimiter for DefaultRateLimiter<T>
163where
164 T: RateLimitAlgorithm + Sync + Debug,
165{
166 async fn acquire(&self) -> Token {
167 loop {
168 if self
170 .tokens
171 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
172 if current > 0 {
173 Some(current - 1)
174 } else {
175 None
176 }
177 })
178 .is_ok()
179 {
180 return Token {
181 start_time: Instant::now(),
182 };
183 }
184
185 self.refill_tokens();
187
188 if self
190 .tokens
191 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
192 if current > 0 {
193 Some(current - 1)
194 } else {
195 None
196 }
197 })
198 .is_ok()
199 {
200 return Token {
201 start_time: Instant::now(),
202 };
203 }
204
205 let now_nanos = std::time::SystemTime::now()
206 .duration_since(std::time::UNIX_EPOCH)
207 .unwrap()
208 .as_nanos() as u64;
209
210 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
211 let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
212 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
213
214 if elapsed_nanos < refill_interval {
215 let wait_nanos = refill_interval - elapsed_nanos;
216
217 tokio::time::sleep(Duration::from_nanos(wait_nanos)).await;
218 } else {
219 tokio::task::yield_now().await;
220 }
221 }
222 }
223
224 async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
225 timeout(duration, self.acquire()).await.ok()
226 }
227
228 async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
229 let response_time = token.start_time.elapsed();
230
231 if let Some(outcome) = outcome {
232 let current_rps = self.requests_per_second.load(Ordering::Relaxed);
233 let sample = RequestSample::new(response_time, current_rps, outcome);
234
235 let new_rps = self.algorithm.update(sample).await;
236 self.requests_per_second.store(new_rps, Ordering::Relaxed);
237
238 if new_rps != current_rps && new_rps > 0 {
240 self.refill_interval_nanos
241 .store(1_000_000_000 / new_rps, Ordering::Relaxed);
242 }
243 }
244 }
245}
246
247impl RateLimiterState {
248 pub fn requests_per_second(&self) -> RequestCount {
250 self.requests_per_second
251 }
252 pub fn available_tokens(&self) -> RequestCount {
254 self.available_tokens
255 }
256 pub fn bucket_capacity(&self) -> RequestCount {
258 self.bucket_capacity
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use crate::{
265 algorithms::Fixed,
266 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
267 };
268 use std::time::Duration;
269
270 #[tokio::test]
271 async fn rate_limiter_allows_requests_within_limit() {
272 let limiter = DefaultRateLimiter::new(Fixed::new(10));
273
274 let token = limiter.acquire().await;
276
277 limiter.release(token, Some(RequestOutcome::Success)).await;
279 }
280
281 #[tokio::test]
282 async fn rate_limiter_waits_for_tokens() {
283 use std::sync::Arc;
284
285 let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
286
287 let token1 = limiter.acquire().await;
289
290 let limiter_clone = Arc::clone(&limiter);
292 let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
293
294 tokio::time::sleep(Duration::from_millis(10)).await;
296
297 limiter.release(token1, Some(RequestOutcome::Success)).await;
299
300 let token2 = acquire_task.await.unwrap();
302 limiter.release(token2, Some(RequestOutcome::Success)).await;
303 }
304}