request_rate_limiter/
limiter.rs1use std::{
4 fmt::Debug,
5 sync::{
6 atomic::{AtomicU64, Ordering},
7 Arc,
8 },
9 time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::{sleep, timeout};
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18type AtomicRequestCount = AtomicU64;
19
20#[derive(Debug)]
23pub struct Token {
24 start_time: Instant
25}
26
27#[async_trait]
33pub trait RateLimiter: Debug + Sync {
34 async fn acquire(&self) -> Token;
36
37 async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
39
40 async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
43}
44
45#[derive(Debug)]
49pub struct DefaultRateLimiter<T> {
50 algorithm: T,
51 tokens: Arc<AtomicRequestCount>,
52 last_refill: Arc<std::sync::Mutex<Instant>>,
53 requests_per_second: Arc<AtomicRequestCount>,
54 bucket_capacity: RequestCount,
55}
56
57#[derive(Debug, Clone, Copy)]
61pub struct RateLimiterState {
62 requests_per_second: RequestCount,
64 available_tokens: RequestCount,
66 bucket_capacity: RequestCount,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum RequestOutcome {
75 Success,
77 Overload,
79 ClientError,
81}
82
83impl<T> DefaultRateLimiter<T>
84where
85 T: RateLimitAlgorithm,
86{
87 pub fn new(algorithm: T) -> Self {
89 let initial_rps = algorithm.requests_per_second();
90 let bucket_capacity = initial_rps; assert!(initial_rps >= 1);
93 Self {
94 algorithm,
95 tokens: Arc::new(AtomicRequestCount::new(bucket_capacity)),
96 last_refill: Arc::new(std::sync::Mutex::new(Instant::now())),
97 requests_per_second: Arc::new(AtomicRequestCount::new(initial_rps)),
98 bucket_capacity,
99 }
100 }
101
102 fn refill_tokens(&self) {
103 let now = Instant::now();
104 if let Ok(mut last_refill) = self.last_refill.try_lock() {
105 let elapsed = now.duration_since(*last_refill);
106 let tokens_to_add = (elapsed.as_secs_f64() * self.requests_per_second.load(Ordering::Acquire) as f64) as u64;
107
108 if tokens_to_add > 0 {
109 let current_tokens = self.tokens.load(Ordering::Acquire);
110 let new_tokens = (current_tokens + tokens_to_add).min(self.bucket_capacity);
111 self.tokens.store(new_tokens, Ordering::SeqCst);
112 *last_refill = now;
113 }
114 }
115 }
116
117 pub fn state(&self) -> RateLimiterState {
119 self.refill_tokens();
120 RateLimiterState {
121 requests_per_second: self.requests_per_second.load(Ordering::Acquire),
122 available_tokens: self.tokens.load(Ordering::Acquire),
123 bucket_capacity: self.bucket_capacity,
124 }
125 }
126}
127
128#[async_trait]
129impl<T> RateLimiter for DefaultRateLimiter<T>
130where
131 T: RateLimitAlgorithm + Sync + Debug,
132{
133 async fn acquire(&self) -> Token {
134 loop {
135 self.refill_tokens();
136
137 let current_tokens = self.tokens.load(Ordering::Acquire);
139 if current_tokens > 0 {
140 match self.tokens.compare_exchange_weak(
141 current_tokens,
142 current_tokens - 1,
143 Ordering::Release,
144 Ordering::Relaxed,
145 ) {
146 Ok(_) => {
147 return Token {
148 start_time: Instant::now(),
149 };
150 },
151 Err(_) => continue, }
153 } else {
154 sleep(Duration::from_millis(1)).await;
156 }
157 }
158 }
159
160 async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
161 match timeout(duration, self.acquire()).await {
162 Ok(token) => Some(token),
163 Err(_) => None, }
165 }
166
167 async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
168 let response_time = token.start_time.elapsed();
169
170 if let Some(outcome) = outcome {
171 let current_rps = self.requests_per_second.load(Ordering::Acquire);
172 let sample = RequestSample::new(response_time, current_rps, outcome);
173
174 let new_rps = self.algorithm.update(sample).await;
175 self.requests_per_second.store(new_rps, Ordering::Release);
176 }
177 }
178}
179
180impl RateLimiterState {
181 pub fn requests_per_second(&self) -> RequestCount {
183 self.requests_per_second
184 }
185 pub fn available_tokens(&self) -> RequestCount {
187 self.available_tokens
188 }
189 pub fn bucket_capacity(&self) -> RequestCount {
191 self.bucket_capacity
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use crate::{
198 limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
199 algorithms::Fixed,
200 };
201 use std::time::Duration;
202
203 #[tokio::test]
204 async fn rate_limiter_allows_requests_within_limit() {
205 let limiter = DefaultRateLimiter::new(Fixed::new(10));
206
207 let token = limiter.acquire().await;
209
210 limiter.release(token, Some(RequestOutcome::Success)).await;
212 }
213
214 #[tokio::test]
215 async fn rate_limiter_waits_for_tokens() {
216 use std::sync::Arc;
217
218 let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
219
220 let token1 = limiter.acquire().await;
222
223 let limiter_clone = Arc::clone(&limiter);
225 let acquire_task = tokio::spawn(async move {
226 limiter_clone.acquire().await
227 });
228
229 tokio::time::sleep(Duration::from_millis(10)).await;
231
232 limiter.release(token1, Some(RequestOutcome::Success)).await;
234
235 let token2 = acquire_task.await.unwrap();
237 limiter.release(token2, Some(RequestOutcome::Success)).await;
238 }
239}