1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use core::cmp::max;
use core::pin::Pin;
use futures::ready;
use futures::stream::Stream;
use futures::task::{Context, Poll};
use pin_project::pin_project;

#[derive(Debug)]
enum State {
    Initial,
    InSt1,
    InSt2,
}

#[pin_project]
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct OrStream<St1, St2> {
    #[pin]
    stream1: St1,
    #[pin]
    stream2: St2,
    state: State,
}

use State::{InSt1, InSt2, Initial};

impl<St1, St2> OrStream<St1, St2>
where
    St1: Stream,
    St2: Stream<Item = St1::Item>,
{
    pub fn new(stream1: St1, stream2: St2) -> Self {
        Self {
            stream1,
            stream2,
            state: Initial,
        }
    }
}

impl<St1, St2> Stream for OrStream<St1, St2>
where
    St1: Stream,
    St2: Stream<Item = St1::Item>,
{
    type Item = St1::Item;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let this = self.project();

        match this.state {
            Initial => match ready!(this.stream1.poll_next(cx)) {
                item @ Some(_) => {
                    *this.state = InSt1;

                    Poll::Ready(item)
                }
                None => {
                    *this.state = InSt2;

                    this.stream2.poll_next(cx)
                }
            },
            InSt1 => this.stream1.poll_next(cx),
            InSt2 => this.stream2.poll_next(cx),
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        match self.state {
            Initial => match self.stream1.size_hint() {
                (_, Some(0)) => self.stream2.size_hint(),
                (0, i1_high) => {
                    let (i2_low, i2_high) = self.stream2.size_hint();
                    let low = if i2_low > 0 { 1 } else { 0 };

                    let high = i1_high.and_then(|h1| i2_high.map(|h2| max(h1, h2)));
                    (low, high)
                }
                i1_hint => i1_hint,
            },
            InSt1 => self.stream1.size_hint(),
            InSt2 => self.stream2.size_hint(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::OrStream;
    use futures::executor::block_on;
    use futures::stream::{empty, once, StreamExt};

    #[test]
    fn basic() {
        block_on(async move {
            let s1 = once(async { 1 });
            let s2 = once(async { 3 });
            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
            assert_eq!(vec![1], v);

            let s1 = empty();
            let s2 = once(async { 3 });
            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
            assert_eq!(vec![3], v);

            let s1 = once(async { 3 });
            let s2 = empty();
            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
            assert_eq!(vec![3], v);
        });
    }
}