async_sink/ext/
buffer.rs

1use super::Sink;
2use alloc::collections::VecDeque;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6use tokio_stream_util::FusedStream;
7
8/// Sink for the [`buffer`](super::SinkExt::buffer) method.
9#[derive(Debug)]
10#[must_use = "sinks do nothing unless polled"]
11pub struct Buffer<Si, Item> {
12    sink: Si,
13    buf: VecDeque<Item>,
14
15    // Track capacity separately from the `VecDeque`, which may be rounded up
16    capacity: usize,
17}
18
19impl<Si: Unpin, Item> Unpin for Buffer<Si, Item> {}
20
21impl<Si: Sink<Item>, Item> Buffer<Si, Item> {
22    pub(super) fn new(sink: Si, capacity: usize) -> Self {
23        Self {
24            sink,
25            buf: VecDeque::with_capacity(capacity),
26            capacity,
27        }
28    }
29
30    /// Acquires a reference to the underlying sink.
31    pub fn get_ref(&self) -> &Si {
32        &self.sink
33    }
34
35    /// Acquires a mutable reference to the underlying sink.
36    ///
37    /// Note that care must be taken to avoid tampering with the state of the
38    /// sink which may otherwise confuse this combinator.
39    pub fn get_mut(&mut self) -> &mut Si {
40        &mut self.sink
41    }
42
43    /// Acquires a pinned mutable reference to the underlying sink.
44    ///
45    /// Note that care must be taken to avoid tampering with the state of the
46    /// sink which may otherwise confuse this combinator.
47    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
48        unsafe { self.map_unchecked_mut(|this| &mut this.sink) }
49    }
50
51    /// Consumes this combinator, returning the underlying sink.
52    ///
53    /// Note that this may discard intermediate state of this combinator, so
54    /// care should be taken to avoid losing resources when this is called.
55    pub fn into_inner(self) -> Si {
56        self.sink
57    }
58
59    fn try_empty_buffer(
60        mut self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62    ) -> Poll<Result<(), Si::Error>> {
63        let this = unsafe { self.as_mut().get_unchecked_mut() };
64        let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
65
66        match sink.as_mut().poll_ready(cx) {
67            Poll::Ready(Ok(())) => {}
68            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
69            Poll::Pending => return Poll::Pending,
70        }
71
72        while let Some(item) = this.buf.pop_front() {
73            if let Err(e) = sink.as_mut().start_send(item) {
74                return Poll::Ready(Err(e));
75            }
76
77            if !this.buf.is_empty() {
78                match sink.as_mut().poll_ready(cx) {
79                    Poll::Ready(Ok(())) => {}
80                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
81                    Poll::Pending => return Poll::Pending,
82                }
83            }
84        }
85        Poll::Ready(Ok(()))
86    }
87}
88
89// Forwarding impl of Stream from the underlying sink
90impl<S, Item> Stream for Buffer<S, Item>
91where
92    S: Sink<Item> + Stream,
93{
94    type Item = S::Item;
95
96    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
97        unsafe { self.map_unchecked_mut(|this| &mut this.sink) }.poll_next(cx)
98    }
99
100    fn size_hint(&self) -> (usize, Option<usize>) {
101        self.sink.size_hint()
102    }
103}
104
105impl<S, Item> FusedStream for Buffer<S, Item>
106where
107    S: Sink<Item> + FusedStream,
108{
109    fn is_terminated(&self) -> bool {
110        self.sink.is_terminated()
111    }
112}
113
114impl<Si: Sink<Item>, Item> Sink<Item> for Buffer<Si, Item> {
115    type Error = Si::Error;
116
117    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        let this = unsafe { self.as_mut().get_unchecked_mut() };
119        if this.capacity == 0 {
120            return unsafe { Pin::new_unchecked(&mut this.sink) }.poll_ready(cx);
121        }
122
123        if this.buf.len() >= this.capacity {
124            match self.try_empty_buffer(cx) {
125                Poll::Ready(Ok(())) => {}
126                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
127                Poll::Pending => return Poll::Pending,
128            }
129        }
130
131        Poll::Ready(Ok(()))
132    }
133
134    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
135        let this = unsafe { self.get_unchecked_mut() };
136        if this.capacity == 0 {
137            unsafe { Pin::new_unchecked(&mut this.sink) }.start_send(item)
138        } else {
139            this.buf.push_back(item);
140            Ok(())
141        }
142    }
143
144    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        match self.as_mut().try_empty_buffer(cx) {
146            Poll::Ready(Ok(())) => (),
147            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
148            Poll::Pending => return Poll::Pending,
149        }
150        let this = unsafe { self.get_unchecked_mut() };
151        debug_assert!(this.buf.is_empty());
152        unsafe { Pin::new_unchecked(&mut this.sink) }.poll_flush(cx)
153    }
154
155    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        match self.as_mut().try_empty_buffer(cx) {
157            Poll::Ready(Ok(())) => (),
158            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
159            Poll::Pending => return Poll::Pending,
160        }
161        let this = unsafe { self.get_unchecked_mut() };
162        debug_assert!(this.buf.is_empty());
163        unsafe { Pin::new_unchecked(&mut this.sink) }.poll_close(cx)
164    }
165}