request_rate_limiter/algorithms/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{algorithms::RequestSample, limiter::RequestOutcome};
10
11use super::RateLimitAlgorithm;
12
13#[derive(Debug)]
23pub struct Aimd {
24 min_rps: u64,
25 max_rps: u64,
26 decrease_factor: f64,
27 increase_by: u64,
28
29 requests_per_second: AtomicU64,
30}
31
32impl Aimd {
33 const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
34 const DEFAULT_INCREASE: u64 = 10;
35
36 #[allow(missing_docs)]
37 pub fn new_with_initial_rate(initial_rps: u64) -> Self {
38 Self::new(initial_rps, 1..=10000)
39 }
40
41 #[allow(missing_docs)]
42 pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
43 assert!(*rate_range.start() >= 1, "Rate must be at least 1");
44 assert!(
45 initial_rps >= *rate_range.start(),
46 "Initial rate less than minimum"
47 );
48 assert!(
49 initial_rps <= *rate_range.end(),
50 "Initial rate more than maximum"
51 );
52
53 Self {
54 min_rps: *rate_range.start(),
55 max_rps: *rate_range.end(),
56 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
57 increase_by: Self::DEFAULT_INCREASE,
58
59 requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
60 }
61 }
62
63 pub fn decrease_factor(self, factor: f64) -> Self {
65 assert!((0.1..1.0).contains(&factor));
66 Self {
67 decrease_factor: factor,
68 ..self
69 }
70 }
71
72 pub fn increase_by(self, increase: u64) -> Self {
74 assert!(increase > 0);
75 Self {
76 increase_by: increase,
77 ..self
78 }
79 }
80
81 #[allow(missing_docs)]
82 pub fn with_max_rate(self, max: u64) -> Self {
83 assert!(max > 0);
84 Self {
85 max_rps: max,
86 ..self
87 }
88 }
89}
90
91#[async_trait]
92impl RateLimitAlgorithm for Aimd {
93 fn requests_per_second(&self) -> u64 {
94 self.requests_per_second.load(Ordering::Acquire)
95 }
96
97 async fn update(&self, sample: RequestSample) -> u64 {
98 use RequestOutcome::*;
99 match sample.outcome {
100 Success | ClientError => {
101 self.requests_per_second
102 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
103 let new_rps = rps + self.increase_by;
104 Some(new_rps.clamp(self.min_rps, self.max_rps))
105 })
106 .expect("we always return Some(rps)");
107 }
108 Overload => {
109 self.requests_per_second
110 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
111 let new_rps = multiplicative_decrease(rps, self.decrease_factor);
112 Some(new_rps.clamp(self.min_rps, self.max_rps))
113 })
114 .expect("we always return Some(rps)");
115 }
116 }
117 self.requests_per_second.load(Ordering::SeqCst)
118 }
119}
120
121pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
122 assert!(decrease_factor <= 1.0, "should not increase the rate");
123
124 let new_rps = rps as f64 * decrease_factor;
125
126 new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
129}
130
131#[cfg(test)]
132mod tests {
133
134 use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
135
136 use super::*;
137
138 #[tokio::test]
139 async fn should_decrease_rate_on_overload() {
140 let aimd = Aimd::new_with_initial_rate(100)
141 .decrease_factor(0.5)
142 .increase_by(10);
143
144 let limiter = DefaultRateLimiter::new(aimd);
145
146 let token = limiter.acquire().await;
147 limiter.release(token, Some(RequestOutcome::Overload)).await;
148
149 let state = limiter.state();
150 assert_eq!(
151 state.requests_per_second(),
152 50,
153 "overload: should decrease rate"
154 );
155 }
156
157 #[tokio::test]
158 async fn should_increase_rate_on_success() {
159 let aimd = Aimd::new_with_initial_rate(50)
160 .decrease_factor(0.5)
161 .increase_by(20);
162
163 let limiter = DefaultRateLimiter::new(aimd);
164
165 let token = limiter.acquire().await;
166 limiter.release(token, Some(RequestOutcome::Success)).await;
167
168 let state = limiter.state();
169 assert_eq!(
170 state.requests_per_second(),
171 70,
172 "success: should increase rate"
173 );
174 }
175
176 #[tokio::test]
177 async fn should_not_change_rate_when_no_outcome() {
178 let aimd = Aimd::new_with_initial_rate(100)
179 .decrease_factor(0.5)
180 .increase_by(10);
181
182 let limiter = DefaultRateLimiter::new(aimd);
183
184 let token = limiter.acquire().await;
185 limiter.release(token, None).await;
186
187 let state = limiter.state();
188 assert_eq!(
189 state.requests_per_second(),
190 100,
191 "should ignore when no outcome"
192 );
193 }
194}