request_rate_limiter/algorithms/
aimd.rs1use std::{
2 ops::RangeInclusive,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use async_trait::async_trait;
7use conv::ConvUtil;
8
9use crate::{limiter::RequestOutcome, algorithms::RequestSample};
10
11use super::RateLimitAlgorithm;
12
13#[derive(Debug)]
23pub struct Aimd {
24 min_rps: u64,
25 max_rps: u64,
26 decrease_factor: f64,
27 increase_by: u64,
28
29 requests_per_second: AtomicU64,
30}
31
32impl Aimd {
33 const DEFAULT_DECREASE_FACTOR: f64 = 0.8;
34 const DEFAULT_INCREASE: u64 = 10;
35
36 #[allow(missing_docs)]
37 pub fn new_with_initial_rate(initial_rps: u64) -> Self {
38 Self::new(initial_rps, 1..=10000)
39 }
40
41 #[allow(missing_docs)]
42 pub fn new(initial_rps: u64, rate_range: RangeInclusive<u64>) -> Self {
43 assert!(*rate_range.start() >= 1, "Rate must be at least 1");
44 assert!(
45 initial_rps >= *rate_range.start(),
46 "Initial rate less than minimum"
47 );
48 assert!(
49 initial_rps <= *rate_range.end(),
50 "Initial rate more than maximum"
51 );
52
53 Self {
54 min_rps: *rate_range.start(),
55 max_rps: *rate_range.end(),
56 decrease_factor: Self::DEFAULT_DECREASE_FACTOR,
57 increase_by: Self::DEFAULT_INCREASE,
58
59 requests_per_second: std::sync::atomic::AtomicU64::new(initial_rps),
60 }
61 }
62
63 pub fn decrease_factor(self, factor: f64) -> Self {
65 assert!((0.1..1.0).contains(&factor));
66 Self {
67 decrease_factor: factor,
68 ..self
69 }
70 }
71
72 pub fn increase_by(self, increase: u64) -> Self {
74 assert!(increase > 0);
75 Self {
76 increase_by: increase,
77 ..self
78 }
79 }
80
81 #[allow(missing_docs)]
82 pub fn with_max_rate(self, max: u64) -> Self {
83 assert!(max > 0);
84 Self {
85 max_rps: max,
86 ..self
87 }
88 }
89}
90
91#[async_trait]
92impl RateLimitAlgorithm for Aimd {
93 fn requests_per_second(&self) -> u64 {
94 self.requests_per_second.load(Ordering::Acquire)
95 }
96
97 async fn update(&self, sample: RequestSample) -> u64 {
98 use RequestOutcome::*;
99 match sample.outcome {
100 Success | ClientError => {
101 self.requests_per_second
102 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
103 let new_rps = rps + self.increase_by;
104 Some(new_rps.clamp(self.min_rps, self.max_rps))
105 })
106 .expect("we always return Some(rps)");
107 }
108 Overload => {
109 self.requests_per_second
110 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |rps| {
111 let new_rps = multiplicative_decrease(rps, self.decrease_factor);
112 Some(new_rps.clamp(self.min_rps, self.max_rps))
113 })
114 .expect("we always return Some(rps)");
115 }
116 }
117 self.requests_per_second.load(Ordering::SeqCst)
118 }
119}
120
121pub(super) fn multiplicative_decrease(rps: u64, decrease_factor: f64) -> u64 {
122 assert!(decrease_factor <= 1.0, "should not increase the rate");
123
124 let new_rps = rps as f64 * decrease_factor;
125
126 new_rps.floor().approx_as::<u64>().unwrap_or(1).max(1)
129}
130
131#[cfg(test)]
132mod tests {
133
134 use crate::limiter::{DefaultRateLimiter, RateLimiter, RequestOutcome};
135
136 use super::*;
137
138 #[tokio::test]
139 async fn should_decrease_rate_on_overload() {
140 let aimd = Aimd::new_with_initial_rate(100)
141 .decrease_factor(0.5)
142 .increase_by(10);
143
144 let limiter = DefaultRateLimiter::new(aimd);
145
146 let token = limiter.acquire().await;
147 limiter.release(token, Some(RequestOutcome::Overload)).await;
148
149 let state = limiter.state();
150 assert_eq!(state.requests_per_second(), 50, "overload: should decrease rate");
151 }
152
153 #[tokio::test]
154 async fn should_increase_rate_on_success() {
155 let aimd = Aimd::new_with_initial_rate(50)
156 .decrease_factor(0.5)
157 .increase_by(20);
158
159 let limiter = DefaultRateLimiter::new(aimd);
160
161 let token = limiter.acquire().await;
162 limiter.release(token, Some(RequestOutcome::Success)).await;
163
164 let state = limiter.state();
165 assert_eq!(state.requests_per_second(), 70, "success: should increase rate");
166 }
167
168 #[tokio::test]
169 async fn should_not_change_rate_when_no_outcome() {
170 let aimd = Aimd::new_with_initial_rate(100)
171 .decrease_factor(0.5)
172 .increase_by(10);
173
174 let limiter = DefaultRateLimiter::new(aimd);
175
176 let token = limiter.acquire().await;
177 limiter.release(token, None).await;
178
179 let state = limiter.state();
180 assert_eq!(state.requests_per_second(), 100, "should ignore when no outcome");
181 }
182}