1use std::time::Duration;
7
8use futures::stream::{self, BoxStream, StreamExt};
9
10use crate::source::Source;
11
12pub fn grouped_within<T: Send + 'static>(src: Source<T>, n: usize, duration: Duration) -> Source<Vec<T>> {
16 assert!(n >= 1, "grouped_within: n must be >= 1");
17
18 struct State<T: Send + 'static> {
19 inner: BoxStream<'static, T>,
20 buf: Vec<T>,
21 deadline: Option<tokio::time::Instant>,
22 n: usize,
23 duration: Duration,
24 upstream_done: bool,
25 }
26
27 let state =
28 State { inner: src.into_boxed(), buf: Vec::new(), deadline: None, n, duration, upstream_done: false };
29
30 let stream = stream::unfold(state, |mut s| async move {
31 loop {
32 if s.upstream_done {
33 if s.buf.is_empty() {
34 return None;
35 }
36 let chunk = std::mem::take(&mut s.buf);
37 return Some((chunk, s));
38 }
39 let next_item = match s.deadline {
41 Some(d) => tokio::select! {
42 biased;
43 _ = tokio::time::sleep_until(d) => DeadlineOrItem::Deadline,
44 item = s.inner.next() => DeadlineOrItem::Item(item),
45 },
46 None => DeadlineOrItem::Item(s.inner.next().await),
47 };
48 match next_item {
49 DeadlineOrItem::Deadline => {
50 if !s.buf.is_empty() {
51 let chunk = std::mem::take(&mut s.buf);
52 s.deadline = None;
53 return Some((chunk, s));
54 }
55 s.deadline = None;
56 }
57 DeadlineOrItem::Item(None) => {
58 s.upstream_done = true;
59 if !s.buf.is_empty() {
60 let chunk = std::mem::take(&mut s.buf);
61 return Some((chunk, s));
62 }
63 return None;
64 }
65 DeadlineOrItem::Item(Some(item)) => {
66 if s.buf.is_empty() {
67 s.deadline = Some(tokio::time::Instant::now() + s.duration);
68 }
69 s.buf.push(item);
70 if s.buf.len() >= s.n {
71 let chunk = std::mem::take(&mut s.buf);
72 s.deadline = None;
73 return Some((chunk, s));
74 }
75 }
76 }
77 }
78 });
79
80 Source { inner: stream.boxed() }
81}
82
83enum DeadlineOrItem<T> {
84 Deadline,
85 Item(Option<T>),
86}
87
88pub fn keep_alive<T, F>(src: Source<T>, idle: Duration, mut gen: F) -> Source<T>
93where
94 T: Send + 'static,
95 F: FnMut() -> T + Send + 'static,
96{
97 let inner = src.into_boxed();
98 let stream = stream::unfold(inner, move |mut inner| {
99 let kick = gen();
100 async move {
101 match tokio::time::timeout(idle, inner.next()).await {
102 Ok(Some(item)) => Some((item, inner)),
103 Ok(None) => None,
104 Err(_) => Some((kick, inner)),
105 }
106 }
107 });
108 Source { inner: stream.boxed() }
109}
110
111pub fn initial_delay<T: Send + 'static>(src: Source<T>, delay: Duration) -> Source<T> {
115 let inner = src.into_boxed();
116 let stream = stream::unfold((inner, Some(delay)), |(mut inner, pending_delay)| async move {
117 if let Some(d) = pending_delay {
118 tokio::time::sleep(d).await;
119 }
120 let next = inner.next().await?;
121 Some((next, (inner, None)))
122 });
123 Source { inner: stream.boxed() }
124}
125
126pub fn idle_timeout<T: Send + 'static>(src: Source<T>, idle: Duration) -> Source<T> {
130 let inner = src.into_boxed();
131 let stream = stream::unfold(inner, move |mut inner| async move {
132 match tokio::time::timeout(idle, inner.next()).await {
133 Ok(Some(item)) => Some((item, inner)),
134 Ok(None) => None,
135 Err(_) => None, }
137 });
138 Source { inner: stream.boxed() }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::sink::Sink;
145
146 #[tokio::test]
147 async fn grouped_within_packs_full_chunks() {
148 let s = Source::from_iter(vec![1, 2, 3, 4, 5]);
149 let out = Sink::collect(grouped_within(s, 2, Duration::from_secs(60))).await;
150 assert_eq!(out, vec![vec![1, 2], vec![3, 4], vec![5]]);
151 }
152
153 #[tokio::test]
154 async fn grouped_within_flushes_on_timeout() {
155 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<i32>();
156 tokio::spawn(async move {
157 tx.send(1).unwrap();
158 tokio::time::sleep(Duration::from_millis(60)).await;
159 tx.send(2).unwrap();
160 });
161 let s = Source::from_receiver(rx);
162 let out = Sink::collect(grouped_within(s, 10, Duration::from_millis(20))).await;
163 assert!(out.len() >= 2);
164 assert_eq!(out[0], vec![1]);
165 assert!(out.iter().any(|c| c.contains(&2)));
167 }
168
169 #[tokio::test]
170 async fn keep_alive_injects_when_idle() {
171 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<i32>();
172 tokio::spawn(async move {
173 tx.send(1).unwrap();
174 tokio::time::sleep(Duration::from_millis(40)).await;
175 tx.send(2).unwrap();
176 });
178 let s = Source::from_receiver(rx);
179 let out = Sink::collect(keep_alive(s, Duration::from_millis(15), || 99)).await;
180 assert_eq!(out[0], 1);
182 assert!(out.contains(&99));
183 assert!(out.contains(&2));
184 }
185
186 #[tokio::test]
187 async fn initial_delay_blocks_first_element() {
188 let s = Source::from_iter(vec![1, 2, 3]);
189 let start = std::time::Instant::now();
190 let out = Sink::collect(initial_delay(s, Duration::from_millis(40))).await;
191 assert!(start.elapsed() >= Duration::from_millis(35), "initial_delay did not delay");
192 assert_eq!(out, vec![1, 2, 3]);
193 }
194
195 #[tokio::test]
196 async fn idle_timeout_terminates_when_silent() {
197 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<i32>();
198 tokio::spawn(async move {
199 tx.send(1).unwrap();
200 tx.send(2).unwrap();
201 tokio::time::sleep(Duration::from_millis(50)).await;
202 let _ = tx.send(3); });
204 let s = Source::from_receiver(rx);
205 let out = Sink::collect(idle_timeout(s, Duration::from_millis(20))).await;
206 assert_eq!(out, vec![1, 2]);
207 }
208}