mediasan_common/
async_skip.rs

1//! Utility functions for the [`AsyncSkip`] trait.
2
3use std::future::Future;
4use std::io;
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::task::{ready, Context, Poll};
8
9use futures_util::io::{BufReader, Cursor};
10use futures_util::{AsyncBufRead, AsyncRead, AsyncSeek};
11
12use crate::{AsyncSkip, SeekSkipAdapter};
13
14//
15// public types
16//
17
18/// An extension trait which adds utility methods to [`AsyncSkip`] types.
19pub trait AsyncSkipExt: AsyncSkip {
20    /// Skip an amount of bytes in a stream.
21    ///
22    /// A skip beyond the end of a stream is allowed, but behavior is defined by the implementation.
23    fn skip(&mut self, amount: u64) -> Skip<'_, Self> {
24        Skip { amount, inner: self }
25    }
26
27    /// Returns the current position of the cursor from the start of the stream.
28    fn stream_position(&mut self) -> StreamPosition<'_, Self> {
29        StreamPosition { inner: self }
30    }
31
32    /// Returns the length of this stream, in bytes.
33    fn stream_len(&mut self) -> StreamLen<'_, Self> {
34        StreamLen { inner: self }
35    }
36}
37
38/// Future for the [`skip`](AsyncSkipExt::skip) method.
39pub struct Skip<'a, T: ?Sized> {
40    amount: u64,
41    inner: &'a mut T,
42}
43
44/// Future for the [`stream_position`](AsyncSkipExt::stream_position) method.
45pub struct StreamPosition<'a, T: ?Sized> {
46    inner: &'a mut T,
47}
48
49/// Future for the [`stream_len`](AsyncSkipExt::stream_len) method.
50pub struct StreamLen<'a, T: ?Sized> {
51    inner: &'a mut T,
52}
53
54//
55// AsyncSkipExt impls
56//
57
58impl<T: AsyncSkip + ?Sized> AsyncSkipExt for T {}
59
60//
61// Skip impls
62//
63
64impl<T: AsyncSkip + Unpin + ?Sized> Future for Skip<'_, T> {
65    type Output = io::Result<()>;
66
67    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        let amount = self.amount;
69        Pin::new(&mut *self.inner).poll_skip(cx, amount)
70    }
71}
72
73//
74// StreamPosition impls
75//
76
77impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamPosition<'_, T> {
78    type Output = io::Result<u64>;
79
80    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81        Pin::new(&mut *self.inner).poll_stream_position(cx)
82    }
83}
84
85//
86// StreamLen impls
87//
88
89impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamLen<'_, T> {
90    type Output = io::Result<u64>;
91
92    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
93        Pin::new(&mut *self.inner).poll_stream_len(cx)
94    }
95}
96
97//
98// SeekSkipAdapter impls
99//
100
101impl<T: AsyncRead + Unpin + ?Sized> AsyncRead for SeekSkipAdapter<T> {
102    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
103        Pin::new(&mut self.0).poll_read(cx, buf)
104    }
105}
106
107impl<R: AsyncSeek + Unpin + ?Sized> AsyncSkip for SeekSkipAdapter<R> {
108    fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
109        match amount.try_into() {
110            Ok(0) => (),
111            Ok(amount) => {
112                let reader = Pin::new(&mut self.get_mut().0);
113                ready!(reader.poll_seek(cx, io::SeekFrom::Current(amount)))?;
114            }
115            Err(_) => {
116                let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
117                let seek_pos = stream_pos
118                    .checked_add(amount)
119                    .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "seek past u64::MAX"))?;
120                let reader = Pin::new(&mut self.get_mut().0);
121                ready!(reader.poll_seek(cx, io::SeekFrom::Start(seek_pos)))?;
122            }
123        }
124        Ok(()).into()
125    }
126
127    fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
128        let reader = Pin::new(&mut self.get_mut().0);
129        reader.poll_seek(cx, io::SeekFrom::Current(0))
130    }
131
132    fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
133        // This is the unstable Seek::stream_len
134        let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
135        let mut reader = Pin::new(&mut self.get_mut().0);
136        let len = ready!(reader.as_mut().poll_seek(cx, io::SeekFrom::End(0)))?;
137
138        if stream_pos != len {
139            ready!(reader.poll_seek(cx, io::SeekFrom::Start(stream_pos)))?;
140        }
141
142        Ok(len).into()
143    }
144}
145
146//
147// AsyncSkip impls
148//
149
150macro_rules! deref_async_skip {
151    () => {
152        fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
153            Pin::new(&mut **self).poll_skip(cx, amount)
154        }
155
156        fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
157            Pin::new(&mut **self).poll_stream_position(cx)
158        }
159
160        fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
161            Pin::new(&mut **self).poll_stream_len(cx)
162        }
163    };
164}
165
166impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for &mut R {
167    deref_async_skip!();
168}
169
170impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for Box<R> {
171    deref_async_skip!();
172}
173
174impl<P: DerefMut + Unpin> AsyncSkip for Pin<P>
175where
176    P::Target: AsyncSkip,
177{
178    fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
179        self.get_mut().as_mut().poll_skip(cx, amount)
180    }
181
182    fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
183        self.get_mut().as_mut().poll_stream_position(cx)
184    }
185
186    fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
187        self.get_mut().as_mut().poll_stream_len(cx)
188    }
189}
190
191macro_rules! async_skip_via_adapter {
192    () => {
193        fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
194            Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_skip(cx, amount)
195        }
196
197        fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
198            Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_position(cx)
199        }
200
201        fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
202            Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_len(cx)
203        }
204    };
205}
206
207impl<T: AsRef<[u8]> + Unpin> AsyncSkip for Cursor<T> {
208    async_skip_via_adapter!();
209}
210
211impl<R: AsyncRead + AsyncSkip> AsyncSkip for BufReader<R> {
212    /// Poll skipping `amount` bytes in a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
213    fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
214        let buf_len = self.buffer().len();
215        if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
216            if skip_amount != 0 {
217                ready!(self.as_mut().get_pin_mut().poll_skip(cx, skip_amount))?
218            }
219        }
220        self.consume(buf_len.min(amount as usize));
221        Poll::Ready(Ok(()))
222    }
223
224    /// Poll the stream position for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
225    fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
226        let stream_pos = ready!(self.as_mut().get_pin_mut().poll_stream_position(cx))?;
227        Poll::Ready(Ok(stream_pos.saturating_sub(self.buffer().len() as u64)))
228    }
229
230    /// Poll the stream length for a [`BufReader`] implementing [`AsyncRead`] + [`AsyncSkip`].
231    fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
232        self.as_mut().get_pin_mut().poll_stream_len(cx)
233    }
234}