multipart_write/io/
multi_async_writer.rs

1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7// https://github.com/rust-lang/rust/blob/ff6dc928c5e33ce8e65c6911a790b9efcb5ef53a/library/std/src/sys/io/mod.rs#L54
8const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10/// Converts an [`AsyncWrite`] into a [`MultipartWrite`].
11///
12/// [`AsyncWrite`]: tokio::io::AsyncWrite
13/// [`MultipartWrite`]: crate::MultipartWrite
14pub fn async_write<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
15    MultiAsyncWriter::new(write)
16}
17
18/// `MultiAsyncWriter` implements [`MultipartWrite`] for an asynchronous
19/// [`tokio::io::AsyncWrite`](tokio::io::AsyncWrite).
20///
21/// [`MultipartWrite`]: crate::MultipartWrite
22#[derive(Debug, Default)]
23#[pin_project::pin_project]
24pub struct MultiAsyncWriter<W: AsyncWrite> {
25    #[pin]
26    inner: W,
27    buf: Vec<u8>,
28    written: usize,
29}
30
31impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
32    pub(super) fn new(inner: W) -> Self {
33        Self {
34            inner,
35            buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
36            written: 0,
37        }
38    }
39
40    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
41        let mut this = self.project();
42
43        let len = this.buf.len();
44        let mut ret = Ok(());
45        while *this.written < len {
46            match task::ready!(
47                this.inner
48                    .as_mut()
49                    .poll_write(cx, &this.buf[*this.written..])
50            ) {
51                Ok(0) => {
52                    ret = Err(std::io::Error::new(
53                        std::io::ErrorKind::WriteZero,
54                        "failed to write buffered data",
55                    ));
56                    break;
57                }
58                Ok(n) => *this.written += n,
59                Err(e) => {
60                    ret = Err(e);
61                    break;
62                }
63            }
64        }
65        if *this.written > 0 {
66            this.buf.drain(..*this.written);
67        }
68        *this.written = 0;
69
70        Poll::Ready(ret)
71    }
72}
73
74impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
75    type Ret = usize;
76    type Output = W;
77    type Error = std::io::Error;
78
79    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        self.flush_buf(cx)
81    }
82
83    fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
84        self.project().buf.extend_from_slice(part);
85        Ok(part.len())
86    }
87
88    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        self.project().inner.poll_flush(cx)
90    }
91
92    fn poll_complete(
93        mut self: Pin<&mut Self>,
94        _cx: &mut Context<'_>,
95    ) -> Poll<Result<Self::Output, Self::Error>> {
96        Poll::Ready(Ok(std::mem::take(&mut self.inner)))
97    }
98}