Skip to main content

request_rate_limiter/algorithms/
aimd.rs

1use std::{
2    ops::RangeInclusive,
3    sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13/// Loss-based request rate limiting.
14#[derive(Debug)]
15pub struct Aimd {
16    min_rps: u64,
17    max_rps: u64,
18    decrease_factor: f64,
19    increase_by: u64,
20
21    requests_per_second: AtomicU64,
22    last_decrease_nanos: AtomicU64,
23    last_increase_nanos: AtomicU64,
24}
25
26impl Clone for Aimd {
27    fn clone(&self) -> Self {
28        Self {
29            min_rps: self.min_rps,
30            max_rps: self.max_rps,
31            decrease_factor: self.decrease_factor,
32            increase_by: self.increase_by,
33            requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
34            last_decrease_nanos: AtomicU64::new(self.last_decrease_nanos.load(Ordering::Acquire)),
35            last_increase_nanos: AtomicU64::new(self.last_increase_nanos.load(Ordering::Acquire)),
36        }
37    }
38}
39
40impl Aimd {
41    const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
42    const DEFAULT_INCREASE: u64 = 10;
43
44    #[allow(missing_docs)]
45    pub fn new_with_initial_rate(initial_rps: u64) -> Self {
46        Self::new(initial_rps, 1..=10000)
47    }
48
49    #[allow(missing_docs)]
50    pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
51        assert!(*rate_range.start() >= 1, "Rate must be at least 1");
52        assert!(
53            initial_rps >= *rate_range.start(),
54            "Initial rate less than minimum"
55        );
56        assert!(
57            initial_rps <= *rate_range.end(),
58            "Initial rate more than maximum"
59        );
60
61        Self {
62            min_rps: *rate_range.start(),
63            max_rps: *rate_range.end(),
64            decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
65            increase_by: Self::DEFAULT_INCREASE,
66
67            requests_per_second: AtomicU64::new(initial_rps),
68            last_decrease_nanos: AtomicU64::new(0),
69            last_increase_nanos: AtomicU64::new(0),
70        }
71    }
72
73    /// Set the multiplier which will be applied when decreasing the rate.
74    pub fn decrease_factor(self, factor: f64) -> Self {
75        assert!((0.1..1.0).contains(&factor));
76        Self {
77            decrease_factor: factor,
78            ..self
79        }
80    }
81
82    /// Set the increment which will be applied when increasing the rate.
83    pub fn increase_by(self, increase: u64) -> Self {
84        assert!(increase > 0);
85        Self {
86            increase_by: increase,
87            ..self
88        }
89    }
90
91    #[allow(missing_docs)]
92    pub fn with_max_rate(self, max: u64) -> Self {
93        assert!(max > 0);
94        Self {
95            max_rps: max,
96            ..self
97        }
98    }
99}
100
101#[async_trait]
102impl RateLimitAlgorithm for Aimd {
103    fn requests_per_second(&self) -> u64 {
104        self.requests_per_second.load(Ordering::Acquire)
105    }
106
107    async fn update(&self, sample: RequestSample) -> u64 {
108        use RequestOutcome::*;
109
110        let now = std::time::SystemTime::now()
111            .duration_since(std::time::UNIX_EPOCH)
112            .unwrap()
113            .as_nanos() as u64;
114
115        match sample.outcome {
116            Success | ClientError => {
117                let last = self.last_increase_nanos.load(Ordering::Relaxed);
118                // Повышаем лимит не чаще, чем раз в секунду (Cooldown)
119                if now.saturating_sub(last) >= 1_000_000_000 {
120                    if self
121                        .last_increase_nanos
122                        .compare_exchange_weak(last, now, Ordering::Relaxed, Ordering::Relaxed)
123                        .is_ok()
124                    {
125                        let mut updated = 0;
126                        self.requests_per_second
127                            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
128                                updated = (rps + self.increase_by).min(self.max_rps);
129                                Some(updated)
130                            })
131                            .unwrap();
132                        return updated;
133                    }
134                }
135            }
136            Overload => {
137                let last = self.last_decrease_nanos.load(Ordering::Relaxed);
138                // Снижаем лимит не чаще, чем раз в секунду
139                // Защищает от лавины 429 ошибок, которые мгновенно роняют лимит в 0
140                if now.saturating_sub(last) >= 1_000_000_000 {
141                    if self
142                        .last_decrease_nanos
143                        .compare_exchange_weak(last, now, Ordering::Relaxed, Ordering::Relaxed)
144                        .is_ok()
145                    {
146                        let mut updated = 0;
147                        self.requests_per_second
148                            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
149                                updated = multiplicative_decrease(rps, self.decrease_factor)
150                                    .max(self.min_rps);
151                                Some(updated)
152                            })
153                            .unwrap();
154                        return updated;
155                    }
156                }
157            }
158        }
159
160        // Если кулдаун не прошел, возвращаем текущий лимит без изменений
161        self.requests_per_second.load(Ordering::Relaxed)
162    }
163}
164
165pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
166    assert!(decrease_factor <= 1.0, "should not increase the rate");
167
168    let new_rps = rps as f64 * decrease_factor;
169    new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
176
177    #[tokio::test]
178    async fn should_decrease_rate_on_overload() {
179        let aimd = Aimd::new_with_initial_rate(100)
180            .decrease_factor(0.5)
181            .increase_by(10);
182        let limiter = DefaultRateLimiter::new(aimd);
183        let token = limiter.acquire().await;
184        limiter.release(token, Some(RequestOutcome::Overload)).await;
185        let state = limiter.state();
186        assert_eq!(state.requests_per_second(), 50);
187    }
188
189    #[tokio::test]
190    async fn should_increase_rate_on_success() {
191        let aimd = Aimd::new_with_initial_rate(50)
192            .decrease_factor(0.5)
193            .increase_by(20);
194        let limiter = DefaultRateLimiter::new(aimd);
195        let token = limiter.acquire().await;
196        limiter.release(token, Some(RequestOutcome::Success)).await;
197        let state = limiter.state();
198        assert_eq!(state.requests_per_second(), 70);
199    }
200
201    #[tokio::test]
202    async fn should_not_change_rate_when_no_outcome() {
203        let aimd = Aimd::new_with_initial_rate(100)
204            .decrease_factor(0.5)
205            .increase_by(10);
206        let limiter = DefaultRateLimiter::new(aimd);
207        let token = limiter.acquire().await;
208        limiter.release(token, None).await;
209        let state = limiter.state();
210        assert_eq!(state.requests_per_second(), 100);
211    }
212}