async_sink/ext/
buffer.rs

1use super::Sink;
2use alloc::collections::VecDeque;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6
7/// Sink for the [`buffer`](super::SinkExt::buffer) method.
8#[derive(Debug)]
9#[must_use = "sinks do nothing unless polled"]
10pub struct Buffer<Si, Item> {
11    sink: Si,
12    buf: VecDeque<Item>,
13
14    // Track capacity separately from the `VecDeque`, which may be rounded up
15    capacity: usize,
16}
17
18impl<Si: Unpin, Item> Unpin for Buffer<Si, Item> {}
19
20impl<Si: Sink<Item>, Item> Buffer<Si, Item> {
21    pub(super) fn new(sink: Si, capacity: usize) -> Self {
22        Self {
23            sink,
24            buf: VecDeque::with_capacity(capacity),
25            capacity,
26        }
27    }
28
29    /// Acquires a reference to the underlying sink.
30    pub fn get_ref(&self) -> &Si {
31        &self.sink
32    }
33
34    /// Acquires a mutable reference to the underlying sink.
35    ///
36    /// Note that care must be taken to avoid tampering with the state of the
37    /// sink which may otherwise confuse this combinator.
38    pub fn get_mut(&mut self) -> &mut Si {
39        &mut self.sink
40    }
41
42    /// Acquires a pinned mutable reference to the underlying sink.
43    ///
44    /// Note that care must be taken to avoid tampering with the state of the
45    /// sink which may otherwise confuse this combinator.
46    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
47        unsafe { self.map_unchecked_mut(|this| &mut this.sink) }
48    }
49
50    /// Consumes this combinator, returning the underlying sink.
51    ///
52    /// Note that this may discard intermediate state of this combinator, so
53    /// care should be taken to avoid losing resources when this is called.
54    pub fn into_inner(self) -> Si {
55        self.sink
56    }
57
58    fn try_empty_buffer(
59        mut self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61    ) -> Poll<Result<(), Si::Error>> {
62        let this = unsafe { self.as_mut().get_unchecked_mut() };
63        let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
64
65        match sink.as_mut().poll_ready(cx) {
66            Poll::Ready(Ok(())) => {}
67            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
68            Poll::Pending => return Poll::Pending,
69        }
70
71        while let Some(item) = this.buf.pop_front() {
72            if let Err(e) = sink.as_mut().start_send(item) {
73                return Poll::Ready(Err(e));
74            }
75
76            if !this.buf.is_empty() {
77                match sink.as_mut().poll_ready(cx) {
78                    Poll::Ready(Ok(())) => {}
79                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
80                    Poll::Pending => return Poll::Pending,
81                }
82            }
83        }
84        Poll::Ready(Ok(()))
85    }
86}
87
88// Forwarding impl of Stream from the underlying sink
89impl<S, Item> Stream for Buffer<S, Item>
90where
91    S: Sink<Item> + Stream,
92{
93    type Item = S::Item;
94
95    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
96        unsafe { self.map_unchecked_mut(|this| &mut this.sink) }.poll_next(cx)
97    }
98
99    fn size_hint(&self) -> (usize, Option<usize>) {
100        self.sink.size_hint()
101    }
102}
103
104impl<Si: Sink<Item>, Item> Sink<Item> for Buffer<Si, Item> {
105    type Error = Si::Error;
106
107    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108        let this = unsafe { self.as_mut().get_unchecked_mut() };
109        if this.capacity == 0 {
110            return unsafe { Pin::new_unchecked(&mut this.sink) }.poll_ready(cx);
111        }
112
113        if this.buf.len() >= this.capacity {
114            match self.try_empty_buffer(cx) {
115                Poll::Ready(Ok(())) => {}
116                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
117                Poll::Pending => return Poll::Pending,
118            }
119        }
120
121        Poll::Ready(Ok(()))
122    }
123
124    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
125        let this = unsafe { self.get_unchecked_mut() };
126        if this.capacity == 0 {
127            unsafe { Pin::new_unchecked(&mut this.sink) }.start_send(item)
128        } else {
129            this.buf.push_back(item);
130            Ok(())
131        }
132    }
133
134    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135        match self.as_mut().try_empty_buffer(cx) {
136            Poll::Ready(Ok(())) => (),
137            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
138            Poll::Pending => return Poll::Pending,
139        }
140        let this = unsafe { self.get_unchecked_mut() };
141        debug_assert!(this.buf.is_empty());
142        unsafe { Pin::new_unchecked(&mut this.sink) }.poll_flush(cx)
143    }
144
145    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        match self.as_mut().try_empty_buffer(cx) {
147            Poll::Ready(Ok(())) => (),
148            Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
149            Poll::Pending => return Poll::Pending,
150        }
151        let this = unsafe { self.get_unchecked_mut() };
152        debug_assert!(this.buf.is_empty());
153        unsafe { Pin::new_unchecked(&mut this.sink) }.poll_close(cx)
154    }
155}