multipart_write/io/
multi_async_writer.rs1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10pub fn async_write<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
15 MultiAsyncWriter::new(write)
16}
17
18pin_project_lite::pin_project! {
19 #[derive(Debug, Default)]
24 pub struct MultiAsyncWriter<W: AsyncWrite> {
25 #[pin]
26 inner: W,
27 buf: Vec<u8>,
28 written: usize,
29 }
30}
31
32impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
33 pub(super) fn new(inner: W) -> Self {
34 Self {
35 inner,
36 buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
37 written: 0,
38 }
39 }
40
41 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
42 let mut this = self.project();
43
44 let len = this.buf.len();
45 let mut ret = Ok(());
46 while *this.written < len {
47 match task::ready!(
48 this.inner
49 .as_mut()
50 .poll_write(cx, &this.buf[*this.written..])
51 ) {
52 Ok(0) => {
53 ret = Err(std::io::Error::new(
54 std::io::ErrorKind::WriteZero,
55 "failed to write buffered data",
56 ));
57 break;
58 }
59 Ok(n) => *this.written += n,
60 Err(e) => {
61 ret = Err(e);
62 break;
63 }
64 }
65 }
66 if *this.written > 0 {
67 this.buf.drain(..*this.written);
68 }
69 *this.written = 0;
70
71 Poll::Ready(ret)
72 }
73}
74
75impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
76 type Ret = usize;
77 type Output = W;
78 type Error = std::io::Error;
79
80 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81 self.flush_buf(cx)
82 }
83
84 fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
85 self.project().buf.extend_from_slice(part);
86 Ok(part.len())
87 }
88
89 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 self.project().inner.poll_flush(cx)
91 }
92
93 fn poll_complete(
94 mut self: Pin<&mut Self>,
95 _cx: &mut Context<'_>,
96 ) -> Poll<Result<Self::Output, Self::Error>> {
97 Poll::Ready(Ok(std::mem::take(&mut self.inner)))
98 }
99}