Skip to main content

request_rate_limiter/
limiter.rs

1//! Rate limiters for controlling request throughput.
2
3use std::{
4    fmt::Debug,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use crossbeam_utils::Backoff;
14use tokio::time::timeout;
15
16use crate::algorithms::{RateLimitAlgorithm, RequestSample};
17
18type RequestCount = u64;
19
20/// A token representing permission to make a request.
21/// The token tracks when the request was started for timing measurements.
22#[derive(Debug)]
23pub struct Token {
24    start_time: Instant,
25}
26
27/// Controls the rate of requests over time.
28///
29/// Rate limiting is achieved by checking if a request is allowed based on the current
30/// rate limit algorithm. The limiter tracks request patterns and adjusts limits dynamically
31/// based on observed success/failure rates and response times.
32#[async_trait]
33pub trait RateLimiter: Debug + Sync {
34    /// Acquire permission to make a request. Waits until a token is available.
35    async fn acquire(&self) -> Token;
36
37    /// Acquire permission to make a request with a timeout. Returns a token if successful.
38    async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
39
40    /// Release the token and record the outcome of the request.
41    /// The response time is calculated from when the token was acquired.
42    async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
43}
44
45/// A token bucket based rate limiter.
46///
47/// Cheaply cloneable.
48#[derive(Debug)]
49pub struct DefaultRateLimiter<T> {
50    algorithm: T,
51    tokens: Arc<AtomicU64>,
52    last_refill_nanos: Arc<AtomicU64>,
53    requests_per_second: Arc<AtomicU64>,
54    bucket_capacity: RequestCount,
55    refill_interval_nanos: Arc<AtomicU64>,
56}
57
58/// A snapshot of the state of the rate limiter.
59///
60/// Not guaranteed to be consistent under high concurrency.
61#[derive(Debug, Clone, Copy)]
62pub struct RateLimiterState {
63    /// Current requests per second limit
64    requests_per_second: RequestCount,
65    /// Available tokens in the bucket
66    available_tokens: RequestCount,
67    /// Maximum bucket capacity
68    bucket_capacity: RequestCount,
69}
70
71/// Whether a request succeeded or failed, potentially due to overload.
72///
73/// Errors not considered to be caused by overload should be ignored.
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum RequestOutcome {
76    /// The request succeeded, or failed in a way unrelated to overload.
77    Success,
78    /// The request failed because of overload, e.g. it timed out or received a 429/503 response.
79    Overload,
80    /// The request failed due to client error (4xx) - not related to rate limiting.
81    ClientError,
82}
83
84impl<T> DefaultRateLimiter<T>
85where
86    T: RateLimitAlgorithm,
87{
88    /// Create a rate limiter with a given rate limiting algorithm.
89    pub fn new(algorithm: T) -> Self {
90        let initial_rps = algorithm.requests_per_second();
91        let bucket_capacity = initial_rps; // Use the same value for bucket capacity
92
93        assert!(initial_rps >= 1);
94        let now_nanos = std::time::SystemTime::now()
95            .duration_since(std::time::UNIX_EPOCH)
96            .unwrap()
97            .as_nanos() as u64;
98
99        Self {
100            algorithm,
101            tokens: Arc::new(AtomicU64::new(bucket_capacity)),
102            last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
103            requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
104            bucket_capacity,
105            refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
106        }
107    }
108
109    #[inline]
110    fn refill_tokens(&self) {
111        let current_tokens = self.tokens.load(Ordering::Relaxed);
112        if current_tokens >= self.bucket_capacity {
113            return; // Already at capacity
114        }
115
116        let now_nanos = std::time::SystemTime::now()
117            .duration_since(std::time::UNIX_EPOCH)
118            .unwrap()
119            .as_nanos() as u64;
120
121        let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
122        let elapsed_nanos = now_nanos.saturating_sub(last_refill);
123        let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
124
125        if elapsed_nanos >= refill_interval {
126            let tokens_to_add = elapsed_nanos / refill_interval;
127
128            if tokens_to_add > 0 {
129                // Atomic update of both tokens and last_refill_nanos
130                let _ = self.last_refill_nanos.compare_exchange_weak(
131                    last_refill,
132                    now_nanos,
133                    Ordering::Release,
134                    Ordering::Relaxed,
135                );
136
137                self.tokens
138                    .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
139                        let new_tokens = (current + tokens_to_add).min(self.bucket_capacity);
140                        if new_tokens > current {
141                            Some(new_tokens)
142                        } else {
143                            None
144                        }
145                    })
146                    .ok();
147            }
148        }
149    }
150
151    /// The current state of the rate limiter.
152    pub fn state(&self) -> RateLimiterState {
153        self.refill_tokens();
154        RateLimiterState {
155            requests_per_second: self.algorithm.requests_per_second(),
156            available_tokens: self.tokens.load(Ordering::Acquire),
157            bucket_capacity: self.bucket_capacity,
158        }
159    }
160}
161
162#[async_trait]
163impl<T> RateLimiter for DefaultRateLimiter<T>
164where
165    T: RateLimitAlgorithm + Sync + Debug,
166{
167    async fn acquire(&self) -> Token {
168        let backoff = Backoff::new();
169
170        loop {
171            // Fast path: try to consume token without refill check
172            if self.tokens
173                    .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
174                        if current > 0 {
175                            Some(current - 1)
176                        } else {
177                            None
178                        }
179                    }).is_ok()
180            {
181                return Token {
182                    start_time: Instant::now(),
183                };
184            }
185
186            // Slow path: refill and retry
187            self.refill_tokens();
188
189            // Try again after refill
190            if self.tokens
191                    .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
192                        if current > 0 {
193                            Some(current - 1)
194                        } else {
195                            None
196                        }
197                    }).is_ok()
198            {
199                return Token {
200                    start_time: Instant::now(),
201                };
202            }
203
204            // Adaptive backoff
205            if backoff.is_completed() {
206                tokio::task::yield_now().await;
207                backoff.reset();
208            } else {
209                backoff.spin();
210            }
211        }
212    }
213
214    async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
215        timeout(duration, self.acquire()).await.ok()
216    }
217
218    async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
219        let response_time = token.start_time.elapsed();
220
221        if let Some(outcome) = outcome {
222            let current_rps = self.requests_per_second.load(Ordering::Relaxed);
223            let sample = RequestSample::new(response_time, current_rps, outcome);
224
225            let new_rps = self.algorithm.update(sample).await;
226            self.requests_per_second.store(new_rps, Ordering::Relaxed);
227
228            // Update refill interval if RPS changed
229            if new_rps != current_rps && new_rps > 0 {
230                self.refill_interval_nanos
231                    .store(1_000_000_000 / new_rps, Ordering::Relaxed);
232            }
233        }
234    }
235}
236
237impl RateLimiterState {
238    /// The current requests per second limit.
239    pub fn requests_per_second(&self) -> RequestCount {
240        self.requests_per_second
241    }
242    /// The number of available tokens in the bucket.
243    pub fn available_tokens(&self) -> RequestCount {
244        self.available_tokens
245    }
246    /// The maximum bucket capacity.
247    pub fn bucket_capacity(&self) -> RequestCount {
248        self.bucket_capacity
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use crate::{
255        algorithms::Fixed,
256        limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
257    };
258    use std::time::Duration;
259
260    #[tokio::test]
261    async fn rate_limiter_allows_requests_within_limit() {
262        let limiter = DefaultRateLimiter::new(Fixed::new(10));
263
264        // Should allow first request
265        let token = limiter.acquire().await;
266
267        // Release with successful outcome
268        limiter.release(token, Some(RequestOutcome::Success)).await;
269    }
270
271    #[tokio::test]
272    async fn rate_limiter_waits_for_tokens() {
273        use std::sync::Arc;
274
275        let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
276
277        // Consume the only token
278        let token1 = limiter.acquire().await;
279
280        // Start acquiring second token (should wait)
281        let limiter_clone = Arc::clone(&limiter);
282        let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
283
284        // Give it a moment to start waiting
285        tokio::time::sleep(Duration::from_millis(10)).await;
286
287        // Release the first token - this should allow the second acquire to complete
288        limiter.release(token1, Some(RequestOutcome::Success)).await;
289
290        // The second acquire should now complete
291        let token2 = acquire_task.await.unwrap();
292        limiter.release(token2, Some(RequestOutcome::Success)).await;
293    }
294}