request_rate_limiter/
limiter.rs

1//! Rate limiters for controlling request throughput.
2
3use std::{
4    fmt::Debug,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::{sleep, timeout};
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18type AtomicRequestCount = AtomicU64;
19
20/// A token representing permission to make a request.
21/// The token tracks when the request was started for timing measurements.
22#[derive(Debug)]
23pub struct Token {
24    start_time: Instant
25}
26
27/// Controls the rate of requests over time.
28///
29/// Rate limiting is achieved by checking if a request is allowed based on the current
30/// rate limit algorithm. The limiter tracks request patterns and adjusts limits dynamically
31/// based on observed success/failure rates and response times.
32#[async_trait]
33pub trait RateLimiter: Debug + Sync {
34    /// Acquire permission to make a request. Waits until a token is available.
35    async fn acquire(&self) -> Token;
36
37    /// Acquire permission to make a request with a timeout. Returns a token if successful.
38    async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
39
40    /// Release the token and record the outcome of the request.
41    /// The response time is calculated from when the token was acquired.
42    async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
43}
44
45/// A token bucket based rate limiter.
46///
47/// Cheaply cloneable.
48#[derive(Debug)]
49pub struct DefaultRateLimiter<T> {
50    algorithm: T,
51    tokens: Arc<AtomicRequestCount>,
52    last_refill: Arc<std::sync::Mutex<Instant>>,
53    requests_per_second: Arc<AtomicRequestCount>,
54    bucket_capacity: RequestCount,
55}
56
57/// A snapshot of the state of the rate limiter.
58///
59/// Not guaranteed to be consistent under high concurrency.
60#[derive(Debug, Clone, Copy)]
61pub struct RateLimiterState {
62    /// Current requests per second limit
63    requests_per_second: RequestCount,
64    /// Available tokens in the bucket
65    available_tokens: RequestCount,
66    /// Maximum bucket capacity
67    bucket_capacity: RequestCount,
68}
69
70/// Whether a request succeeded or failed, potentially due to overload.
71///
72/// Errors not considered to be caused by overload should be ignored.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum RequestOutcome {
75    /// The request succeeded, or failed in a way unrelated to overload.
76    Success,
77    /// The request failed because of overload, e.g. it timed out or received a 429/503 response.
78    Overload,
79    /// The request failed due to client error (4xx) - not related to rate limiting.
80    ClientError,
81}
82
83impl<T> DefaultRateLimiter<T>
84where
85    T: RateLimitAlgorithm,
86{
87    /// Create a rate limiter with a given rate limiting algorithm.
88    pub fn new(algorithm: T) -> Self {
89        let initial_rps = algorithm.requests_per_second();
90        let bucket_capacity = initial_rps; // Use the same value for bucket capacity
91        
92        assert!(initial_rps >= 1);
93        Self {
94            algorithm,
95            tokens: Arc::new(AtomicRequestCount::new(bucket_capacity)),
96            last_refill: Arc::new(std::sync::Mutex::new(Instant::now())),
97            requests_per_second: Arc::new(AtomicRequestCount::new(initial_rps)),
98            bucket_capacity,
99        }
100    }
101
102    fn refill_tokens(&self) {
103        let now = Instant::now();
104        if let Ok(mut last_refill) = self.last_refill.try_lock() {
105            let elapsed = now.duration_since(*last_refill);
106            let tokens_to_add = (elapsed.as_secs_f64() * self.requests_per_second.load(Ordering::Acquire) as f64) as u64;
107            
108            if tokens_to_add > 0 {
109                let current_tokens = self.tokens.load(Ordering::Acquire);
110                let new_tokens = (current_tokens + tokens_to_add).min(self.bucket_capacity);
111                self.tokens.store(new_tokens, Ordering::SeqCst);
112                *last_refill = now;
113            }
114        }
115    }
116
117    /// The current state of the rate limiter.
118    pub fn state(&self) -> RateLimiterState {
119        self.refill_tokens();
120        RateLimiterState {
121            requests_per_second: self.requests_per_second.load(Ordering::Acquire),
122            available_tokens: self.tokens.load(Ordering::Acquire),
123            bucket_capacity: self.bucket_capacity,
124        }
125    }
126}
127
128#[async_trait]
129impl<T> RateLimiter for DefaultRateLimiter<T>
130where
131    T: RateLimitAlgorithm + Sync + Debug,
132{
133    async fn acquire(&self) -> Token {
134        loop {
135            self.refill_tokens();
136            
137            // Try to consume a token atomically
138            let current_tokens = self.tokens.load(Ordering::Acquire);
139            if current_tokens > 0 {
140                match self.tokens.compare_exchange_weak(
141                    current_tokens,
142                    current_tokens - 1,
143                    Ordering::Release,
144                    Ordering::Relaxed,
145                ) {
146                    Ok(_) => {
147                        return Token {
148                            start_time: Instant::now(),
149                        };
150                    },
151                    Err(_) => continue, // Retry on contention
152                }
153            } else {
154                // No tokens available, wait a bit before retrying
155                sleep(Duration::from_millis(1)).await;
156            }
157        }
158    }
159
160    async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
161        match timeout(duration, self.acquire()).await {
162            Ok(token) => Some(token),
163            Err(_) => None, // Timeout
164        }
165    }
166
167    async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
168        let response_time = token.start_time.elapsed();
169        
170        if let Some(outcome) = outcome {
171            let current_rps = self.requests_per_second.load(Ordering::Acquire);
172            let sample = RequestSample::new(response_time, current_rps, outcome);
173            
174            let new_rps = self.algorithm.update(sample).await;
175            self.requests_per_second.store(new_rps, Ordering::Release);
176        }
177    }
178}
179
180impl RateLimiterState {
181    /// The current requests per second limit.
182    pub fn requests_per_second(&self) -> RequestCount {
183        self.requests_per_second
184    }
185    /// The number of available tokens in the bucket.
186    pub fn available_tokens(&self) -> RequestCount {
187        self.available_tokens
188    }
189    /// The maximum bucket capacity.
190    pub fn bucket_capacity(&self) -> RequestCount {
191        self.bucket_capacity
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use crate::{
198        limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
199        algorithms::Fixed,
200    };
201    use std::time::Duration;
202
203    #[tokio::test]
204    async fn rate_limiter_allows_requests_within_limit() {
205        let limiter = DefaultRateLimiter::new(Fixed::new(10));
206
207        // Should allow first request
208        let token = limiter.acquire().await;
209        
210        // Release with successful outcome
211        limiter.release(token, Some(RequestOutcome::Success)).await;
212    }
213
214    #[tokio::test]
215    async fn rate_limiter_waits_for_tokens() {
216        use std::sync::Arc;
217        
218        let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
219
220        // Consume the only token
221        let token1 = limiter.acquire().await;
222        
223        // Start acquiring second token (should wait)
224        let limiter_clone = Arc::clone(&limiter);
225        let acquire_task = tokio::spawn(async move {
226            limiter_clone.acquire().await
227        });
228        
229        // Give it a moment to start waiting
230        tokio::time::sleep(Duration::from_millis(10)).await;
231        
232        // Release the first token - this should allow the second acquire to complete
233        limiter.release(token1, Some(RequestOutcome::Success)).await;
234        
235        // The second acquire should now complete
236        let token2 = acquire_task.await.unwrap();
237        limiter.release(token2, Some(RequestOutcome::Success)).await;
238    }
239}