1use crate::{common::*, shared_stream::Shared, state_stream::StateStream};
2use futures::stream::Zip;
3
4pub type WithState<S, B> = Zip<S, StateStream<B>>;
6
7pub trait StreamExt
9where
10 Self: Stream,
11{
12 fn shared(self) -> Shared<Self>;
37
38 fn with_state<B>(self, init: B) -> WithState<Self, B>
46 where
47 Self: Sized;
48
49 fn wait_until<Fut>(self, fut: Fut) -> WaitUntil<Self, Fut>
85 where
86 Fut: Future<Output = bool>;
87
88 fn reduce<F, Fut>(self, f: F) -> Reduce<Self, F, Fut>;
94
95 fn batching<T, F, Fut>(self, f: F) -> Batching<Self, T, F, Fut>
129 where
130 Self: Sized,
131 F: 'static + Send + FnMut(Self) -> Fut,
132 Fut: 'static + Future<Output = Option<(T, Self)>> + Send,
133 T: 'static + Send;
134
135 fn stateful_batching<T, B, F, Fut>(self, init: B, f: F) -> StatefulBatching<Self, B, T, F, Fut>
173 where
174 Self: Sized + Stream,
175 F: FnMut(B, Self) -> Fut,
176 Fut: Future<Output = Option<(T, B, Self)>>;
177
178 fn stateful_then<T, B, F, Fut>(self, init: B, f: F) -> StatefulThen<Self, B, T, F, Fut>
180 where
181 Self: Stream,
182 F: FnMut(B, Self::Item) -> Fut,
183 Fut: Future<Output = Option<(B, T)>>;
184
185 fn stateful_map<T, B, F>(self, init: B, f: F) -> StatefulMap<Self, B, T, F>
187 where
188 Self: Stream,
189 F: FnMut(B, Self::Item) -> Option<(B, T)>;
190}
191
192impl<S> StreamExt for S
193where
194 S: Stream,
195{
196 fn shared(self) -> Shared<Self> {
197 Shared::new(self)
198 }
199
200 fn with_state<B>(self, init: B) -> WithState<Self, B>
201 where
202 Self: Sized,
203 {
204 self.zip(StateStream::new(init))
205 }
206
207 fn reduce<F, Fut>(self, f: F) -> Reduce<Self, F, Fut> {
208 Reduce {
209 fold: None,
210 f,
211 future: None,
212 stream: self,
213 }
214 }
215
216 fn wait_until<Fut>(self, fut: Fut) -> WaitUntil<Self, Fut>
217 where
218 Fut: Future<Output = bool>,
219 {
220 WaitUntil::new(self, fut)
221 }
222
223 fn batching<T, F, Fut>(self, f: F) -> Batching<Self, T, F, Fut> {
224 Batching {
225 f,
226 future: None,
227 stream: Some(self),
228 _phantom: PhantomData,
229 }
230 }
231
232 fn stateful_batching<T, B, F, Fut>(self, init: B, f: F) -> StatefulBatching<Self, B, T, F, Fut>
233 where
234 Self: Stream,
235 F: FnMut(B, Self) -> Fut,
236 Fut: Future<Output = Option<(T, B, Self)>>,
237 {
238 StatefulBatching {
239 state: Some((init, self)),
240 future: None,
241 f,
242 _phantom: PhantomData,
243 }
244 }
245
246 fn stateful_then<T, B, F, Fut>(self, init: B, f: F) -> StatefulThen<Self, B, T, F, Fut>
247 where
248 Self: Stream,
249 F: FnMut(B, Self::Item) -> Fut,
250 Fut: Future<Output = Option<(B, T)>>,
251 {
252 StatefulThen {
253 stream: self,
254 future: None,
255 state: Some(init),
256 f,
257 _phantom: PhantomData,
258 }
259 }
260
261 fn stateful_map<T, B, F>(self, init: B, f: F) -> StatefulMap<Self, B, T, F>
262 where
263 Self: Stream,
264 F: FnMut(B, Self::Item) -> Option<(B, T)>,
265 {
266 StatefulMap {
267 stream: self,
268 state: Some(init),
269 f,
270 _phantom: PhantomData,
271 }
272 }
273}
274
275pub use batching::*;
276mod batching {
277 use super::*;
278
279 #[pin_project]
281 pub struct Batching<St, T, F, Fut> {
282 pub(super) f: F,
283 #[pin]
284 pub(super) future: Option<Fut>,
285 pub(super) _phantom: PhantomData<T>,
286 pub(super) stream: Option<St>,
287 }
288
289 impl<St, T, F, Fut> Stream for Batching<St, T, F, Fut>
290 where
291 F: FnMut(St) -> Fut,
292 Fut: Future<Output = Option<(T, St)>>,
293 {
294 type Item = T;
295
296 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297 let mut this = self.project();
298
299 if let Some(stream) = this.stream.take() {
300 let new_future = (this.f)(stream);
301 this.future.set(Some(new_future));
302 }
303
304 Ready({
305 if let Some(mut future) = this.future.as_pin_mut() {
306 match ready!(future.poll_unpin(cx)) {
307 Some((item, stream)) => {
308 let new_future = (this.f)(stream);
309 future.set(new_future);
310 Some(item)
311 }
312 None => None,
313 }
314 } else {
315 None
316 }
317 })
318 }
319 }
320}
321
322pub use stateful_then::*;
323mod stateful_then {
324 use super::*;
325
326 #[pin_project]
328 pub struct StatefulThen<St, B, T, F, Fut>
329 where
330 St: ?Sized,
331 {
332 #[pin]
333 pub(super) future: Option<Fut>,
334 pub(super) state: Option<B>,
335 pub(super) f: F,
336 pub(super) _phantom: PhantomData<T>,
337 #[pin]
338 pub(super) stream: St,
339 }
340
341 impl<St, B, T, F, Fut> Stream for StatefulThen<St, B, T, F, Fut>
342 where
343 St: Stream,
344 F: FnMut(B, St::Item) -> Fut,
345 Fut: Future<Output = Option<(B, T)>>,
346 {
347 type Item = T;
348
349 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
350 let mut this = self.project();
351
352 Poll::Ready(loop {
353 if let Some(fut) = this.future.as_mut().as_pin_mut() {
354 let output = ready!(fut.poll(cx));
355 this.future.set(None);
356
357 if let Some((state, item)) = output {
358 *this.state = Some(state);
359 break Some(item);
360 } else {
361 break None;
362 }
363 } else if let Some(state) = this.state.take() {
364 match this.stream.as_mut().poll_next(cx) {
365 Ready(Some(item)) => {
366 this.future.set(Some((this.f)(state, item)));
367 }
368 Ready(None) => break None,
369 Pending => {
370 *this.state = Some(state);
371 return Pending;
372 }
373 }
374 } else {
375 break None;
376 }
377 })
378 }
379 }
380}
381
382pub use stateful_map::*;
383mod stateful_map {
384 use super::*;
385
386 #[pin_project]
388 pub struct StatefulMap<St, B, T, F>
389 where
390 St: ?Sized,
391 {
392 pub(super) state: Option<B>,
393 pub(super) f: F,
394 pub(super) _phantom: PhantomData<T>,
395 #[pin]
396 pub(super) stream: St,
397 }
398
399 impl<St, B, T, F> Stream for StatefulMap<St, B, T, F>
400 where
401 St: Stream,
402 F: FnMut(B, St::Item) -> Option<(B, T)>,
403 {
404 type Item = T;
405
406 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
407 let mut this = self.project();
408
409 Poll::Ready({
410 if let Some(state) = this.state.take() {
411 match this.stream.as_mut().poll_next(cx) {
412 Ready(Some(in_item)) => {
413 if let Some((state, out_item)) = (this.f)(state, in_item) {
414 *this.state = Some(state);
415 Some(out_item)
416 } else {
417 None
418 }
419 }
420 Ready(None) => None,
421 Pending => {
422 *this.state = Some(state);
423 return Pending;
424 }
425 }
426 } else {
427 None
428 }
429 })
430 }
431 }
432}
433
434pub use stateful_batching::*;
435mod stateful_batching {
436 use super::*;
437
438 #[pin_project]
440 pub struct StatefulBatching<St, B, T, F, Fut> {
441 pub(super) f: F,
442 pub(super) _phantom: PhantomData<T>,
443 #[pin]
444 pub(super) future: Option<Fut>,
445 pub(super) state: Option<(B, St)>,
446 }
447
448 impl<St, B, T, F, Fut> Stream for StatefulBatching<St, B, T, F, Fut>
449 where
450 St: Stream,
451 F: FnMut(B, St) -> Fut,
452 Fut: Future<Output = Option<(T, B, St)>>,
453 {
454 type Item = T;
455
456 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
457 let mut this = self.project();
458
459 Poll::Ready(loop {
460 if let Some(fut) = this.future.as_mut().as_pin_mut() {
461 let output = ready!(fut.poll(cx));
462 this.future.set(None);
463
464 if let Some((item, state, stream)) = output {
465 *this.state = Some((state, stream));
466 break Some(item);
467 } else {
468 break None;
469 }
470 } else if let Some((state, stream)) = this.state.take() {
471 this.future.set(Some((this.f)(state, stream)));
472 } else {
473 break None;
474 }
475 })
476 }
477 }
478}
479
480use reduce::*;
481mod reduce {
482 use super::*;
483
484 #[pin_project]
485 pub struct Reduce<St, F, Fut>
486 where
487 St: ?Sized + Stream,
488 {
489 pub(super) fold: Option<St::Item>,
490 pub(super) f: F,
491 #[pin]
492 pub(super) future: Option<Fut>,
493 #[pin]
494 pub(super) stream: St,
495 }
496
497 impl<St, F, Fut> Future for Reduce<St, F, Fut>
498 where
499 St: Stream,
500 F: FnMut(St::Item, St::Item) -> Fut,
501 Fut: Future<Output = St::Item>,
502 {
503 type Output = Option<St::Item>;
504
505 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
506 let mut this = self.project();
507
508 Ready(loop {
509 if let Some(mut future) = this.future.as_mut().as_pin_mut() {
510 let fold = ready!(future.poll_unpin(cx));
511 this.future.set(None);
512 *this.fold = Some(fold);
513 } else if let Some(item) = ready!(this.stream.poll_next_unpin(cx)) {
514 if let Some(fold) = this.fold.take() {
515 let future = (this.f)(fold, item);
516 this.future.set(Some(future));
517 } else {
518 *this.fold = Some(item);
519 }
520 } else {
521 break this.fold.take();
522 }
523 })
524 }
525 }
526}
527
528use wait_until::*;
529mod wait_until {
530 use super::*;
531
532 #[pin_project]
533 pub struct WaitUntil<St, Fut>
534 where
535 St: ?Sized + Stream,
536 Fut: Future<Output = bool>,
537 {
538 pub(super) is_fused: bool,
539 #[pin]
540 pub(super) future: Option<Fut>,
541 #[pin]
542 pub(super) stream: St,
543 }
544
545 impl<St, Fut> WaitUntil<St, Fut>
546 where
547 St: Stream,
548 Fut: Future<Output = bool>,
549 {
550 pub(super) fn new(stream: St, fut: Fut) -> Self {
551 Self {
552 stream,
553 future: Some(fut),
554 is_fused: false,
555 }
556 }
557 }
558
559 impl<St, Fut> Stream for WaitUntil<St, Fut>
560 where
561 St: Stream,
562 Fut: Future<Output = bool>,
563 {
564 type Item = St::Item;
565
566 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
567 let mut this = self.project();
568
569 Ready(loop {
570 if *this.is_fused {
571 break None;
572 } else if let Some(future) = this.future.as_mut().as_pin_mut() {
573 let ok = ready!(future.poll(cx));
574 this.future.set(None);
575
576 if !ok {
577 *this.is_fused = true;
578 break None;
579 }
580 } else {
581 break ready!(this.stream.poll_next(cx));
582 }
583 })
584 }
585
586 fn size_hint(&self) -> (usize, Option<usize>) {
587 if self.is_fused {
588 (0, Some(0))
590 } else {
591 let (lower, upper) = self.stream.size_hint();
592
593 if self.future.is_some() {
594 (0, upper)
596 } else {
597 (lower, upper)
599 }
600 }
601 }
602 }
603
604 impl<St, Fut> FusedStream for WaitUntil<St, Fut>
605 where
606 St: FusedStream,
607 Fut: Future<Output = bool>,
608 {
609 fn is_terminated(&self) -> bool {
610 self.is_fused || self.stream.is_terminated()
611 }
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::{rt, utils::async_test};
619 use std::time::Instant;
620
621 async_test! {
622 async fn stream_wait_until_future_test() {
623 let wait = Duration::from_millis(200);
624
625 {
626 let instant = Instant::now();
627 let vec: Vec<_> = stream::iter([3, 1, 4])
628 .wait_until(async move {
629 rt::sleep(wait).await;
630 true
631 })
632 .collect()
633 .await;
634
635 assert!(instant.elapsed() >= wait);
636 assert_eq!(vec, [3, 1, 4]);
637 }
638
639 {
640 let instant = Instant::now();
641 let vec: Vec<_> = stream::iter([3, 1, 4])
642 .wait_until(async move {
643 rt::sleep(wait).await;
644 false
645 })
646 .collect()
647 .await;
648
649 assert!(instant.elapsed() >= wait);
650 assert_eq!(vec, []);
651 }
652 }
653
654
655 async fn reduce_test() {
656 {
657 let output = stream::iter(1..=10)
658 .reduce(|lhs, rhs| async move { lhs + rhs })
659 .await;
660 assert_eq!(output, Some(55));
661 }
662
663 {
664 let output = future::ready(1)
665 .into_stream()
666 .reduce(|lhs, rhs| async move { lhs + rhs })
667 .await;
668 assert_eq!(output, Some(1));
669 }
670
671 {
672 let output = stream::empty::<usize>()
673 .reduce(|lhs, rhs| async move { lhs + rhs })
674 .await;
675 assert_eq!(output, None);
676 }
677 }
678
679
680 async fn stateful_then_test() {
681 let vec: Vec<_> = stream::repeat(())
682 .stateful_then(0, |count, ()| async move {
683 (count < 10).then(|| (count + 1, count))
684 })
685 .collect()
686 .await;
687
688 assert_eq!(&*vec, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
689 }
690
691
692 async fn stateful_map_test() {
693 let vec: Vec<_> = stream::repeat(())
694 .stateful_map(0, |count, ()| (count < 10).then(|| (count + 1, count)))
695 .collect()
696 .await;
697
698 assert_eq!(&*vec, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
699 }
700
701
702 async fn stateful_batching_test() {
703 let vec: Vec<_> = stream::iter([1i32, 1, 1, -1, -1, 1])
704 .stateful_batching(None, |mut sum: Option<i32>, mut stream| async move {
705 while let Some(val) = stream.next().await {
706 match &mut sum {
707 Some(sum) => {
708 if sum.signum() == val.signum() {
709 *sum += val;
710 } else {
711 return Some((*sum, Some(val), stream));
712 }
713 }
714 sum => *sum = Some(val),
715 }
716 }
717
718 match sum {
719 Some(sum) => Some((sum, None, stream)),
720 None => None,
721 }
722 })
723 .collect()
724 .await;
725
726 assert_eq!(vec, [3, -2, 1]);
727 }
728
729
730 async fn batching_test() {
731 let sums: Vec<_> = stream::iter(0..10)
732 .batching(|mut stream| async move {
733 let mut sum = 0;
734
735 while let Some(val) = stream.next().await {
736 sum += val;
737
738 if sum >= 10 {
739 return Some((sum, stream));
740 }
741 }
742
743 None
744 })
745 .collect()
746 .await;
747
748 assert_eq!(sums, vec![10, 11, 15]);
749 }
750 }
751}