request_rate_limiter/
limiter.rs1use std::{
4 fmt::Debug,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use crossbeam_utils::Backoff;
14use tokio::time::timeout;
15
16use crate::algorithms::{RateLimitAlgorithm, RequestSample};
17
18type RequestCount = u64;
19
20#[derive(Debug)]
23pub struct Token {
24 start_time: Instant,
25}
26
27#[async_trait]
33pub trait RateLimiter: Debug + Sync {
34 async fn acquire(&self) -> Token;
36
37 async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
39
40 async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
43}
44
45#[derive(Debug)]
49pub struct DefaultRateLimiter<T> {
50 algorithm: T,
51 tokens: Arc<AtomicU64>,
52 last_refill_nanos: Arc<AtomicU64>,
53 requests_per_second: Arc<AtomicU64>,
54 bucket_capacity: RequestCount,
55 refill_interval_nanos: Arc<AtomicU64>,
56}
57
58#[derive(Debug, Clone, Copy)]
62pub struct RateLimiterState {
63 requests_per_second: RequestCount,
65 available_tokens: RequestCount,
67 bucket_capacity: RequestCount,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum RequestOutcome {
76 Success,
78 Overload,
80 ClientError,
82}
83
84impl<T> DefaultRateLimiter<T>
85where
86 T: RateLimitAlgorithm,
87{
88 pub fn new(algorithm: T) -> Self {
90 let initial_rps = algorithm.requests_per_second();
91 let bucket_capacity = initial_rps; assert!(initial_rps >= 1);
94 let now_nanos = std::time::SystemTime::now()
95 .duration_since(std::time::UNIX_EPOCH)
96 .unwrap()
97 .as_nanos() as u64;
98
99 Self {
100 algorithm,
101 tokens: Arc::new(AtomicU64::new(bucket_capacity)),
102 last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
103 requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
104 bucket_capacity,
105 refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
106 }
107 }
108
109 #[inline]
110 fn refill_tokens(&self) {
111 let current_tokens = self.tokens.load(Ordering::Relaxed);
112 if current_tokens >= self.bucket_capacity {
113 return; }
115
116 let now_nanos = std::time::SystemTime::now()
117 .duration_since(std::time::UNIX_EPOCH)
118 .unwrap()
119 .as_nanos() as u64;
120
121 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
122 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
123 let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
124
125 if elapsed_nanos >= refill_interval {
126 let tokens_to_add = elapsed_nanos / refill_interval;
127
128 if tokens_to_add > 0 {
129 let _ = self.last_refill_nanos.compare_exchange_weak(
131 last_refill,
132 now_nanos,
133 Ordering::Release,
134 Ordering::Relaxed,
135 );
136
137 self.tokens
138 .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
139 let new_tokens = (current + tokens_to_add).min(self.bucket_capacity);
140 if new_tokens > current {
141 Some(new_tokens)
142 } else {
143 None
144 }
145 })
146 .ok();
147 }
148 }
149 }
150
151 pub fn state(&self) -> RateLimiterState {
153 self.refill_tokens();
154 RateLimiterState {
155 requests_per_second: self.algorithm.requests_per_second(),
156 available_tokens: self.tokens.load(Ordering::Acquire),
157 bucket_capacity: self.bucket_capacity,
158 }
159 }
160}
161
162#[async_trait]
163impl<T> RateLimiter for DefaultRateLimiter<T>
164where
165 T: RateLimitAlgorithm + Sync + Debug,
166{
167 async fn acquire(&self) -> Token {
168 let backoff = Backoff::new();
169
170 loop {
171 if self.tokens
173 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
174 if current > 0 {
175 Some(current - 1)
176 } else {
177 None
178 }
179 }).is_ok()
180 {
181 return Token {
182 start_time: Instant::now(),
183 };
184 }
185
186 self.refill_tokens();
188
189 if self.tokens
191 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
192 if current > 0 {
193 Some(current - 1)
194 } else {
195 None
196 }
197 }).is_ok()
198 {
199 return Token {
200 start_time: Instant::now(),
201 };
202 }
203
204 if backoff.is_completed() {
206 tokio::task::yield_now().await;
207 backoff.reset();
208 } else {
209 backoff.spin();
210 }
211 }
212 }
213
214 async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
215 timeout(duration, self.acquire()).await.ok()
216 }
217
218 async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
219 let response_time = token.start_time.elapsed();
220
221 if let Some(outcome) = outcome {
222 let current_rps = self.requests_per_second.load(Ordering::Relaxed);
223 let sample = RequestSample::new(response_time, current_rps, outcome);
224
225 let new_rps = self.algorithm.update(sample).await;
226 self.requests_per_second.store(new_rps, Ordering::Relaxed);
227
228 if new_rps != current_rps && new_rps > 0 {
230 self.refill_interval_nanos
231 .store(1_000_000_000 / new_rps, Ordering::Relaxed);
232 }
233 }
234 }
235}
236
237impl RateLimiterState {
238 pub fn requests_per_second(&self) -> RequestCount {
240 self.requests_per_second
241 }
242 pub fn available_tokens(&self) -> RequestCount {
244 self.available_tokens
245 }
246 pub fn bucket_capacity(&self) -> RequestCount {
248 self.bucket_capacity
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use crate::{
255 algorithms::Fixed,
256 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
257 };
258 use std::time::Duration;
259
260 #[tokio::test]
261 async fn rate_limiter_allows_requests_within_limit() {
262 let limiter = DefaultRateLimiter::new(Fixed::new(10));
263
264 let token = limiter.acquire().await;
266
267 limiter.release(token, Some(RequestOutcome::Success)).await;
269 }
270
271 #[tokio::test]
272 async fn rate_limiter_waits_for_tokens() {
273 use std::sync::Arc;
274
275 let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
276
277 let token1 = limiter.acquire().await;
279
280 let limiter_clone = Arc::clone(&limiter);
282 let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
283
284 tokio::time::sleep(Duration::from_millis(10)).await;
286
287 limiter.release(token1, Some(RequestOutcome::Success)).await;
289
290 let token2 = acquire_task.await.unwrap();
292 limiter.release(token2, Some(RequestOutcome::Success)).await;
293 }
294}