use crate::MultipartWrite;
use std::pin::Pin;
use std::task::{self, Context, Poll};
use tokio::io::AsyncWrite;
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
pub fn async_writer<W: AsyncWrite + Unpin + Default>(
write: W,
) -> MultiAsyncWriter<W> {
MultiAsyncWriter::new(write)
}
pin_project_lite::pin_project! {
#[derive(Debug, Default)]
pub struct MultiAsyncWriter<W: AsyncWrite> {
#[pin]
inner: W,
buf: Vec<u8>,
written: usize,
}
}
impl<W: AsyncWrite + Unpin> MultiAsyncWriter<W> {
pub(super) fn new(inner: W) -> Self {
Self { inner, buf: Vec::with_capacity(DEFAULT_BUF_SIZE), written: 0 }
}
fn flush_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut this = self.project();
let len = this.buf.len();
let mut ret = Ok(());
while *this.written < len {
match task::ready!(
this.inner.as_mut().poll_write(cx, &this.buf[*this.written..])
) {
Ok(0) => {
ret = Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write buffered data",
));
break;
},
Ok(n) => *this.written += n,
Err(e) => {
ret = Err(e);
break;
},
}
}
if *this.written > 0 {
this.buf.drain(..*this.written);
}
*this.written = 0;
Poll::Ready(ret)
}
}
impl<W: AsyncWrite + Default + Unpin> MultipartWrite<&[u8]>
for MultiAsyncWriter<W>
{
type Error = std::io::Error;
type Output = W;
type Recv = usize;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.flush_buf(cx)
}
fn start_send(
self: Pin<&mut Self>,
part: &[u8],
) -> Result<Self::Recv, Self::Error> {
self.project().buf.extend_from_slice(part);
Ok(part.len())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_complete(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Self::Output, Self::Error>> {
Poll::Ready(Ok(std::mem::take(&mut self.inner)))
}
}