1use futures::{Stream, StreamExt};
7use pin_project_lite::pin_project;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::{Duration, Instant};
11
12pin_project! {
13 pub struct DebouncedStream<S>
18 where
19 S: Stream,
20 {
21 #[pin]
22 inner: S,
23 debounce_interval: Duration,
24 buffer: Vec<S::Item>,
25 last_emit: Option<Instant>,
26 finished: bool,
27 }
28}
29
30impl<S> DebouncedStream<S>
31where
32 S: Stream,
33{
34 pub fn new(inner: S, debounce: Duration) -> Self {
36 Self {
37 inner,
38 debounce_interval: debounce,
39 buffer: Vec::new(),
40 last_emit: None,
41 finished: false,
42 }
43 }
44
45 pub fn interval(&self) -> Duration {
47 self.debounce_interval
48 }
49
50 pub fn buffer_len(&self) -> usize {
52 self.buffer.len()
53 }
54}
55
56impl<S> Stream for DebouncedStream<S>
57where
58 S: Stream + Unpin,
59 S::Item: Clone,
60{
61 type Item = Vec<S::Item>;
62
63 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64 let mut this = self.project();
65
66 if *this.finished && this.buffer.is_empty() {
67 return Poll::Ready(None);
68 }
69
70 loop {
72 match this.inner.poll_next_unpin(cx) {
73 Poll::Ready(Some(item)) => {
74 this.buffer.push(item);
75
76 let should_emit = match this.last_emit {
78 Some(last) => last.elapsed() >= *this.debounce_interval,
79 None => false,
80 };
81
82 if should_emit && !this.buffer.is_empty() {
83 *this.last_emit = Some(Instant::now());
84 let batch = std::mem::take(this.buffer);
85 return Poll::Ready(Some(batch));
86 }
87 }
88 Poll::Ready(None) => {
89 *this.finished = true;
90 if !this.buffer.is_empty() {
92 let batch = std::mem::take(this.buffer);
93 return Poll::Ready(Some(batch));
94 }
95 return Poll::Ready(None);
96 }
97 Poll::Pending => {
98 let should_emit = match this.last_emit {
100 Some(last) => last.elapsed() >= *this.debounce_interval,
101 None => !this.buffer.is_empty(),
102 };
103
104 if should_emit && !this.buffer.is_empty() {
105 *this.last_emit = Some(Instant::now());
106 let batch = std::mem::take(this.buffer);
107 return Poll::Ready(Some(batch));
108 }
109
110 return Poll::Pending;
111 }
112 }
113 }
114 }
115}
116
117pin_project! {
118 pub struct ThrottledStream<S>
120 where
121 S: Stream,
122 {
123 #[pin]
124 inner: S,
125 min_interval: Duration,
126 last_emit: Option<Instant>,
127 pending_item: Option<S::Item>,
128 }
129}
130
131impl<S> ThrottledStream<S>
132where
133 S: Stream,
134{
135 pub fn new(inner: S, min_interval: Duration) -> Self {
137 Self {
138 inner,
139 min_interval,
140 last_emit: None,
141 pending_item: None,
142 }
143 }
144}
145
146impl<S> Stream for ThrottledStream<S>
147where
148 S: Stream + Unpin,
149{
150 type Item = S::Item;
151
152 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153 let mut this = self.project();
154
155 if let Some(item) = this.pending_item.take() {
157 let can_emit = match this.last_emit {
158 Some(last) => last.elapsed() >= *this.min_interval,
159 None => true,
160 };
161
162 if can_emit {
163 *this.last_emit = Some(Instant::now());
164 return Poll::Ready(Some(item));
165 } else {
166 *this.pending_item = Some(item);
167 cx.waker().wake_by_ref();
168 return Poll::Pending;
169 }
170 }
171
172 match this.inner.poll_next_unpin(cx) {
174 Poll::Ready(Some(item)) => {
175 let can_emit = match this.last_emit {
176 Some(last) => last.elapsed() >= *this.min_interval,
177 None => true,
178 };
179
180 if can_emit {
181 *this.last_emit = Some(Instant::now());
182 Poll::Ready(Some(item))
183 } else {
184 *this.pending_item = Some(item);
185 cx.waker().wake_by_ref();
186 Poll::Pending
187 }
188 }
189 Poll::Ready(None) => Poll::Ready(None),
190 Poll::Pending => Poll::Pending,
191 }
192 }
193}
194
195pin_project! {
196 pub struct CoalescedTextStream<S> {
198 #[pin]
199 inner: S,
200 buffer: String,
201 min_chunk_size: usize,
202 max_chunk_size: usize,
203 finished: bool,
204 }
205}
206
207impl<S> CoalescedTextStream<S>
208where
209 S: Stream<Item = String>,
210{
211 pub fn new(inner: S, min_chunk_size: usize, max_chunk_size: usize) -> Self {
213 Self {
214 inner,
215 buffer: String::new(),
216 min_chunk_size,
217 max_chunk_size,
218 finished: false,
219 }
220 }
221}
222
223impl<S> Stream for CoalescedTextStream<S>
224where
225 S: Stream<Item = String> + Unpin,
226{
227 type Item = String;
228
229 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
230 let mut this = self.project();
231
232 if *this.finished && this.buffer.is_empty() {
233 return Poll::Ready(None);
234 }
235
236 loop {
237 if this.buffer.len() >= *this.max_chunk_size {
239 let chunk = std::mem::take(this.buffer);
240 return Poll::Ready(Some(chunk));
241 }
242
243 match this.inner.poll_next_unpin(cx) {
244 Poll::Ready(Some(text)) => {
245 this.buffer.push_str(&text);
246
247 if this.buffer.len() >= *this.max_chunk_size {
249 let chunk = std::mem::take(this.buffer);
250 return Poll::Ready(Some(chunk));
251 }
252 }
253 Poll::Ready(None) => {
254 *this.finished = true;
255 if !this.buffer.is_empty() {
256 let chunk = std::mem::take(this.buffer);
257 return Poll::Ready(Some(chunk));
258 }
259 return Poll::Ready(None);
260 }
261 Poll::Pending => {
262 if this.buffer.len() >= *this.min_chunk_size {
264 let chunk = std::mem::take(this.buffer);
265 return Poll::Ready(Some(chunk));
266 }
267 return Poll::Pending;
268 }
269 }
270 }
271 }
272}
273
274pub trait StreamDebounceExt: Stream {
276 fn debounce(self, duration: Duration) -> DebouncedStream<Self>
278 where
279 Self: Sized,
280 {
281 DebouncedStream::new(self, duration)
282 }
283
284 fn throttle(self, min_interval: Duration) -> ThrottledStream<Self>
286 where
287 Self: Sized,
288 {
289 ThrottledStream::new(self, min_interval)
290 }
291}
292
293impl<S: Stream> StreamDebounceExt for S {}
294
295pub trait TextStreamExt: Stream<Item = String> {
297 fn coalesce(self, min_size: usize, max_size: usize) -> CoalescedTextStream<Self>
299 where
300 Self: Sized,
301 {
302 CoalescedTextStream::new(self, min_size, max_size)
303 }
304}
305
306impl<S: Stream<Item = String>> TextStreamExt for S {}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use futures::stream;
312 use futures::StreamExt;
313
314 #[tokio::test]
315 async fn test_debounced_stream() {
316 let items = vec![1, 2, 3, 4, 5];
317 let inner = stream::iter(items);
318 let debounced = DebouncedStream::new(inner, Duration::from_millis(10));
319
320 let batches: Vec<Vec<i32>> = debounced.collect().await;
321
322 let total: i32 = batches.iter().flat_map(|b| b.iter()).sum();
324 assert_eq!(total, 15);
325 }
326
327 #[tokio::test]
328 async fn test_throttled_stream() {
329 let items = vec![1, 2, 3];
330 let inner = stream::iter(items);
331 let throttled = ThrottledStream::new(inner, Duration::from_millis(1));
332
333 let results: Vec<i32> = throttled.collect().await;
334 assert_eq!(results, vec![1, 2, 3]);
335 }
336
337 #[tokio::test]
338 async fn test_coalesced_text_stream() {
339 let items = vec![
340 "a".to_string(),
341 "b".to_string(),
342 "c".to_string(),
343 "d".to_string(),
344 ];
345 let inner = stream::iter(items);
346 let coalesced = CoalescedTextStream::new(inner, 2, 10);
347
348 let results: Vec<String> = coalesced.collect().await;
349
350 let total_len: usize = results.iter().map(|s| s.len()).sum();
352 assert_eq!(total_len, 4);
353 }
354
355 #[tokio::test]
356 async fn test_extension_traits() {
357 let items = vec![1, 2, 3];
358 let inner = stream::iter(items);
359
360 let results: Vec<Vec<i32>> = inner.debounce(Duration::from_millis(1)).collect().await;
361 assert!(!results.is_empty());
362 }
363
364 #[tokio::test]
365 async fn test_text_extension() {
366 let items = vec!["hello".to_string(), " ".to_string(), "world".to_string()];
367 let inner = stream::iter(items);
368
369 let results: Vec<String> = inner.coalesce(5, 100).collect().await;
370 assert!(!results.is_empty());
371 }
372}