congestion_limiter/limits/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicUsize, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvAsUtil;
8
9use crate::{limiter::Outcome, limits::Sample};
10
11use super::{defaults, LimitAlgorithm};
12
13#[derive(Debug)]
23pub struct Aimd {
24 min_limit: usize,
25 max_limit: usize,
26 decrease_factor: f64,
27 increase_by: usize,
28 min_utilisation_threshold: f64,
29
30 limit: AtomicUsize,
31}
32
33impl Aimd {
34 const DEFAULT_DECREASE_FACTOR: f64 = 0.9;
35 const DEFAULT_INCREASE: usize = 1;
36 const DEFAULT_INCREASE_MIN_UTILISATION: f64 = 0.8;
37
38 #[allow(missing_docs)]
39 pub fn new_with_initial_limit(initial_limit: usize) -> Self {
40 Self::new(
41 initial_limit,
42 defaults::DEFAULT_MIN_LIMIT..=defaults::DEFAULT_MAX_LIMIT,
43 )
44 }
45
46 #[allow(missing_docs)]
47 pub fn new(initial_limit: usize, limit_range: RangeInclusive<usize>) -> Self {
48 assert!(*limit_range.start() >= 1, "Limits must be at least 1");
49 assert!(
50 initial_limit >= *limit_range.start(),
51 "Initial limit less than minimum"
52 );
53 assert!(
54 initial_limit <= *limit_range.end(),
55 "Initial limit more than maximum"
56 );
57
58 Self {
59 min_limit: *limit_range.start(),
60 max_limit: *limit_range.end(),
61 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
62 increase_by: Self::DEFAULT_INCREASE,
63 min_utilisation_threshold: Self::DEFAULT_INCREASE_MIN_UTILISATION,
64
65 limit: AtomicUsize::new(initial_limit),
66 }
67 }
68
69 pub fn decrease_factor(self, factor: f64) -> Self {
71 assert!((0.5..1.0).contains(&factor));
72 Self {
73 decrease_factor: factor,
74 ..self
75 }
76 }
77
78 pub fn increase_by(self, increase: usize) -> Self {
80 assert!(increase > 0);
81 Self {
82 increase_by: increase,
83 ..self
84 }
85 }
86
87 #[allow(missing_docs)]
88 pub fn with_max_limit(self, max: usize) -> Self {
89 assert!(max > 0);
90 Self {
91 max_limit: max,
92 ..self
93 }
94 }
95
96 pub fn with_min_utilisation_threshold(self, min_util: f64) -> Self {
98 assert!(min_util > 0. && min_util < 1.);
99 Self {
100 min_utilisation_threshold: min_util,
101 ..self
102 }
103 }
104}
105
106#[async_trait]
107impl LimitAlgorithm for Aimd {
108 fn limit(&self) -> usize {
109 self.limit.load(Ordering::Acquire)
110 }
111
112 async fn update(&self, sample: Sample) -> usize {
113 use Outcome::*;
114 match sample.outcome {
115 Success => {
116 self.limit
117 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |limit| {
118 let utilisation = sample.in_flight as f64 / limit as f64;
119
120 if utilisation > self.min_utilisation_threshold {
121 let limit = limit + self.increase_by;
122 Some(limit.clamp(self.min_limit, self.max_limit))
123 } else {
124 Some(limit)
125 }
126 })
127 .expect("we always return Some(limit)");
128 }
129 Overload => {
130 self.limit
131 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |limit| {
132 let limit = multiplicative_decrease(limit, self.decrease_factor);
133
134 Some(limit.clamp(self.min_limit, self.max_limit))
135 })
136 .expect("we always return Some(limit)");
137 }
138 }
139 self.limit.load(Ordering::SeqCst)
140 }
141}
142
143pub(super) fn multiplicative_decrease(limit: usize, decrease_factor: f64) -> usize {
144 assert!(decrease_factor <= 1.0, "should not increase the limit");
145
146 let limit = limit as f64 * decrease_factor;
147
148 limit.floor().approx().expect("should not have increased")
151}
152
153#[cfg(test)]
154mod tests {
155 use std::sync::Arc;
156
157 use tokio::sync::Notify;
158
159 use crate::limiter::{DefaultLimiter, Limiter};
160
161 use super::*;
162
163 #[tokio::test]
164 async fn should_decrease_limit_on_overload() {
165 let aimd = Aimd::new_with_initial_limit(10)
166 .decrease_factor(0.5)
167 .increase_by(1);
168
169 let release_notifier = Arc::new(Notify::new());
170
171 let limiter = DefaultLimiter::new(aimd).with_release_notifier(release_notifier.clone());
172
173 let token = limiter.try_acquire().await.unwrap();
174 limiter.release(token, Some(Outcome::Overload)).await;
175 release_notifier.notified().await;
176 assert_eq!(limiter.limit(), 5, "overload: decrease");
177 }
178
179 #[tokio::test]
180 async fn should_increase_limit_on_success_when_using_gt_util_threshold() {
181 let aimd = Aimd::new_with_initial_limit(4)
182 .decrease_factor(0.5)
183 .increase_by(1)
184 .with_min_utilisation_threshold(0.5);
185
186 let limiter = DefaultLimiter::new(aimd);
187
188 let token = limiter.try_acquire().await.unwrap();
189 let _token = limiter.try_acquire().await.unwrap();
190 let _token = limiter.try_acquire().await.unwrap();
191
192 limiter.release(token, Some(Outcome::Success)).await;
193 assert_eq!(limiter.limit(), 5, "success: increase");
194 }
195
196 #[tokio::test]
197 async fn should_not_change_limit_on_success_when_using_lt_util_threshold() {
198 let aimd = Aimd::new_with_initial_limit(4)
199 .decrease_factor(0.5)
200 .increase_by(1)
201 .with_min_utilisation_threshold(0.5);
202
203 let limiter = DefaultLimiter::new(aimd);
204
205 let token = limiter.try_acquire().await.unwrap();
206
207 limiter.release(token, Some(Outcome::Success)).await;
208 assert_eq!(limiter.limit(), 4, "success: ignore when < half limit");
209 }
210
211 #[tokio::test]
212 async fn should_not_change_limit_when_no_outcome() {
213 let aimd = Aimd::new_with_initial_limit(10)
214 .decrease_factor(0.5)
215 .increase_by(1);
216
217 let limiter = DefaultLimiter::new(aimd);
218
219 let token = limiter.try_acquire().await.unwrap();
220 limiter.release(token, None).await;
221 assert_eq!(limiter.limit(), 10, "ignore");
222 }
223}