1use std::time::Duration;
2
3use crate::Algorithm;
4
5#[derive(Debug, Clone)]
23pub struct Gradient2 {
24 estimated_limit: f64,
25 min_limit: usize,
26 max_limit: usize,
27 smoothing: f64,
28 rtt_tolerance: f64,
29 queue_size: fn(usize) -> usize,
30
31 long_rtt_ns: f64,
33 long_rtt_count: usize,
34 long_rtt_warmup: usize,
35 long_rtt_warmup_sum: f64,
36 long_rtt_factor: f64,
37
38 last_rtt_ns: f64,
40}
41
42impl Gradient2 {
43 pub fn builder() -> Gradient2Builder {
45 Gradient2Builder::default()
46 }
47}
48
49impl Default for Gradient2 {
50 fn default() -> Self {
51 Gradient2Builder::default().build()
52 }
53}
54
55impl Algorithm for Gradient2 {
56 fn max_concurrency(&self) -> usize {
57 (self.estimated_limit as usize)
58 .clamp(self.min_limit, self.max_limit)
59 .max(1)
60 }
61
62 fn update(&mut self, rtt: Duration, num_inflight: usize, _is_error: bool, is_canceled: bool) {
63 if is_canceled {
64 return;
65 }
66
67 let rtt_ns = rtt.as_nanos() as f64;
68 if rtt_ns <= 0.0 {
69 return;
70 }
71
72 let limit = self.estimated_limit as usize;
73
74 self.last_rtt_ns = rtt_ns;
76
77 if self.long_rtt_count < self.long_rtt_warmup {
79 self.long_rtt_warmup_sum += rtt_ns;
80 self.long_rtt_count += 1;
81 self.long_rtt_ns = self.long_rtt_warmup_sum / self.long_rtt_count as f64;
82 } else {
83 self.long_rtt_ns =
84 self.long_rtt_ns * (1.0 - self.long_rtt_factor) + rtt_ns * self.long_rtt_factor;
85 }
86
87 if self.long_rtt_ns / self.last_rtt_ns > 2.0 {
90 self.long_rtt_ns *= 0.95;
91 }
92
93 if num_inflight * 2 < limit {
95 return;
96 }
97
98 let gradient = (self.rtt_tolerance * self.long_rtt_ns / self.last_rtt_ns).clamp(0.5, 1.0);
101
102 let queue_size = (self.queue_size)(limit);
103 let new_limit = gradient * self.estimated_limit + queue_size as f64;
104
105 self.estimated_limit = ((1.0 - self.smoothing) * self.estimated_limit
107 + self.smoothing * new_limit)
108 .clamp(self.min_limit as f64, self.max_limit as f64);
109 }
110}
111
112fn log10_queue_size(limit: usize) -> usize {
113 std::cmp::max(1, (limit as f64).log10().ceil() as usize)
114}
115
116pub struct Gradient2Builder {
132 initial_limit: usize,
133 min_limit: usize,
134 max_limit: usize,
135 smoothing: f64,
136 rtt_tolerance: f64,
137 long_window: usize,
138 queue_size: fn(usize) -> usize,
139}
140
141impl Default for Gradient2Builder {
142 fn default() -> Self {
143 Self {
144 initial_limit: 20,
145 min_limit: 20,
146 max_limit: 200,
147 smoothing: 0.2,
148 rtt_tolerance: 1.5,
149 long_window: 600,
150 queue_size: log10_queue_size,
151 }
152 }
153}
154
155impl Gradient2Builder {
156 pub fn initial_limit(mut self, limit: usize) -> Self {
158 self.initial_limit = limit;
159 self
160 }
161
162 pub fn min_limit(mut self, limit: usize) -> Self {
164 self.min_limit = limit;
165 self
166 }
167
168 pub fn max_limit(mut self, limit: usize) -> Self {
170 self.max_limit = limit;
171 self
172 }
173
174 pub fn smoothing(mut self, smoothing: f64) -> Self {
179 self.smoothing = smoothing;
180 self
181 }
182
183 pub fn rtt_tolerance(mut self, tolerance: f64) -> Self {
193 assert!(tolerance >= 1.0, "rtt_tolerance must be >= 1.0");
194 self.rtt_tolerance = tolerance;
195 self
196 }
197
198 pub fn long_window(mut self, window: usize) -> Self {
201 self.long_window = window;
202 self
203 }
204
205 pub fn queue_size(mut self, f: fn(usize) -> usize) -> Self {
208 self.queue_size = f;
209 self
210 }
211
212 pub fn build(self) -> Gradient2 {
218 assert!(
219 self.min_limit <= self.max_limit,
220 "min_limit ({}) must be <= max_limit ({})",
221 self.min_limit,
222 self.max_limit,
223 );
224 let long_window = std::cmp::max(1, self.long_window);
225 Gradient2 {
226 estimated_limit: self.initial_limit as f64,
227 min_limit: self.min_limit,
228 max_limit: self.max_limit,
229 smoothing: self.smoothing,
230 rtt_tolerance: self.rtt_tolerance,
231 queue_size: self.queue_size,
232 long_rtt_ns: 0.0,
233 long_rtt_count: 0,
234 long_rtt_warmup: 10,
235 long_rtt_warmup_sum: 0.0,
236 long_rtt_factor: 2.0 / (long_window as f64 + 1.0),
237 last_rtt_ns: 0.0,
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn stable_rtt_allows_growth() {
248 let mut g2 = Gradient2::builder().initial_limit(20).build();
249
250 for _ in 0..10 {
252 g2.update(Duration::from_millis(50), 20, false, false);
253 }
254
255 let limit_before = g2.max_concurrency();
256 for _ in 0..20 {
258 g2.update(Duration::from_millis(50), 20, false, false);
259 }
260 assert!(g2.max_concurrency() >= limit_before);
261 }
262
263 #[test]
264 fn high_rtt_reduces_limit() {
265 let mut g2 = Gradient2::builder().initial_limit(100).build();
266
267 for _ in 0..10 {
269 g2.update(Duration::from_millis(50), 100, false, false);
270 }
271
272 let limit_before = g2.max_concurrency();
273 for _ in 0..20 {
275 g2.update(Duration::from_millis(500), 100, false, false);
276 }
277 assert!(g2.max_concurrency() < limit_before);
278 }
279
280 #[test]
281 fn limit_respects_max() {
282 let mut g2 = Gradient2::builder()
283 .initial_limit(200)
284 .max_limit(200)
285 .min_limit(1)
286 .build();
287
288 for _ in 0..10 {
290 g2.update(Duration::from_millis(50), 200, false, false);
291 }
292
293 for _ in 0..100 {
295 g2.update(Duration::from_millis(50), 200, false, false);
296 assert!(g2.max_concurrency() <= 200);
297 }
298 }
299
300 #[test]
301 fn canceled_requests_are_ignored() {
302 let mut g2 = Gradient2::builder().initial_limit(20).build();
303 g2.update(Duration::from_millis(50), 20, false, true);
304 assert_eq!(g2.max_concurrency(), 20);
305 }
306
307 #[test]
308 fn limit_stays_above_min() {
309 let mut g2 = Gradient2::builder().initial_limit(20).min_limit(10).build();
310
311 for _ in 0..10 {
313 g2.update(Duration::from_millis(50), 20, false, false);
314 }
315 for _ in 0..200 {
317 g2.update(Duration::from_millis(500), 20, false, false);
318 }
319 assert!(g2.max_concurrency() >= 10);
320 }
321
322 #[test]
323 fn tolerance_allows_moderate_rtt_increase() {
324 let mut g2 = Gradient2::builder()
325 .initial_limit(50)
326 .rtt_tolerance(2.0)
327 .build();
328
329 for _ in 0..10 {
331 g2.update(Duration::from_millis(50), 50, false, false);
332 }
333
334 let limit_before = g2.max_concurrency();
337 for _ in 0..10 {
338 g2.update(Duration::from_millis(75), 50, false, false);
339 }
340 assert!(g2.max_concurrency() >= limit_before);
341 }
342}