Skip to main content

request_rate_limiter/
limiter.rs

1//! Rate limiters for controlling request throughput.
2
3use std::{
4    fmt::Debug,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::timeout;
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18
19/// A token representing permission to make a request.
20/// The token tracks when the request was started for timing measurements.
21#[derive(Debug)]
22pub struct Token {
23    start_time: Instant,
24}
25
26/// Controls the rate of requests over time.
27///
28/// Rate limiting is achieved by checking if a request is allowed based on the current
29/// rate limit algorithm. The limiter tracks request patterns and adjusts limits dynamically
30/// based on observed success/failure rates and response times.
31#[async_trait]
32pub trait RateLimiter: Debug + Sync {
33    /// Acquire permission to make a request. Waits until a token is available.
34    async fn acquire(&self) -> Token;
35
36    /// Acquire permission to make a request with a timeout. Returns a token if successful.
37    async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
38
39    /// Release the token and record the outcome of the request.
40    /// The response time is calculated from when the token was acquired.
41    async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
42}
43
44/// A token bucket based rate limiter.
45///
46/// Cheaply cloneable.
47#[derive(Debug, Clone)]
48pub struct DefaultRateLimiter<T> {
49    algorithm: T,
50    tokens: Arc<AtomicU64>,
51    last_refill_nanos: Arc<AtomicU64>,
52    requests_per_second: Arc<AtomicU64>,
53    bucket_capacity: RequestCount,
54    refill_interval_nanos: Arc<AtomicU64>,
55}
56
57/// A snapshot of the state of the rate limiter.
58///
59/// Not guaranteed to be consistent under high concurrency.
60#[derive(Debug, Clone, Copy)]
61pub struct RateLimiterState {
62    /// Current requests per second limit
63    requests_per_second: RequestCount,
64    /// Available tokens in the bucket
65    available_tokens: RequestCount,
66    /// Maximum bucket capacity
67    bucket_capacity: RequestCount,
68}
69
70/// Whether a request succeeded or failed, potentially due to overload.
71///
72/// Errors not considered to be caused by overload should be ignored.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum RequestOutcome {
75    /// The request succeeded, or failed in a way unrelated to overload.
76    Success,
77    /// The request failed because of overload, e.g. it timed out or received a 429/503 response.
78    Overload,
79    /// The request failed due to client error (4xx) - not related to rate limiting.
80    ClientError,
81}
82
83impl<T> DefaultRateLimiter<T>
84where
85    T: RateLimitAlgorithm,
86{
87    /// Create a rate limiter with a given rate limiting algorithm.
88    pub fn new(algorithm: T) -> Self {
89        let initial_rps = algorithm.requests_per_second();
90        let bucket_capacity = initial_rps; // Use the same value for bucket capacity
91
92        assert!(initial_rps >= 1);
93        let now_nanos = std::time::SystemTime::now()
94            .duration_since(std::time::UNIX_EPOCH)
95            .unwrap()
96            .as_nanos() as u64;
97
98        Self {
99            algorithm,
100            tokens: Arc::new(AtomicU64::new(bucket_capacity)),
101            last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
102            requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
103            bucket_capacity,
104            refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
105        }
106    }
107
108    #[inline]
109    fn refill_tokens(&self) {
110        let current_tokens = self.tokens.load(Ordering::Relaxed);
111        if current_tokens >= self.bucket_capacity {
112            return; // Already at capacity
113        }
114
115        let now_nanos = std::time::SystemTime::now()
116            .duration_since(std::time::UNIX_EPOCH)
117            .unwrap()
118            .as_nanos() as u64;
119
120        let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
121        let elapsed_nanos = now_nanos.saturating_sub(last_refill);
122        let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
123
124        if elapsed_nanos >= refill_interval {
125            let tokens_to_add = elapsed_nanos / refill_interval;
126
127            if tokens_to_add > 0 {
128                // Atomic update of both tokens and last_refill_nanos
129                let _ = self.last_refill_nanos.compare_exchange_weak(
130                    last_refill,
131                    now_nanos,
132                    Ordering::Release,
133                    Ordering::Relaxed,
134                );
135
136                self.tokens
137                    .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
138                        let new_tokens = (current + tokens_to_add).min(self.bucket_capacity);
139                        if new_tokens > current {
140                            Some(new_tokens)
141                        } else {
142                            None
143                        }
144                    })
145                    .ok();
146            }
147        }
148    }
149
150    /// The current state of the rate limiter.
151    pub fn state(&self) -> RateLimiterState {
152        self.refill_tokens();
153        RateLimiterState {
154            requests_per_second: self.algorithm.requests_per_second(),
155            available_tokens: self.tokens.load(Ordering::Acquire),
156            bucket_capacity: self.bucket_capacity,
157        }
158    }
159}
160
161#[async_trait]
162impl<T> RateLimiter for DefaultRateLimiter<T>
163where
164    T: RateLimitAlgorithm + Sync + Debug,
165{
166    async fn acquire(&self) -> Token {
167        loop {
168            // Fast path: try to consume token without refill check
169            if self
170                .tokens
171                .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
172                    if current > 0 {
173                        Some(current - 1)
174                    } else {
175                        None
176                    }
177                })
178                .is_ok()
179            {
180                return Token {
181                    start_time: Instant::now(),
182                };
183            }
184
185            // Slow path: refill and retry
186            self.refill_tokens();
187
188            // Try again after refill
189            if self
190                .tokens
191                .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
192                    if current > 0 {
193                        Some(current - 1)
194                    } else {
195                        None
196                    }
197                })
198                .is_ok()
199            {
200                return Token {
201                    start_time: Instant::now(),
202                };
203            }
204
205            let now_nanos = std::time::SystemTime::now()
206                .duration_since(std::time::UNIX_EPOCH)
207                .unwrap()
208                .as_nanos() as u64;
209
210            let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
211            let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
212            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
213
214            if elapsed_nanos < refill_interval {
215                let wait_nanos = refill_interval - elapsed_nanos;
216
217                tokio::time::sleep(Duration::from_nanos(wait_nanos)).await;
218            } else {
219                tokio::task::yield_now().await;
220            }
221        }
222    }
223
224    async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
225        timeout(duration, self.acquire()).await.ok()
226    }
227
228    async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
229        let response_time = token.start_time.elapsed();
230
231        if let Some(outcome) = outcome {
232            let current_rps = self.requests_per_second.load(Ordering::Relaxed);
233            let sample = RequestSample::new(response_time, current_rps, outcome);
234
235            let new_rps = self.algorithm.update(sample).await;
236            self.requests_per_second.store(new_rps, Ordering::Relaxed);
237
238            // Update refill interval if RPS changed
239            if new_rps != current_rps && new_rps > 0 {
240                self.refill_interval_nanos
241                    .store(1_000_000_000 / new_rps, Ordering::Relaxed);
242            }
243        }
244    }
245}
246
247impl RateLimiterState {
248    /// The current requests per second limit.
249    pub fn requests_per_second(&self) -> RequestCount {
250        self.requests_per_second
251    }
252    /// The number of available tokens in the bucket.
253    pub fn available_tokens(&self) -> RequestCount {
254        self.available_tokens
255    }
256    /// The maximum bucket capacity.
257    pub fn bucket_capacity(&self) -> RequestCount {
258        self.bucket_capacity
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use crate::{
265        algorithms::Fixed,
266        limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
267    };
268    use std::time::Duration;
269
270    #[tokio::test]
271    async fn rate_limiter_allows_requests_within_limit() {
272        let limiter = DefaultRateLimiter::new(Fixed::new(10));
273
274        // Should allow first request
275        let token = limiter.acquire().await;
276
277        // Release with successful outcome
278        limiter.release(token, Some(RequestOutcome::Success)).await;
279    }
280
281    #[tokio::test]
282    async fn rate_limiter_waits_for_tokens() {
283        use std::sync::Arc;
284
285        let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
286
287        // Consume the only token
288        let token1 = limiter.acquire().await;
289
290        // Start acquiring second token (should wait)
291        let limiter_clone = Arc::clone(&limiter);
292        let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
293
294        // Give it a moment to start waiting
295        tokio::time::sleep(Duration::from_millis(10)).await;
296
297        // Release the first token - this should allow the second acquire to complete
298        limiter.release(token1, Some(RequestOutcome::Success)).await;
299
300        // The second acquire should now complete
301        let token2 = acquire_task.await.unwrap();
302        limiter.release(token2, Some(RequestOutcome::Success)).await;
303    }
304}