gluesql_utils/
or_stream.rs

1use {
2    core::{cmp::max, pin::Pin},
3    futures::{
4        ready,
5        stream::Stream,
6        task::{Context, Poll},
7    },
8    pin_project::pin_project,
9};
10
11#[derive(Debug)]
12enum State {
13    Initial,
14    St1,
15    St2,
16}
17
18#[pin_project]
19#[derive(Debug)]
20#[must_use = "streams do nothing unless polled"]
21pub struct OrStream<St1, St2> {
22    #[pin]
23    stream1: St1,
24    #[pin]
25    stream2: St2,
26    state: State,
27}
28
29use State::{Initial, St1, St2};
30
31impl<St1, St2> OrStream<St1, St2>
32where
33    St1: Stream,
34    St2: Stream<Item = St1::Item>,
35{
36    pub fn new(stream1: St1, stream2: St2) -> Self {
37        Self {
38            stream1,
39            stream2,
40            state: Initial,
41        }
42    }
43}
44
45impl<St1, St2> Stream for OrStream<St1, St2>
46where
47    St1: Stream,
48    St2: Stream<Item = St1::Item>,
49{
50    type Item = St1::Item;
51
52    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        let this = self.project();
54
55        match this.state {
56            Initial => match ready!(this.stream1.poll_next(cx)) {
57                item @ Some(_) => {
58                    *this.state = St1;
59
60                    Poll::Ready(item)
61                }
62                None => {
63                    *this.state = St2;
64
65                    this.stream2.poll_next(cx)
66                }
67            },
68            St1 => this.stream1.poll_next(cx),
69            St2 => this.stream2.poll_next(cx),
70        }
71    }
72
73    fn size_hint(&self) -> (usize, Option<usize>) {
74        match self.state {
75            Initial => {
76                let (s1_low, s1_high) = self.stream1.size_hint();
77                let (s2_low, s2_high) = self.stream2.size_hint();
78
79                if s1_high == Some(0) {
80                    (s2_low, s2_high)
81                } else if s1_low > 0 {
82                    (s1_low, s1_high)
83                } else {
84                    let low = if s2_low > 0 { 1 } else { 0 };
85                    let high = match (s1_high, s2_high) {
86                        (Some(h1), Some(h2)) => Some(max(h1, h2)),
87                        _ => None,
88                    };
89                    (low, high)
90                }
91            }
92            St1 => self.stream1.size_hint(),
93            St2 => self.stream2.size_hint(),
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use {
101        super::OrStream,
102        futures::{
103            Stream,
104            executor::block_on,
105            pin_mut,
106            stream::{StreamExt, empty, iter, once, poll_fn},
107        },
108        std::task::Poll,
109    };
110
111    #[test]
112    fn basic() {
113        block_on(async move {
114            let s1 = once(async { 1 });
115            let s2 = once(async { 3 });
116            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
117            assert_eq!(vec![1], v);
118
119            let s1 = empty();
120            let s2 = once(async { 3 });
121            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
122            assert_eq!(vec![3], v);
123
124            let s1 = once(async { 3 });
125            let s2 = empty();
126            let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
127            assert_eq!(vec![3], v);
128        });
129    }
130
131    #[test]
132    fn size_hint_initial_branches() {
133        // stream1 high is Some(0)
134        let s1 = empty();
135        let s2 = once(async { 1 });
136        let or = OrStream::new(s1, s2);
137        assert_eq!(or.size_hint(), (1, Some(1)));
138
139        // stream1 low > 0
140        let s1 = once(async { 1 });
141        let s2 = empty();
142        let or = OrStream::new(s1, s2);
143        assert_eq!(or.size_hint(), (1, Some(1)));
144
145        // else branch with s2_low > 0
146        let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
147        let s2 = once(async { 1 });
148        let or = OrStream::new(s1, s2);
149        assert_eq!(or.size_hint(), (1, None));
150
151        // else branch with s2_low == 0
152        let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
153        let s2 = empty();
154        let or = OrStream::new(s1, s2);
155        assert_eq!(or.size_hint(), (0, None));
156
157        // both highs defined triggers max branch
158        let s1 = iter([1, 2, 3]).filter(|_| async { true });
159        let s2 = iter([1, 2]);
160        let or = OrStream::new(s1, s2);
161        assert_eq!(or.size_hint(), (1, Some(3)));
162    }
163
164    #[test]
165    fn size_hint_state_changes() {
166        block_on(async {
167            // move to St1 after first item from stream1
168            let or = OrStream::new(once(async { 1 }), once(async { 2 }));
169            pin_mut!(or);
170            assert_eq!(or.next().await, Some(1));
171            assert_eq!(or.size_hint(), (0, Some(0)));
172
173            // move to St2 when stream1 is empty
174            let or = OrStream::new(empty(), once(async { 2 }));
175            pin_mut!(or);
176            assert_eq!(or.next().await, Some(2));
177            assert_eq!(or.size_hint(), (0, Some(0)));
178        });
179    }
180}