1use super::Sink;
2use alloc::collections::VecDeque;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use tokio_stream::Stream;
6
7#[derive(Debug)]
9#[must_use = "sinks do nothing unless polled"]
10pub struct Buffer<Si, Item> {
11 sink: Si,
12 buf: VecDeque<Item>,
13
14 capacity: usize,
16}
17
18impl<Si: Unpin, Item> Unpin for Buffer<Si, Item> {}
19
20impl<Si: Sink<Item>, Item> Buffer<Si, Item> {
21 pub(super) fn new(sink: Si, capacity: usize) -> Self {
22 Self {
23 sink,
24 buf: VecDeque::with_capacity(capacity),
25 capacity,
26 }
27 }
28
29 pub fn get_ref(&self) -> &Si {
31 &self.sink
32 }
33
34 pub fn get_mut(&mut self) -> &mut Si {
39 &mut self.sink
40 }
41
42 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
47 unsafe { self.map_unchecked_mut(|this| &mut this.sink) }
48 }
49
50 pub fn into_inner(self) -> Si {
55 self.sink
56 }
57
58 fn try_empty_buffer(
59 mut self: Pin<&mut Self>,
60 cx: &mut Context<'_>,
61 ) -> Poll<Result<(), Si::Error>> {
62 let this = unsafe { self.as_mut().get_unchecked_mut() };
63 let mut sink = unsafe { Pin::new_unchecked(&mut this.sink) };
64
65 match sink.as_mut().poll_ready(cx) {
66 Poll::Ready(Ok(())) => {}
67 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
68 Poll::Pending => return Poll::Pending,
69 }
70
71 while let Some(item) = this.buf.pop_front() {
72 if let Err(e) = sink.as_mut().start_send(item) {
73 return Poll::Ready(Err(e));
74 }
75
76 if !this.buf.is_empty() {
77 match sink.as_mut().poll_ready(cx) {
78 Poll::Ready(Ok(())) => {}
79 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
80 Poll::Pending => return Poll::Pending,
81 }
82 }
83 }
84 Poll::Ready(Ok(()))
85 }
86}
87
88impl<S, Item> Stream for Buffer<S, Item>
90where
91 S: Sink<Item> + Stream,
92{
93 type Item = S::Item;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
96 unsafe { self.map_unchecked_mut(|this| &mut this.sink) }.poll_next(cx)
97 }
98
99 fn size_hint(&self) -> (usize, Option<usize>) {
100 self.sink.size_hint()
101 }
102}
103
104impl<Si: Sink<Item>, Item> Sink<Item> for Buffer<Si, Item> {
105 type Error = Si::Error;
106
107 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108 let this = unsafe { self.as_mut().get_unchecked_mut() };
109 if this.capacity == 0 {
110 return unsafe { Pin::new_unchecked(&mut this.sink) }.poll_ready(cx);
111 }
112
113 if this.buf.len() >= this.capacity {
114 match self.try_empty_buffer(cx) {
115 Poll::Ready(Ok(())) => {}
116 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
117 Poll::Pending => return Poll::Pending,
118 }
119 }
120
121 Poll::Ready(Ok(()))
122 }
123
124 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
125 let this = unsafe { self.get_unchecked_mut() };
126 if this.capacity == 0 {
127 unsafe { Pin::new_unchecked(&mut this.sink) }.start_send(item)
128 } else {
129 this.buf.push_back(item);
130 Ok(())
131 }
132 }
133
134 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135 match self.as_mut().try_empty_buffer(cx) {
136 Poll::Ready(Ok(())) => (),
137 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
138 Poll::Pending => return Poll::Pending,
139 }
140 let this = unsafe { self.get_unchecked_mut() };
141 debug_assert!(this.buf.is_empty());
142 unsafe { Pin::new_unchecked(&mut this.sink) }.poll_flush(cx)
143 }
144
145 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 match self.as_mut().try_empty_buffer(cx) {
147 Poll::Ready(Ok(())) => (),
148 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
149 Poll::Pending => return Poll::Pending,
150 }
151 let this = unsafe { self.get_unchecked_mut() };
152 debug_assert!(this.buf.is_empty());
153 unsafe { Pin::new_unchecked(&mut this.sink) }.poll_close(cx)
154 }
155}