Skip to main content

atomr_streams/
rate.rs

1//! Rate-mediation operators on `Source<T>`: `Conflate`,
2//! `ConflateWithSeed`, `Expand`, `Extrapolate`.
3//!
4//! These operators decouple producer / consumer rates without buffering
5//! every element: when downstream is slow, `conflate` collapses
6//! upstream values into a running aggregate; when upstream is slow,
7//! `expand` repeatedly emits a derived value until the next upstream
8//! element arrives.
9
10use std::collections::HashMap;
11use std::time::Duration;
12
13use futures::stream::{BoxStream, StreamExt};
14use tokio::sync::mpsc;
15use tokio::time::Instant;
16
17use crate::source::Source;
18
19/// `conflate(seed, fold)` — when downstream is slower than upstream,
20/// merge consecutive upstream elements into a running aggregate via
21/// `fold`. The aggregate is emitted whenever downstream pulls.
22///
23/// In our buffered-channel model "merge until pulled" is approximated
24/// by folding contiguous bursts inside the upstream task and emitting
25/// at each await point: every output element is the fold of the
26/// upstream burst since the last emission.
27pub fn conflate<T, U, S, F>(src: Source<T>, mut seed: S, mut fold: F) -> Source<U>
28where
29    T: Send + 'static,
30    U: Send + 'static,
31    S: FnMut(T) -> U + Send + 'static,
32    F: FnMut(U, T) -> U + Send + 'static,
33{
34    let (tx, rx) = mpsc::unbounded_channel::<U>();
35    let mut inner = src.into_boxed();
36    tokio::spawn(async move {
37        let mut acc: Option<U> = None;
38        loop {
39            match inner.next().await {
40                Some(item) => {
41                    acc = Some(match acc.take() {
42                        None => seed(item),
43                        Some(prev) => fold(prev, item),
44                    });
45                    // Try to flush the accumulator if downstream is
46                    // ready to receive — best-effort; otherwise keep
47                    // folding.
48                    if let Some(a) = acc.take() {
49                        if tx.send(a).is_err() {
50                            return;
51                        }
52                    }
53                }
54                None => {
55                    if let Some(a) = acc.take() {
56                        let _ = tx.send(a);
57                    }
58                    return;
59                }
60            }
61        }
62    });
63    Source::from_receiver(rx)
64}
65
66/// `expand(extrapolate)` — when upstream is slower than downstream,
67/// repeatedly call `extrapolate(last)` between elements to keep
68/// downstream supplied. After the upstream completes, the iterator
69/// returned by `extrapolate(last)` continues to be drained until it
70/// itself is exhausted.
71///
72/// / `Source.Extrapolate`.
73///
74/// The closure receives the most recent upstream element by reference
75/// and returns an `Iterator<Item = T>` describing the synthetic
76/// values to emit while waiting for the next upstream element.
77pub fn expand<T, F, I>(src: Source<T>, mut extrapolate: F) -> Source<T>
78where
79    T: Clone + Send + 'static,
80    F: FnMut(&T) -> I + Send + 'static,
81    I: Iterator<Item = T> + Send + 'static,
82{
83    let (tx, rx) = mpsc::unbounded_channel::<T>();
84    let mut inner = src.into_boxed();
85    tokio::spawn(async move {
86        let mut last: Option<T> = None;
87        loop {
88            match inner.next().await {
89                Some(item) => {
90                    if tx.send(item.clone()).is_err() {
91                        return;
92                    }
93                    last = Some(item);
94                }
95                None => {
96                    // Upstream done — drain extrapolation iterator
97                    // once, then close.
98                    if let Some(l) = last {
99                        for synth in extrapolate(&l) {
100                            if tx.send(synth).is_err() {
101                                return;
102                            }
103                        }
104                    }
105                    return;
106                }
107            }
108        }
109    });
110    Source::from_receiver(rx)
111}
112
113/// A classic token bucket: it accrues `rate_per_sec` tokens per second up to a
114/// ceiling of `burst` tokens, and each tracked unit consumes one token, waiting
115/// for one to accrue if the bucket is empty.
116///
117/// Refill is computed lazily from elapsed wall time ([`tokio::time::Instant`])
118/// at each call, so there is no background timer.
119struct TokenBucket {
120    /// Tokens accrued per second.
121    rate_per_sec: f64,
122    /// Maximum tokens that can be banked.
123    capacity: f64,
124    /// Current token count (fractional accrual is tracked).
125    tokens: f64,
126    /// Last time tokens were refilled.
127    last: Instant,
128}
129
130impl TokenBucket {
131    fn new(rate_per_sec: f64, burst: u32) -> Self {
132        let capacity = burst as f64;
133        TokenBucket {
134            rate_per_sec: rate_per_sec.max(0.0),
135            capacity,
136            // Start full so the initial burst is permitted immediately.
137            tokens: capacity,
138            last: Instant::now(),
139        }
140    }
141
142    /// Add tokens accrued since `last`, clamped to `capacity`.
143    fn refill(&mut self, now: Instant) {
144        let elapsed = now.saturating_duration_since(self.last).as_secs_f64();
145        if elapsed > 0.0 {
146            self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity);
147            self.last = now;
148        }
149    }
150
151    /// How long until at least one token is available, given `now`. `None`
152    /// means a token is available right now.
153    fn delay_until_token(&mut self, now: Instant) -> Option<Duration> {
154        self.refill(now);
155        if self.tokens >= 1.0 {
156            None
157        } else if self.rate_per_sec <= 0.0 {
158            // No refill will ever happen; treat as a very long wait so the
159            // element effectively stalls (degenerate config).
160            Some(Duration::from_secs(u64::MAX / 2))
161        } else {
162            let needed = 1.0 - self.tokens;
163            Some(Duration::from_secs_f64(needed / self.rate_per_sec))
164        }
165    }
166
167    /// Consume one token (caller must have ensured availability).
168    fn consume(&mut self) {
169        self.tokens -= 1.0;
170    }
171}
172
173/// Wait for a token in `bucket`, sleeping as needed, then consume it.
174async fn acquire(bucket: &mut TokenBucket) {
175    loop {
176        let now = Instant::now();
177        match bucket.delay_until_token(now) {
178            None => {
179                bucket.consume();
180                return;
181            }
182            Some(d) => {
183                tokio::time::sleep(d).await;
184                // Loop and re-check: sleep granularity / scheduling may mean we
185                // still need a touch more time.
186            }
187        }
188    }
189}
190
191/// `token_bucket(rate_per_sec, burst)` — a real-time leaky/token-bucket rate
192/// limiter.
193///
194/// The bucket refills at `rate_per_sec` tokens per second up to `burst` tokens
195/// of capacity and starts full, so an initial burst of up to `burst` elements
196/// passes immediately. Thereafter each element waits (via
197/// [`tokio::time::sleep`]) until a token is available, then consumes one.
198///
199/// **Property.** Over any window the number of emitted elements never exceeds
200/// `burst + rate_per_sec * window_seconds` (sustained rate plus the bucket
201/// capacity), modulo timer granularity.
202///
203/// This is a *wall-time* limiter — it intentionally uses real time rather than
204/// the logical [`Clock`](atomr_core::time::Clock); use
205/// [`clock_gated`](crate::clock_gated::clock_gated) for logical-time gating.
206pub fn token_bucket<T>(src: Source<T>, rate_per_sec: f64, burst: u32) -> Source<T>
207where
208    T: Send + 'static,
209{
210    struct State<T> {
211        inner: BoxStream<'static, T>,
212        bucket: TokenBucket,
213    }
214    let state = State { inner: src.into_boxed(), bucket: TokenBucket::new(rate_per_sec, burst) };
215    Source::unfold(state, |mut st| async move {
216        match st.inner.next().await {
217            None => None,
218            Some(item) => {
219                acquire(&mut st.bucket).await;
220                Some((item, st))
221            }
222        }
223    })
224}
225
226/// `token_bucket_keyed(key, rate_per_sec, burst)` — like [`token_bucket`] but
227/// maintains an independent bucket per key.
228///
229/// Each distinct key returned by `key` gets its own `TokenBucket` with the
230/// same `rate_per_sec` / `burst` parameters, so heavy traffic on one key never
231/// starves another. Buckets are created lazily on first sight of a key and held
232/// for the lifetime of the stream.
233pub fn token_bucket_keyed<T, K, F>(src: Source<T>, key: F, rate_per_sec: f64, burst: u32) -> Source<T>
234where
235    T: Send + 'static,
236    K: Eq + std::hash::Hash + Send + 'static,
237    F: Fn(&T) -> K + Send + 'static,
238{
239    struct State<T, K, F> {
240        inner: BoxStream<'static, T>,
241        buckets: HashMap<K, TokenBucket>,
242        key: F,
243        rate_per_sec: f64,
244        burst: u32,
245    }
246    let state = State { inner: src.into_boxed(), buckets: HashMap::new(), key, rate_per_sec, burst };
247    Source::unfold(state, |mut st| async move {
248        match st.inner.next().await {
249            None => None,
250            Some(item) => {
251                let k = (st.key)(&item);
252                let rate = st.rate_per_sec;
253                let burst = st.burst;
254                let bucket = st.buckets.entry(k).or_insert_with(|| TokenBucket::new(rate, burst));
255                acquire(bucket).await;
256                Some((item, st))
257            }
258        }
259    })
260}
261
262/// A minimal carrier for a "retry after N seconds" signal.
263///
264/// This is deliberately protocol-agnostic: parsing HTTP `429 Too Many Requests`
265/// / `Retry-After` headers into a `RetryAfter` lives in the future
266/// `atomr-streams-io` crate. [`respect_retry_after`] only consumes the
267/// already-extracted duration.
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct RetryAfter {
270    /// Number of seconds the producer asked us to back off.
271    pub seconds: u64,
272}
273
274/// `respect_retry_after` — honour back-off signals carried in-band.
275///
276/// The input is a stream of `Result<T, RetryAfter>`. Every element is passed
277/// through unchanged (both `Ok` and `Err` variants are forwarded — nothing is
278/// dropped), but whenever an element carries a [`RetryAfter`], emission pauses
279/// for the requested duration *after* forwarding it, so all subsequent elements
280/// are delayed rather than discarded.
281///
282/// Parsing of protocol-level rate-limit responses (e.g. HTTP 429) into
283/// `RetryAfter` is the responsibility of the future `atomr-streams-io` crate;
284/// this operator is the generic, transport-free building block.
285pub fn respect_retry_after<T>(src: Source<Result<T, RetryAfter>>) -> Source<Result<T, RetryAfter>>
286where
287    T: Send + 'static,
288{
289    struct State<T> {
290        inner: BoxStream<'static, Result<T, RetryAfter>>,
291        // A back-off to serve before pulling the next upstream element.
292        pending_backoff: Option<Duration>,
293    }
294    let state = State { inner: src.into_boxed(), pending_backoff: None };
295    Source::unfold(state, |mut st| async move {
296        // Honour a back-off requested by the previously emitted element before
297        // touching upstream again — subsequent elements are delayed, not dropped.
298        if let Some(d) = st.pending_backoff.take() {
299            tokio::time::sleep(d).await;
300        }
301        match st.inner.next().await {
302            None => None,
303            Some(item) => {
304                if let Err(ra) = &item {
305                    if ra.seconds > 0 {
306                        st.pending_backoff = Some(Duration::from_secs(ra.seconds));
307                    }
308                }
309                Some((item, st))
310            }
311        }
312    })
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::sink::Sink;
319
320    #[tokio::test]
321    async fn conflate_passes_through_when_downstream_keeps_up() {
322        let s = Source::from_iter(vec![1u32, 2, 3]);
323        let out = Sink::collect(conflate(s, |t| t, |a, b| a + b)).await;
324        // With unbounded channel + immediate flush, each element
325        // emerges separately rather than folded.
326        assert_eq!(out, vec![1, 2, 3]);
327    }
328
329    #[tokio::test]
330    async fn conflate_seed_initializes_accumulator() {
331        let s = Source::from_iter(vec![10u32]);
332        let out = Sink::collect(conflate(s, |t| t * 2, |a, b| a + b)).await;
333        assert_eq!(out, vec![20]);
334    }
335
336    #[tokio::test]
337    async fn expand_emits_extrapolated_values_after_upstream_close() {
338        let s = Source::from_iter(vec![5i32]);
339        let out = Sink::collect(expand(s, |last| {
340            let l = *last;
341            (0..3).map(move |i| l + i + 1)
342        }))
343        .await;
344        // upstream emits 5, then extrapolation iterator emits 6, 7, 8.
345        assert_eq!(out, vec![5, 6, 7, 8]);
346    }
347
348    #[tokio::test]
349    async fn expand_no_synthetics_when_iterator_empty() {
350        let s = Source::from_iter(vec![1i32, 2, 3]);
351        let out = Sink::collect(expand(s, |_last| std::iter::empty::<i32>())).await;
352        assert_eq!(out, vec![1, 2, 3]);
353    }
354
355    #[tokio::test]
356    async fn token_bucket_respects_rate_plus_burst() {
357        use std::time::Instant as StdInstant;
358        // 50 tokens/sec, burst 5: in a ~120ms window we expect at most
359        // 5 (burst) + 50 * 0.12 = ~11 emissions.
360        let rate = 50.0;
361        let burst = 5u32;
362        let src = Source::from_iter(0u32..1000);
363        let mut stream = token_bucket(src, rate, burst).into_boxed();
364
365        let start = StdInstant::now();
366        let window = Duration::from_millis(120);
367        let mut count = 0u64;
368        while StdInstant::now().duration_since(start) < window {
369            match tokio::time::timeout(Duration::from_millis(20), stream.next()).await {
370                Ok(Some(_)) => count += 1,
371                Ok(None) => break,
372                Err(_) => {}
373            }
374        }
375        let elapsed = StdInstant::now().duration_since(start).as_secs_f64();
376        let allowed = burst as f64 + rate * elapsed;
377        // Generous timing slack for scheduler jitter.
378        assert!(count as f64 <= allowed + 4.0, "emitted {count} in {elapsed:.3}s, allowed ~{allowed:.1}",);
379        // And the initial burst should at least have come out promptly.
380        assert!(count >= burst as u64, "expected at least the burst {burst}, got {count}");
381    }
382
383    #[tokio::test]
384    async fn token_bucket_keyed_limits_keys_independently() {
385        use std::time::Instant as StdInstant;
386        // Interleave two keys "a"/"b". Each key: 20/sec, burst 2.
387        // Within ~50ms a single key allows ~2 + 20*0.05 = 3, so two keys ~6.
388        let items: Vec<&'static str> = (0..200).map(|i| if i % 2 == 0 { "a" } else { "b" }).collect();
389        let src = Source::from_iter(items);
390        let limited = token_bucket_keyed(src, |s: &&str| *s, 20.0, 2);
391        let mut stream = limited.into_boxed();
392
393        let start = StdInstant::now();
394        let mut a = 0u64;
395        let mut b = 0u64;
396        while StdInstant::now().duration_since(start) < Duration::from_millis(60) {
397            match tokio::time::timeout(Duration::from_millis(15), stream.next()).await {
398                Ok(Some("a")) => a += 1,
399                Ok(Some("b")) => b += 1,
400                Ok(Some(_)) => {}
401                Ok(None) => break,
402                Err(_) => {}
403            }
404        }
405        let elapsed = StdInstant::now().duration_since(start).as_secs_f64();
406        let allowed = 2.0 + 20.0 * elapsed + 4.0;
407        // Each key independently limited — neither should blow past its own
408        // allowance, and both should make progress (independent buckets).
409        assert!(a as f64 <= allowed, "key a emitted {a}, allowed ~{allowed:.1}");
410        assert!(b as f64 <= allowed, "key b emitted {b}, allowed ~{allowed:.1}");
411        assert!(a >= 2 && b >= 2, "both keys should pass their burst: a={a} b={b}");
412    }
413
414    #[tokio::test]
415    async fn respect_retry_after_pauses_then_continues() {
416        use std::time::Instant as StdInstant;
417        let src: Source<Result<u32, RetryAfter>> =
418            Source::from_iter(vec![Ok(1u32), Err(RetryAfter { seconds: 1 }), Ok(2u32)]);
419        let start = StdInstant::now();
420        let out = Sink::collect(respect_retry_after(src)).await;
421        let elapsed = start.elapsed();
422
423        // Nothing dropped: all three elements forwarded in order.
424        assert_eq!(out, vec![Ok(1), Err(RetryAfter { seconds: 1 }), Ok(2)]);
425        // The 1s back-off after the Err delayed the trailing Ok.
426        assert!(elapsed >= Duration::from_millis(950), "expected ~1s pause, got {elapsed:?}");
427    }
428}
429
430#[cfg(test)]
431mod proptests {
432    use super::*;
433    use proptest::prelude::*;
434    use std::time::Instant as StdInstant;
435
436    proptest! {
437        #![proptest_config(ProptestConfig { cases: 12, ..ProptestConfig::default() })]
438
439        /// Over the elapsed emission time, the count emitted by `token_bucket`
440        /// never exceeds `burst + rate * elapsed` (plus timing slack).
441        #[test]
442        fn token_bucket_count_bounded_by_burst_plus_rate(
443            rate in 20.0f64..200.0,
444            burst in 1u32..8,
445            n in 20usize..60,
446        ) {
447            let rt = tokio::runtime::Builder::new_current_thread()
448                .enable_time()
449                .build()
450                .unwrap();
451            rt.block_on(async move {
452                let src = Source::from_iter(0..n as u32);
453                let mut stream = token_bucket(src, rate, burst).into_boxed();
454                let start = StdInstant::now();
455                let mut count = 0u64;
456                while stream.next().await.is_some() {
457                    count += 1;
458                    let elapsed = start.elapsed().as_secs_f64();
459                    let allowed = burst as f64 + rate * elapsed + 5.0;
460                    prop_assert!(
461                        count as f64 <= allowed,
462                        "emitted {count} after {elapsed:.4}s, allowed ~{allowed:.2} (rate={rate}, burst={burst})",
463                    );
464                }
465                Ok(())
466            })?;
467        }
468    }
469}