request_rate_limiter/algorithms/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13#[derive(Debug)]
15pub struct Aimd {
16 min_rps: u64,
17 max_rps: u64,
18 decrease_factor: f64,
19 increase_by: u64,
20
21 requests_per_second: AtomicU64,
22 last_decrease_nanos: AtomicU64,
23 last_increase_nanos: AtomicU64,
24}
25
26impl Clone for Aimd {
27 fn clone(&self) -> Self {
28 Self {
29 min_rps: self.min_rps,
30 max_rps: self.max_rps,
31 decrease_factor: self.decrease_factor,
32 increase_by: self.increase_by,
33 requests_per_second: AtomicU64::new(self.requests_per_second.load(Ordering::Acquire)),
34 last_decrease_nanos: AtomicU64::new(self.last_decrease_nanos.load(Ordering::Acquire)),
35 last_increase_nanos: AtomicU64::new(self.last_increase_nanos.load(Ordering::Acquire)),
36 }
37 }
38}
39
40impl Aimd {
41 const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
42 const DEFAULT_INCREASE: u64 = 10;
43
44 #[allow(missing_docs)]
45 pub fn new_with_initial_rate(initial_rps: u64) -> Self {
46 Self::new(initial_rps, 1..=10000)
47 }
48
49 #[allow(missing_docs)]
50 pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
51 assert!(*rate_range.start() >= 1, "Rate must be at least 1");
52 assert!(
53 initial_rps >= *rate_range.start(),
54 "Initial rate less than minimum"
55 );
56 assert!(
57 initial_rps <= *rate_range.end(),
58 "Initial rate more than maximum"
59 );
60
61 Self {
62 min_rps: *rate_range.start(),
63 max_rps: *rate_range.end(),
64 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
65 increase_by: Self::DEFAULT_INCREASE,
66
67 requests_per_second: AtomicU64::new(initial_rps),
68 last_decrease_nanos: AtomicU64::new(0),
69 last_increase_nanos: AtomicU64::new(0),
70 }
71 }
72
73 pub fn decrease_factor(self, factor: f64) -> Self {
75 assert!((0.1..1.0).contains(&factor));
76 Self {
77 decrease_factor: factor,
78 ..self
79 }
80 }
81
82 pub fn increase_by(self, increase: u64) -> Self {
84 assert!(increase > 0);
85 Self {
86 increase_by: increase,
87 ..self
88 }
89 }
90
91 #[allow(missing_docs)]
92 pub fn with_max_rate(self, max: u64) -> Self {
93 assert!(max > 0);
94 Self {
95 max_rps: max,
96 ..self
97 }
98 }
99}
100
101#[async_trait]
102impl RateLimitAlgorithm for Aimd {
103 fn requests_per_second(&self) -> u64 {
104 self.requests_per_second.load(Ordering::Acquire)
105 }
106
107 async fn update(&self, sample: RequestSample) -> u64 {
108 use RequestOutcome::*;
109
110 let now = std::time::SystemTime::now()
111 .duration_since(std::time::UNIX_EPOCH)
112 .unwrap()
113 .as_nanos() as u64;
114
115 match sample.outcome {
116 Success | ClientError => {
117 let last = self.last_increase_nanos.load(Ordering::Relaxed);
118 if now.saturating_sub(last) >= 1_000_000_000 {
120 if self
121 .last_increase_nanos
122 .compare_exchange_weak(last, now, Ordering::Relaxed, Ordering::Relaxed)
123 .is_ok()
124 {
125 let mut updated = 0;
126 self.requests_per_second
127 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
128 updated = (rps + self.increase_by).min(self.max_rps);
129 Some(updated)
130 })
131 .unwrap();
132 return updated;
133 }
134 }
135 }
136 Overload => {
137 let last = self.last_decrease_nanos.load(Ordering::Relaxed);
138 if now.saturating_sub(last) >= 1_000_000_000 {
141 if self
142 .last_decrease_nanos
143 .compare_exchange_weak(last, now, Ordering::Relaxed, Ordering::Relaxed)
144 .is_ok()
145 {
146 let mut updated = 0;
147 self.requests_per_second
148 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |rps| {
149 updated = multiplicative_decrease(rps, self.decrease_factor)
150 .max(self.min_rps);
151 Some(updated)
152 })
153 .unwrap();
154 return updated;
155 }
156 }
157 }
158 }
159
160 self.requests_per_second.load(Ordering::Relaxed)
162 }
163}
164
165pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
166 assert!(decrease_factor <= 1.0, "should not increase the rate");
167
168 let new_rps = rps as f64 * decrease_factor;
169 new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
176
177 #[tokio::test]
178 async fn should_decrease_rate_on_overload() {
179 let aimd = Aimd::new_with_initial_rate(100)
180 .decrease_factor(0.5)
181 .increase_by(10);
182 let limiter = DefaultRateLimiter::new(aimd);
183 let token = limiter.acquire().await;
184 limiter.release(token, Some(RequestOutcome::Overload)).await;
185 let state = limiter.state();
186 assert_eq!(state.requests_per_second(), 50);
187 }
188
189 #[tokio::test]
190 async fn should_increase_rate_on_success() {
191 let aimd = Aimd::new_with_initial_rate(50)
192 .decrease_factor(0.5)
193 .increase_by(20);
194 let limiter = DefaultRateLimiter::new(aimd);
195 let token = limiter.acquire().await;
196 limiter.release(token, Some(RequestOutcome::Success)).await;
197 let state = limiter.state();
198 assert_eq!(state.requests_per_second(), 70);
199 }
200
201 #[tokio::test]
202 async fn should_not_change_rate_when_no_outcome() {
203 let aimd = Aimd::new_with_initial_rate(100)
204 .decrease_factor(0.5)
205 .increase_by(10);
206 let limiter = DefaultRateLimiter::new(aimd);
207 let token = limiter.acquire().await;
208 limiter.release(token, None).await;
209 let state = limiter.state();
210 assert_eq!(state.requests_per_second(), 100);
211 }
212}