request_rate_limiter/algorithms/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13#[derive(Debug)]
23pub struct Aimd {
24 min_rps: u64,
25 max_rps: u64,
26 decrease_factor: f64,
27 increase_by: u64,
28
29 requests_per_second: AtomicU64,
30}
31
32impl Clone for Aimd {
33 fn clone(&self) -> Self {
34 Self {
35 min_rps: self.min_rps.clone(),
36 max_rps: self.max_rps.clone(),
37 decrease_factor: self.decrease_factor.clone(),
38 increase_by: self.increase_by.clone(),
39 requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
40 }
41 }
42}
43
44impl Aimd {
45 const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
46 const DEFAULT_INCREASE: u64 = 10;
47
48 #[allow(missing_docs)]
49 pub fn new_with_initial_rate(initial_rps: u64) -> Self {
50 Self::new(initial_rps, 1..=10000)
51 }
52
53 #[allow(missing_docs)]
54 pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
55 assert!(*rate_range.start() >= 1, "Rate must be at least 1");
56 assert!(
57 initial_rps >= *rate_range.start(),
58 "Initial rate less than minimum"
59 );
60 assert!(
61 initial_rps <= *rate_range.end(),
62 "Initial rate more than maximum"
63 );
64
65 Self {
66 min_rps: *rate_range.start(),
67 max_rps: *rate_range.end(),
68 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
69 increase_by: Self::DEFAULT_INCREASE,
70
71 requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
72 }
73 }
74
75 pub fn decrease_factor(self, factor: f64) -> Self {
77 assert!((0.1..1.0).contains(&factor));
78 Self {
79 decrease_factor: factor,
80 ..self
81 }
82 }
83
84 pub fn increase_by(self, increase: u64) -> Self {
86 assert!(increase > 0);
87 Self {
88 increase_by: increase,
89 ..self
90 }
91 }
92
93 #[allow(missing_docs)]
94 pub fn with_max_rate(self, max: u64) -> Self {
95 assert!(max > 0);
96 Self {
97 max_rps: max,
98 ..self
99 }
100 }
101}
102
103#[async_trait]
104impl RateLimitAlgorithm for Aimd {
105 fn requests_per_second(&self) -> u64 {
106 self.requests_per_second.load(Ordering::Acquire)
107 }
108
109 async fn update(&self, sample: RequestSample) -> u64 {
110 use RequestOutcome::*;
111 match sample.outcome {
112 Success | ClientError => {
113 self.requests_per_second
114 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
115 let new_rps = rps + self.increase_by;
116 Some(new_rps.clamp(self.min_rps, self.max_rps))
117 })
118 .expect("we always return Some(rps)");
119 }
120 Overload => {
121 self.requests_per_second
122 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
123 let new_rps = multiplicative_decrease(rps, self.decrease_factor);
124 Some(new_rps.clamp(self.min_rps, self.max_rps))
125 })
126 .expect("we always return Some(rps)");
127 }
128 }
129 self.requests_per_second.load(Ordering::SeqCst)
130 }
131}
132
133pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
134 assert!(decrease_factor <= 1.0, "should not increase the rate");
135
136 let new_rps = rps as f64 * decrease_factor;
137
138 new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
141}
142
143#[cfg(test)]
144mod tests {
145
146 use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
147
148 use super::*;
149
150 #[tokio::test]
151 async fn should_decrease_rate_on_overload() {
152 let aimd = Aimd::new_with_initial_rate(100)
153 .decrease_factor(0.5)
154 .increase_by(10);
155
156 let limiter = DefaultRateLimiter::new(aimd);
157
158 let token = limiter.acquire().await;
159 limiter.release(token, Some(RequestOutcome::Overload)).await;
160
161 let state = limiter.state();
162 assert_eq!(
163 state.requests_per_second(),
164 50,
165 "overload: should decrease rate"
166 );
167 }
168
169 #[tokio::test]
170 async fn should_increase_rate_on_success() {
171 let aimd = Aimd::new_with_initial_rate(50)
172 .decrease_factor(0.5)
173 .increase_by(20);
174
175 let limiter = DefaultRateLimiter::new(aimd);
176
177 let token = limiter.acquire().await;
178 limiter.release(token, Some(RequestOutcome::Success)).await;
179
180 let state = limiter.state();
181 assert_eq!(
182 state.requests_per_second(),
183 70,
184 "success: should increase rate"
185 );
186 }
187
188 #[tokio::test]
189 async fn should_not_change_rate_when_no_outcome() {
190 let aimd = Aimd::new_with_initial_rate(100)
191 .decrease_factor(0.5)
192 .increase_by(10);
193
194 let limiter = DefaultRateLimiter::new(aimd);
195
196 let token = limiter.acquire().await;
197 limiter.release(token, None).await;
198
199 let state = limiter.state();
200 assert_eq!(
201 state.requests_per_second(),
202 100,
203 "should ignore when no outcome"
204 );
205 }
206}