Skip to main content

multipart_write/io/
multi_async_writer.rs

1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7// https://github.com/rust-lang/rust/blob/ff6dc928c5e33ce8e65c6911a790b9efcb5ef53a/library/std/src/sys/io/mod.rs#L54
8const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10/// Constructs a `MultipartWrite` from a `tokio::io::AsyncWrite`.
11pub fn async_writer<W: AsyncWrite + Unpin + Default>(
12    write: W,
13) -> MultiAsyncWriter<W> {
14    MultiAsyncWriter::new(write)
15}
16
17pin_project_lite::pin_project! {
18    /// The writer returned by [`async_writer`](self::async_writer).
19    #[derive(Debug, Default)]
20    pub struct MultiAsyncWriter<W: AsyncWrite> {
21        #[pin]
22        inner: W,
23        buf: Vec<u8>,
24        written: usize,
25    }
26}
27
28impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
29    pub(super) fn new(inner: W) -> Self {
30        Self { inner, buf: Vec::with_capacity(DEFAULT_BUF_SIZE), written: 0 }
31    }
32
33    fn flush_buf(
34        self: Pin<&mut Self>,
35        cx: &mut Context<'_>,
36    ) -> Poll<std::io::Result<()>> {
37        let mut this = self.project();
38
39        let len = this.buf.len();
40        let mut ret = Ok(());
41        while *this.written < len {
42            match task::ready!(
43                this.inner.as_mut().poll_write(cx, &this.buf[*this.written..])
44            ) {
45                Ok(0) => {
46                    ret = Err(std::io::Error::new(
47                        std::io::ErrorKind::WriteZero,
48                        "failed to write buffered data",
49                    ));
50                    break;
51                },
52                Ok(n) => *this.written += n,
53                Err(e) => {
54                    ret = Err(e);
55                    break;
56                },
57            }
58        }
59        if *this.written > 0 {
60            this.buf.drain(..*this.written);
61        }
62        *this.written = 0;
63
64        Poll::Ready(ret)
65    }
66}
67
68impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]>
69    for MultiAsyncWriter<W>
70{
71    type Error = std::io::Error;
72    type Output = W;
73    type Recv = usize;
74
75    fn poll_ready(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78    ) -> Poll<Result<(), Self::Error>> {
79        self.flush_buf(cx)
80    }
81
82    fn start_send(
83        self: Pin<&mut Self>,
84        part: &[u8],
85    ) -> Result<Self::Recv, Self::Error> {
86        self.project().buf.extend_from_slice(part);
87        Ok(part.len())
88    }
89
90    fn poll_flush(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93    ) -> Poll<Result<(), Self::Error>> {
94        self.project().inner.poll_flush(cx)
95    }
96
97    fn poll_complete(
98        mut self: Pin<&mut Self>,
99        _cx: &mut Context<'_>,
100    ) -> Poll<Result<Self::Output, Self::Error>> {
101        Poll::Ready(Ok(std::mem::take(&mut self.inner)))
102    }
103}