request_rate_limiter/algorithms/
aimd.rs

1use std::{
2    ops::RangeInclusive,
3    sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13/// Loss-based request rate limiting.
14///
15/// Additive-increase, multiplicative decrease for requests per second.
16///
17/// Increases request rate when:
18/// 1. no overload errors are observed, and
19/// 2. requests are successful.
20///
21/// Reduces request rate by a factor when overload is detected.
22#[derive(Debug)]
23pub struct Aimd {
24    min_rps: u64,
25    max_rps: u64,
26    decrease_factor: f64,
27    increase_by: u64,
28
29    requests_per_second: AtomicU64,
30}
31
32impl Clone for Aimd {
33    fn clone(&self) -> Self {
34        Self {
35            min_rps: self.min_rps,
36            max_rps: self.max_rps,
37            decrease_factor: self.decrease_factor,
38            increase_by: self.increase_by,
39            requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
40        }
41    }
42}
43
44impl Aimd {
45    const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
46    const DEFAULT_INCREASE: u64 = 10;
47
48    #[allow(missing_docs)]
49    pub fn new_with_initial_rate(initial_rps: u64) -> Self {
50        Self::new(initial_rps, 1..=10000)
51    }
52
53    #[allow(missing_docs)]
54    pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
55        assert!(*rate_range.start() >= 1, "Rate must be at least 1");
56        assert!(
57            initial_rps >= *rate_range.start(),
58            "Initial rate less than minimum"
59        );
60        assert!(
61            initial_rps <= *rate_range.end(),
62            "Initial rate more than maximum"
63        );
64
65        Self {
66            min_rps: *rate_range.start(),
67            max_rps: *rate_range.end(),
68            decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
69            increase_by: Self::DEFAULT_INCREASE,
70
71            requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
72        }
73    }
74
75    /// Set the multiplier which will be applied when decreasing the rate.
76    pub fn decrease_factor(self, factor: f64) -> Self {
77        assert!((0.1..1.0).contains(&factor));
78        Self {
79            decrease_factor: factor,
80            ..self
81        }
82    }
83
84    /// Set the increment which will be applied when increasing the rate.
85    pub fn increase_by(self, increase: u64) -> Self {
86        assert!(increase > 0);
87        Self {
88            increase_by: increase,
89            ..self
90        }
91    }
92
93    #[allow(missing_docs)]
94    pub fn with_max_rate(self, max: u64) -> Self {
95        assert!(max > 0);
96        Self {
97            max_rps: max,
98            ..self
99        }
100    }
101}
102
103#[async_trait]
104impl RateLimitAlgorithm for Aimd {
105    fn requests_per_second(&self) -> u64 {
106        self.requests_per_second.load(Ordering::Acquire)
107    }
108
109    async fn update(&self, sample: RequestSample) -> u64 {
110        use RequestOutcome::*;
111        match sample.outcome {
112            Success | ClientError => self
113                .requests_per_second
114                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
115                    let new_rps = (rps + self.increase_by).min(self.max_rps);
116                    Some(new_rps)
117                })
118                .expect("we always return Some(rps)"),
119            Overload => self
120                .requests_per_second
121                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
122                    let new_rps =
123                        multiplicative_decrease(rps, self.decrease_factor).max(self.min_rps);
124                    Some(new_rps)
125                })
126                .expect("we always return Some(rps)"),
127        }
128    }
129}
130
131pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
132    assert!(decrease_factor <= 1.0, "should not increase the rate");
133
134    let new_rps = rps as f64 * decrease_factor;
135
136    // Floor instead of round, so the rate reduces even with small numbers.
137    // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
138    new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
139}
140
141#[cfg(test)]
142mod tests {
143
144    use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
145
146    use super::*;
147
148    #[tokio::test]
149    async fn should_decrease_rate_on_overload() {
150        let aimd = Aimd::new_with_initial_rate(100)
151            .decrease_factor(0.5)
152            .increase_by(10);
153
154        let limiter = DefaultRateLimiter::new(aimd);
155
156        let token = limiter.acquire().await;
157        limiter.release(token, Some(RequestOutcome::Overload)).await;
158
159        let state = limiter.state();
160        assert_eq!(
161            state.requests_per_second(),
162            50,
163            "overload: should decrease rate"
164        );
165    }
166
167    #[tokio::test]
168    async fn should_increase_rate_on_success() {
169        let aimd = Aimd::new_with_initial_rate(50)
170            .decrease_factor(0.5)
171            .increase_by(20);
172
173        let limiter = DefaultRateLimiter::new(aimd);
174
175        let token = limiter.acquire().await;
176        limiter.release(token, Some(RequestOutcome::Success)).await;
177
178        let state = limiter.state();
179        assert_eq!(
180            state.requests_per_second(),
181            70,
182            "success: should increase rate"
183        );
184    }
185
186    #[tokio::test]
187    async fn should_not_change_rate_when_no_outcome() {
188        let aimd = Aimd::new_with_initial_rate(100)
189            .decrease_factor(0.5)
190            .increase_by(10);
191
192        let limiter = DefaultRateLimiter::new(aimd);
193
194        let token = limiter.acquire().await;
195        limiter.release(token, None).await;
196
197        let state = limiter.state();
198        assert_eq!(
199            state.requests_per_second(),
200            100,
201            "should ignore when no outcome"
202        );
203    }
204}