1use core::fmt;
2use core::marker::PhantomData;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6
7use super::Sink;
8
9#[must_use = "sinks do nothing unless polled"]
11pub struct WithFlatMap<Si, Item, U, St, F> {
12 sink: Si,
13 f: F,
14 stream: Option<St>,
15 buffer: Option<Item>,
16 _marker: PhantomData<fn(U)>,
17}
18
19impl<Si: Unpin, Item, U, St: Unpin, F> Unpin for WithFlatMap<Si, Item, U, St, F> {}
20
21impl<Si, Item, U, St, F> fmt::Debug for WithFlatMap<Si, Item, U, St, F>
22where
23 Si: fmt::Debug,
24 St: fmt::Debug,
25 Item: fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("WithFlatMap")
29 .field("sink", &self.sink)
30 .field("stream", &self.stream)
31 .field("buffer", &self.buffer)
32 .finish()
33 }
34}
35
36impl<Si, Item, U, St, F> WithFlatMap<Si, Item, U, St, F>
37where
38 Si: Sink<Item>,
39 F: FnMut(U) -> St,
40 St: Stream<Item = Result<Item, Si::Error>>,
41{
42 pub(super) fn new(sink: Si, f: F) -> Self {
43 Self {
44 sink,
45 f,
46 stream: None,
47 buffer: None,
48 _marker: PhantomData,
49 }
50 }
51
52 pub fn get_ref(&self) -> &Si {
54 &self.sink
55 }
56
57 pub fn get_mut(&mut self) -> &mut Si {
62 &mut self.sink
63 }
64
65 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
70 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
71 }
72
73 pub fn into_inner(self) -> Si {
78 self.sink
79 }
80
81 fn try_empty_stream(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
82 let this = unsafe { self.get_unchecked_mut() };
83 let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
84
85 if this.buffer.is_some() {
86 match sink.as_mut().poll_ready(cx) {
87 Poll::Ready(Ok(())) => {}
88 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
89 Poll::Pending => return Poll::Pending,
90 }
91 let item = this.buffer.take().unwrap();
92 if let Err(e) = sink.as_mut().start_send(item) {
93 return Poll::Ready(Err(e));
94 }
95 }
96 let stream_pin = unsafe { Pin::new_unchecked(&mut this.stream) };
97 if let Some(mut some_stream) = stream_pin.as_pin_mut() {
98 loop {
99 let item = match some_stream.as_mut().poll_next(cx) {
100 Poll::Ready(Some(Ok(item))) => Some(item),
101 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
102 Poll::Ready(None) => None,
103 Poll::Pending => return Poll::Pending,
104 };
105
106 if let Some(item) = item {
107 match sink.as_mut().poll_ready(cx) {
108 Poll::Ready(Ok(())) => {
109 if let Err(e) = sink.as_mut().start_send(item) {
110 return Poll::Ready(Err(e));
111 }
112 }
113 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
114 Poll::Pending => {
115 this.buffer = Some(item);
116 return Poll::Pending;
117 }
118 };
119 } else {
120 break;
121 }
122 }
123 }
124 this.stream = None;
125 Poll::Ready(Ok(()))
126 }
127}
128
129impl<S, Item, U, St, F> Stream for WithFlatMap<S, Item, U, St, F>
131where
132 S: Stream + Sink<Item>,
133 F: FnMut(U) -> St,
134 St: Stream<Item = Result<Item, S::Error>>,
135{
136 type Item = S::Item;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_next(cx)
140 }
141
142 fn size_hint(&self) -> (usize, Option<usize>) {
143 self.sink.size_hint()
144 }
145}
146
147impl<Si, Item, U, St, F> Sink<U> for WithFlatMap<Si, Item, U, St, F>
148where
149 Si: Sink<Item>,
150 F: FnMut(U) -> St,
151 St: Stream<Item = Result<Item, Si::Error>>,
152{
153 type Error = Si::Error;
154
155 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 self.try_empty_stream(cx)
157 }
158
159 fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
160 let this = unsafe { self.get_unchecked_mut() };
161
162 assert!(this.stream.is_none());
163 this.stream = Some((this.f)(item));
164 Ok(())
165 }
166
167 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 match self.as_mut().try_empty_stream(cx) {
169 Poll::Ready(Ok(())) => {}
170 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
171 Poll::Pending => return Poll::Pending,
172 };
173 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_flush(cx)
174 }
175
176 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 match self.as_mut().try_empty_stream(cx) {
178 Poll::Ready(Ok(())) => {}
179 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
180 Poll::Pending => return Poll::Pending,
181 };
182 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_close(cx)
183 }
184}