async_sink/ext/
send_all.rs

1use super::Sink;
2use core::{
3    fmt,
4    future::Future,
5    ops::DerefMut,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use tokio_stream::Stream;
10
11/// Future for the [`send_all`](super::SinkExt::send_all) method.
12#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct SendAll<'a, Si, Item, St>
14where
15    Si: ?Sized + Sink<Item>,
16    St: Stream<Item = Result<Item, Si::Error>> + ?Sized,
17{
18    sink: &'a mut Si,
19    stream: &'a mut St,
20    buffered: Option<Item>,
21    stream_done: bool,
22}
23
24impl<Si, Item, St> fmt::Debug for SendAll<'_, Si, Item, St>
25where
26    Si: fmt::Debug + ?Sized + Sink<Item>,
27    Si::Error: core::error::Error,
28    Item: fmt::Debug,
29    St: fmt::Debug + Stream<Item = Result<Item, Si::Error>> + ?Sized,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("SendAll")
33            .field("sink", &self.sink)
34            .field("stream", &self.stream)
35            .field("buffered", &self.buffered)
36            .field("stream_done", &self.stream_done)
37            .finish()
38    }
39}
40
41impl<'a, Si, Item, St> SendAll<'a, Si, Item, St>
42where
43    Si: Sink<Item> + Unpin + ?Sized,
44    Si::Error: core::error::Error,
45    St: Stream<Item = Result<Item, Si::Error>> + Unpin + ?Sized,
46{
47    pub(super) fn new(sink: &'a mut Si, stream: &'a mut St) -> Self {
48        Self {
49            sink,
50            stream,
51            buffered: None,
52            stream_done: false,
53        }
54    }
55
56    fn try_start_send(
57        self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        item: Item,
60    ) -> Poll<Result<(), Si::Error>> {
61        let this = unsafe { Pin::get_unchecked_mut(self) };
62        debug_assert!(this.buffered.is_none());
63        match Pin::new(&mut *this.sink).poll_ready(cx) {
64            Poll::Ready(Ok(())) => Poll::Ready(Pin::new(&mut *this.sink).start_send(item)),
65            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
66            Poll::Pending => {
67                this.buffered = Some(item);
68                Poll::Pending
69            }
70        }
71    }
72}
73
74impl<'a, Si, Item, St> Future for SendAll<'a, Si, Item, St>
75where
76    Si: Sink<Item> + Unpin + ?Sized,
77    St: Stream<Item = Result<Item, Si::Error>> + Unpin + ?Sized,
78{
79    type Output = Result<(), Si::Error>;
80
81    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
82        if let Some(item) = unsafe { self.as_mut().get_unchecked_mut() }.buffered.take() {
83            match self.as_mut().try_start_send(cx, item) {
84                Poll::Ready(Ok(())) => {}
85                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
86                Poll::Pending => return Poll::Pending,
87            }
88        }
89
90        loop {
91            let this = unsafe { self.as_mut().get_unchecked_mut() };
92            if this.stream_done {
93                return Pin::new(&mut *this.sink).poll_flush(cx);
94            }
95
96            match <St as Stream>::poll_next(Pin::new(this.stream.deref_mut()), cx) {
97                Poll::Ready(Some(Ok(item))) => match self.as_mut().try_start_send(cx, item) {
98                    Poll::Ready(Ok(())) => continue,
99                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
100                    Poll::Pending => return Poll::Pending,
101                },
102                Poll::Ready(Some(Err(e))) => {
103                    unsafe { self.as_mut().get_unchecked_mut() }.stream_done = true;
104                    return Poll::Ready(Err(e));
105                }
106                Poll::Ready(None) => {
107                    unsafe { self.as_mut().get_unchecked_mut() }.stream_done = true;
108                }
109                Poll::Pending => {
110                    let this = unsafe { self.as_mut().get_unchecked_mut() };
111                    return match Pin::new(&mut *this.sink).poll_flush(cx) {
112                        Poll::Ready(Ok(())) => Poll::Pending,
113                        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
114                        Poll::Pending => Poll::Pending,
115                    };
116                }
117            }
118        }
119    }
120}