multipart_write/io/
multi_async_writer.rs

1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7// https://github.com/rust-lang/rust/blob/ff6dc928c5e33ce8e65c6911a790b9efcb5ef53a/library/std/src/sys/io/mod.rs#L54
8const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10/// Constructs a `MultipartWrite` from a `tokio::io::AsyncWrite`.
11pub fn async_writer<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
12    MultiAsyncWriter::new(write)
13}
14
15pin_project_lite::pin_project! {
16    /// The writer returned by [`async_writer`](self::async_writer).
17    #[derive(Debug, Default)]
18    pub struct MultiAsyncWriter<W: AsyncWrite> {
19        #[pin]
20        inner: W,
21        buf: Vec<u8>,
22        written: usize,
23    }
24}
25
26impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
27    pub(super) fn new(inner: W) -> Self {
28        Self {
29            inner,
30            buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
31            written: 0,
32        }
33    }
34
35    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
36        let mut this = self.project();
37
38        let len = this.buf.len();
39        let mut ret = Ok(());
40        while *this.written < len {
41            match task::ready!(
42                this.inner
43                    .as_mut()
44                    .poll_write(cx, &this.buf[*this.written..])
45            ) {
46                Ok(0) => {
47                    ret = Err(std::io::Error::new(
48                        std::io::ErrorKind::WriteZero,
49                        "failed to write buffered data",
50                    ));
51                    break;
52                }
53                Ok(n) => *this.written += n,
54                Err(e) => {
55                    ret = Err(e);
56                    break;
57                }
58            }
59        }
60        if *this.written > 0 {
61            this.buf.drain(..*this.written);
62        }
63        *this.written = 0;
64
65        Poll::Ready(ret)
66    }
67}
68
69impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
70    type Ret = usize;
71    type Output = W;
72    type Error = std::io::Error;
73
74    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.flush_buf(cx)
76    }
77
78    fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
79        self.project().buf.extend_from_slice(part);
80        Ok(part.len())
81    }
82
83    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.project().inner.poll_flush(cx)
85    }
86
87    fn poll_complete(
88        mut self: Pin<&mut Self>,
89        _cx: &mut Context<'_>,
90    ) -> Poll<Result<Self::Output, Self::Error>> {
91        Poll::Ready(Ok(std::mem::take(&mut self.inner)))
92    }
93}