1use {
2 core::{cmp::max, pin::Pin},
3 futures::{
4 ready,
5 stream::Stream,
6 task::{Context, Poll},
7 },
8 pin_project::pin_project,
9};
10
11#[derive(Debug)]
12enum State {
13 Initial,
14 St1,
15 St2,
16}
17
18#[pin_project]
19#[derive(Debug)]
20#[must_use = "streams do nothing unless polled"]
21pub struct OrStream<St1, St2> {
22 #[pin]
23 stream1: St1,
24 #[pin]
25 stream2: St2,
26 state: State,
27}
28
29use State::{Initial, St1, St2};
30
31impl<St1, St2> OrStream<St1, St2>
32where
33 St1: Stream,
34 St2: Stream<Item = St1::Item>,
35{
36 pub fn new(stream1: St1, stream2: St2) -> Self {
37 Self {
38 stream1,
39 stream2,
40 state: Initial,
41 }
42 }
43}
44
45impl<St1, St2> Stream for OrStream<St1, St2>
46where
47 St1: Stream,
48 St2: Stream<Item = St1::Item>,
49{
50 type Item = St1::Item;
51
52 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 let this = self.project();
54
55 match this.state {
56 Initial => {
57 if let item @ Some(_) = ready!(this.stream1.poll_next(cx)) {
58 *this.state = St1;
59
60 Poll::Ready(item)
61 } else {
62 *this.state = St2;
63
64 this.stream2.poll_next(cx)
65 }
66 }
67 St1 => this.stream1.poll_next(cx),
68 St2 => this.stream2.poll_next(cx),
69 }
70 }
71
72 fn size_hint(&self) -> (usize, Option<usize>) {
73 match self.state {
74 Initial => {
75 let (s1_low, s1_high) = self.stream1.size_hint();
76 let (s2_low, s2_high) = self.stream2.size_hint();
77
78 if s1_high == Some(0) {
79 (s2_low, s2_high)
80 } else if s1_low > 0 {
81 (s1_low, s1_high)
82 } else {
83 let low = usize::from(s2_low > 0);
84 let high = match (s1_high, s2_high) {
85 (Some(h1), Some(h2)) => Some(max(h1, h2)),
86 _ => None,
87 };
88 (low, high)
89 }
90 }
91 St1 => self.stream1.size_hint(),
92 St2 => self.stream2.size_hint(),
93 }
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use {
100 super::OrStream,
101 futures::{
102 Stream,
103 executor::block_on,
104 pin_mut,
105 stream::{StreamExt, empty, iter, once, poll_fn},
106 },
107 std::task::Poll,
108 };
109
110 #[test]
111 fn basic() {
112 block_on(async move {
113 let s1 = once(async { 1 });
114 let s2 = once(async { 3 });
115 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
116 assert_eq!(vec![1], v);
117
118 let s1 = empty();
119 let s2 = once(async { 3 });
120 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
121 assert_eq!(vec![3], v);
122
123 let s1 = once(async { 3 });
124 let s2 = empty();
125 let v = OrStream::new(s1, s2).collect::<Vec<i32>>().await;
126 assert_eq!(vec![3], v);
127 });
128 }
129
130 #[test]
131 fn size_hint_initial_branches() {
132 let s1 = empty();
134 let s2 = once(async { 1 });
135 let or = OrStream::new(s1, s2);
136 assert_eq!(or.size_hint(), (1, Some(1)));
137
138 let s1 = once(async { 1 });
140 let s2 = empty();
141 let or = OrStream::new(s1, s2);
142 assert_eq!(or.size_hint(), (1, Some(1)));
143
144 let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
146 let s2 = once(async { 1 });
147 let or = OrStream::new(s1, s2);
148 assert_eq!(or.size_hint(), (1, None));
149
150 let s1 = poll_fn(|_| Poll::<Option<i32>>::Pending);
152 let s2 = empty();
153 let or = OrStream::new(s1, s2);
154 assert_eq!(or.size_hint(), (0, None));
155
156 let s1 = iter([1, 2, 3]).filter(|_| async { true });
158 let s2 = iter([1, 2]);
159 let or = OrStream::new(s1, s2);
160 assert_eq!(or.size_hint(), (1, Some(3)));
161 }
162
163 #[test]
164 fn size_hint_state_changes() {
165 block_on(async {
166 let or = OrStream::new(once(async { 1 }), once(async { 2 }));
168 pin_mut!(or);
169 assert_eq!(or.next().await, Some(1));
170 assert_eq!(or.size_hint(), (0, Some(0)));
171
172 let or = OrStream::new(empty(), once(async { 2 }));
174 pin_mut!(or);
175 assert_eq!(or.next().await, Some(2));
176 assert_eq!(or.size_hint(), (0, Some(0)));
177 });
178 }
179}