#![allow(clippy::cast_possible_truncation)]
use std::convert::Infallible;
use std::future::Future;
use std::io;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use aliasable::{boxed::AliasableBox, AliasableMut};
use completion_core::CompletionFuture;
use completion_io::{
AsyncBufRead, AsyncBufReadWith, AsyncRead, AsyncReadWith, ReadBuf, ReadBufMut,
};
use futures_core::ready;
use pin_project_lite::pin_project;
use super::extend_lifetime_mut;
#[derive(Debug)]
pub struct Take<T> {
inner: T,
limit: u64,
}
impl<T> Take<T> {
pub(super) fn new(inner: T, limit: u64) -> Self {
Self { inner, limit }
}
pub fn limit(&self) -> u64 {
self.limit
}
pub fn set_limit(&mut self, limit: u64) {
self.limit = limit;
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<'a, T: AsyncRead> AsyncReadWith<'a> for Take<T> {
type ReadFuture = ReadTake<'a, T>;
fn read(&'a mut self, buf: ReadBufMut<'a>) -> Self::ReadFuture {
let mut buf = AliasableMut::from_unique(unsafe { buf.into_mut() });
if self.limit == 0 {
ReadTake {
fut: None,
short_buf: None,
_pinned: PhantomPinned,
buf,
initial_filled: 0,
limit: &mut self.limit,
}
} else {
let initial_initialized = buf.initialized().len();
let initial_filled = buf.filled().len();
let (short_buf, used_buf) = if buf.remaining() as u64 > self.limit {
let limit = self.limit as usize;
let shortened = &mut unsafe { buf.all_mut() }[..initial_filled + limit];
let mut short_buf = ReadBuf::uninit(unsafe { extend_lifetime_mut(shortened) });
unsafe {
short_buf
.assume_init(std::cmp::min(limit, initial_initialized) - initial_filled)
};
short_buf.set_filled(initial_filled);
let mut short_buf = AliasableBox::from_unique(Box::new(short_buf));
let short_buf_mut = unsafe { extend_lifetime_mut(&mut *short_buf) };
(Some(short_buf), short_buf_mut)
} else {
(None, unsafe { extend_lifetime_mut(&mut *buf) })
};
ReadTake {
fut: Some(self.inner.read(used_buf.as_mut())),
short_buf,
_pinned: PhantomPinned,
buf,
initial_filled,
limit: &mut self.limit,
}
}
}
}
pin_project! {
pub struct ReadTake<'a, T: AsyncRead> {
#[pin]
fut: Option<<T as AsyncReadWith<'a>>::ReadFuture>,
short_buf: Option<AliasableBox<ReadBuf<'a>>>,
#[pin]
_pinned: PhantomPinned,
buf: AliasableMut<'a, ReadBuf<'a>>,
initial_filled: usize,
limit: &'a mut u64,
}
}
impl<'a, T: AsyncRead> ReadTake<'a, T> {
fn poll_with<E, F>(self: Pin<&mut Self>, cx: &mut Context<'_>, f: F) -> Poll<Result<(), E>>
where
F: FnOnce(
Pin<&mut <T as AsyncReadWith<'a>>::ReadFuture>,
&mut Context<'_>,
) -> Poll<Result<(), E>>,
{
let mut this = self.project();
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
ready!(f(fut, cx))?;
this.fut.set(None);
if let Some(short_buf) = this.short_buf.take() {
let initialized = short_buf.initialized().len();
let filled = short_buf.filled().len();
drop(short_buf);
unsafe { this.buf.assume_init(initialized - *this.initial_filled) };
this.buf.set_filled(filled);
}
**this.limit -= (this.buf.filled().len() - *this.initial_filled) as u64;
}
Poll::Ready(Ok(()))
}
}
impl<T: AsyncRead> CompletionFuture for ReadTake<'_, T> {
type Output = io::Result<()>;
unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.poll_with(cx, |fut, cx| fut.poll(cx))
}
unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
self.poll_with(cx, |fut, cx| fut.poll_cancel(cx).map(Ok::<_, Infallible>))
.map(drop)
}
}
impl<'a, T: AsyncRead> Future for ReadTake<'a, T>
where
<T as AsyncReadWith<'a>>::ReadFuture: Future<Output = io::Result<()>>,
{
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { CompletionFuture::poll(self, cx) }
}
}
impl<'a, T: AsyncBufRead> AsyncBufReadWith<'a> for Take<T> {
type FillBufFuture = FillBufTake<'a, T>;
fn fill_buf(&'a mut self) -> Self::FillBufFuture {
if self.limit == 0 {
FillBufTake {
fut: None,
limit: 0,
}
} else {
FillBufTake {
fut: Some(self.inner.fill_buf()),
limit: self.limit,
}
}
}
fn consume(&mut self, amt: usize) {
let amt = std::cmp::min(amt as u64, self.limit) as usize;
self.limit -= amt as u64;
self.inner.consume(amt);
}
}
pin_project! {
pub struct FillBufTake<'a, T: AsyncBufRead> {
#[pin]
fut: Option<<T as AsyncBufReadWith<'a>>::FillBufFuture>,
limit: u64,
}
}
impl<'a, T: AsyncBufRead> CompletionFuture for FillBufTake<'a, T> {
type Output = io::Result<&'a [u8]>;
unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
Poll::Ready(Ok(match this.fut.as_pin_mut() {
Some(fut) => {
let buf = ready!(fut.poll(cx))?;
let cap = std::cmp::min(buf.len() as u64, *this.limit) as usize;
&buf[..cap]
}
None => &[],
}))
}
unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if let Some(fut) = self.project().fut.as_pin_mut() {
fut.poll_cancel(cx)
} else {
Poll::Ready(())
}
}
}
impl<'a, T: AsyncBufRead> Future for FillBufTake<'a, T>
where
<T as AsyncBufReadWith<'a>>::FillBufFuture: Future<Output = io::Result<&'a [u8]>>,
{
type Output = io::Result<&'a [u8]>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { CompletionFuture::poll(self, cx) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem::MaybeUninit;
use crate::future::{block_on, CompletionFutureExt};
use super::super::{test_utils::YieldingReader, AsyncReadExt};
#[test]
fn reader_is_smaller() {
let mut reader = YieldingReader::new(vec![Ok("Hello "), Ok("World!")]).take(13);
let mut storage = [MaybeUninit::uninit(); 4];
let mut buf = ReadBuf::uninit(&mut storage);
block_on(reader.read(buf.as_mut())).unwrap();
assert_eq!(buf.into_filled(), b"Hell");
let mut storage = [0; 8];
let mut buf = ReadBuf::new(&mut storage);
block_on(reader.read(buf.as_mut())).unwrap();
assert_eq!(buf.filled(), b"o ");
buf.clear();
block_on(reader.read(buf.as_mut())).unwrap();
assert_eq!(buf.filled(), b"World!");
buf.clear();
block_on(reader.read(buf.as_mut())).unwrap();
assert_eq!(buf.filled(), b"");
}
#[test]
fn reader_is_larger() {
let mut reader = YieldingReader::new(vec![Ok("Hello "), Ok("World!")]).take(8);
let mut s = String::new();
block_on(reader.read_to_string(&mut s)).unwrap();
assert_eq!(s, "Hello Wo");
}
#[test]
fn cancellation_doesnt_lose_data() {
let mut reader = YieldingReader::empty()
.after_cancellation(vec![&[1, 2, 3][..], &[0, 0]])
.take(5);
let mut storage = [MaybeUninit::uninit(); 10];
let mut buf = ReadBuf::uninit(&mut storage);
assert!(block_on(reader.read(buf.as_mut()).now_or_never()).is_none());
assert_eq!(buf.into_filled(), &[1, 2, 3]);
assert_eq!(reader.limit(), 2);
}
}