mediasan_common/
async_skip.rs1use std::future::Future;
4use std::io;
5use std::ops::DerefMut;
6use std::pin::Pin;
7use std::task::{ready, Context, Poll};
8
9use futures_util::io::{BufReader, Cursor};
10use futures_util::{AsyncBufRead, AsyncRead, AsyncSeek};
11
12use crate::{AsyncSkip, SeekSkipAdapter};
13
14pub trait AsyncSkipExt: AsyncSkip {
20 fn skip(&mut self, amount: u64) -> Skip<'_, Self> {
24 Skip { amount, inner: self }
25 }
26
27 fn stream_position(&mut self) -> StreamPosition<'_, Self> {
29 StreamPosition { inner: self }
30 }
31
32 fn stream_len(&mut self) -> StreamLen<'_, Self> {
34 StreamLen { inner: self }
35 }
36}
37
38pub struct Skip<'a, T: ?Sized> {
40 amount: u64,
41 inner: &'a mut T,
42}
43
44pub struct StreamPosition<'a, T: ?Sized> {
46 inner: &'a mut T,
47}
48
49pub struct StreamLen<'a, T: ?Sized> {
51 inner: &'a mut T,
52}
53
54impl<T: AsyncSkip + ?Sized> AsyncSkipExt for T {}
59
60impl<T: AsyncSkip + Unpin + ?Sized> Future for Skip<'_, T> {
65 type Output = io::Result<()>;
66
67 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68 let amount = self.amount;
69 Pin::new(&mut *self.inner).poll_skip(cx, amount)
70 }
71}
72
73impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamPosition<'_, T> {
78 type Output = io::Result<u64>;
79
80 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81 Pin::new(&mut *self.inner).poll_stream_position(cx)
82 }
83}
84
85impl<T: AsyncSkip + Unpin + ?Sized> Future for StreamLen<'_, T> {
90 type Output = io::Result<u64>;
91
92 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
93 Pin::new(&mut *self.inner).poll_stream_len(cx)
94 }
95}
96
97impl<T: AsyncRead + Unpin + ?Sized> AsyncRead for SeekSkipAdapter<T> {
102 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
103 Pin::new(&mut self.0).poll_read(cx, buf)
104 }
105}
106
107impl<R: AsyncSeek + Unpin + ?Sized> AsyncSkip for SeekSkipAdapter<R> {
108 fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
109 match amount.try_into() {
110 Ok(0) => (),
111 Ok(amount) => {
112 let reader = Pin::new(&mut self.get_mut().0);
113 ready!(reader.poll_seek(cx, io::SeekFrom::Current(amount)))?;
114 }
115 Err(_) => {
116 let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
117 let seek_pos = stream_pos
118 .checked_add(amount)
119 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "seek past u64::MAX"))?;
120 let reader = Pin::new(&mut self.get_mut().0);
121 ready!(reader.poll_seek(cx, io::SeekFrom::Start(seek_pos)))?;
122 }
123 }
124 Ok(()).into()
125 }
126
127 fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
128 let reader = Pin::new(&mut self.get_mut().0);
129 reader.poll_seek(cx, io::SeekFrom::Current(0))
130 }
131
132 fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
133 let stream_pos = ready!(self.as_mut().poll_stream_position(cx))?;
135 let mut reader = Pin::new(&mut self.get_mut().0);
136 let len = ready!(reader.as_mut().poll_seek(cx, io::SeekFrom::End(0)))?;
137
138 if stream_pos != len {
139 ready!(reader.poll_seek(cx, io::SeekFrom::Start(stream_pos)))?;
140 }
141
142 Ok(len).into()
143 }
144}
145
146macro_rules! deref_async_skip {
151 () => {
152 fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
153 Pin::new(&mut **self).poll_skip(cx, amount)
154 }
155
156 fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
157 Pin::new(&mut **self).poll_stream_position(cx)
158 }
159
160 fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
161 Pin::new(&mut **self).poll_stream_len(cx)
162 }
163 };
164}
165
166impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for &mut R {
167 deref_async_skip!();
168}
169
170impl<R: AsyncSkip + Unpin + ?Sized> AsyncSkip for Box<R> {
171 deref_async_skip!();
172}
173
174impl<P: DerefMut + Unpin> AsyncSkip for Pin<P>
175where
176 P::Target: AsyncSkip,
177{
178 fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
179 self.get_mut().as_mut().poll_skip(cx, amount)
180 }
181
182 fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
183 self.get_mut().as_mut().poll_stream_position(cx)
184 }
185
186 fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
187 self.get_mut().as_mut().poll_stream_len(cx)
188 }
189}
190
191macro_rules! async_skip_via_adapter {
192 () => {
193 fn poll_skip(self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
194 Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_skip(cx, amount)
195 }
196
197 fn poll_stream_position(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
198 Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_position(cx)
199 }
200
201 fn poll_stream_len(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
202 Pin::new(&mut SeekSkipAdapter(self.get_mut())).poll_stream_len(cx)
203 }
204 };
205}
206
207impl<T: AsRef<[u8]> + Unpin> AsyncSkip for Cursor<T> {
208 async_skip_via_adapter!();
209}
210
211impl<R: AsyncRead + AsyncSkip> AsyncSkip for BufReader<R> {
212 fn poll_skip(mut self: Pin<&mut Self>, cx: &mut Context<'_>, amount: u64) -> Poll<io::Result<()>> {
214 let buf_len = self.buffer().len();
215 if let Some(skip_amount) = amount.checked_sub(buf_len as u64) {
216 if skip_amount != 0 {
217 ready!(self.as_mut().get_pin_mut().poll_skip(cx, skip_amount))?
218 }
219 }
220 self.consume(buf_len.min(amount as usize));
221 Poll::Ready(Ok(()))
222 }
223
224 fn poll_stream_position(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
226 let stream_pos = ready!(self.as_mut().get_pin_mut().poll_stream_position(cx))?;
227 Poll::Ready(Ok(stream_pos.saturating_sub(self.buffer().len() as u64)))
228 }
229
230 fn poll_stream_len(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
232 self.as_mut().get_pin_mut().poll_stream_len(cx)
233 }
234}