request_rate_limiter/algorithms/
aimd.rs

1use std::{
2    ops::RangeInclusive,
3    sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13/// Loss-based request rate limiting.
14///
15/// Additive-increase, multiplicative decrease for requests per second.
16///
17/// Increases request rate when:
18/// 1. no overload errors are observed, and
19/// 2. requests are successful.
20///
21/// Reduces request rate by a factor when overload is detected.
22#[derive(Debug)]
23pub struct Aimd {
24    min_rps: u64,
25    max_rps: u64,
26    decrease_factor: f64,
27    increase_by: u64,
28
29    requests_per_second: AtomicU64,
30}
31
32impl Clone for Aimd {
33    fn clone(&self) -> Self {
34        Self {
35            min_rps: self.min_rps.clone(),
36            max_rps: self.max_rps.clone(),
37            decrease_factor: self.decrease_factor.clone(),
38            increase_by: self.increase_by.clone(),
39            requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
40        }
41    }
42}
43
44impl Aimd {
45    const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
46    const DEFAULT_INCREASE: u64 = 10;
47
48    #[allow(missing_docs)]
49    pub fn new_with_initial_rate(initial_rps: u64) -> Self {
50        Self::new(initial_rps, 1..=10000)
51    }
52
53    #[allow(missing_docs)]
54    pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
55        assert!(*rate_range.start() >= 1, "Rate must be at least 1");
56        assert!(
57            initial_rps >= *rate_range.start(),
58            "Initial rate less than minimum"
59        );
60        assert!(
61            initial_rps <= *rate_range.end(),
62            "Initial rate more than maximum"
63        );
64
65        Self {
66            min_rps: *rate_range.start(),
67            max_rps: *rate_range.end(),
68            decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
69            increase_by: Self::DEFAULT_INCREASE,
70
71            requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
72        }
73    }
74
75    /// Set the multiplier which will be applied when decreasing the rate.
76    pub fn decrease_factor(self, factor: f64) -> Self {
77        assert!((0.1..1.0).contains(&factor));
78        Self {
79            decrease_factor: factor,
80            ..self
81        }
82    }
83
84    /// Set the increment which will be applied when increasing the rate.
85    pub fn increase_by(self, increase: u64) -> Self {
86        assert!(increase > 0);
87        Self {
88            increase_by: increase,
89            ..self
90        }
91    }
92
93    #[allow(missing_docs)]
94    pub fn with_max_rate(self, max: u64) -> Self {
95        assert!(max > 0);
96        Self {
97            max_rps: max,
98            ..self
99        }
100    }
101}
102
103#[async_trait]
104impl RateLimitAlgorithm for Aimd {
105    fn requests_per_second(&self) -> u64 {
106        self.requests_per_second.load(Ordering::Acquire)
107    }
108
109    async fn update(&self, sample: RequestSample) -> u64 {
110        use RequestOutcome::*;
111        match sample.outcome {
112            Success | ClientError => {
113                self.requests_per_second
114                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
115                        let new_rps = rps + self.increase_by;
116                        Some(new_rps.clamp(self.min_rps, self.max_rps))
117                    })
118                    .expect("we always return Some(rps)");
119            }
120            Overload => {
121                self.requests_per_second
122                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
123                        let new_rps = multiplicative_decrease(rps, self.decrease_factor);
124                        Some(new_rps.clamp(self.min_rps, self.max_rps))
125                    })
126                    .expect("we always return Some(rps)");
127            }
128        }
129        self.requests_per_second.load(Ordering::SeqCst)
130    }
131}
132
133pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
134    assert!(decrease_factor <= 1.0, "should not increase the rate");
135
136    let new_rps = rps as f64 * decrease_factor;
137
138    // Floor instead of round, so the rate reduces even with small numbers.
139    // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
140    new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
141}
142
143#[cfg(test)]
144mod tests {
145
146    use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
147
148    use super::*;
149
150    #[tokio::test]
151    async fn should_decrease_rate_on_overload() {
152        let aimd = Aimd::new_with_initial_rate(100)
153            .decrease_factor(0.5)
154            .increase_by(10);
155
156        let limiter = DefaultRateLimiter::new(aimd);
157
158        let token = limiter.acquire().await;
159        limiter.release(token, Some(RequestOutcome::Overload)).await;
160
161        let state = limiter.state();
162        assert_eq!(
163            state.requests_per_second(),
164            50,
165            "overload: should decrease rate"
166        );
167    }
168
169    #[tokio::test]
170    async fn should_increase_rate_on_success() {
171        let aimd = Aimd::new_with_initial_rate(50)
172            .decrease_factor(0.5)
173            .increase_by(20);
174
175        let limiter = DefaultRateLimiter::new(aimd);
176
177        let token = limiter.acquire().await;
178        limiter.release(token, Some(RequestOutcome::Success)).await;
179
180        let state = limiter.state();
181        assert_eq!(
182            state.requests_per_second(),
183            70,
184            "success: should increase rate"
185        );
186    }
187
188    #[tokio::test]
189    async fn should_not_change_rate_when_no_outcome() {
190        let aimd = Aimd::new_with_initial_rate(100)
191            .decrease_factor(0.5)
192            .increase_by(10);
193
194        let limiter = DefaultRateLimiter::new(aimd);
195
196        let token = limiter.acquire().await;
197        limiter.release(token, None).await;
198
199        let state = limiter.state();
200        assert_eq!(
201            state.requests_per_second(),
202            100,
203            "should ignore when no outcome"
204        );
205    }
206}