use std::{
io::{self, BufRead, Read, Write},
mem::MaybeUninit,
};
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut};
use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE};
#[derive(Debug)]
pub struct SyncStream<S> {
inner: S,
read_buf: Buffer,
write_buf: Buffer,
eof: bool,
base_capacity: usize,
max_buffer_size: usize,
}
impl<S> SyncStream<S> {
const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
pub fn new(stream: S) -> Self {
Self::with_capacity(DEFAULT_BUF_SIZE, stream)
}
pub fn with_capacity(base_capacity: usize, stream: S) -> Self {
Self::with_limits(base_capacity, Self::DEFAULT_MAX_BUFFER, stream)
}
pub fn with_limits(base_capacity: usize, max_buffer_size: usize, stream: S) -> Self {
Self {
inner: stream,
read_buf: Buffer::with_capacity(base_capacity),
write_buf: Buffer::with_capacity(base_capacity),
eof: false,
base_capacity,
max_buffer_size,
}
}
pub fn get_ref(&self) -> &S {
&self.inner
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub fn into_inner(self) -> S {
self.inner
}
pub fn is_eof(&self) -> bool {
self.eof
}
fn available_read(&self) -> &[u8] {
self.read_buf.buffer()
}
fn consume_read(&mut self, amt: usize) {
let all_done = self.read_buf.advance(amt);
if all_done {
self.read_buf
.compact_to(self.base_capacity, self.max_buffer_size);
}
}
pub fn read_buf_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
let available = self.fill_buf()?;
let to_read = available.len().min(buf.len());
buf[..to_read].copy_from_slice(unsafe {
std::slice::from_raw_parts(available.as_ptr().cast(), to_read)
});
self.consume(to_read);
Ok(to_read)
}
pub(crate) fn has_pending_write(&self) -> bool {
!self.write_buf.is_empty()
}
}
impl<S> Read for SyncStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut slice = self.fill_buf()?;
slice.read(buf).inspect(|res| {
self.consume(*res);
})
}
#[cfg(feature = "read_buf")]
fn read_buf(&mut self, mut buf: io::BorrowedCursor<'_>) -> io::Result<()> {
let mut slice = self.fill_buf()?;
let old_written = buf.written();
slice.read_buf(buf.reborrow())?;
let len = buf.written() - old_written;
self.consume(len);
Ok(())
}
}
impl<S> BufRead for SyncStream<S> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
let available = self.available_read();
if available.is_empty() && !self.eof {
return Err(would_block("need to fill read buffer"));
}
Ok(available)
}
fn consume(&mut self, amt: usize) {
self.consume_read(amt);
}
}
impl<S> Write for SyncStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.write_buf.need_flush() && !self.write_buf.is_empty() {
return Err(would_block("need to flush write buffer"));
}
let written = self.write_buf.with_sync(|mut inner| {
let res = (|| {
if inner.buf_len() + buf.len() > self.max_buffer_size {
let space = self.max_buffer_size - inner.buf_len();
if space == 0 {
Err(would_block("write buffer full, need to flush"))
} else {
inner.extend_from_slice(&buf[..space])?;
Ok(space)
}
} else {
inner.extend_from_slice(buf)?;
Ok(buf.len())
}
})();
BufResult(res, inner)
})?;
Ok(written)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn would_block(msg: &str) -> io::Error {
io::Error::new(io::ErrorKind::WouldBlock, msg)
}
impl<S: crate::AsyncRead> SyncStream<S> {
pub async fn fill_read_buf(&mut self) -> io::Result<usize> {
if self.eof {
return Ok(0);
}
self.read_buf
.compact_to(self.base_capacity, self.max_buffer_size);
let read = self
.read_buf
.with(|mut inner| async {
let current_len = inner.buf_len();
if current_len >= self.max_buffer_size {
return BufResult(
Err(io::Error::new(
io::ErrorKind::OutOfMemory,
format!("read buffer size limit ({}) exceeded", self.max_buffer_size),
)),
inner,
);
}
let capacity = inner.buf_capacity();
let available_space = capacity - current_len;
let target_space = self.base_capacity;
if available_space < target_space {
let new_capacity = current_len + target_space;
let _ = inner.reserve_exact(new_capacity - capacity);
}
let len = inner.buf_len();
let read_slice = inner.slice(len..);
self.inner.read(read_slice).await.into_inner()
})
.await?;
if read == 0 {
self.eof = true;
}
Ok(read)
}
}
impl<S: crate::AsyncWrite> SyncStream<S> {
pub async fn flush_write_buf(&mut self) -> io::Result<usize> {
let flushed = self.write_buf.flush_to(&mut self.inner).await?;
self.write_buf
.compact_to(self.base_capacity, self.max_buffer_size);
self.inner.flush().await?;
Ok(flushed)
}
}