use std::io;
use monoio::{
buf::IoBufMut,
io::{AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt, Split},
BufResult,
};
use crate::{box_future::MaybeArmedBoxFuture, buf::Buf};
pub struct StreamWrapper<T> {
stream: T,
read_buf: Option<Buf>,
write_buf: Option<Buf>,
read_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
write_fut: MaybeArmedBoxFuture<BufResult<usize, Buf>>,
flush_fut: MaybeArmedBoxFuture<io::Result<()>>,
shutdown_fut: MaybeArmedBoxFuture<io::Result<()>>,
}
unsafe impl<T: Split> Split for StreamWrapper<T> {}
impl<T> StreamWrapper<T> {
pub fn into_inner(self) -> T {
self.stream
}
pub fn new_with_buffer_size(stream: T, read_buffer: usize, write_buffer: usize) -> Self {
let r_buf = Buf::new(read_buffer);
let w_buf = Buf::new(write_buffer);
Self {
stream,
read_buf: Some(r_buf),
write_buf: Some(w_buf),
read_fut: Default::default(),
write_fut: Default::default(),
flush_fut: Default::default(),
shutdown_fut: Default::default(),
}
}
pub fn new(stream: T) -> Self {
const DEFAULT_READ_BUFFER: usize = 8 * 1024;
const DEFAULT_WRITE_BUFFER: usize = 8 * 1024;
Self::new_with_buffer_size(stream, DEFAULT_READ_BUFFER, DEFAULT_WRITE_BUFFER)
}
}
impl<T: AsyncReadRent + Unpin + 'static> tokio::io::AsyncRead for StreamWrapper<T> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.get_mut();
loop {
if !this.read_fut.armed() {
let read_buf_mut = unsafe { this.read_buf.as_mut().unwrap_unchecked() };
if !read_buf_mut.is_empty() {
let our_buf = read_buf_mut.buf_to_read(buf.remaining());
let our_buf_len = our_buf.len();
buf.put_slice(our_buf);
unsafe { read_buf_mut.advance_offset(our_buf_len) };
return std::task::Poll::Ready(Ok(()));
}
let buf = unsafe { this.read_buf.take().unwrap_unchecked() };
#[allow(clippy::cast_ref_to_mut)]
let stream = unsafe { &mut *(&this.stream as *const T as *mut T) };
this.read_fut.arm_future(AsyncReadRent::read(stream, buf));
}
let (ret, buf) = match this.read_fut.poll(cx) {
std::task::Poll::Ready(out) => out,
std::task::Poll::Pending => {
return std::task::Poll::Pending;
}
};
this.read_buf = Some(buf);
if ret? == 0 {
return std::task::Poll::Ready(Ok(()));
}
}
}
}
impl<T: AsyncWriteRent + Unpin + 'static> tokio::io::AsyncWrite for StreamWrapper<T> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
if buf.is_empty() {
return std::task::Poll::Ready(Ok(0));
}
let this = self.get_mut();
if this.write_fut.armed() {
let (ret, mut owned_buf) = match this.write_fut.poll(cx) {
std::task::Poll::Ready(r) => r,
std::task::Poll::Pending => {
return std::task::Poll::Pending;
}
};
unsafe { owned_buf.set_init(0) };
this.write_buf = Some(owned_buf);
if ret.is_err() {
return std::task::Poll::Ready(ret);
}
}
let mut owned_buf = unsafe { this.write_buf.take().unwrap_unchecked() };
let owned_buf_mut = owned_buf.buf_to_write();
let len = buf.len().min(owned_buf_mut.len());
unsafe { std::ptr::copy_nonoverlapping(buf.as_ptr(), owned_buf_mut.as_mut_ptr(), len) };
unsafe { owned_buf.set_init(len) };
#[allow(clippy::cast_ref_to_mut)]
let stream = unsafe { &mut *(&this.stream as *const T as *mut T) };
this.write_fut
.arm_future(AsyncWriteRentExt::write_all(stream, owned_buf));
match this.write_fut.poll(cx) {
std::task::Poll::Ready((ret, mut buf)) => {
unsafe { buf.set_init(0) };
this.write_buf = Some(buf);
if ret.is_err() {
return std::task::Poll::Ready(ret);
}
}
std::task::Poll::Pending => (),
}
std::task::Poll::Ready(Ok(len))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
if this.write_fut.armed() {
match this.write_fut.poll(cx) {
std::task::Poll::Ready((ret, mut buf)) => {
unsafe { buf.set_init(0) };
this.write_buf = Some(buf);
if let Err(e) = ret {
return std::task::Poll::Ready(Err(e));
}
}
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
if !this.flush_fut.armed() {
#[allow(clippy::cast_ref_to_mut)]
let stream = unsafe { &mut *(&this.stream as *const T as *mut T) };
this.flush_fut.arm_future(stream.flush());
}
this.flush_fut.poll(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
if this.write_fut.armed() {
match this.write_fut.poll(cx) {
std::task::Poll::Ready((ret, mut buf)) => {
unsafe { buf.set_init(0) };
this.write_buf = Some(buf);
if let Err(e) = ret {
return std::task::Poll::Ready(Err(e));
}
}
std::task::Poll::Pending => return std::task::Poll::Pending,
}
}
if !this.shutdown_fut.armed() {
#[allow(clippy::cast_ref_to_mut)]
let stream = unsafe { &mut *(&this.stream as *const T as *mut T) };
this.shutdown_fut.arm_future(stream.shutdown());
}
this.shutdown_fut.poll(cx)
}
}