multipart_write/io/
multi_async_writer.rs1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10pub fn async_writer<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
12 MultiAsyncWriter::new(write)
13}
14
15pin_project_lite::pin_project! {
16 #[derive(Debug, Default)]
18 pub struct MultiAsyncWriter<W: AsyncWrite> {
19 #[pin]
20 inner: W,
21 buf: Vec<u8>,
22 written: usize,
23 }
24}
25
26impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
27 pub(super) fn new(inner: W) -> Self {
28 Self {
29 inner,
30 buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
31 written: 0,
32 }
33 }
34
35 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
36 let mut this = self.project();
37
38 let len = this.buf.len();
39 let mut ret = Ok(());
40 while *this.written < len {
41 match task::ready!(
42 this.inner
43 .as_mut()
44 .poll_write(cx, &this.buf[*this.written..])
45 ) {
46 Ok(0) => {
47 ret = Err(std::io::Error::new(
48 std::io::ErrorKind::WriteZero,
49 "failed to write buffered data",
50 ));
51 break;
52 }
53 Ok(n) => *this.written += n,
54 Err(e) => {
55 ret = Err(e);
56 break;
57 }
58 }
59 }
60 if *this.written > 0 {
61 this.buf.drain(..*this.written);
62 }
63 *this.written = 0;
64
65 Poll::Ready(ret)
66 }
67}
68
69impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
70 type Ret = usize;
71 type Output = W;
72 type Error = std::io::Error;
73
74 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.flush_buf(cx)
76 }
77
78 fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
79 self.project().buf.extend_from_slice(part);
80 Ok(part.len())
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.project().inner.poll_flush(cx)
85 }
86
87 fn poll_complete(
88 mut self: Pin<&mut Self>,
89 _cx: &mut Context<'_>,
90 ) -> Poll<Result<Self::Output, Self::Error>> {
91 Poll::Ready(Ok(std::mem::take(&mut self.inner)))
92 }
93}