request_rate_limiter/algorithms/
aimd.rs

1use std::{
2    ops::RangeInclusive,
3    sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13/// Loss-based request rate limiting.
14///
15/// Additive-increase, multiplicative decrease for requests per second.
16///
17/// Increases request rate when:
18/// 1. no overload errors are observed, and
19/// 2. requests are successful.
20///
21/// Reduces request rate by a factor when overload is detected.
22#[derive(Debug)]
23pub struct Aimd {
24    min_rps: u64,
25    max_rps: u64,
26    decrease_factor: f64,
27    increase_by: u64,
28
29    requests_per_second: AtomicU64,
30}
31
32impl Aimd {
33    const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
34    const DEFAULT_INCREASE: u64 = 10;
35
36    #[allow(missing_docs)]
37    pub fn new_with_initial_rate(initial_rps: u64) -> Self {
38        Self::new(initial_rps, 1..=10000)
39    }
40
41    #[allow(missing_docs)]
42    pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
43        assert!(*rate_range.start() >= 1, "Rate must be at least 1");
44        assert!(
45            initial_rps >= *rate_range.start(),
46            "Initial rate less than minimum"
47        );
48        assert!(
49            initial_rps <= *rate_range.end(),
50            "Initial rate more than maximum"
51        );
52
53        Self {
54            min_rps: *rate_range.start(),
55            max_rps: *rate_range.end(),
56            decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
57            increase_by: Self::DEFAULT_INCREASE,
58
59            requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
60        }
61    }
62
63    /// Set the multiplier which will be applied when decreasing the rate.
64    pub fn decrease_factor(self, factor: f64) -> Self {
65        assert!((0.1..1.0).contains(&factor));
66        Self {
67            decrease_factor: factor,
68            ..self
69        }
70    }
71
72    /// Set the increment which will be applied when increasing the rate.
73    pub fn increase_by(self, increase: u64) -> Self {
74        assert!(increase > 0);
75        Self {
76            increase_by: increase,
77            ..self
78        }
79    }
80
81    #[allow(missing_docs)]
82    pub fn with_max_rate(self, max: u64) -> Self {
83        assert!(max > 0);
84        Self {
85            max_rps: max,
86            ..self
87        }
88    }
89}
90
91#[async_trait]
92impl RateLimitAlgorithm for Aimd {
93    fn requests_per_second(&self) -> u64 {
94        self.requests_per_second.load(Ordering::Acquire)
95    }
96
97    async fn update(&self, sample: RequestSample) -> u64 {
98        use RequestOutcome::*;
99        match sample.outcome {
100            Success | ClientError => {
101                self.requests_per_second
102                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
103                        let new_rps = rps + self.increase_by;
104                        Some(new_rps.clamp(self.min_rps, self.max_rps))
105                    })
106                    .expect("we always return Some(rps)");
107            }
108            Overload => {
109                self.requests_per_second
110                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
111                        let new_rps = multiplicative_decrease(rps, self.decrease_factor);
112                        Some(new_rps.clamp(self.min_rps, self.max_rps))
113                    })
114                    .expect("we always return Some(rps)");
115            }
116        }
117        self.requests_per_second.load(Ordering::SeqCst)
118    }
119}
120
121pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
122    assert!(decrease_factor <= 1.0, "should not increase the rate");
123
124    let new_rps = rps as f64 * decrease_factor;
125
126    // Floor instead of round, so the rate reduces even with small numbers.
127    // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
128    new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
129}
130
131#[cfg(test)]
132mod tests {
133
134    use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
135
136    use super::*;
137
138    #[tokio::test]
139    async fn should_decrease_rate_on_overload() {
140        let aimd = Aimd::new_with_initial_rate(100)
141            .decrease_factor(0.5)
142            .increase_by(10);
143
144        let limiter = DefaultRateLimiter::new(aimd);
145
146        let token = limiter.acquire().await;
147        limiter.release(token, Some(RequestOutcome::Overload)).await;
148
149        let state = limiter.state();
150        assert_eq!(
151            state.requests_per_second(),
152            50,
153            "overload: should decrease rate"
154        );
155    }
156
157    #[tokio::test]
158    async fn should_increase_rate_on_success() {
159        let aimd = Aimd::new_with_initial_rate(50)
160            .decrease_factor(0.5)
161            .increase_by(20);
162
163        let limiter = DefaultRateLimiter::new(aimd);
164
165        let token = limiter.acquire().await;
166        limiter.release(token, Some(RequestOutcome::Success)).await;
167
168        let state = limiter.state();
169        assert_eq!(
170            state.requests_per_second(),
171            70,
172            "success: should increase rate"
173        );
174    }
175
176    #[tokio::test]
177    async fn should_not_change_rate_when_no_outcome() {
178        let aimd = Aimd::new_with_initial_rate(100)
179            .decrease_factor(0.5)
180            .increase_by(10);
181
182        let limiter = DefaultRateLimiter::new(aimd);
183
184        let token = limiter.acquire().await;
185        limiter.release(token, None).await;
186
187        let state = limiter.state();
188        assert_eq!(
189            state.requests_per_second(),
190            100,
191            "should ignore when no outcome"
192        );
193    }
194}