multipart_write/io/
multi_async_writer.rs

1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7// https://github.com/rust-lang/rust/blob/ff6dc928c5e33ce8e65c6911a790b9efcb5ef53a/library/std/src/sys/io/mod.rs#L54
8const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10/// Converts an [`AsyncWrite`] into a [`MultipartWrite`].
11///
12/// [`AsyncWrite`]: tokio::io::AsyncWrite
13/// [`MultipartWrite`]: crate::MultipartWrite
14pub fn async_write<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
15    MultiAsyncWriter::new(write)
16}
17
18pin_project_lite::pin_project! {
19    /// `MultiAsyncWriter` implements [`MultipartWrite`] for an asynchronous
20    /// [`tokio::io::AsyncWrite`](tokio::io::AsyncWrite).
21    ///
22    /// [`MultipartWrite`]: crate::MultipartWrite
23    #[derive(Debug, Default)]
24    pub struct MultiAsyncWriter<W: AsyncWrite> {
25        #[pin]
26        inner: W,
27        buf: Vec<u8>,
28        written: usize,
29    }
30}
31
32impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
33    pub(super) fn new(inner: W) -> Self {
34        Self {
35            inner,
36            buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
37            written: 0,
38        }
39    }
40
41    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
42        let mut this = self.project();
43
44        let len = this.buf.len();
45        let mut ret = Ok(());
46        while *this.written < len {
47            match task::ready!(
48                this.inner
49                    .as_mut()
50                    .poll_write(cx, &this.buf[*this.written..])
51            ) {
52                Ok(0) => {
53                    ret = Err(std::io::Error::new(
54                        std::io::ErrorKind::WriteZero,
55                        "failed to write buffered data",
56                    ));
57                    break;
58                }
59                Ok(n) => *this.written += n,
60                Err(e) => {
61                    ret = Err(e);
62                    break;
63                }
64            }
65        }
66        if *this.written > 0 {
67            this.buf.drain(..*this.written);
68        }
69        *this.written = 0;
70
71        Poll::Ready(ret)
72    }
73}
74
75impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
76    type Ret = usize;
77    type Output = W;
78    type Error = std::io::Error;
79
80    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81        self.flush_buf(cx)
82    }
83
84    fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
85        self.project().buf.extend_from_slice(part);
86        Ok(part.len())
87    }
88
89    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.project().inner.poll_flush(cx)
91    }
92
93    fn poll_complete(
94        mut self: Pin<&mut Self>,
95        _cx: &mut Context<'_>,
96    ) -> Poll<Result<Self::Output, Self::Error>> {
97        Poll::Ready(Ok(std::mem::take(&mut self.inner)))
98    }
99}