use std::{
alloc::{alloc, Layout},
io,
pin::Pin,
ptr,
ptr::slice_from_raw_parts_mut,
task::Poll,
};
use futures::{io::BufReader, AsyncRead};
use pin_project::pin_project;
pub trait AsyncReadUtil: AsyncRead + Sized {
fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F>;
fn map_read<F>(self, f: F) -> MappedReader<Self, F>
where
F: FnMut(&[u8], &mut [u8]) -> (usize, usize);
fn buffered(self) -> BufReader<Self>;
}
impl<R: AsyncRead + Sized> AsyncReadUtil for R {
fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F> {
ObservedReader::new(self, f)
}
fn map_read<F>(self, f: F) -> MappedReader<Self, F>
where
F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
{
MappedReader::new(self, f)
}
fn buffered(self) -> BufReader<Self> {
BufReader::new(self)
}
}
#[pin_project]
pub struct ObservedReader<R, F> {
#[pin]
inner: R,
f: F,
}
impl<R, F> ObservedReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8]),
{
pub fn new(inner: R, f: F) -> Self {
Self { inner, f }
}
}
impl<R, F> AsyncRead for ObservedReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8]),
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.as_mut().project();
let num_read = futures::ready!(this.inner.poll_read(cx, buf))?;
(this.f)(&buf[0..num_read]);
Poll::Ready(Ok(num_read))
}
}
#[pin_project]
pub struct MappedReader<R, F> {
#[pin]
inner: R,
f: F,
buf: Box<[u8]>,
pos: usize,
cap: usize,
done: bool,
}
impl<R, F> MappedReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
{
pub fn new(inner: R, f: F) -> Self {
Self::with_capacity(8096, inner, f)
}
pub fn with_capacity(capacity: usize, inner: R, f: F) -> Self {
let buf = unsafe { uninit_buf(capacity) };
Self {
inner,
f,
buf,
pos: 0,
cap: 0,
done: false,
}
}
}
impl<R, F> AsyncRead for MappedReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.as_mut().project();
if *this.pos == *this.cap {
*this.pos = 0;
*this.cap = 0;
}
if !*this.done && *this.cap < this.buf.len() {
let nread = futures::ready!(this.inner.poll_read(cx, &mut this.buf[*this.cap..]))?;
*this.cap += nread;
if nread == 0 {
*this.done = true;
}
}
let unprocessed = &this.buf[*this.pos..*this.cap];
let (nsrc, ndst) = (this.f)(&this.buf[*this.pos..*this.cap], buf);
assert!(
ndst <= buf.len(),
"mapped reader is reportedly reading more than the destination buffer's capacity"
);
if nsrc == 0 && !unprocessed.is_empty() {
assert!(unprocessed.len() < this.buf.len());
let count = unprocessed.len();
unsafe {
ptr::copy(unprocessed.as_ptr(), this.buf.as_mut().as_mut_ptr(), count);
}
}
*this.pos += nsrc;
Poll::Ready(Ok(ndst))
}
}
unsafe fn uninit_buf(size: usize) -> Box<[u8]> {
let layout = Layout::array::<u8>(size).unwrap();
let ptr = slice_from_raw_parts_mut(alloc(layout), size);
Box::from_raw(ptr)
}