use core::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use std::io;
use futures_lite::io::{AsyncWrite, AsyncWriteExt};
use pin_project_lite::pin_project;
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
pin_project! {
pub struct BufWriter<W> {
#[pin]
inner: W,
buf: Vec<u8>,
written: usize,
}
}
#[derive(Debug)]
pub struct IntoInnerError<W>(W, io::Error);
impl<W> BufWriter<W> {
pub fn new(inner: W) -> BufWriter<W> {
BufWriter::with_capacity(DEFAULT_BUF_SIZE, inner)
}
pub fn with_capacity(capacity: usize, inner: W) -> BufWriter<W> {
BufWriter {
inner,
buf: Vec::with_capacity(capacity),
written: 0,
}
}
pub fn get_ref(&self) -> &W {
&self.inner
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
self.project().inner
}
pub fn buffer(&self) -> &[u8] {
&self.buf
}
}
impl<W: AsyncWrite> BufWriter<W> {
pub async fn into_inner(mut self) -> Result<W, IntoInnerError<BufWriter<W>>>
where
Self: Unpin,
{
match self.flush().await {
Err(e) => Err(IntoInnerError(self, e)),
Ok(()) => Ok(self.inner),
}
}
fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let mut this = self.project();
let len = this.buf.len();
let mut ret = Ok(());
while *this.written < len {
match this
.inner
.as_mut()
.poll_write(cx, &this.buf[*this.written..])
{
Poll::Ready(Ok(0)) => {
ret = Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered data",
));
break;
}
Poll::Ready(Ok(n)) => *this.written += n,
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => {}
Poll::Ready(Err(e)) => {
ret = Err(e);
break;
}
Poll::Pending => return Poll::Pending,
}
}
if *this.written > 0 {
this.buf.drain(..*this.written);
}
*this.written = 0;
Poll::Ready(ret)
}
}
impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.buf.len() + buf.len() > self.buf.capacity() {
self.as_mut().poll_flush_buf(cx).ready()??;
}
if buf.len() >= self.buf.capacity() {
self.get_pin_mut().poll_write(cx, buf)
} else {
Pin::new(&mut *self.project().buf).poll_write(cx, buf)
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_mut().poll_flush_buf(cx).ready()??;
self.get_pin_mut().poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.as_mut().poll_flush_buf(cx).ready()??;
self.get_pin_mut().poll_close(cx)
}
}
impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BufWriter")
.field("writer", &self.inner)
.field("buf", &self.buf)
.finish()
}
}