multipart_write/io/
multi_async_writer.rs1use crate::MultipartWrite;
2
3use std::pin::Pin;
4use std::task::{self, Context, Poll};
5use tokio::io::AsyncWrite;
6
7const DEFAULT_BUF_SIZE: usize = 8 * 1024;
9
10pub fn async_write<W: AsyncWrite + Unpin + Default>(write: W) -> MultiAsyncWriter<W> {
15 MultiAsyncWriter::new(write)
16}
17
18#[derive(Debug, Default)]
23#[pin_project::pin_project]
24pub struct MultiAsyncWriter<W: AsyncWrite> {
25 #[pin]
26 inner: W,
27 buf: Vec<u8>,
28 written: usize,
29}
30
31impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
32 pub(super) fn new(inner: W) -> Self {
33 Self {
34 inner,
35 buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
36 written: 0,
37 }
38 }
39
40 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
41 let mut this = self.project();
42
43 let len = this.buf.len();
44 let mut ret = Ok(());
45 while *this.written < len {
46 match task::ready!(
47 this.inner
48 .as_mut()
49 .poll_write(cx, &this.buf[*this.written..])
50 ) {
51 Ok(0) => {
52 ret = Err(std::io::Error::new(
53 std::io::ErrorKind::WriteZero,
54 "failed to write buffered data",
55 ));
56 break;
57 }
58 Ok(n) => *this.written += n,
59 Err(e) => {
60 ret = Err(e);
61 break;
62 }
63 }
64 }
65 if *this.written > 0 {
66 this.buf.drain(..*this.written);
67 }
68 *this.written = 0;
69
70 Poll::Ready(ret)
71 }
72}
73
74impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]> for MultiAsyncWriter<W> {
75 type Ret = usize;
76 type Output = W;
77 type Error = std::io::Error;
78
79 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.flush_buf(cx)
81 }
82
83 fn start_send(self: Pin<&mut Self>, part: &[u8]) -> Result<Self::Ret, Self::Error> {
84 self.project().buf.extend_from_slice(part);
85 Ok(part.len())
86 }
87
88 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 self.project().inner.poll_flush(cx)
90 }
91
92 fn poll_complete(
93 mut self: Pin<&mut Self>,
94 _cx: &mut Context<'_>,
95 ) -> Poll<Result<Self::Output, Self::Error>> {
96 Poll::Ready(Ok(std::mem::take(&mut self.inner)))
97 }
98}