use super::{AsyncRead, AsyncWrite};
use std::{fmt, io};
pub struct AsyncBufRead {
buf: Vec<u8>,
pos: usize,
cap: usize,
}
impl AsyncBufRead {
#[inline]
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity {} must > 0", capacity);
AsyncBufRead { buf: vec![0; capacity], pos: 0, cap: 0 }
}
#[inline]
pub async fn read_buffered<T: AsyncRead>(
&mut self, reader: &mut T, buf: &mut [u8],
) -> io::Result<usize> {
if self.pos < self.cap {
let n = std::cmp::min(buf.len(), self.cap - self.pos);
buf[..n].copy_from_slice(&self.buf[self.pos..self.pos + n]);
self.pos += n;
return Ok(n);
}
if buf.len() >= self.buf.len() {
return reader.read(buf).await;
}
self.cap = reader.read(&mut self.buf).await?;
self.pos = 0;
let n = std::cmp::min(buf.len(), self.cap);
buf[..n].copy_from_slice(&self.buf[..n]);
self.pos += n;
Ok(n)
}
}
pub struct AsyncBufWrite {
buf: Vec<u8>,
pos: usize,
}
impl AsyncBufWrite {
#[inline]
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity {} must > 0", capacity);
AsyncBufWrite { buf: vec![0; capacity], pos: 0 }
}
#[inline]
pub async fn flush<W: AsyncWrite>(&mut self, writer: &mut W) -> io::Result<()> {
if self.pos > 0 {
writer.write_all(&self.buf[..self.pos]).await?;
self.pos = 0;
}
Ok(())
}
#[inline]
pub async fn write_buffered<W: AsyncWrite>(
&mut self, writer: &mut W, buf: &[u8],
) -> io::Result<usize> {
if buf.len() >= self.buf.len() {
self.flush(writer).await?;
return writer.write(buf).await;
}
if self.buf.len() - self.pos < buf.len() {
self.flush(writer).await?;
}
let n = buf.len();
self.buf[self.pos..self.pos + n].copy_from_slice(buf);
self.pos += n;
Ok(n)
}
}
pub struct AsyncBufStream<T: AsyncRead + AsyncWrite> {
read_buf: AsyncBufRead,
write_buf: AsyncBufWrite,
inner: T,
}
impl<T: AsyncRead + AsyncWrite> AsyncBufStream<T> {
#[inline]
pub fn new(stream: T, buf_size: usize) -> Self {
Self {
read_buf: AsyncBufRead::new(buf_size),
write_buf: AsyncBufWrite::new(buf_size),
inner: stream,
}
}
#[inline(always)]
pub async fn flush(&mut self) -> io::Result<()> {
self.write_buf.flush(&mut self.inner).await
}
#[inline(always)]
pub fn get_inner(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T: AsyncRead + AsyncWrite + fmt::Debug> fmt::Debug for AsyncBufStream<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T: AsyncRead + AsyncWrite + fmt::Display> fmt::Display for AsyncBufStream<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl<T: AsyncRead + AsyncWrite> AsyncRead for AsyncBufStream<T> {
#[inline(always)]
async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_buf.read_buffered(&mut self.inner, buf).await
}
}
impl<T: AsyncRead + AsyncWrite> AsyncWrite for AsyncBufStream<T> {
#[inline(always)]
async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_buf.write_buffered(&mut self.inner, buf).await
}
}