1use core::fmt;
2use core::marker::PhantomData;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6use tokio_stream_util::FusedStream;
7
8use super::Sink;
9
10#[must_use = "sinks do nothing unless polled"]
12pub struct WithFlatMap<Si, Item, U, St, F> {
13 sink: Si,
14 f: F,
15 stream: Option<St>,
16 buffer: Option<Item>,
17 _marker: PhantomData<fn(U)>,
18}
19
20impl<Si: Unpin, Item, U, St: Unpin, F> Unpin for WithFlatMap<Si, Item, U, St, F> {}
21
22impl<Si, Item, U, St, F> fmt::Debug for WithFlatMap<Si, Item, U, St, F>
23where
24 Si: fmt::Debug,
25 St: fmt::Debug,
26 Item: fmt::Debug,
27{
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("WithFlatMap")
30 .field("sink", &self.sink)
31 .field("stream", &self.stream)
32 .field("buffer", &self.buffer)
33 .finish()
34 }
35}
36
37impl<Si, Item, U, St, F> WithFlatMap<Si, Item, U, St, F>
38where
39 Si: Sink<Item>,
40 F: FnMut(U) -> St,
41 St: Stream<Item = Result<Item, Si::Error>>,
42{
43 pub(super) fn new(sink: Si, f: F) -> Self {
44 Self {
45 sink,
46 f,
47 stream: None,
48 buffer: None,
49 _marker: PhantomData,
50 }
51 }
52
53 pub fn get_ref(&self) -> &Si {
55 &self.sink
56 }
57
58 pub fn get_mut(&mut self) -> &mut Si {
63 &mut self.sink
64 }
65
66 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
71 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
72 }
73
74 pub fn into_inner(self) -> Si {
79 self.sink
80 }
81
82 fn try_empty_stream(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
83 let this = unsafe { self.get_unchecked_mut() };
84 let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
85
86 if this.buffer.is_some() {
87 match sink.as_mut().poll_ready(cx) {
88 Poll::Ready(Ok(())) => {}
89 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
90 Poll::Pending => return Poll::Pending,
91 }
92 let item = this.buffer.take().unwrap();
93 if let Err(e) = sink.as_mut().start_send(item) {
94 return Poll::Ready(Err(e));
95 }
96 }
97 let stream_pin = unsafe { Pin::new_unchecked(&mut this.stream) };
98 if let Some(mut some_stream) = stream_pin.as_pin_mut() {
99 loop {
100 let item = match some_stream.as_mut().poll_next(cx) {
101 Poll::Ready(Some(Ok(item))) => Some(item),
102 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
103 Poll::Ready(None) => None,
104 Poll::Pending => return Poll::Pending,
105 };
106
107 if let Some(item) = item {
108 match sink.as_mut().poll_ready(cx) {
109 Poll::Ready(Ok(())) => {
110 if let Err(e) = sink.as_mut().start_send(item) {
111 return Poll::Ready(Err(e));
112 }
113 }
114 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
115 Poll::Pending => {
116 this.buffer = Some(item);
117 return Poll::Pending;
118 }
119 };
120 } else {
121 break;
122 }
123 }
124 }
125 this.stream = None;
126 Poll::Ready(Ok(()))
127 }
128}
129
130impl<S, Item, U, St, F> Stream for WithFlatMap<S, Item, U, St, F>
132where
133 S: Stream + Sink<Item>,
134 F: FnMut(U) -> St,
135 St: Stream<Item = Result<Item, S::Error>>,
136{
137 type Item = S::Item;
138
139 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_next(cx)
141 }
142
143 fn size_hint(&self) -> (usize, Option<usize>) {
144 self.sink.size_hint()
145 }
146}
147
148impl<S, Item, U, St, F> FusedStream for WithFlatMap<S, Item, U, St, F>
149where
150 S: FusedStream + Sink<Item>,
151 F: FnMut(U) -> St,
152 St: Stream<Item = Result<Item, S::Error>>,
153{
154 fn is_terminated(&self) -> bool {
155 self.sink.is_terminated()
156 }
157}
158
159impl<Si, Item, U, St, F> Sink<U> for WithFlatMap<Si, Item, U, St, F>
160where
161 Si: Sink<Item>,
162 F: FnMut(U) -> St,
163 St: Stream<Item = Result<Item, Si::Error>>,
164{
165 type Error = Si::Error;
166
167 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 self.try_empty_stream(cx)
169 }
170
171 fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
172 let this = unsafe { self.get_unchecked_mut() };
173
174 assert!(this.stream.is_none());
175 this.stream = Some((this.f)(item));
176 Ok(())
177 }
178
179 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180 match self.as_mut().try_empty_stream(cx) {
181 Poll::Ready(Ok(())) => {}
182 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
183 Poll::Pending => return Poll::Pending,
184 };
185 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_flush(cx)
186 }
187
188 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 match self.as_mut().try_empty_stream(cx) {
190 Poll::Ready(Ok(())) => {}
191 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
192 Poll::Pending => return Poll::Pending,
193 };
194 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_close(cx)
195 }
196}