1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
use core::cmp::max; use core::pin::Pin; use futures::ready; use futures::stream::Stream; use futures::task::{Context, Poll}; use pin_project::pin_project; #[derive(Debug)] enum State { Initial, InSt1, InSt2, } #[pin_project] #[derive(Debug)] #[must_use = "streams do nothing unless polled"] pub struct OrStream<St1, St2> { #[pin] stream1: St1, #[pin] stream2: St2, state: State, } use State::{InSt1, InSt2, Initial}; impl<St1, St2> OrStream<St1, St2> where St1: Stream, St2: Stream<Item = St1::Item>, { pub fn new(stream1: St1, stream2: St2) -> Self { Self { stream1, stream2, state: Initial, } } } impl<St1, St2> Stream for OrStream<St1, St2> where St1: Stream, St2: Stream<Item = St1::Item>, { type Item = St1::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { let this = self.project(); match this.state { Initial => match ready!(this.stream1.poll_next(cx)) { item @ Some(_) => { *this.state = InSt1; Poll::Ready(item) } None => { *this.state = InSt2; this.stream2.poll_next(cx) } }, InSt1 => this.stream1.poll_next(cx), InSt2 => this.stream2.poll_next(cx), } } fn size_hint(&self) -> (usize, Option<usize>) { match self.state { Initial => match self.stream1.size_hint() { (_, Some(0)) => self.stream2.size_hint(), (0, i1_high) => { let (i2_low, i2_high) = self.stream2.size_hint(); let low = if i2_low > 0 { 1 } else { 0 }; let high = i1_high.and_then(|h1| i2_high.map(|h2| max(h1, h2))); (low, high) } i1_hint => i1_hint, }, InSt1 => self.stream1.size_hint(), InSt2 => self.stream2.size_hint(), } } } #[cfg(test)] mod tests { use super::OrStream; use futures::executor::block_on; use futures::stream::{empty, once, StreamExt}; #[test] fn basic() { block_on(async move { let s1 = once(async { 1 }); let s2 = once(async { 3 }); let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await; assert_eq!(vec![1], v); let s1 = empty(); let s2 = once(async { 3 }); let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await; assert_eq!(vec![3], v); let s1 = once(async { 3 }); let s2 = empty(); let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await; assert_eq!(vec![3], v); }); } }