multipart_write/io/
multi_async_writer.rs1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10pub fn async_writer<W: AsyncWrite + Unpin + Default>(
12 write: W,
13) -> MultiAsyncWriter<W> {
14 MultiAsyncWriter::new(write)
15}
16
17pin_project_lite::pin_project! {
18 #[derive(Debug, Default)]
20 pub struct MultiAsyncWriter<W: AsyncWrite> {
21 #[pin]
22 inner: W,
23 buf: Vec<u8>,
24 written: usize,
25 }
26}
27
28impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
29 pub(super) fn new(inner: W) -> Self {
30 Self { inner, buf: Vec::with_capacity(DEFAULT_BUF_SIZE), written: 0 }
31 }
32
33 fn flush_buf(
34 self: Pin<&mut Self>,
35 cx: &mut Context<'_>,
36 ) -> Poll<std::io::Result<()>> {
37 let mut this = self.project();
38
39 let len = this.buf.len();
40 let mut ret = Ok(());
41 while *this.written < len {
42 match task::ready!(
43 this.inner.as_mut().poll_write(cx, &this.buf[*this.written..])
44 ) {
45 Ok(0) => {
46 ret = Err(std::io::Error::new(
47 std::io::ErrorKind::WriteZero,
48 "failed to write buffered data",
49 ));
50 break;
51 },
52 Ok(n) => *this.written += n,
53 Err(e) => {
54 ret = Err(e);
55 break;
56 },
57 }
58 }
59 if *this.written > 0 {
60 this.buf.drain(..*this.written);
61 }
62 *this.written = 0;
63
64 Poll::Ready(ret)
65 }
66}
67
68impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]>
69 for MultiAsyncWriter<W>
70{
71 type Error = std::io::Error;
72 type Output = W;
73 type Recv = usize;
74
75 fn poll_ready(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 ) -> Poll<Result<(), Self::Error>> {
79 self.flush_buf(cx)
80 }
81
82 fn start_send(
83 self: Pin<&mut Self>,
84 part: &[u8],
85 ) -> Result<Self::Recv, Self::Error> {
86 self.project().buf.extend_from_slice(part);
87 Ok(part.len())
88 }
89
90 fn poll_flush(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 ) -> Poll<Result<(), Self::Error>> {
94 self.project().inner.poll_flush(cx)
95 }
96
97 fn poll_complete(
98 mut self: Pin<&mut Self>,
99 _cx: &mut Context<'_>,
100 ) -> Poll<Result<Self::Output, Self::Error>> {
101 Poll::Ready(Ok(std::mem::take(&mut self.inner)))
102 }
103}