use crate::{futures_util::FuturesOps, PartialOp};
use futures::prelude::*;
use pin_project::pin_project;
use std::{
fmt, io,
pin::Pin,
task::{Context, Poll},
};
#[pin_project]
pub struct PartialAsyncRead<R> {
#[pin]
inner: R,
ops: FuturesOps,
}
impl<R> PartialAsyncRead<R> {
pub fn new<I>(inner: R, iter: I) -> Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
PartialAsyncRead {
inner,
ops: FuturesOps::new(iter),
}
}
pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
self.ops.replace(iter);
self
}
pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
where
I: IntoIterator<Item = PartialOp> + 'static,
I::IntoIter: Send,
{
let mut this = self;
this.as_mut().project().ops.replace(iter);
this
}
pub fn get_ref(&self) -> &R {
&self.inner
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}
pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.project().inner
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R> AsyncRead for PartialAsyncRead<R>
where
R: AsyncRead,
{
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let inner = this.inner;
let len = buf.len();
this.ops.poll_impl(
cx,
|cx, len| match len {
Some(len) => inner.poll_read(cx, &mut buf[..len]),
None => inner.poll_read(cx, buf),
},
len,
"error during poll_read, generated by partial-io",
)
}
}
impl<R> AsyncBufRead for PartialAsyncRead<R>
where
R: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_fill_buf(cx),
"error during poll_read, generated by partial-io",
)
}
#[inline]
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
}
}
impl<R> AsyncWrite for PartialAsyncRead<R>
where
R: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_close(cx)
}
}
impl<R> AsyncSeek for PartialAsyncRead<R>
where
R: AsyncSeek,
{
#[inline]
fn poll_seek(
self: Pin<&mut Self>,
cx: &mut Context,
pos: io::SeekFrom,
) -> Poll<io::Result<u64>> {
self.project().inner.poll_seek(cx, pos)
}
}
#[cfg(feature = "tokio1")]
pub(crate) mod tokio_impl {
use super::PartialAsyncRead;
use std::{
io::{self, SeekFrom},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
impl<R> AsyncRead for PartialAsyncRead<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
let capacity = buf.capacity();
this.ops.poll_impl(
cx,
|cx, len| match len {
Some(len) => {
buf.with_limited(len, |limited_buf| inner.poll_read(cx, limited_buf))
}
None => inner.poll_read(cx, buf),
},
capacity,
"error during poll_read, generated by partial-io",
)
}
}
pub trait ReadBufExt {
fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
where
F: FnOnce(&mut ReadBuf<'_>) -> T;
}
impl<'a> ReadBufExt for ReadBuf<'a> {
fn with_limited<F, T>(&mut self, limit: usize, callback: F) -> T
where
F: FnOnce(&mut ReadBuf<'_>) -> T,
{
let capacity_limit = self.capacity().min(limit);
let old_initialized_len = self.initialized().len().min(limit);
let old_filled_len = self.filled().len().min(limit);
let mut limited_buf = unsafe {
let inner_mut = &mut self.inner_mut()[..capacity_limit];
let mut limited_buf = ReadBuf::uninit(inner_mut);
limited_buf.assume_init(old_initialized_len);
limited_buf
};
limited_buf.set_filled(old_filled_len);
let ret = callback(&mut limited_buf);
let new_initialized_len = limited_buf.initialized().len();
let new_filled_len = limited_buf.filled().len();
if new_initialized_len > old_initialized_len {
unsafe {
self.assume_init(new_initialized_len - self.filled().len());
}
}
if new_filled_len != old_filled_len {
self.set_filled(new_filled_len);
}
ret
}
}
impl<R> AsyncBufRead for PartialAsyncRead<R>
where
R: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.project();
let inner = this.inner;
this.ops.poll_impl_no_limit(
cx,
|cx| inner.poll_fill_buf(cx),
"error during poll_fill_buf, generated by partial-io",
)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
}
}
impl<R> AsyncWrite for PartialAsyncRead<R>
where
R: AsyncWrite,
{
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}
impl<R> AsyncSeek for PartialAsyncRead<R>
where
R: AsyncSeek,
{
#[inline]
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
self.project().inner.start_seek(position)
}
#[inline]
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
self.project().inner.poll_complete(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use std::mem::MaybeUninit;
#[test]
fn test_with_limited() {
const CAPACITY: usize = 256;
let inputs = vec![
(256, 256),
(64, 256),
(0, 256),
(128, 128),
(64, 128),
(0, 128),
(0, 0),
];
let limits = vec![0, 32, 64, 128, 192, 256, 384];
for ((filled, initialized), limit) in inputs.into_iter().cartesian_product(limits) {
let mut storage: [MaybeUninit<u8>; CAPACITY] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut buf = ReadBuf::uninit(&mut storage);
buf.initialize_unfilled_to(initialized);
buf.set_filled(filled);
println!("*** limit = {}, original buf = {:?}", limit, buf);
buf.with_limited(limit, |limited_buf| {
println!(" * do-nothing: limited buf = {:?}", limited_buf);
assert!(
limited_buf.capacity() <= limit,
"limit is applied to capacity"
);
assert!(
limited_buf.initialized().len() <= limit,
"limit is applied to initialized len"
);
assert!(
limited_buf.filled().len() <= limit,
"limit is applied to filled len"
);
});
assert_eq!(
buf.filled().len(),
filled,
"do-nothing -> filled is the same as before"
);
assert_eq!(
buf.initialized().len(),
initialized,
"do-nothing -> initialized is the same as before"
);
let new_filled = buf.with_limited(limit, |limited_buf| {
println!(" * halve-filled: limited buf = {:?}", limited_buf);
let new_filled = limited_buf.filled().len() / 2;
limited_buf.set_filled(new_filled);
println!(" * halve-filled: after = {:?}", limited_buf);
new_filled
});
match new_filled.cmp(&limit) {
std::cmp::Ordering::Less => {
assert_eq!(
buf.filled().len(),
new_filled,
"halve-filled, new filled < limit -> filled is updated"
);
}
std::cmp::Ordering::Equal => {
assert_eq!(limit, 0, "halve-filled, new filled == limit -> limit = 0");
assert_eq!(
buf.filled().len(),
filled,
"halve-filled, new filled == limit -> filled stays the same"
);
}
std::cmp::Ordering::Greater => {
panic!("new_filled {} must be <= limit {}", new_filled, limit);
}
}
assert_eq!(
buf.initialized().len(),
initialized,
"halve-filled -> initialized is same as before"
);
if filled < limit.min(CAPACITY) {
let mut storage: [MaybeUninit<u8>; CAPACITY] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut buf = ReadBuf::uninit(&mut storage);
buf.initialize_unfilled_to(initialized);
buf.set_filled(filled);
buf.with_limited(limit, |limited_buf| {
println!(" * push-one-byte: limited buf = {:?}", limited_buf);
limited_buf.put_slice(&[42]);
println!(" * push-one-byte: after = {:?}", limited_buf);
});
assert_eq!(
buf.filled().len(),
filled + 1,
"push-one-byte, filled incremented by 1"
);
assert_eq!(
buf.filled()[filled],
42,
"push-one-byte, correct byte was pushed"
);
if filled == initialized {
assert_eq!(
buf.initialized().len(),
initialized + 1,
"push-one-byte, filled == initialized -> initialized incremented by 1"
);
} else {
assert_eq!(
buf.initialized().len(),
initialized,
"push-one-byte, filled < initialized -> initialized stays the same"
);
}
}
if initialized <= limit.min(CAPACITY) {
let mut storage: [MaybeUninit<u8>; CAPACITY] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut buf = ReadBuf::uninit(&mut storage);
buf.initialize_unfilled_to(initialized);
buf.set_filled(filled);
buf.with_limited(limit, |limited_buf| {
println!(" * initialize-unfilled: limited buf = {:?}", limited_buf);
limited_buf.initialize_unfilled();
println!(" * initialize-unfilled: after = {:?}", limited_buf);
});
assert_eq!(
buf.filled().len(),
filled,
"initialize-unfilled, filled stays the same"
);
assert_eq!(
buf.initialized().len(),
limit.min(CAPACITY),
"initialize-unfilled, initialized is capped at the limit"
);
assert_eq!(
buf.initialized(),
vec![0; buf.initialized().len()],
"initialize-unfilled, bytes are correct"
);
}
}
}
}
}
impl<R> fmt::Debug for PartialAsyncRead<R>
where
R: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartialAsyncRead")
.field("inner", &self.inner)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use crate::tests::assert_send;
#[test]
fn test_sendable() {
assert_send::<PartialAsyncRead<File>>();
}
}