1use std::sync::Mutex;
16use std::time::Duration;
17
18use crate::Result;
19
20#[derive(Debug, Clone)]
34pub struct AimdConfig {
35 pub initial_rate: f64,
36 pub min_rate: f64,
37 pub max_rate: f64,
38 pub decrease_factor: f64,
39 pub additive_increment: f64,
40 pub window_duration: Duration,
41 pub throttle_threshold: f64,
42}
43
44impl Default for AimdConfig {
45 fn default() -> Self {
46 Self {
47 initial_rate: 2000.0,
48 min_rate: 1.0,
49 max_rate: 5000.0,
50 decrease_factor: 0.5,
51 additive_increment: 300.0,
52 window_duration: Duration::from_secs(1),
53 throttle_threshold: 0.0,
54 }
55 }
56}
57
58impl AimdConfig {
59 pub fn with_initial_rate(self, initial_rate: f64) -> Self {
60 Self {
61 initial_rate,
62 ..self
63 }
64 }
65
66 pub fn with_min_rate(self, min_rate: f64) -> Self {
67 Self { min_rate, ..self }
68 }
69
70 pub fn with_max_rate(self, max_rate: f64) -> Self {
71 Self { max_rate, ..self }
72 }
73
74 pub fn with_decrease_factor(self, decrease_factor: f64) -> Self {
75 Self {
76 decrease_factor,
77 ..self
78 }
79 }
80
81 pub fn with_additive_increment(self, additive_increment: f64) -> Self {
82 Self {
83 additive_increment,
84 ..self
85 }
86 }
87
88 pub fn with_window_duration(self, window_duration: Duration) -> Self {
89 Self {
90 window_duration,
91 ..self
92 }
93 }
94
95 pub fn with_throttle_threshold(self, throttle_threshold: f64) -> Self {
96 Self {
97 throttle_threshold,
98 ..self
99 }
100 }
101
102 pub fn validate(&self) -> Result<()> {
104 if self.initial_rate <= 0.0 {
105 return Err(crate::Error::invalid_input(format!(
106 "initial_rate must be positive, got {}",
107 self.initial_rate
108 )));
109 }
110 if self.min_rate <= 0.0 {
111 return Err(crate::Error::invalid_input(format!(
112 "min_rate must be positive, got {}",
113 self.min_rate
114 )));
115 }
116 if self.max_rate < 0.0 {
117 return Err(crate::Error::invalid_input(format!(
118 "max_rate must be non-negative (0.0 = no ceiling), got {}",
119 self.max_rate
120 )));
121 }
122 if self.max_rate > 0.0 && self.min_rate > self.max_rate {
123 return Err(crate::Error::invalid_input(format!(
124 "min_rate ({}) must not exceed max_rate ({})",
125 self.min_rate, self.max_rate
126 )));
127 }
128 if self.decrease_factor <= 0.0 || self.decrease_factor >= 1.0 {
129 return Err(crate::Error::invalid_input(format!(
130 "decrease_factor must be in (0, 1), got {}",
131 self.decrease_factor
132 )));
133 }
134 if self.additive_increment <= 0.0 {
135 return Err(crate::Error::invalid_input(format!(
136 "additive_increment must be positive, got {}",
137 self.additive_increment
138 )));
139 }
140 if self.window_duration.is_zero() {
141 return Err(crate::Error::invalid_input(
142 "window_duration must be non-zero",
143 ));
144 }
145 if !(0.0..=1.0).contains(&self.throttle_threshold) {
146 return Err(crate::Error::invalid_input(format!(
147 "throttle_threshold must be in [0.0, 1.0], got {}",
148 self.throttle_threshold
149 )));
150 }
151 if self.max_rate > 0.0 && self.initial_rate > self.max_rate {
152 return Err(crate::Error::invalid_input(format!(
153 "initial_rate ({}) must not exceed max_rate ({})",
154 self.initial_rate, self.max_rate
155 )));
156 }
157 if self.initial_rate < self.min_rate {
158 return Err(crate::Error::invalid_input(format!(
159 "initial_rate ({}) must not be below min_rate ({})",
160 self.initial_rate, self.min_rate
161 )));
162 }
163 Ok(())
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum RequestOutcome {
173 Success,
174 Throttled,
175}
176
177struct AimdState {
178 rate: f64,
179 window_start: std::time::Instant,
180 success_count: u64,
181 throttle_count: u64,
182}
183
184pub struct AimdController {
189 config: AimdConfig,
190 state: Mutex<AimdState>,
191}
192
193impl std::fmt::Debug for AimdController {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("AimdController")
196 .field("config", &self.config)
197 .field("rate", &self.current_rate())
198 .finish()
199 }
200}
201
202impl AimdController {
203 pub fn new(config: AimdConfig) -> Result<Self> {
205 config.validate()?;
206 let rate = config.initial_rate;
207 Ok(Self {
208 config,
209 state: Mutex::new(AimdState {
210 rate,
211 window_start: std::time::Instant::now(),
212 success_count: 0,
213 throttle_count: 0,
214 }),
215 })
216 }
217
218 pub fn record_outcome(&self, outcome: RequestOutcome) -> f64 {
223 let mut state = self.state.lock().unwrap();
224 self.record_outcome_inner(&mut state, outcome, std::time::Instant::now())
225 }
226
227 fn record_outcome_inner(
228 &self,
229 state: &mut AimdState,
230 outcome: RequestOutcome,
231 now: std::time::Instant,
232 ) -> f64 {
233 let elapsed = now.duration_since(state.window_start);
235 if elapsed >= self.config.window_duration {
236 let total = state.success_count + state.throttle_count;
237 if total > 0 {
238 let throttle_ratio = state.throttle_count as f64 / total as f64;
239 if throttle_ratio > self.config.throttle_threshold {
240 state.rate =
242 (state.rate * self.config.decrease_factor).max(self.config.min_rate);
243 } else {
244 state.rate += self.config.additive_increment;
246 if self.config.max_rate > 0.0 {
247 state.rate = state.rate.min(self.config.max_rate);
248 }
249 }
250 }
251 state.window_start = now;
253 state.success_count = 0;
254 state.throttle_count = 0;
255 }
256
257 match outcome {
259 RequestOutcome::Success => state.success_count += 1,
260 RequestOutcome::Throttled => state.throttle_count += 1,
261 }
262
263 state.rate
264 }
265
266 pub fn current_rate(&self) -> f64 {
268 self.state.lock().unwrap().rate
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use rstest::rstest;
276
277 #[rstest]
278 #[case::zero_initial_rate(
279 AimdConfig::default().with_initial_rate(0.0),
280 "initial_rate must be positive"
281 )]
282 #[case::negative_min_rate(
283 AimdConfig::default().with_min_rate(-1.0),
284 "min_rate must be positive"
285 )]
286 #[case::negative_max_rate(
287 AimdConfig::default().with_max_rate(-1.0),
288 "max_rate must be non-negative"
289 )]
290 #[case::min_exceeds_max(
291 AimdConfig::default().with_min_rate(100.0).with_max_rate(10.0),
292 "min_rate (100) must not exceed max_rate (10)"
293 )]
294 #[case::decrease_factor_zero(
295 AimdConfig::default().with_decrease_factor(0.0),
296 "decrease_factor must be in (0, 1)"
297 )]
298 #[case::decrease_factor_one(
299 AimdConfig::default().with_decrease_factor(1.0),
300 "decrease_factor must be in (0, 1)"
301 )]
302 #[case::decrease_factor_over_one(
303 AimdConfig::default().with_decrease_factor(1.5),
304 "decrease_factor must be in (0, 1)"
305 )]
306 #[case::zero_additive_increment(
307 AimdConfig::default().with_additive_increment(0.0),
308 "additive_increment must be positive"
309 )]
310 #[case::zero_window_duration(
311 AimdConfig::default().with_window_duration(Duration::ZERO),
312 "window_duration must be non-zero"
313 )]
314 #[case::threshold_over_one(
315 AimdConfig::default().with_throttle_threshold(1.1),
316 "throttle_threshold must be in [0.0, 1.0]"
317 )]
318 #[case::threshold_negative(
319 AimdConfig::default().with_throttle_threshold(-0.1),
320 "throttle_threshold must be in [0.0, 1.0]"
321 )]
322 #[case::initial_exceeds_max(
323 AimdConfig::default().with_initial_rate(6000.0),
324 "initial_rate (6000) must not exceed max_rate (5000)"
325 )]
326 #[case::initial_below_min(
327 AimdConfig::default().with_initial_rate(0.5).with_min_rate(1.0),
328 "initial_rate (0.5) must not be below min_rate (1)"
329 )]
330 fn test_config_validation_rejects_invalid(
331 #[case] config: AimdConfig,
332 #[case] expected_msg: &str,
333 ) {
334 let err = config.validate().unwrap_err();
335 let msg = err.to_string();
336 assert!(
337 msg.contains(expected_msg),
338 "Expected error containing '{}', got: {}",
339 expected_msg,
340 msg
341 );
342 }
343
344 #[test]
345 fn test_default_config_is_valid() {
346 AimdConfig::default().validate().unwrap();
347 }
348
349 #[test]
350 fn test_no_ceiling_config_is_valid() {
351 AimdConfig::default().with_max_rate(0.0).validate().unwrap();
352 }
353
354 #[test]
355 fn test_additive_increase_on_success_window() {
356 let config = AimdConfig::default()
357 .with_initial_rate(100.0)
358 .with_additive_increment(10.0)
359 .with_window_duration(Duration::from_millis(100));
360 let controller = AimdController::new(config).unwrap();
361
362 let start = std::time::Instant::now();
364 {
365 let mut state = controller.state.lock().unwrap();
366 controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
367 }
368
369 let after_window = start + Duration::from_millis(150);
371 {
372 let mut state = controller.state.lock().unwrap();
373 controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window);
374 }
375
376 assert_eq!(controller.current_rate(), 110.0);
378 }
379
380 #[test]
381 fn test_multiplicative_decrease_on_throttle_window() {
382 let config = AimdConfig::default()
383 .with_initial_rate(100.0)
384 .with_decrease_factor(0.5)
385 .with_window_duration(Duration::from_millis(100));
386 let controller = AimdController::new(config).unwrap();
387
388 let start = std::time::Instant::now();
389 {
390 let mut state = controller.state.lock().unwrap();
391 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
392 }
393
394 let after_window = start + Duration::from_millis(150);
396 {
397 let mut state = controller.state.lock().unwrap();
398 controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window);
399 }
400
401 assert_eq!(controller.current_rate(), 50.0);
402 }
403
404 #[test]
405 fn test_floor_enforcement() {
406 let config = AimdConfig::default()
407 .with_initial_rate(2.0)
408 .with_min_rate(1.0)
409 .with_decrease_factor(0.5)
410 .with_window_duration(Duration::from_millis(100));
411 let controller = AimdController::new(config).unwrap();
412
413 let start = std::time::Instant::now();
414 {
415 let mut state = controller.state.lock().unwrap();
416 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
417 }
418
419 let t1 = start + Duration::from_millis(150);
421 {
422 let mut state = controller.state.lock().unwrap();
423 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, t1);
424 }
425 assert_eq!(controller.current_rate(), 1.0);
426
427 let t2 = t1 + Duration::from_millis(150);
429 {
430 let mut state = controller.state.lock().unwrap();
431 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2);
432 }
433 assert_eq!(controller.current_rate(), 1.0);
434 }
435
436 #[test]
437 fn test_ceiling_enforcement() {
438 let config = AimdConfig::default()
439 .with_initial_rate(4990.0)
440 .with_max_rate(5000.0)
441 .with_additive_increment(20.0)
442 .with_window_duration(Duration::from_millis(100));
443 let controller = AimdController::new(config).unwrap();
444
445 let start = std::time::Instant::now();
446 {
447 let mut state = controller.state.lock().unwrap();
448 controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
449 }
450
451 let t1 = start + Duration::from_millis(150);
452 {
453 let mut state = controller.state.lock().unwrap();
454 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
455 }
456 assert_eq!(controller.current_rate(), 5000.0);
458 }
459
460 #[test]
461 fn test_no_ceiling_allows_unbounded_growth() {
462 let config = AimdConfig::default()
463 .with_initial_rate(100.0)
464 .with_max_rate(0.0)
465 .with_additive_increment(50.0)
466 .with_window_duration(Duration::from_millis(100));
467 let controller = AimdController::new(config).unwrap();
468
469 let start = std::time::Instant::now();
470 let mut t = start;
471
472 for _ in 0..5 {
473 {
474 let mut state = controller.state.lock().unwrap();
475 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t);
476 }
477 t += Duration::from_millis(150);
478 }
479
480 {
482 let mut state = controller.state.lock().unwrap();
483 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t);
484 }
485
486 assert_eq!(controller.current_rate(), 350.0);
488 }
489
490 #[test]
491 fn test_empty_window_no_adjustment() {
492 let config = AimdConfig::default()
493 .with_initial_rate(100.0)
494 .with_window_duration(Duration::from_millis(100));
495 let controller = AimdController::new(config).unwrap();
496
497 let start = std::time::Instant::now();
499 let after = start + Duration::from_millis(150);
500 {
501 let mut state = controller.state.lock().unwrap();
502 controller.record_outcome_inner(&mut state, RequestOutcome::Success, after);
504 }
505 assert_eq!(controller.current_rate(), 100.0);
507 }
508
509 #[test]
510 fn test_throttle_threshold_filtering() {
511 let config = AimdConfig::default()
513 .with_initial_rate(100.0)
514 .with_throttle_threshold(0.5)
515 .with_additive_increment(10.0)
516 .with_window_duration(Duration::from_millis(100));
517 let controller = AimdController::new(config).unwrap();
518
519 let start = std::time::Instant::now();
520 {
521 let mut state = controller.state.lock().unwrap();
522 controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
524 controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
525 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
526 }
527
528 let t1 = start + Duration::from_millis(150);
530 {
531 let mut state = controller.state.lock().unwrap();
532 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
533 }
534
535 assert_eq!(controller.current_rate(), 110.0);
537 }
538
539 #[test]
540 fn test_throttle_threshold_triggers_decrease() {
541 let config = AimdConfig::default()
543 .with_initial_rate(100.0)
544 .with_throttle_threshold(0.5)
545 .with_decrease_factor(0.5)
546 .with_window_duration(Duration::from_millis(100));
547 let controller = AimdController::new(config).unwrap();
548
549 let start = std::time::Instant::now();
550 {
551 let mut state = controller.state.lock().unwrap();
552 controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
554 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
555 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
556 }
557
558 let t1 = start + Duration::from_millis(150);
559 {
560 let mut state = controller.state.lock().unwrap();
561 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
562 }
563
564 assert_eq!(controller.current_rate(), 50.0);
565 }
566
567 #[test]
568 fn test_recovery_after_decrease() {
569 let config = AimdConfig::default()
570 .with_initial_rate(100.0)
571 .with_decrease_factor(0.5)
572 .with_additive_increment(10.0)
573 .with_window_duration(Duration::from_millis(100));
574 let controller = AimdController::new(config).unwrap();
575
576 let start = std::time::Instant::now();
577
578 {
580 let mut state = controller.state.lock().unwrap();
581 controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
582 }
583 let t1 = start + Duration::from_millis(150);
584
585 {
587 let mut state = controller.state.lock().unwrap();
588 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
589 }
590 let t2 = t1 + Duration::from_millis(150);
591
592 {
594 let mut state = controller.state.lock().unwrap();
595 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2);
596 }
597 let t3 = t2 + Duration::from_millis(150);
598
599 {
601 let mut state = controller.state.lock().unwrap();
602 controller.record_outcome_inner(&mut state, RequestOutcome::Success, t3);
603 }
604
605 assert_eq!(controller.current_rate(), 70.0);
606 }
607
608 #[test]
609 fn test_within_window_no_adjustment() {
610 let config = AimdConfig::default()
611 .with_initial_rate(100.0)
612 .with_window_duration(Duration::from_secs(10));
613 let controller = AimdController::new(config).unwrap();
614
615 for _ in 0..100 {
617 controller.record_outcome(RequestOutcome::Throttled);
618 }
619
620 assert_eq!(controller.current_rate(), 100.0);
622 }
623}