1use {
2 core::{cmp::max, pin::Pin},
3 futures::{
4 ready,
5 stream::Stream,
6 task::{Context, Poll},
7 },
8 pin_project::pin_project,
9};
10
11#[derive(Debug)]
12enum State {
13 Initial,
14 St1,
15 St2,
16}
17
18#[pin_project]
19#[derive(Debug)]
20#[must_use = "streams do nothing unless polled"]
21pub struct OrStream<St1, St2> {
22 #[pin]
23 stream1: St1,
24 #[pin]
25 stream2: St2,
26 state: State,
27}
28
29use State::{Initial, St1, St2};
30
31impl<St1, St2> OrStream<St1, St2>
32where
33 St1: Stream,
34 St2: Stream<Item = St1::Item>,
35{
36 pub fn new(stream1: St1, stream2: St2) -> Self {
37 Self {
38 stream1,
39 stream2,
40 state: Initial,
41 }
42 }
43}
44
45impl<St1, St2> Stream for OrStream<St1, St2>
46where
47 St1: Stream,
48 St2: Stream<Item = St1::Item>,
49{
50 type Item = St1::Item;
51
52 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 let this = self.project();
54
55 match this.state {
56 Initial => match ready!(this.stream1.poll_next(cx)) {
57 item @ Some(_) => {
58 *this.state = St1;
59
60 Poll::Ready(item)
61 }
62 None => {
63 *this.state = St2;
64
65 this.stream2.poll_next(cx)
66 }
67 },
68 St1 => this.stream1.poll_next(cx),
69 St2 => this.stream2.poll_next(cx),
70 }
71 }
72
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 match self.state {
75 Initial => {
76 let (s1_low, s1_high) = self.stream1.size_hint();
77 let (s2_low, s2_high) = self.stream2.size_hint();
78
79 if s1_high == Some(0) {
80 (s2_low, s2_high)
81 } else if s1_low > 0 {
82 (s1_low, s1_high)
83 } else {
84 let low = if s2_low > 0 { 1 } else { 0 };
85 let high = match (s1_high, s2_high) {
86 (Some(h1), Some(h2)) => Some(max(h1, h2)),
87 _ => None,
88 };
89 (low, high)
90 }
91 }
92 St1 => self.stream1.size_hint(),
93 St2 => self.stream2.size_hint(),
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use {
101 super::OrStream,
102 futures::{
103 Stream,
104 executor::block_on,
105 pin_mut,
106 stream::{StreamExt, empty, iter, once, poll_fn},
107 },
108 std::task::Poll,
109 };
110
111 #[test]
112 fn basic() {
113 block_on(async move {
114 let s1 = once(async { 1 });
115 let s2 = once(async { 3 });
116 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
117 assert_eq!(vec![1], v);
118
119 let s1 = empty();
120 let s2 = once(async { 3 });
121 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
122 assert_eq!(vec![3], v);
123
124 let s1 = once(async { 3 });
125 let s2 = empty();
126 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
127 assert_eq!(vec![3], v);
128 });
129 }
130
131 #[test]
132 fn size_hint_initial_branches() {
133 let s1 = empty();
135 let s2 = once(async { 1 });
136 let or = OrStream::new(s1, s2);
137 assert_eq!(or.size_hint(), (1, Some(1)));
138
139 let s1 = once(async { 1 });
141 let s2 = empty();
142 let or = OrStream::new(s1, s2);
143 assert_eq!(or.size_hint(), (1, Some(1)));
144
145 let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
147 let s2 = once(async { 1 });
148 let or = OrStream::new(s1, s2);
149 assert_eq!(or.size_hint(), (1, None));
150
151 let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
153 let s2 = empty();
154 let or = OrStream::new(s1, s2);
155 assert_eq!(or.size_hint(), (0, None));
156
157 let s1 = iter([1, 2, 3]).filter(|_| async { true });
159 let s2 = iter([1, 2]);
160 let or = OrStream::new(s1, s2);
161 assert_eq!(or.size_hint(), (1, Some(3)));
162 }
163
164 #[test]
165 fn size_hint_state_changes() {
166 block_on(async {
167 let or = OrStream::new(once(async { 1 }), once(async { 2 }));
169 pin_mut!(or);
170 assert_eq!(or.next().await, Some(1));
171 assert_eq!(or.size_hint(), (0, Some(0)));
172
173 let or = OrStream::new(empty(), once(async { 2 }));
175 pin_mut!(or);
176 assert_eq!(or.next().await, Some(2));
177 assert_eq!(or.size_hint(), (0, Some(0)));
178 });
179 }
180}