Skip to main content

lance_core/utils/
aimd.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! AIMD (Additive Increase / Multiplicative Decrease) rate controller.
5//!
6//! This module provides a reusable AIMD algorithm for dynamically adjusting
7//! request rates. On success windows, the rate increases additively. On
8//! windows with throttle signals, the rate decreases multiplicatively.
9//!
10//! The algorithm operates in discrete time windows. At the end of each window,
11//! the throttle ratio (throttled / total) is compared against a threshold:
12//! - Above threshold: `rate = max(rate * decrease_factor, min_rate)`
13//! - At or below threshold: `rate = min(rate + additive_increment, max_rate)`
14
15use std::sync::Mutex;
16use std::time::Duration;
17
18use crate::Result;
19
20/// Configuration for the AIMD rate controller.
21///
22/// Use builder methods to customize. Defaults are tuned for cloud object stores
23/// and will start at about 40% of the max rate and require 10 seconds to reach
24/// the max rate.
25///
26/// - initial_rate: 2000 req/s
27/// - min_rate: 1 req/s
28/// - max_rate: 5000 req/s (0.0 disables ceiling)
29/// - decrease_factor: 0.5 (halve on throttle)
30/// - additive_increment: 300 req/s per success window
31/// - window_duration: 1 second
32/// - throttle_threshold: 0.0 (any throttle triggers decrease)
33#[derive(Debug, Clone)]
34pub struct AimdConfig {
35    pub initial_rate: f64,
36    pub min_rate: f64,
37    pub max_rate: f64,
38    pub decrease_factor: f64,
39    pub additive_increment: f64,
40    pub window_duration: Duration,
41    pub throttle_threshold: f64,
42}
43
44impl Default for AimdConfig {
45    fn default() -> Self {
46        Self {
47            initial_rate: 2000.0,
48            min_rate: 1.0,
49            max_rate: 5000.0,
50            decrease_factor: 0.5,
51            additive_increment: 300.0,
52            window_duration: Duration::from_secs(1),
53            throttle_threshold: 0.0,
54        }
55    }
56}
57
58impl AimdConfig {
59    pub fn with_initial_rate(self, initial_rate: f64) -> Self {
60        Self {
61            initial_rate,
62            ..self
63        }
64    }
65
66    pub fn with_min_rate(self, min_rate: f64) -> Self {
67        Self { min_rate, ..self }
68    }
69
70    pub fn with_max_rate(self, max_rate: f64) -> Self {
71        Self { max_rate, ..self }
72    }
73
74    pub fn with_decrease_factor(self, decrease_factor: f64) -> Self {
75        Self {
76            decrease_factor,
77            ..self
78        }
79    }
80
81    pub fn with_additive_increment(self, additive_increment: f64) -> Self {
82        Self {
83            additive_increment,
84            ..self
85        }
86    }
87
88    pub fn with_window_duration(self, window_duration: Duration) -> Self {
89        Self {
90            window_duration,
91            ..self
92        }
93    }
94
95    pub fn with_throttle_threshold(self, throttle_threshold: f64) -> Self {
96        Self {
97            throttle_threshold,
98            ..self
99        }
100    }
101
102    /// Validate that the configuration values are sensible.
103    pub fn validate(&self) -> Result<()> {
104        if self.initial_rate <= 0.0 {
105            return Err(crate::Error::invalid_input(format!(
106                "initial_rate must be positive, got {}",
107                self.initial_rate
108            )));
109        }
110        if self.min_rate <= 0.0 {
111            return Err(crate::Error::invalid_input(format!(
112                "min_rate must be positive, got {}",
113                self.min_rate
114            )));
115        }
116        if self.max_rate < 0.0 {
117            return Err(crate::Error::invalid_input(format!(
118                "max_rate must be non-negative (0.0 = no ceiling), got {}",
119                self.max_rate
120            )));
121        }
122        if self.max_rate > 0.0 && self.min_rate > self.max_rate {
123            return Err(crate::Error::invalid_input(format!(
124                "min_rate ({}) must not exceed max_rate ({})",
125                self.min_rate, self.max_rate
126            )));
127        }
128        if self.decrease_factor <= 0.0 || self.decrease_factor >= 1.0 {
129            return Err(crate::Error::invalid_input(format!(
130                "decrease_factor must be in (0, 1), got {}",
131                self.decrease_factor
132            )));
133        }
134        if self.additive_increment <= 0.0 {
135            return Err(crate::Error::invalid_input(format!(
136                "additive_increment must be positive, got {}",
137                self.additive_increment
138            )));
139        }
140        if self.window_duration.is_zero() {
141            return Err(crate::Error::invalid_input(
142                "window_duration must be non-zero",
143            ));
144        }
145        if !(0.0..=1.0).contains(&self.throttle_threshold) {
146            return Err(crate::Error::invalid_input(format!(
147                "throttle_threshold must be in [0.0, 1.0], got {}",
148                self.throttle_threshold
149            )));
150        }
151        if self.max_rate > 0.0 && self.initial_rate > self.max_rate {
152            return Err(crate::Error::invalid_input(format!(
153                "initial_rate ({}) must not exceed max_rate ({})",
154                self.initial_rate, self.max_rate
155            )));
156        }
157        if self.initial_rate < self.min_rate {
158            return Err(crate::Error::invalid_input(format!(
159                "initial_rate ({}) must not be below min_rate ({})",
160                self.initial_rate, self.min_rate
161            )));
162        }
163        Ok(())
164    }
165}
166
167/// Outcome of a single request, used to feed the AIMD controller.
168///
169/// Non-throttle errors (e.g. 404, network timeout) should be mapped to
170/// `Success` since they don't indicate capacity problems.
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum RequestOutcome {
173    Success,
174    Throttled,
175}
176
177struct AimdState {
178    rate: f64,
179    window_start: std::time::Instant,
180    success_count: u64,
181    throttle_count: u64,
182}
183
184/// AIMD rate controller.
185///
186/// Thread-safe: uses an internal `Mutex` to protect state. The lock is held
187/// only briefly during `record_outcome` and `current_rate`.
188pub struct AimdController {
189    config: AimdConfig,
190    state: Mutex<AimdState>,
191}
192
193impl std::fmt::Debug for AimdController {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("AimdController")
196            .field("config", &self.config)
197            .field("rate", &self.current_rate())
198            .finish()
199    }
200}
201
202impl AimdController {
203    /// Create a new AIMD controller with the given configuration.
204    pub fn new(config: AimdConfig) -> Result<Self> {
205        config.validate()?;
206        let rate = config.initial_rate;
207        Ok(Self {
208            config,
209            state: Mutex::new(AimdState {
210                rate,
211                window_start: std::time::Instant::now(),
212                success_count: 0,
213                throttle_count: 0,
214            }),
215        })
216    }
217
218    /// Record a request outcome and return the current rate.
219    ///
220    /// If the current time window has expired, the rate is adjusted before
221    /// recording the new outcome in a fresh window.
222    pub fn record_outcome(&self, outcome: RequestOutcome) -> f64 {
223        let mut state = self.state.lock().unwrap();
224        self.record_outcome_inner(&mut state, outcome, std::time::Instant::now())
225    }
226
227    fn record_outcome_inner(
228        &self,
229        state: &mut AimdState,
230        outcome: RequestOutcome,
231        now: std::time::Instant,
232    ) -> f64 {
233        // Check if the window has expired
234        let elapsed = now.duration_since(state.window_start);
235        if elapsed >= self.config.window_duration {
236            let total = state.success_count + state.throttle_count;
237            if total > 0 {
238                let throttle_ratio = state.throttle_count as f64 / total as f64;
239                if throttle_ratio > self.config.throttle_threshold {
240                    // Multiplicative decrease
241                    state.rate =
242                        (state.rate * self.config.decrease_factor).max(self.config.min_rate);
243                } else {
244                    // Additive increase
245                    state.rate += self.config.additive_increment;
246                    if self.config.max_rate > 0.0 {
247                        state.rate = state.rate.min(self.config.max_rate);
248                    }
249                }
250            }
251            // Reset window
252            state.window_start = now;
253            state.success_count = 0;
254            state.throttle_count = 0;
255        }
256
257        // Record this outcome
258        match outcome {
259            RequestOutcome::Success => state.success_count += 1,
260            RequestOutcome::Throttled => state.throttle_count += 1,
261        }
262
263        state.rate
264    }
265
266    /// Get the current rate without recording an outcome.
267    pub fn current_rate(&self) -> f64 {
268        self.state.lock().unwrap().rate
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use rstest::rstest;
276
277    #[rstest]
278    #[case::zero_initial_rate(
279        AimdConfig::default().with_initial_rate(0.0),
280        "initial_rate must be positive"
281    )]
282    #[case::negative_min_rate(
283        AimdConfig::default().with_min_rate(-1.0),
284        "min_rate must be positive"
285    )]
286    #[case::negative_max_rate(
287        AimdConfig::default().with_max_rate(-1.0),
288        "max_rate must be non-negative"
289    )]
290    #[case::min_exceeds_max(
291        AimdConfig::default().with_min_rate(100.0).with_max_rate(10.0),
292        "min_rate (100) must not exceed max_rate (10)"
293    )]
294    #[case::decrease_factor_zero(
295        AimdConfig::default().with_decrease_factor(0.0),
296        "decrease_factor must be in (0, 1)"
297    )]
298    #[case::decrease_factor_one(
299        AimdConfig::default().with_decrease_factor(1.0),
300        "decrease_factor must be in (0, 1)"
301    )]
302    #[case::decrease_factor_over_one(
303        AimdConfig::default().with_decrease_factor(1.5),
304        "decrease_factor must be in (0, 1)"
305    )]
306    #[case::zero_additive_increment(
307        AimdConfig::default().with_additive_increment(0.0),
308        "additive_increment must be positive"
309    )]
310    #[case::zero_window_duration(
311        AimdConfig::default().with_window_duration(Duration::ZERO),
312        "window_duration must be non-zero"
313    )]
314    #[case::threshold_over_one(
315        AimdConfig::default().with_throttle_threshold(1.1),
316        "throttle_threshold must be in [0.0, 1.0]"
317    )]
318    #[case::threshold_negative(
319        AimdConfig::default().with_throttle_threshold(-0.1),
320        "throttle_threshold must be in [0.0, 1.0]"
321    )]
322    #[case::initial_exceeds_max(
323        AimdConfig::default().with_initial_rate(6000.0),
324        "initial_rate (6000) must not exceed max_rate (5000)"
325    )]
326    #[case::initial_below_min(
327        AimdConfig::default().with_initial_rate(0.5).with_min_rate(1.0),
328        "initial_rate (0.5) must not be below min_rate (1)"
329    )]
330    fn test_config_validation_rejects_invalid(
331        #[case] config: AimdConfig,
332        #[case] expected_msg: &str,
333    ) {
334        let err = config.validate().unwrap_err();
335        let msg = err.to_string();
336        assert!(
337            msg.contains(expected_msg),
338            "Expected error containing '{}', got: {}",
339            expected_msg,
340            msg
341        );
342    }
343
344    #[test]
345    fn test_default_config_is_valid() {
346        AimdConfig::default().validate().unwrap();
347    }
348
349    #[test]
350    fn test_no_ceiling_config_is_valid() {
351        AimdConfig::default().with_max_rate(0.0).validate().unwrap();
352    }
353
354    #[test]
355    fn test_additive_increase_on_success_window() {
356        let config = AimdConfig::default()
357            .with_initial_rate(100.0)
358            .with_additive_increment(10.0)
359            .with_window_duration(Duration::from_millis(100));
360        let controller = AimdController::new(config).unwrap();
361
362        // Record some successes in the first window
363        let start = std::time::Instant::now();
364        {
365            let mut state = controller.state.lock().unwrap();
366            controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
367        }
368
369        // Advance past the window boundary and record another success
370        let after_window = start + Duration::from_millis(150);
371        {
372            let mut state = controller.state.lock().unwrap();
373            controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window);
374        }
375
376        // Rate should have increased by additive_increment
377        assert_eq!(controller.current_rate(), 110.0);
378    }
379
380    #[test]
381    fn test_multiplicative_decrease_on_throttle_window() {
382        let config = AimdConfig::default()
383            .with_initial_rate(100.0)
384            .with_decrease_factor(0.5)
385            .with_window_duration(Duration::from_millis(100));
386        let controller = AimdController::new(config).unwrap();
387
388        let start = std::time::Instant::now();
389        {
390            let mut state = controller.state.lock().unwrap();
391            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
392        }
393
394        // Advance past window
395        let after_window = start + Duration::from_millis(150);
396        {
397            let mut state = controller.state.lock().unwrap();
398            controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window);
399        }
400
401        assert_eq!(controller.current_rate(), 50.0);
402    }
403
404    #[test]
405    fn test_floor_enforcement() {
406        let config = AimdConfig::default()
407            .with_initial_rate(2.0)
408            .with_min_rate(1.0)
409            .with_decrease_factor(0.5)
410            .with_window_duration(Duration::from_millis(100));
411        let controller = AimdController::new(config).unwrap();
412
413        let start = std::time::Instant::now();
414        {
415            let mut state = controller.state.lock().unwrap();
416            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
417        }
418
419        // After decrease: 2.0 * 0.5 = 1.0 (at floor)
420        let t1 = start + Duration::from_millis(150);
421        {
422            let mut state = controller.state.lock().unwrap();
423            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, t1);
424        }
425        assert_eq!(controller.current_rate(), 1.0);
426
427        // Another decrease should stay at floor
428        let t2 = t1 + Duration::from_millis(150);
429        {
430            let mut state = controller.state.lock().unwrap();
431            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2);
432        }
433        assert_eq!(controller.current_rate(), 1.0);
434    }
435
436    #[test]
437    fn test_ceiling_enforcement() {
438        let config = AimdConfig::default()
439            .with_initial_rate(4990.0)
440            .with_max_rate(5000.0)
441            .with_additive_increment(20.0)
442            .with_window_duration(Duration::from_millis(100));
443        let controller = AimdController::new(config).unwrap();
444
445        let start = std::time::Instant::now();
446        {
447            let mut state = controller.state.lock().unwrap();
448            controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
449        }
450
451        let t1 = start + Duration::from_millis(150);
452        {
453            let mut state = controller.state.lock().unwrap();
454            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
455        }
456        // 4990 + 20 = 5010, clamped to 5000
457        assert_eq!(controller.current_rate(), 5000.0);
458    }
459
460    #[test]
461    fn test_no_ceiling_allows_unbounded_growth() {
462        let config = AimdConfig::default()
463            .with_initial_rate(100.0)
464            .with_max_rate(0.0)
465            .with_additive_increment(50.0)
466            .with_window_duration(Duration::from_millis(100));
467        let controller = AimdController::new(config).unwrap();
468
469        let start = std::time::Instant::now();
470        let mut t = start;
471
472        for _ in 0..5 {
473            {
474                let mut state = controller.state.lock().unwrap();
475                controller.record_outcome_inner(&mut state, RequestOutcome::Success, t);
476            }
477            t += Duration::from_millis(150);
478        }
479
480        // Trigger final window evaluation
481        {
482            let mut state = controller.state.lock().unwrap();
483            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t);
484        }
485
486        // 100 + 50*5 = 350
487        assert_eq!(controller.current_rate(), 350.0);
488    }
489
490    #[test]
491    fn test_empty_window_no_adjustment() {
492        let config = AimdConfig::default()
493            .with_initial_rate(100.0)
494            .with_window_duration(Duration::from_millis(100));
495        let controller = AimdController::new(config).unwrap();
496
497        // Don't record anything in the first window, just advance time
498        let start = std::time::Instant::now();
499        let after = start + Duration::from_millis(150);
500        {
501            let mut state = controller.state.lock().unwrap();
502            // First outcome in a new window after empty window
503            controller.record_outcome_inner(&mut state, RequestOutcome::Success, after);
504        }
505        // No adjustment because the expired window had 0 total
506        assert_eq!(controller.current_rate(), 100.0);
507    }
508
509    #[test]
510    fn test_throttle_threshold_filtering() {
511        // With threshold 0.5, less than 50% throttles should still increase
512        let config = AimdConfig::default()
513            .with_initial_rate(100.0)
514            .with_throttle_threshold(0.5)
515            .with_additive_increment(10.0)
516            .with_window_duration(Duration::from_millis(100));
517        let controller = AimdController::new(config).unwrap();
518
519        let start = std::time::Instant::now();
520        {
521            let mut state = controller.state.lock().unwrap();
522            // 1 throttle out of 3 = 33% < 50% threshold
523            controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
524            controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
525            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
526        }
527
528        // Advance past window
529        let t1 = start + Duration::from_millis(150);
530        {
531            let mut state = controller.state.lock().unwrap();
532            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
533        }
534
535        // Should have increased because 33% <= 50%
536        assert_eq!(controller.current_rate(), 110.0);
537    }
538
539    #[test]
540    fn test_throttle_threshold_triggers_decrease() {
541        // With threshold 0.5, >= 50% throttles should decrease
542        let config = AimdConfig::default()
543            .with_initial_rate(100.0)
544            .with_throttle_threshold(0.5)
545            .with_decrease_factor(0.5)
546            .with_window_duration(Duration::from_millis(100));
547        let controller = AimdController::new(config).unwrap();
548
549        let start = std::time::Instant::now();
550        {
551            let mut state = controller.state.lock().unwrap();
552            // 2 throttle out of 3 = 67% > 50% threshold
553            controller.record_outcome_inner(&mut state, RequestOutcome::Success, start);
554            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
555            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
556        }
557
558        let t1 = start + Duration::from_millis(150);
559        {
560            let mut state = controller.state.lock().unwrap();
561            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
562        }
563
564        assert_eq!(controller.current_rate(), 50.0);
565    }
566
567    #[test]
568    fn test_recovery_after_decrease() {
569        let config = AimdConfig::default()
570            .with_initial_rate(100.0)
571            .with_decrease_factor(0.5)
572            .with_additive_increment(10.0)
573            .with_window_duration(Duration::from_millis(100));
574        let controller = AimdController::new(config).unwrap();
575
576        let start = std::time::Instant::now();
577
578        // Window 1: throttle → decrease to 50
579        {
580            let mut state = controller.state.lock().unwrap();
581            controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start);
582        }
583        let t1 = start + Duration::from_millis(150);
584
585        // Window 2: success → increase to 60
586        {
587            let mut state = controller.state.lock().unwrap();
588            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1);
589        }
590        let t2 = t1 + Duration::from_millis(150);
591
592        // Window 3: success → increase to 70
593        {
594            let mut state = controller.state.lock().unwrap();
595            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2);
596        }
597        let t3 = t2 + Duration::from_millis(150);
598
599        // Trigger final evaluation
600        {
601            let mut state = controller.state.lock().unwrap();
602            controller.record_outcome_inner(&mut state, RequestOutcome::Success, t3);
603        }
604
605        assert_eq!(controller.current_rate(), 70.0);
606    }
607
608    #[test]
609    fn test_within_window_no_adjustment() {
610        let config = AimdConfig::default()
611            .with_initial_rate(100.0)
612            .with_window_duration(Duration::from_secs(10));
613        let controller = AimdController::new(config).unwrap();
614
615        // Record many outcomes but all within the same window
616        for _ in 0..100 {
617            controller.record_outcome(RequestOutcome::Throttled);
618        }
619
620        // Rate should still be initial since window hasn't expired
621        assert_eq!(controller.current_rate(), 100.0);
622    }
623}