futures_rx/stream_ext/
switch_map.rs1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures::{
7 stream::{Fuse, FusedStream},
8 Stream, StreamExt,
9};
10use pin_project_lite::pin_project;
11
12pin_project! {
13 #[must_use = "streams do nothing unless polled"]
15 pub struct SwitchMap<S: Stream, St: Stream, F: FnMut(S::Item) -> St> {
16 #[pin]
17 stream: Fuse<S>,
18 #[pin]
19 switch_stream: Option<Fuse<F::Output>>,
20 f: F,
21 }
22}
23
24impl<S: Stream, St: Stream, F: FnMut(S::Item) -> St> SwitchMap<S, St, F> {
25 pub(crate) fn new(stream: S, f: F) -> Self {
26 Self {
27 stream: stream.fuse(),
28 switch_stream: None,
29 f,
30 }
31 }
32}
33
34impl<S: Stream, St: Stream, F: FnMut(S::Item) -> St> FusedStream for SwitchMap<S, St, F> {
35 fn is_terminated(&self) -> bool {
36 if self.stream.is_terminated() {
37 self.switch_stream
38 .as_ref()
39 .map(|it| it.is_terminated())
40 .unwrap_or(false)
41 } else {
42 false
43 }
44 }
45}
46
47impl<S: Stream, St: Stream, F: FnMut(S::Item) -> St> Stream for SwitchMap<S, St, F>
48where
49 F::Output: Stream,
50{
51 type Item = St::Item;
52
53 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 let mut this = self.project();
55 let is_done = match this.stream.as_mut().poll_next(cx) {
56 Poll::Ready(Some(event)) => {
57 this.switch_stream.set((this.f)(event).fuse().into());
58
59 false
60 }
61 Poll::Ready(None) => true,
62 Poll::Pending => false,
63 };
64
65 this.switch_stream
66 .as_pin_mut()
67 .map(|it| it.poll_next(cx))
68 .unwrap_or_else(|| {
69 if is_done {
70 Poll::Ready(None)
71 } else {
72 Poll::Pending
73 }
74 })
75 }
76
77 fn size_hint(&self) -> (usize, Option<usize>) {
78 if self.stream.is_terminated() {
79 self.switch_stream
80 .as_ref()
81 .map(|it| it.size_hint())
82 .unwrap_or((0, None))
83 } else {
84 let (lower, _) = self.stream.size_hint();
85
86 (lower, None)
87 }
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use futures::{executor::block_on, stream, StreamExt};
94
95 use crate::RxExt;
96
97 #[test]
98 fn smoke() {
99 block_on(async {
100 let stream = stream::iter(0usize..=3usize);
101 let all_events = stream
102 .switch_map(|i| stream::iter([i.pow(2), i.pow(3), i.pow(4)]))
103 .collect::<Vec<_>>()
104 .await;
105
106 assert_eq!(all_events, [0, 1, 4, 9, 27, 81]);
107 });
108 }
109}