congestion_limiter/limits/
windowed.rs

1use std::{ops::RangeInclusive, time::Duration};
2
3use async_trait::async_trait;
4use tokio::{sync::Mutex, time::Instant};
5
6use crate::aggregation::Aggregator;
7
8use super::{defaults::MIN_SAMPLE_LATENCY, LimitAlgorithm, Sample};
9
10/// A wrapper around a [LimitAlgorithm] which aggregates samples within a window, periodically
11/// updating the limit.
12///
13/// The window duration is dynamic, based on latencies seen in the previous window.
14///
15/// Various [aggregators](crate::aggregation) are available to aggregate samples.
16#[derive(Debug)]
17pub struct Windowed<L, S> {
18    window_bounds: RangeInclusive<Duration>,
19    min_samples: usize,
20
21    /// Samples below this threshold will be discarded and not contribute to the current window.
22    ///
23    /// Useful for discarding samples which are not representative of the system we're trying to
24    /// observe. For example, if an error occurs locally on the client machine, it doesn't tell us
25    /// anything about the state of the server we're trying to communicate with.
26    min_latency_threshold: Duration,
27
28    inner: L,
29
30    window: Mutex<Window<S>>,
31}
32
33#[derive(Debug)]
34struct Window<S> {
35    start: Instant,
36    duration: Duration,
37
38    aggregator: S,
39    /// The minimum latency observed in the current window.
40    ///
41    /// Used to determine the next window duration.
42    min_latency: Duration,
43}
44
45impl<L: LimitAlgorithm, S: Aggregator> Windowed<L, S> {
46    const DEFAULT_MIN_SAMPLES: usize = 10;
47
48    #[allow(missing_docs)]
49    pub fn new(inner: L, sampler: S) -> Self {
50        let min_window = Duration::from_micros(1);
51        Self {
52            window_bounds: RangeInclusive::new(min_window, Duration::from_secs(1)),
53            min_samples: Self::DEFAULT_MIN_SAMPLES,
54            min_latency_threshold: MIN_SAMPLE_LATENCY,
55
56            inner,
57
58            window: Mutex::new(Window {
59                duration: min_window,
60                start: Instant::now(),
61
62                aggregator: sampler,
63                min_latency: Duration::MAX,
64            }),
65        }
66    }
67
68    /// At least this many samples need to be aggregated before updating the limit.
69    pub fn with_min_samples(mut self, samples: usize) -> Self {
70        assert!(samples > 0, "at least one sample required per window");
71        self.min_samples = samples;
72        self
73    }
74
75    /// Minimum time to wait before attempting to update the limit.
76    pub fn with_min_window(mut self, min: Duration) -> Self {
77        self.window_bounds = min..=*self.window_bounds.end();
78        self
79    }
80
81    /// Maximum time to wait before attempting to update the limit.
82    ///
83    /// Will wait for longer if not enough samples have been aggregated. See
84    /// [with_min_samples()](Self::with_min_samples()).
85    pub fn with_max_window(mut self, max: Duration) -> Self {
86        self.window_bounds = *self.window_bounds.start()..=max;
87        self
88    }
89}
90
91#[async_trait]
92impl<L, S> LimitAlgorithm for Windowed<L, S>
93where
94    L: LimitAlgorithm + Send + Sync,
95    S: Aggregator + Send + Sync,
96{
97    fn limit(&self) -> usize {
98        self.inner.limit()
99    }
100
101    async fn update(&self, sample: Sample) -> usize {
102        if sample.latency < self.min_latency_threshold {
103            return self.inner.limit();
104        }
105
106        let mut window = self.window.lock().await;
107
108        window.min_latency = window.min_latency.min(sample.latency);
109
110        let agg_sample = window.aggregator.sample(sample);
111
112        if window.aggregator.sample_size() >= self.min_samples
113            && window.start.elapsed() >= window.duration
114        {
115            window.reset(&self.window_bounds);
116
117            self.inner.update(agg_sample).await
118        } else {
119            self.inner.limit()
120        }
121    }
122}
123
124impl<S> Window<S>
125where
126    S: Aggregator,
127{
128    fn reset(&mut self, bounds: &RangeInclusive<Duration>) {
129        self.min_latency = Duration::MAX;
130        self.aggregator.reset();
131
132        self.start = Instant::now();
133
134        // Use a window duration of 2 * RTT (RTT ~= min latency).
135        self.duration = self.min_latency.clamp(*bounds.start(), *bounds.end()) * 2;
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use crate::{aggregation::Average, limiter::Outcome, limits::Vegas};
142
143    use super::*;
144
145    #[tokio::test]
146    async fn it_works() {
147        let samples = 2;
148
149        // Just test with a min sample size for now
150        let windowed_vegas = Windowed::new(Vegas::new_with_initial_limit(10), Average::default())
151            .with_min_samples(samples)
152            .with_min_window(Duration::ZERO)
153            .with_max_window(Duration::ZERO);
154
155        let mut limit = 0;
156
157        for _ in 0..samples {
158            limit = windowed_vegas
159                .update(Sample {
160                    in_flight: 1,
161                    latency: Duration::from_millis(10),
162                    outcome: Outcome::Success,
163                })
164                .await;
165        }
166        assert_eq!(limit, 10, "first window shouldn't change limit for Vegas");
167
168        for _ in 0..samples {
169            limit = windowed_vegas
170                .update(Sample {
171                    in_flight: 1,
172                    latency: Duration::from_millis(100),
173                    outcome: Outcome::Overload,
174                })
175                .await;
176        }
177        assert!(limit < 10, "limit should be reduced");
178    }
179}