Skip to main content

request_rate_limiter/
limiter.rs

1//! Rate limiters for controlling request throughput.
2
3use std::{
4    fmt::Debug,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9    time::{Duration, Instant},
10};
11
12use async_trait::async_trait;
13use tokio::time::timeout;
14
15use crate::algorithms::{RateLimitAlgorithm, RequestSample};
16
17type RequestCount = u64;
18
19/// A token representing permission to make a request.
20#[derive(Debug)]
21pub struct Token {
22    start_time: Instant,
23}
24
25/// Controls the rate of requests over time.
26///
27/// Rate limiting is achieved by checking if a request is allowed based on the current
28/// rate limit algorithm. The limiter tracks request patterns and adjusts limits dynamically
29/// based on observed success/failure rates and response times.
30#[async_trait]
31pub trait RateLimiter: Debug + Sync {
32    /// Acquire permission to make a request. Waits until a token is available.
33    async fn acquire(&self) -> Token;
34
35    /// Acquire permission to make a request with a timeout. Returns a token if successful.
36    async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
37
38    /// Release the token and record the outcome of the request.
39    /// The response time is calculated from when the token was acquired.
40    async fn release(&self, token: Token, outcome: Option<RequestOutcome>);
41}
42
43/// A token bucket based rate limiter.
44/// Cheaply cloneable.
45#[derive(Debug, Clone)]
46pub struct DefaultRateLimiter<T> {
47    algorithm: T,
48    tokens: Arc<AtomicU64>,
49    last_refill_nanos: Arc<AtomicU64>,
50    requests_per_second: Arc<AtomicU64>,
51    refill_interval_nanos: Arc<AtomicU64>,
52}
53
54/// A snapshot of the state of the rate limiter.
55///
56/// Not guaranteed to be consistent under high concurrency.
57#[derive(Debug, Clone, Copy)]
58pub struct RateLimiterState {
59    /// Current requests per second limit
60    requests_per_second: RequestCount,
61    /// Available tokens in the bucket
62    available_tokens: RequestCount,
63    /// Maximum bucket capacity
64    bucket_capacity: RequestCount,
65}
66
67/// Whether a request succeeded or failed, potentially due to overload.
68///
69/// Errors not considered to be caused by overload should be ignored.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum RequestOutcome {
72    /// The request succeeded, or failed in a way unrelated to overload.
73    Success,
74    /// The request failed because of overload, e.g. it timed out or received a 429/503 response.
75    Overload,
76    /// The request failed due to client error (4xx) - not related to rate limiting.
77    ClientError,
78}
79
80impl<T> DefaultRateLimiter<T>
81where
82    T: RateLimitAlgorithm,
83{
84    /// Create a rate limiter with a given rate limiting algorithm.
85    pub fn new(algorithm: T) -> Self {
86        let initial_rps = algorithm.requests_per_second();
87        assert!(initial_rps >= 1);
88        let now_nanos = std::time::SystemTime::now()
89            .duration_since(std::time::UNIX_EPOCH)
90            .unwrap()
91            .as_nanos() as u64;
92
93        Self {
94            algorithm,
95            tokens: Arc::new(AtomicU64::new(initial_rps)),
96            last_refill_nanos: Arc::new(AtomicU64::new(now_nanos)),
97            requests_per_second: Arc::new(AtomicU64::new(initial_rps)),
98            refill_interval_nanos: Arc::new(AtomicU64::new(1_000_000_000 / initial_rps)),
99        }
100    }
101
102    fn refill_tokens(&self) {
103        // ДИНАМИЧЕСКИЙ размер корзины! Всегда равен текущему RPS
104        let bucket_capacity = self.requests_per_second.load(Ordering::Relaxed);
105        let current_tokens = self.tokens.load(Ordering::Relaxed);
106
107        // Урезаем токены, если capacity уменьшилась из-за 429 ошибки
108        if current_tokens > bucket_capacity {
109            self.tokens.store(bucket_capacity, Ordering::Relaxed);
110        }
111
112        let current_tokens = self.tokens.load(Ordering::Relaxed);
113        if current_tokens >= bucket_capacity {
114            return;
115        }
116
117        let now_nanos = std::time::SystemTime::now()
118            .duration_since(std::time::UNIX_EPOCH)
119            .unwrap()
120            .as_nanos() as u64;
121
122        let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
123        let elapsed_nanos = now_nanos.saturating_sub(last_refill);
124        let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
125
126        if elapsed_nanos >= refill_interval {
127            let tokens_to_add = elapsed_nanos / refill_interval;
128
129            if tokens_to_add > 0 {
130                // Продвигаем время ровно на интервалы добавленных токенов, чтобы не терять остаток
131                let actual_elapsed = tokens_to_add * refill_interval;
132
133                if self
134                    .last_refill_nanos
135                    .compare_exchange_weak(
136                        last_refill,
137                        last_refill + actual_elapsed,
138                        Ordering::Release,
139                        Ordering::Relaxed,
140                    )
141                    .is_ok()
142                {
143                    self.tokens
144                        .fetch_update(Ordering::Release, Ordering::Relaxed, |current| {
145                            let new_tokens = (current + tokens_to_add).min(bucket_capacity);
146                            if new_tokens > current {
147                                Some(new_tokens)
148                            } else {
149                                None
150                            }
151                        })
152                        .ok();
153                }
154            }
155        }
156    }
157
158    /// The current state of the rate limiter.
159    pub fn state(&self) -> RateLimiterState {
160        self.refill_tokens();
161        RateLimiterState {
162            requests_per_second: self.algorithm.requests_per_second(),
163            available_tokens: self.tokens.load(Ordering::Acquire),
164            bucket_capacity: self.requests_per_second.load(Ordering::Relaxed),
165        }
166    }
167}
168
169#[async_trait]
170impl<T> RateLimiter for DefaultRateLimiter<T>
171where
172    T: RateLimitAlgorithm + Sync + Debug,
173{
174    async fn acquire(&self) -> Token {
175        loop {
176            if self
177                .tokens
178                .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
179                    if current > 0 {
180                        Some(current - 1)
181                    } else {
182                        None
183                    }
184                })
185                .is_ok()
186            {
187                return Token {
188                    start_time: Instant::now(),
189                };
190            }
191
192            self.refill_tokens();
193
194            if self
195                .tokens
196                .fetch_update(Ordering::Acquire, Ordering::Relaxed, |current| {
197                    if current > 0 {
198                        Some(current - 1)
199                    } else {
200                        None
201                    }
202                })
203                .is_ok()
204            {
205                return Token {
206                    start_time: Instant::now(),
207                };
208            }
209
210            let now_nanos = std::time::SystemTime::now()
211                .duration_since(std::time::UNIX_EPOCH)
212                .unwrap()
213                .as_nanos() as u64;
214
215            let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
216            let refill_interval = self.refill_interval_nanos.load(Ordering::Relaxed);
217            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
218
219            if elapsed_nanos < refill_interval {
220                let wait_nanos = refill_interval - elapsed_nanos;
221                tokio::time::sleep(Duration::from_nanos(wait_nanos)).await;
222            } else {
223                tokio::task::yield_now().await;
224            }
225        }
226    }
227
228    async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
229        timeout(duration, self.acquire()).await.ok()
230    }
231
232    async fn release(&self, token: Token, outcome: Option<RequestOutcome>) {
233        let response_time = token.start_time.elapsed();
234
235        if let Some(outcome) = outcome {
236            let current_rps = self.requests_per_second.load(Ordering::Relaxed);
237            let sample = RequestSample::new(response_time, current_rps, outcome);
238
239            let new_rps = self.algorithm.update(sample).await;
240            self.requests_per_second.store(new_rps, Ordering::Relaxed);
241
242            if new_rps != current_rps && new_rps > 0 {
243                self.refill_interval_nanos
244                    .store(1_000_000_000 / new_rps, Ordering::Relaxed);
245            }
246        }
247    }
248}
249
250impl RateLimiterState {
251    /// The current requests per second limit.
252    pub fn requests_per_second(&self) -> RequestCount {
253        self.requests_per_second
254    }
255    /// The number of available tokens in the bucket.
256    pub fn available_tokens(&self) -> RequestCount {
257        self.available_tokens
258    }
259    /// The maximum bucket capacity.
260    pub fn bucket_capacity(&self) -> RequestCount {
261        self.bucket_capacity
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use crate::{
268        algorithms::Fixed,
269        limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome},
270    };
271    use std::time::Duration;
272
273    #[tokio::test]
274    async fn rate_limiter_allows_requests_within_limit() {
275        let limiter = DefaultRateLimiter::new(Fixed::new(10));
276        let token = limiter.acquire().await;
277        limiter.release(token, Some(RequestOutcome::Success)).await;
278    }
279
280    #[tokio::test]
281    async fn rate_limiter_waits_for_tokens() {
282        use std::sync::Arc;
283
284        let limiter = Arc::new(DefaultRateLimiter::new(Fixed::new(1)));
285        let token1 = limiter.acquire().await;
286
287        let limiter_clone = Arc::clone(&limiter);
288        let acquire_task = tokio::spawn(async move { limiter_clone.acquire().await });
289
290        tokio::time::sleep(Duration::from_millis(10)).await;
291        limiter.release(token1, Some(RequestOutcome::Success)).await;
292
293        let token2 = acquire_task.await.unwrap();
294        limiter.release(token2, Some(RequestOutcome::Success)).await;
295    }
296}