async_compression/tokio/write/
buf_writer.rs1use super::AsyncBufWrite;
6use futures_core::ready;
7use pin_project_lite::pin_project;
8use std::{
9 cmp::min,
10 fmt, io,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tokio::io::AsyncWrite;
15
16const DEFAULT_BUF_SIZE: usize = 8192;
17
18pin_project! {
19 pub struct BufWriter<W> {
20 #[pin]
21 inner: W,
22 buf: Box<[u8]>,
23 written: usize,
24 buffered: usize,
25 }
26}
27
28impl<W: AsyncWrite> BufWriter<W> {
29 pub fn new(inner: W) -> Self {
32 Self::with_capacity(DEFAULT_BUF_SIZE, inner)
33 }
34
35 pub fn with_capacity(cap: usize, inner: W) -> Self {
37 Self {
38 inner,
39 buf: vec![0; cap].into(),
40 written: 0,
41 buffered: 0,
42 }
43 }
44
45 fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46 let mut this = self.project();
47
48 let mut ret = Ok(());
49 while *this.written < *this.buffered {
50 match this
51 .inner
52 .as_mut()
53 .poll_write(cx, &this.buf[*this.written..*this.buffered])
54 {
55 Poll::Pending => {
56 break;
57 }
58 Poll::Ready(Ok(0)) => {
59 ret = Err(io::Error::new(
60 io::ErrorKind::WriteZero,
61 "failed to write the buffered data",
62 ));
63 break;
64 }
65 Poll::Ready(Ok(n)) => *this.written += n,
66 Poll::Ready(Err(e)) => {
67 ret = Err(e);
68 break;
69 }
70 }
71 }
72
73 if *this.written > 0 {
74 this.buf.copy_within(*this.written..*this.buffered, 0);
75 *this.buffered -= *this.written;
76 *this.written = 0;
77
78 Poll::Ready(ret)
79 } else if *this.buffered == 0 {
80 Poll::Ready(ret)
81 } else {
82 ret?;
83 Poll::Pending
84 }
85 }
86
87 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 let mut this = self.project();
89
90 let mut ret = Ok(());
91 while *this.written < *this.buffered {
92 match ready!(this
93 .inner
94 .as_mut()
95 .poll_write(cx, &this.buf[*this.written..*this.buffered]))
96 {
97 Ok(0) => {
98 ret = Err(io::Error::new(
99 io::ErrorKind::WriteZero,
100 "failed to write the buffered data",
101 ));
102 break;
103 }
104 Ok(n) => *this.written += n,
105 Err(e) => {
106 ret = Err(e);
107 break;
108 }
109 }
110 }
111 this.buf.copy_within(*this.written..*this.buffered, 0);
112 *this.buffered -= *this.written;
113 *this.written = 0;
114 Poll::Ready(ret)
115 }
116}
117
118impl<W> BufWriter<W> {
119 pub fn get_ref(&self) -> &W {
121 &self.inner
122 }
123
124 pub fn get_mut(&mut self) -> &mut W {
128 &mut self.inner
129 }
130
131 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
135 self.project().inner
136 }
137
138 pub fn into_inner(self) -> W {
142 self.inner
143 }
144}
145
146impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
147 fn poll_write(
148 mut self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &[u8],
151 ) -> Poll<io::Result<usize>> {
152 let this = self.as_mut().project();
153 if *this.buffered + buf.len() > this.buf.len() {
154 ready!(self.as_mut().partial_flush_buf(cx))?;
155 }
156
157 let this = self.as_mut().project();
158 if buf.len() >= this.buf.len() {
159 if *this.buffered == 0 {
160 this.inner.poll_write(cx, buf)
161 } else {
162 Poll::Pending
165 }
166 } else {
167 let len = min(this.buf.len() - *this.buffered, buf.len());
168 this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]);
169 *this.buffered += len;
170 Poll::Ready(Ok(len))
171 }
172 }
173
174 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175 ready!(self.as_mut().flush_buf(cx))?;
176 self.project().inner.poll_flush(cx)
177 }
178
179 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 ready!(self.as_mut().flush_buf(cx))?;
181 self.project().inner.poll_shutdown(cx)
182 }
183}
184
185impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
186 fn poll_partial_flush_buf(
187 mut self: Pin<&mut Self>,
188 cx: &mut Context<'_>,
189 ) -> Poll<io::Result<&mut [u8]>> {
190 ready!(self.as_mut().partial_flush_buf(cx))?;
191 let this = self.project();
192 Poll::Ready(Ok(&mut this.buf[*this.buffered..]))
193 }
194
195 fn produce(self: Pin<&mut Self>, amt: usize) {
196 let this = self.project();
197 debug_assert!(
198 *this.buffered + amt <= this.buf.len(),
199 "produce called with amt exceeding buffer capacity"
200 );
201 *this.buffered += amt;
202 }
203}
204
205impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 f.debug_struct("BufWriter")
208 .field("writer", &self.inner)
209 .field(
210 "buffer",
211 &format_args!("{}/{}", self.buffered, self.buf.len()),
212 )
213 .field("written", &self.written)
214 .finish()
215 }
216}