gluesql_utils/
or_stream.rs

1use {
2    core::{cmp::max, pin::Pin},
3    futures::{
4        ready,
5        stream::Stream,
6        task::{Context, Poll},
7    },
8    pin_project::pin_project,
9};
10
11#[derive(Debug)]
12enum State {
13    Initial,
14    St1,
15    St2,
16}
17
18#[pin_project]
19#[derive(Debug)]
20#[must_use = "streams do nothing unless polled"]
21pub struct OrStream<St1, St2> {
22    #[pin]
23    stream1: St1,
24    #[pin]
25    stream2: St2,
26    state: State,
27}
28
29use State::{Initial, St1, St2};
30
31impl<St1, St2> OrStream<St1, St2>
32where
33    St1: Stream,
34    St2: Stream<Item = St1::Item>,
35{
36    pub fn new(stream1: St1, stream2: St2) -> Self {
37        Self {
38            stream1,
39            stream2,
40            state: Initial,
41        }
42    }
43}
44
45impl<St1, St2> Stream for OrStream<St1, St2>
46where
47    St1: Stream,
48    St2: Stream<Item = St1::Item>,
49{
50    type Item = St1::Item;
51
52    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        let this = self.project();
54
55        match this.state {
56            Initial => {
57                if let item @ Some(_) = ready!(this.stream1.poll_next(cx)) {
58                    *this.state = St1;
59
60                    Poll::Ready(item)
61                } else {
62                    *this.state = St2;
63
64                    this.stream2.poll_next(cx)
65                }
66            }
67            St1 => this.stream1.poll_next(cx),
68            St2 => this.stream2.poll_next(cx),
69        }
70    }
71
72    fn size_hint(&self) -> (usize, Option<usize>) {
73        match self.state {
74            Initial => {
75                let (s1_low, s1_high) = self.stream1.size_hint();
76                let (s2_low, s2_high) = self.stream2.size_hint();
77
78                if s1_high == Some(0) {
79                    (s2_low, s2_high)
80                } else if s1_low > 0 {
81                    (s1_low, s1_high)
82                } else {
83                    let low = usize::from(s2_low > 0);
84                    let high = match (s1_high, s2_high) {
85                        (Some(h1), Some(h2)) => Some(max(h1, h2)),
86                        _ => None,
87                    };
88                    (low, high)
89                }
90            }
91            St1 => self.stream1.size_hint(),
92            St2 => self.stream2.size_hint(),
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use {
100        super::OrStream,
101        futures::{
102            Stream,
103            executor::block_on,
104            pin_mut,
105            stream::{StreamExt, empty, iter, once, poll_fn},
106        },
107        std::task::Poll,
108    };
109
110    #[test]
111    fn basic() {
112        block_on(async move {
113            let s1 = once(async { 1 });
114            let s2 = once(async { 3 });
115            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
116            assert_eq!(vec![1], v);
117
118            let s1 = empty();
119            let s2 = once(async { 3 });
120            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
121            assert_eq!(vec![3], v);
122
123            let s1 = once(async { 3 });
124            let s2 = empty();
125            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
126            assert_eq!(vec![3], v);
127        });
128    }
129
130    #[test]
131    fn size_hint_initial_branches() {
132        // stream1 high is Some(0)
133        let s1 = empty();
134        let s2 = once(async { 1 });
135        let or = OrStream::new(s1, s2);
136        assert_eq!(or.size_hint(), (1, Some(1)));
137
138        // stream1 low > 0
139        let s1 = once(async { 1 });
140        let s2 = empty();
141        let or = OrStream::new(s1, s2);
142        assert_eq!(or.size_hint(), (1, Some(1)));
143
144        // else branch with s2_low > 0
145        let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
146        let s2 = once(async { 1 });
147        let or = OrStream::new(s1, s2);
148        assert_eq!(or.size_hint(), (1, None));
149
150        // else branch with s2_low == 0
151        let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
152        let s2 = empty();
153        let or = OrStream::new(s1, s2);
154        assert_eq!(or.size_hint(), (0, None));
155
156        // both highs defined triggers max branch
157        let s1 = iter([1, 2, 3]).filter(|_| async { true });
158        let s2 = iter([1, 2]);
159        let or = OrStream::new(s1, s2);
160        assert_eq!(or.size_hint(), (1, Some(3)));
161    }
162
163    #[test]
164    fn size_hint_state_changes() {
165        block_on(async {
166            // move to St1 after first item from stream1
167            let or = OrStream::new(once(async { 1 }), once(async { 2 }));
168            pin_mut!(or);
169            assert_eq!(or.next().await, Some(1));
170            assert_eq!(or.size_hint(), (0, Some(0)));
171
172            // move to St2 when stream1 is empty
173            let or = OrStream::new(empty(), once(async { 2 }));
174            pin_mut!(or);
175            assert_eq!(or.next().await, Some(2));
176            assert_eq!(or.size_hint(), (0, Some(0)));
177        });
178    }
179}