request_rate_limiter/algorithms/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13#[derive(Debug)]
23pub struct Aimd {
24 min_rps: u64,
25 max_rps: u64,
26 decrease_factor: f64,
27 increase_by: u64,
28
29 requests_per_second: AtomicU64,
30}
31
32impl Clone for Aimd {
33 fn clone(&self) -> Self {
34 Self {
35 min_rps: self.min_rps,
36 max_rps: self.max_rps,
37 decrease_factor: self.decrease_factor,
38 increase_by: self.increase_by,
39 requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
40 }
41 }
42}
43
44impl Aimd {
45 const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
46 const DEFAULT_INCREASE: u64 = 10;
47
48 #[allow(missing_docs)]
49 pub fn new_with_initial_rate(initial_rps: u64) -> Self {
50 Self::new(initial_rps, 1..=10000)
51 }
52
53 #[allow(missing_docs)]
54 pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
55 assert!(*rate_range.start() >= 1, "Rate must be at least 1");
56 assert!(
57 initial_rps >= *rate_range.start(),
58 "Initial rate less than minimum"
59 );
60 assert!(
61 initial_rps <= *rate_range.end(),
62 "Initial rate more than maximum"
63 );
64
65 Self {
66 min_rps: *rate_range.start(),
67 max_rps: *rate_range.end(),
68 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
69 increase_by: Self::DEFAULT_INCREASE,
70
71 requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
72 }
73 }
74
75 pub fn decrease_factor(self, factor: f64) -> Self {
77 assert!((0.1..1.0).contains(&factor));
78 Self {
79 decrease_factor: factor,
80 ..self
81 }
82 }
83
84 pub fn increase_by(self, increase: u64) -> Self {
86 assert!(increase > 0);
87 Self {
88 increase_by: increase,
89 ..self
90 }
91 }
92
93 #[allow(missing_docs)]
94 pub fn with_max_rate(self, max: u64) -> Self {
95 assert!(max > 0);
96 Self {
97 max_rps: max,
98 ..self
99 }
100 }
101}
102
103#[async_trait]
104impl RateLimitAlgorithm for Aimd {
105 fn requests_per_second(&self) -> u64 {
106 self.requests_per_second.load(Ordering::Acquire)
107 }
108
109 async fn update(&self, sample: RequestSample) -> u64 {
110 use RequestOutcome::*;
111 match sample.outcome {
112 Success | ClientError => self
113 .requests_per_second
114 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
115 let new_rps = (rps + self.increase_by).min(self.max_rps);
116 Some(new_rps)
117 })
118 .expect("we always return Some(rps)"),
119 Overload => self
120 .requests_per_second
121 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
122 let new_rps =
123 multiplicative_decrease(rps, self.decrease_factor).max(self.min_rps);
124 Some(new_rps)
125 })
126 .expect("we always return Some(rps)"),
127 }
128 }
129}
130
131pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
132 assert!(decrease_factor <= 1.0, "should not increase the rate");
133
134 let new_rps = rps as f64 * decrease_factor;
135
136 new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
139}
140
141#[cfg(test)]
142mod tests {
143
144 use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
145
146 use super::*;
147
148 #[tokio::test]
149 async fn should_decrease_rate_on_overload() {
150 let aimd = Aimd::new_with_initial_rate(100)
151 .decrease_factor(0.5)
152 .increase_by(10);
153
154 let limiter = DefaultRateLimiter::new(aimd);
155
156 let token = limiter.acquire().await;
157 limiter.release(token, Some(RequestOutcome::Overload)).await;
158
159 let state = limiter.state();
160 assert_eq!(
161 state.requests_per_second(),
162 50,
163 "overload: should decrease rate"
164 );
165 }
166
167 #[tokio::test]
168 async fn should_increase_rate_on_success() {
169 let aimd = Aimd::new_with_initial_rate(50)
170 .decrease_factor(0.5)
171 .increase_by(20);
172
173 let limiter = DefaultRateLimiter::new(aimd);
174
175 let token = limiter.acquire().await;
176 limiter.release(token, Some(RequestOutcome::Success)).await;
177
178 let state = limiter.state();
179 assert_eq!(
180 state.requests_per_second(),
181 70,
182 "success: should increase rate"
183 );
184 }
185
186 #[tokio::test]
187 async fn should_not_change_rate_when_no_outcome() {
188 let aimd = Aimd::new_with_initial_rate(100)
189 .decrease_factor(0.5)
190 .increase_by(10);
191
192 let limiter = DefaultRateLimiter::new(aimd);
193
194 let token = limiter.acquire().await;
195 limiter.release(token, None).await;
196
197 let state = limiter.state();
198 assert_eq!(
199 state.requests_per_second(),
200 100,
201 "should ignore when no outcome"
202 );
203 }
204}