1use std::collections::HashMap;
11use std::time::Duration;
12
13use futures::stream::{BoxStream, StreamExt};
14use tokio::sync::mpsc;
15use tokio::time::Instant;
16
17use crate::source::Source;
18
19pub fn conflate<T, U, S, F>(src: Source<T>, mut seed: S, mut fold: F) -> Source<U>
28where
29 T: Send + 'static,
30 U: Send + 'static,
31 S: FnMut(T) -> U + Send + 'static,
32 F: FnMut(U, T) -> U + Send + 'static,
33{
34 let (tx, rx) = mpsc::unbounded_channel::<U>();
35 let mut inner = src.into_boxed();
36 tokio::spawn(async move {
37 let mut acc: Option<U> = None;
38 loop {
39 match inner.next().await {
40 Some(item) => {
41 acc = Some(match acc.take() {
42 None => seed(item),
43 Some(prev) => fold(prev, item),
44 });
45 if let Some(a) = acc.take() {
49 if tx.send(a).is_err() {
50 return;
51 }
52 }
53 }
54 None => {
55 if let Some(a) = acc.take() {
56 let _ = tx.send(a);
57 }
58 return;
59 }
60 }
61 }
62 });
63 Source::from_receiver(rx)
64}
65
66pub fn expand<T, F, I>(src: Source<T>, mut extrapolate: F) -> Source<T>
78where
79 T: Clone + Send + 'static,
80 F: FnMut(&T) -> I + Send + 'static,
81 I: Iterator<Item = T> + Send + 'static,
82{
83 let (tx, rx) = mpsc::unbounded_channel::<T>();
84 let mut inner = src.into_boxed();
85 tokio::spawn(async move {
86 let mut last: Option<T> = None;
87 loop {
88 match inner.next().await {
89 Some(item) => {
90 if tx.send(item.clone()).is_err() {
91 return;
92 }
93 last = Some(item);
94 }
95 None => {
96 if let Some(l) = last {
99 for synth in extrapolate(&l) {
100 if tx.send(synth).is_err() {
101 return;
102 }
103 }
104 }
105 return;
106 }
107 }
108 }
109 });
110 Source::from_receiver(rx)
111}
112
113struct TokenBucket {
120 rate_per_sec: f64,
122 capacity: f64,
124 tokens: f64,
126 last: Instant,
128}
129
130impl TokenBucket {
131 fn new(rate_per_sec: f64, burst: u32) -> Self {
132 let capacity = burst as f64;
133 TokenBucket {
134 rate_per_sec: rate_per_sec.max(0.0),
135 capacity,
136 tokens: capacity,
138 last: Instant::now(),
139 }
140 }
141
142 fn refill(&mut self, now: Instant) {
144 let elapsed = now.saturating_duration_since(self.last).as_secs_f64();
145 if elapsed > 0.0 {
146 self.tokens = (self.tokens + elapsed * self.rate_per_sec).min(self.capacity);
147 self.last = now;
148 }
149 }
150
151 fn delay_until_token(&mut self, now: Instant) -> Option<Duration> {
154 self.refill(now);
155 if self.tokens >= 1.0 {
156 None
157 } else if self.rate_per_sec <= 0.0 {
158 Some(Duration::from_secs(u64::MAX / 2))
161 } else {
162 let needed = 1.0 - self.tokens;
163 Some(Duration::from_secs_f64(needed / self.rate_per_sec))
164 }
165 }
166
167 fn consume(&mut self) {
169 self.tokens -= 1.0;
170 }
171}
172
173async fn acquire(bucket: &mut TokenBucket) {
175 loop {
176 let now = Instant::now();
177 match bucket.delay_until_token(now) {
178 None => {
179 bucket.consume();
180 return;
181 }
182 Some(d) => {
183 tokio::time::sleep(d).await;
184 }
187 }
188 }
189}
190
191pub fn token_bucket<T>(src: Source<T>, rate_per_sec: f64, burst: u32) -> Source<T>
207where
208 T: Send + 'static,
209{
210 struct State<T> {
211 inner: BoxStream<'static, T>,
212 bucket: TokenBucket,
213 }
214 let state = State { inner: src.into_boxed(), bucket: TokenBucket::new(rate_per_sec, burst) };
215 Source::unfold(state, |mut st| async move {
216 match st.inner.next().await {
217 None => None,
218 Some(item) => {
219 acquire(&mut st.bucket).await;
220 Some((item, st))
221 }
222 }
223 })
224}
225
226pub fn token_bucket_keyed<T, K, F>(src: Source<T>, key: F, rate_per_sec: f64, burst: u32) -> Source<T>
234where
235 T: Send + 'static,
236 K: Eq + std::hash::Hash + Send + 'static,
237 F: Fn(&T) -> K + Send + 'static,
238{
239 struct State<T, K, F> {
240 inner: BoxStream<'static, T>,
241 buckets: HashMap<K, TokenBucket>,
242 key: F,
243 rate_per_sec: f64,
244 burst: u32,
245 }
246 let state = State { inner: src.into_boxed(), buckets: HashMap::new(), key, rate_per_sec, burst };
247 Source::unfold(state, |mut st| async move {
248 match st.inner.next().await {
249 None => None,
250 Some(item) => {
251 let k = (st.key)(&item);
252 let rate = st.rate_per_sec;
253 let burst = st.burst;
254 let bucket = st.buckets.entry(k).or_insert_with(|| TokenBucket::new(rate, burst));
255 acquire(bucket).await;
256 Some((item, st))
257 }
258 }
259 })
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct RetryAfter {
270 pub seconds: u64,
272}
273
274pub fn respect_retry_after<T>(src: Source<Result<T, RetryAfter>>) -> Source<Result<T, RetryAfter>>
286where
287 T: Send + 'static,
288{
289 struct State<T> {
290 inner: BoxStream<'static, Result<T, RetryAfter>>,
291 pending_backoff: Option<Duration>,
293 }
294 let state = State { inner: src.into_boxed(), pending_backoff: None };
295 Source::unfold(state, |mut st| async move {
296 if let Some(d) = st.pending_backoff.take() {
299 tokio::time::sleep(d).await;
300 }
301 match st.inner.next().await {
302 None => None,
303 Some(item) => {
304 if let Err(ra) = &item {
305 if ra.seconds > 0 {
306 st.pending_backoff = Some(Duration::from_secs(ra.seconds));
307 }
308 }
309 Some((item, st))
310 }
311 }
312 })
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::sink::Sink;
319
320 #[tokio::test]
321 async fn conflate_passes_through_when_downstream_keeps_up() {
322 let s = Source::from_iter(vec![1u32, 2, 3]);
323 let out = Sink::collect(conflate(s, |t| t, |a, b| a + b)).await;
324 assert_eq!(out, vec![1, 2, 3]);
327 }
328
329 #[tokio::test]
330 async fn conflate_seed_initializes_accumulator() {
331 let s = Source::from_iter(vec![10u32]);
332 let out = Sink::collect(conflate(s, |t| t * 2, |a, b| a + b)).await;
333 assert_eq!(out, vec![20]);
334 }
335
336 #[tokio::test]
337 async fn expand_emits_extrapolated_values_after_upstream_close() {
338 let s = Source::from_iter(vec![5i32]);
339 let out = Sink::collect(expand(s, |last| {
340 let l = *last;
341 (0..3).map(move |i| l + i + 1)
342 }))
343 .await;
344 assert_eq!(out, vec![5, 6, 7, 8]);
346 }
347
348 #[tokio::test]
349 async fn expand_no_synthetics_when_iterator_empty() {
350 let s = Source::from_iter(vec![1i32, 2, 3]);
351 let out = Sink::collect(expand(s, |_last| std::iter::empty::<i32>())).await;
352 assert_eq!(out, vec![1, 2, 3]);
353 }
354
355 #[tokio::test]
356 async fn token_bucket_respects_rate_plus_burst() {
357 use std::time::Instant as StdInstant;
358 let rate = 50.0;
361 let burst = 5u32;
362 let src = Source::from_iter(0u32..1000);
363 let mut stream = token_bucket(src, rate, burst).into_boxed();
364
365 let start = StdInstant::now();
366 let window = Duration::from_millis(120);
367 let mut count = 0u64;
368 while StdInstant::now().duration_since(start) < window {
369 match tokio::time::timeout(Duration::from_millis(20), stream.next()).await {
370 Ok(Some(_)) => count += 1,
371 Ok(None) => break,
372 Err(_) => {}
373 }
374 }
375 let elapsed = StdInstant::now().duration_since(start).as_secs_f64();
376 let allowed = burst as f64 + rate * elapsed;
377 assert!(count as f64 <= allowed + 4.0, "emitted {count} in {elapsed:.3}s, allowed ~{allowed:.1}",);
379 assert!(count >= burst as u64, "expected at least the burst {burst}, got {count}");
381 }
382
383 #[tokio::test]
384 async fn token_bucket_keyed_limits_keys_independently() {
385 use std::time::Instant as StdInstant;
386 let items: Vec<&'static str> = (0..200).map(|i| if i % 2 == 0 { "a" } else { "b" }).collect();
389 let src = Source::from_iter(items);
390 let limited = token_bucket_keyed(src, |s: &&str| *s, 20.0, 2);
391 let mut stream = limited.into_boxed();
392
393 let start = StdInstant::now();
394 let mut a = 0u64;
395 let mut b = 0u64;
396 while StdInstant::now().duration_since(start) < Duration::from_millis(60) {
397 match tokio::time::timeout(Duration::from_millis(15), stream.next()).await {
398 Ok(Some("a")) => a += 1,
399 Ok(Some("b")) => b += 1,
400 Ok(Some(_)) => {}
401 Ok(None) => break,
402 Err(_) => {}
403 }
404 }
405 let elapsed = StdInstant::now().duration_since(start).as_secs_f64();
406 let allowed = 2.0 + 20.0 * elapsed + 4.0;
407 assert!(a as f64 <= allowed, "key a emitted {a}, allowed ~{allowed:.1}");
410 assert!(b as f64 <= allowed, "key b emitted {b}, allowed ~{allowed:.1}");
411 assert!(a >= 2 && b >= 2, "both keys should pass their burst: a={a} b={b}");
412 }
413
414 #[tokio::test]
415 async fn respect_retry_after_pauses_then_continues() {
416 use std::time::Instant as StdInstant;
417 let src: Source<Result<u32, RetryAfter>> =
418 Source::from_iter(vec![Ok(1u32), Err(RetryAfter { seconds: 1 }), Ok(2u32)]);
419 let start = StdInstant::now();
420 let out = Sink::collect(respect_retry_after(src)).await;
421 let elapsed = start.elapsed();
422
423 assert_eq!(out, vec![Ok(1), Err(RetryAfter { seconds: 1 }), Ok(2)]);
425 assert!(elapsed >= Duration::from_millis(950), "expected ~1s pause, got {elapsed:?}");
427 }
428}
429
430#[cfg(test)]
431mod proptests {
432 use super::*;
433 use proptest::prelude::*;
434 use std::time::Instant as StdInstant;
435
436 proptest! {
437 #![proptest_config(ProptestConfig { cases: 12, ..ProptestConfig::default() })]
438
439 #[test]
442 fn token_bucket_count_bounded_by_burst_plus_rate(
443 rate in 20.0f64..200.0,
444 burst in 1u32..8,
445 n in 20usize..60,
446 ) {
447 let rt = tokio::runtime::Builder::new_current_thread()
448 .enable_time()
449 .build()
450 .unwrap();
451 rt.block_on(async move {
452 let src = Source::from_iter(0..n as u32);
453 let mut stream = token_bucket(src, rate, burst).into_boxed();
454 let start = StdInstant::now();
455 let mut count = 0u64;
456 while stream.next().await.is_some() {
457 count += 1;
458 let elapsed = start.elapsed().as_secs_f64();
459 let allowed = burst as f64 + rate * elapsed + 5.0;
460 prop_assert!(
461 count as f64 <= allowed,
462 "emitted {count} after {elapsed:.4}s, allowed ~{allowed:.2} (rate={rate}, burst={burst})",
463 );
464 }
465 Ok(())
466 })?;
467 }
468 }
469}