use std::{future::Future, io};
use crate::{
buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut, IoVecWrapper, Slice},
io::{AsyncBufRead, AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt},
BufResult,
};
pub struct BufWriter<W> {
inner: W,
buf: Option<Box<[u8]>>,
pos: usize,
cap: usize,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
impl<W> BufWriter<W> {
#[inline]
pub fn new(inner: W) -> Self {
Self::with_capacity(DEFAULT_BUF_SIZE, inner)
}
#[inline]
pub fn with_capacity(capacity: usize, inner: W) -> Self {
let buffer = vec![0; capacity];
Self {
inner,
buf: Some(buffer.into_boxed_slice()),
pos: 0,
cap: 0,
}
}
#[inline]
pub fn get_ref(&self) -> &W {
&self.inner
}
#[inline]
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
#[inline]
pub fn into_inner(self) -> W {
self.inner
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.buf.as_ref().expect("unable to take buffer")[self.pos..self.cap]
}
#[inline]
fn discard_buffer(&mut self) {
self.pos = 0;
self.cap = 0;
}
}
impl<W: AsyncWriteRent> BufWriter<W> {
async fn flush_buf(&mut self) -> io::Result<()> {
if self.pos != self.cap {
let buf = self
.buf
.take()
.expect("no buffer available, generated future must be awaited");
let slice = Slice::new(buf, self.pos, self.cap);
let (ret, slice) = self.inner.write_all(slice).await;
self.buf = Some(slice.into_inner());
ret?;
self.discard_buffer();
}
Ok(())
}
}
impl<W: AsyncWriteRent> AsyncWriteRent for BufWriter<W> {
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let owned_buf = self.buf.as_ref().unwrap();
let owned_len = owned_buf.len();
let amt = buf.bytes_init();
if self.pos + amt > owned_len {
match self.flush_buf().await {
Ok(_) => (),
Err(e) => {
return (Err(e), buf);
}
}
}
if amt > owned_len {
self.inner.write(buf).await
} else {
unsafe {
let owned_buf = self.buf.as_mut().unwrap();
owned_buf
.as_mut_ptr()
.add(self.cap)
.copy_from_nonoverlapping(buf.read_ptr(), amt);
}
self.cap += amt;
(Ok(amt), buf)
}
}
async fn writev<T: IoVecBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let slice = match IoVecWrapper::new(buf) {
Ok(slice) => slice,
Err(buf) => return (Ok(0), buf),
};
let (result, slice) = self.write(slice).await;
(result, slice.into_inner())
}
async fn flush(&mut self) -> std::io::Result<()> {
self.flush_buf().await?;
self.inner.flush().await
}
async fn shutdown(&mut self) -> std::io::Result<()> {
self.flush_buf().await?;
self.inner.shutdown().await
}
}
impl<W: AsyncWriteRent + AsyncReadRent> AsyncReadRent for BufWriter<W> {
#[inline]
fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
self.inner.read(buf)
}
#[inline]
fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
self.inner.readv(buf)
}
}
impl<W: AsyncWriteRent + AsyncBufRead> AsyncBufRead for BufWriter<W> {
#[inline]
fn fill_buf(&mut self) -> impl Future<Output = std::io::Result<&[u8]>> {
self.inner.fill_buf()
}
#[inline]
fn consume(&mut self, amt: usize) {
self.inner.consume(amt)
}
}