use std::future::Future;
use std::io;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use futures_util::io::{BufReader, Cursor};
use futures_util::{AsyncBufRead, AsyncRead, AsyncSeek};
use crate::{AsyncSkip, SeekSkipAdapter};
pub trait AsyncSkipExt: AsyncSkip {
fn skip(&mut self, amount: u64) -> Skip<'_, Self> {
Skip { amount, inner: self }
}
fn stream_position(&mut self) -> StreamPosition<'_, Self> {
StreamPosition { inner: self }
}
fn stream_len(&mut self) -> StreamLen<'_, Self> {
StreamLen { inner: self }
}
}
pub struct Skip<'a, T: ?Sized> {
amount: u64,
inner: &'a mut T,
}
pub struct StreamPosition<'a, T: ?Sized> {
inner: &'a mut T,
}
pub struct StreamLen<'a, T: ?Sized> {
inner: &'a mut T,
}
impl<T: AsyncSkip + ?Sized> AsyncSkipExt for T {}
impl<T: AsyncSkip + Unpin + ?Sized> Future for Skip<'_, T> {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let amount = self.amount;
Pin::new(&mut *self.inner).poll_skip(cx, amount)
}
}
impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamPosition<'_, T> {
type Output = io::Result<u64>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut *self.inner).poll_stream_position(cx)
}
}
impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamLen<'_, T> {
type Output = io::Result<u64>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut *self.inner).poll_stream_len(cx)
}
}
impl<T: AsyncRead + Unpin + ?Sized> AsyncRead for SeekSkipAdapter<T> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<R: AsyncSeek + Unpin + ?Sized> AsyncSkip for SeekSkipAdapter<R> {
fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
match amount.try_into() {
Ok(0) => (),
Ok(amount) => {
let reader = Pin::new(&mut self.get_mut().0);
ready!(reader.poll_seek(cx, io::SeekFrom::Current(amount)))?;
}
Err(_) => {
let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
let seek_pos = stream_pos
.checked_add(amount)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "seek past u64::MAX"))?;
let reader = Pin::new(&mut self.get_mut().0);
ready!(reader.poll_seek(cx, io::SeekFrom::Start(seek_pos)))?;
}
}
Ok(()).into()
}
fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let reader = Pin::new(&mut self.get_mut().0);
reader.poll_seek(cx, io::SeekFrom::Current(0))
}
fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
let mut reader = Pin::new(&mut self.get_mut().0);
let len = ready!(reader.as_mut().poll_seek(cx, io::SeekFrom::End(0)))?;
if stream_pos != len {
ready!(reader.poll_seek(cx, io::SeekFrom::Start(stream_pos)))?;
}
Ok(len).into()
}
}
macro_rules! deref_async_skip {
() => {
fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_skip(cx, amount)
}
fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Pin::new(&mut **self).poll_stream_position(cx)
}
fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Pin::new(&mut **self).poll_stream_len(cx)
}
};
}
impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for &mut R {
deref_async_skip!();
}
impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for Box<R> {
deref_async_skip!();
}
impl<P: DerefMut + Unpin> AsyncSkip for Pin<P>
where
P::Target: AsyncSkip,
{
fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_skip(cx, amount)
}
fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
self.get_mut().as_mut().poll_stream_position(cx)
}
fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
self.get_mut().as_mut().poll_stream_len(cx)
}
}
macro_rules! async_skip_via_adapter {
() => {
fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_skip(cx, amount)
}
fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_position(cx)
}
fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_len(cx)
}
};
}
impl<T: AsRef<[u8]> + Unpin> AsyncSkip for Cursor<T> {
async_skip_via_adapter!();
}
impl<R: AsyncRead + AsyncSkip> AsyncSkip for BufReader<R> {
fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
let buf_len = self.buffer().len();
if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
if skip_amount != 0 {
ready!(self.as_mut().get_pin_mut().poll_skip(cx, skip_amount))?
}
}
self.consume(buf_len.min(amount as usize));
Poll::Ready(Ok(()))
}
fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let stream_pos = ready!(self.as_mut().get_pin_mut().poll_stream_position(cx))?;
Poll::Ready(Ok(stream_pos.saturating_sub(self.buffer().len() as u64)))
}
fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
self.as_mut().get_pin_mut().poll_stream_len(cx)
}
}