async_sink/ext/
with_flat_map.rs

1use core::fmt;
2use core::marker::PhantomData;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6
7use super::Sink;
8
9/// Sink for the [`with_flat_map`](super::SinkExt::with_flat_map) method.
10#[must_use = "sinks do nothing unless polled"]
11pub struct WithFlatMap<Si, Item, U, St, F> {
12    sink: Si,
13    f: F,
14    stream: Option<St>,
15    buffer: Option<Item>,
16    _marker: PhantomData<fn(U)>,
17}
18
19impl<Si: Unpin, Item, U, St: Unpin, F> Unpin for WithFlatMap<Si, Item, U, St, F> {}
20
21impl<Si, Item, U, St, F> fmt::Debug for WithFlatMap<Si, Item, U, St, F>
22where
23    Si: fmt::Debug,
24    St: fmt::Debug,
25    Item: fmt::Debug,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        f.debug_struct("WithFlatMap")
29            .field("sink", &self.sink)
30            .field("stream", &self.stream)
31            .field("buffer", &self.buffer)
32            .finish()
33    }
34}
35
36impl<Si, Item, U, St, F> WithFlatMap<Si, Item, U, St, F>
37where
38    Si: Sink<Item>,
39    F: FnMut(U) -> St,
40    St: Stream<Item = Result<Item, Si::Error>>,
41{
42    pub(super) fn new(sink: Si, f: F) -> Self {
43        Self {
44            sink,
45            f,
46            stream: None,
47            buffer: None,
48            _marker: PhantomData,
49        }
50    }
51
52    /// Acquires a reference to the underlying sink.
53    pub fn get_ref(&self) -> &Si {
54        &self.sink
55    }
56
57    /// Acquires a mutable reference to the underlying sink.
58    ///
59    /// Note that care must be taken to avoid tampering with the state of the
60    /// sink which may otherwise confuse this combinator.
61    pub fn get_mut(&mut self) -> &mut Si {
62        &mut self.sink
63    }
64
65    /// Acquires a pinned mutable reference to the underlying sink.
66    ///
67    /// Note that care must be taken to avoid tampering with the state of the
68    /// sink which may otherwise confuse this combinator.
69    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
70        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
71    }
72
73    /// Consumes this combinator, returning the underlying sink.
74    ///
75    /// Note that this may discard intermediate state of this combinator, so
76    /// care should be taken to avoid losing resources when this is called.
77    pub fn into_inner(self) -> Si {
78        self.sink
79    }
80
81    fn try_empty_stream(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Si::Error>> {
82        let this = unsafe { self.get_unchecked_mut() };
83        let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
84
85        if this.buffer.is_some() {
86            match sink.as_mut().poll_ready(cx) {
87                Poll::Ready(Ok(())) => {}
88                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
89                Poll::Pending => return Poll::Pending,
90            }
91            let item = this.buffer.take().unwrap();
92            if let Err(e) = sink.as_mut().start_send(item) {
93                return Poll::Ready(Err(e));
94            }
95        }
96        let stream_pin = unsafe { Pin::new_unchecked(&mut this.stream) };
97        if let Some(mut some_stream) = stream_pin.as_pin_mut() {
98            loop {
99                let item = match some_stream.as_mut().poll_next(cx) {
100                    Poll::Ready(Some(Ok(item))) => Some(item),
101                    Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
102                    Poll::Ready(None) => None,
103                    Poll::Pending => return Poll::Pending,
104                };
105
106                if let Some(item) = item {
107                    match sink.as_mut().poll_ready(cx) {
108                        Poll::Ready(Ok(())) => {
109                            if let Err(e) = sink.as_mut().start_send(item) {
110                                return Poll::Ready(Err(e));
111                            }
112                        }
113                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
114                        Poll::Pending => {
115                            this.buffer = Some(item);
116                            return Poll::Pending;
117                        }
118                    };
119                } else {
120                    break;
121                }
122            }
123        }
124        this.stream = None;
125        Poll::Ready(Ok(()))
126    }
127}
128
129// Forwarding impl of Stream from the underlying sink
130impl<S, Item, U, St, F> Stream for WithFlatMap<S, Item, U, St, F>
131where
132    S: Stream + Sink<Item>,
133    F: FnMut(U) -> St,
134    St: Stream<Item = Result<Item, S::Error>>,
135{
136    type Item = S::Item;
137
138    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_next(cx)
140    }
141
142    fn size_hint(&self) -> (usize, Option<usize>) {
143        self.sink.size_hint()
144    }
145}
146
147impl<Si, Item, U, St, F> Sink<U> for WithFlatMap<Si, Item, U, St, F>
148where
149    Si: Sink<Item>,
150    F: FnMut(U) -> St,
151    St: Stream<Item = Result<Item, Si::Error>>,
152{
153    type Error = Si::Error;
154
155    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        self.try_empty_stream(cx)
157    }
158
159    fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
160        let this = unsafe { self.get_unchecked_mut() };
161
162        assert!(this.stream.is_none());
163        this.stream = Some((this.f)(item));
164        Ok(())
165    }
166
167    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        match self.as_mut().try_empty_stream(cx) {
169            Poll::Ready(Ok(())) => {}
170            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
171            Poll::Pending => return Poll::Pending,
172        };
173        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_flush(cx)
174    }
175
176    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        match self.as_mut().try_empty_stream(cx) {
178            Poll::Ready(Ok(())) => {}
179            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
180            Poll::Pending => return Poll::Pending,
181        };
182        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }.poll_close(cx)
183    }
184}