congestion_limiter/limits/
aimd.rs

1use std::{
2    ops::RangeInclusive,
3    sync::atomic::{AtomicUsize, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvAsUtil;
8
9use crate::{limiter::Outcome, limits::Sample};
10
11use super::{defaults, LimitAlgorithm};
12
13/// Loss-based overload avoidance.
14///
15/// Additive-increase, multiplicative decrease.
16///
17/// Adds available currency when:
18/// 1. no load-based errors are observed, and
19/// 2. the utilisation of the current limit is high.
20///
21/// Reduces available concurrency by a factor when load-based errors are detected.
22#[derive(Debug)]
23pub struct Aimd {
24    min_limit: usize,
25    max_limit: usize,
26    decrease_factor: f64,
27    increase_by: usize,
28    min_utilisation_threshold: f64,
29
30    limit: AtomicUsize,
31}
32
33impl Aimd {
34    const DEFAULT_DECREASE_FACTOR: f64 = 0.9;
35    const DEFAULT_INCREASE: usize = 1;
36    const DEFAULT_INCREASE_MIN_UTILISATION: f64 = 0.8;
37
38    #[allow(missing_docs)]
39    pub fn new_with_initial_limit(initial_limit: usize) -> Self {
40        Self::new(
41            initial_limit,
42            defaults::DEFAULT_MIN_LIMIT..=defaults::DEFAULT_MAX_LIMIT,
43        )
44    }
45
46    #[allow(missing_docs)]
47    pub fn new(initial_limit: usize, limit_range: RangeInclusive<usize>) -> Self {
48        assert!(*limit_range.start() >= 1, "Limits must be at least 1");
49        assert!(
50            initial_limit >= *limit_range.start(),
51            "Initial limit less than minimum"
52        );
53        assert!(
54            initial_limit <= *limit_range.end(),
55            "Initial limit more than maximum"
56        );
57
58        Self {
59            min_limit: *limit_range.start(),
60            max_limit: *limit_range.end(),
61            decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
62            increase_by: Self::DEFAULT_INCREASE,
63            min_utilisation_threshold: Self::DEFAULT_INCREASE_MIN_UTILISATION,
64
65            limit: AtomicUsize::new(initial_limit),
66        }
67    }
68
69    /// Set the multiplier which will be applied when decreasing the limit.
70    pub fn decrease_factor(self, factor: f64) -> Self {
71        assert!((0.5..1.0).contains(&factor));
72        Self {
73            decrease_factor: factor,
74            ..self
75        }
76    }
77
78    /// Set the increment which will be applied when increasing the limit.
79    pub fn increase_by(self, increase: usize) -> Self {
80        assert!(increase > 0);
81        Self {
82            increase_by: increase,
83            ..self
84        }
85    }
86
87    #[allow(missing_docs)]
88    pub fn with_max_limit(self, max: usize) -> Self {
89        assert!(max > 0);
90        Self {
91            max_limit: max,
92            ..self
93        }
94    }
95
96    /// A threshold below which the limit won't be increased. 0.5 = 50%.
97    pub fn with_min_utilisation_threshold(self, min_util: f64) -> Self {
98        assert!(min_util > 0. && min_util < 1.);
99        Self {
100            min_utilisation_threshold: min_util,
101            ..self
102        }
103    }
104}
105
106#[async_trait]
107impl LimitAlgorithm for Aimd {
108    fn limit(&self) -> usize {
109        self.limit.load(Ordering::Acquire)
110    }
111
112    async fn update(&self, sample: Sample) -> usize {
113        use Outcome::*;
114        match sample.outcome {
115            Success => {
116                self.limit
117                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |limit| {
118                        let utilisation = sample.in_flight as f64 / limit as f64;
119
120                        if utilisation > self.min_utilisation_threshold {
121                            let limit = limit + self.increase_by;
122                            Some(limit.clamp(self.min_limit, self.max_limit))
123                        } else {
124                            Some(limit)
125                        }
126                    })
127                    .expect("we always return Some(limit)");
128            }
129            Overload => {
130                self.limit
131                    .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |limit| {
132                        let limit = multiplicative_decrease(limit, self.decrease_factor);
133
134                        Some(limit.clamp(self.min_limit, self.max_limit))
135                    })
136                    .expect("we always return Some(limit)");
137            }
138        }
139        self.limit.load(Ordering::SeqCst)
140    }
141}
142
143pub(super) fn multiplicative_decrease(limit: usize, decrease_factor: f64) -> usize {
144    assert!(decrease_factor <= 1.0, "should not increase the limit");
145
146    let limit = limit as f64 * decrease_factor;
147
148    // Floor instead of round, so the limit reduces even with small numbers.
149    // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1
150    limit.floor().approx().expect("should not have increased")
151}
152
153#[cfg(test)]
154mod tests {
155    use std::sync::Arc;
156
157    use tokio::sync::Notify;
158
159    use crate::limiter::{DefaultLimiter, Limiter};
160
161    use super::*;
162
163    #[tokio::test]
164    async fn should_decrease_limit_on_overload() {
165        let aimd = Aimd::new_with_initial_limit(10)
166            .decrease_factor(0.5)
167            .increase_by(1);
168
169        let release_notifier = Arc::new(Notify::new());
170
171        let limiter = DefaultLimiter::new(aimd).with_release_notifier(release_notifier.clone());
172
173        let token = limiter.try_acquire().await.unwrap();
174        limiter.release(token, Some(Outcome::Overload)).await;
175        release_notifier.notified().await;
176        assert_eq!(limiter.limit(), 5, "overload: decrease");
177    }
178
179    #[tokio::test]
180    async fn should_increase_limit_on_success_when_using_gt_util_threshold() {
181        let aimd = Aimd::new_with_initial_limit(4)
182            .decrease_factor(0.5)
183            .increase_by(1)
184            .with_min_utilisation_threshold(0.5);
185
186        let limiter = DefaultLimiter::new(aimd);
187
188        let token = limiter.try_acquire().await.unwrap();
189        let _token = limiter.try_acquire().await.unwrap();
190        let _token = limiter.try_acquire().await.unwrap();
191
192        limiter.release(token, Some(Outcome::Success)).await;
193        assert_eq!(limiter.limit(), 5, "success: increase");
194    }
195
196    #[tokio::test]
197    async fn should_not_change_limit_on_success_when_using_lt_util_threshold() {
198        let aimd = Aimd::new_with_initial_limit(4)
199            .decrease_factor(0.5)
200            .increase_by(1)
201            .with_min_utilisation_threshold(0.5);
202
203        let limiter = DefaultLimiter::new(aimd);
204
205        let token = limiter.try_acquire().await.unwrap();
206
207        limiter.release(token, Some(Outcome::Success)).await;
208        assert_eq!(limiter.limit(), 4, "success: ignore when < half limit");
209    }
210
211    #[tokio::test]
212    async fn should_not_change_limit_when_no_outcome() {
213        let aimd = Aimd::new_with_initial_limit(10)
214            .decrease_factor(0.5)
215            .increase_by(1);
216
217        let limiter = DefaultLimiter::new(aimd);
218
219        let token = limiter.try_acquire().await.unwrap();
220        limiter.release(token, None).await;
221        assert_eq!(limiter.limit(), 10, "ignore");
222    }
223}