async_read_util/
lib.rs

1use std::{
2    alloc::{alloc, Layout},
3    io,
4    pin::Pin,
5    ptr,
6    ptr::slice_from_raw_parts_mut,
7    task::Poll,
8};
9
10use futures::{io::BufReader, AsyncRead};
11use pin_project::pin_project;
12
13/// Convenience trait to apply the utility functions to types implementing
14/// [`futures::AsyncRead`].
15pub trait AsyncReadUtil: AsyncRead + Sized {
16    /// Observe the bytes being read from `self` using the provided closure.
17    /// Refer to [`crate::ObservedReader`] for more info.
18    fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F>;
19    /// Map the bytes being read from `self` into a new buffer using the
20    /// provided closure. Refer to [`crate::MappedReader`] for more
21    /// info.
22    fn map_read<F>(self, f: F) -> MappedReader<Self, F>
23    where
24        F: FnMut(&[u8], &mut [u8]) -> (usize, usize);
25    fn buffered(self) -> BufReader<Self>;
26}
27
28impl<R: AsyncRead + Sized> AsyncReadUtil for R {
29    fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F> {
30        ObservedReader::new(self, f)
31    }
32
33    fn map_read<F>(self, f: F) -> MappedReader<Self, F>
34    where
35        F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
36    {
37        MappedReader::new(self, f)
38    }
39
40    fn buffered(self) -> BufReader<Self> {
41        BufReader::new(self)
42    }
43}
44
45/// An async reader which allows a closure to observe the bytes being read as
46/// they are ready. This has use cases such as hashing the output of a reader
47/// without interfering with the actual content.
48#[pin_project]
49pub struct ObservedReader<R, F> {
50    #[pin]
51    inner: R,
52    f: F,
53}
54
55impl<R, F> ObservedReader<R, F>
56where
57    R: AsyncRead,
58    F: FnMut(&[u8]),
59{
60    pub fn new(inner: R, f: F) -> Self {
61        Self { inner, f }
62    }
63}
64
65impl<R, F> AsyncRead for ObservedReader<R, F>
66where
67    R: AsyncRead,
68    F: FnMut(&[u8]),
69{
70    fn poll_read(
71        mut self: Pin<&mut Self>,
72        cx: &mut std::task::Context<'_>,
73        buf: &mut [u8],
74    ) -> Poll<io::Result<usize>> {
75        let this = self.as_mut().project();
76        let num_read = futures::ready!(this.inner.poll_read(cx, buf))?;
77        (this.f)(&buf[0..num_read]);
78        Poll::Ready(Ok(num_read))
79    }
80}
81
82/// An async reader which allows a closure to map the output of the inner async
83/// reader into a new buffer. This allows things like compression/encryption to
84/// be layered on top of a normal reader.
85///
86/// NOTE: The closure must consume at least 1 byte for the reader to continue.
87///
88/// SAFETY: This currently creates the equivalent of `Box<[MaybeUninit<u8>]>`,
89/// but does so through use of accessing the allocator directly. Once new_uninit
90/// is available on stable, `Box::new_uninit_slice` will be used. This will
91/// still utilize unsafe. A uninitialized buffer is acceptable because it's
92/// contents are only ever written to before reading only the written section.
93///
94/// ref: https://github.com/rust-lang/rust/issues/63291
95#[pin_project]
96pub struct MappedReader<R, F> {
97    #[pin]
98    inner: R,
99    f: F,
100    buf: Box<[u8]>,
101    pos: usize,
102    cap: usize,
103    done: bool,
104}
105
106impl<R, F> MappedReader<R, F>
107where
108    R: AsyncRead,
109    F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
110{
111    pub fn new(inner: R, f: F) -> Self {
112        Self::with_capacity(8096, inner, f)
113    }
114
115    pub fn with_capacity(capacity: usize, inner: R, f: F) -> Self {
116        let buf = unsafe { uninit_buf(capacity) };
117        Self {
118            inner,
119            f,
120            buf,
121            pos: 0,
122            cap: 0,
123            done: false,
124        }
125    }
126}
127
128impl<R, F> AsyncRead for MappedReader<R, F>
129where
130    R: AsyncRead,
131    F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
132{
133    fn poll_read(
134        mut self: Pin<&mut Self>,
135        cx: &mut std::task::Context<'_>,
136        buf: &mut [u8],
137    ) -> Poll<io::Result<usize>> {
138        let this = self.as_mut().project();
139        if *this.pos == *this.cap {
140            *this.pos = 0;
141            *this.cap = 0;
142        }
143        if !*this.done && *this.cap < this.buf.len() {
144            let nread = futures::ready!(this.inner.poll_read(cx, &mut this.buf[*this.cap..]))?;
145            *this.cap += nread;
146            if nread == 0 {
147                *this.done = true;
148            }
149        }
150        let unprocessed = &this.buf[*this.pos..*this.cap];
151        let (nsrc, ndst) = (this.f)(&this.buf[*this.pos..*this.cap], buf);
152        assert!(
153            ndst <= buf.len(),
154            "mapped reader is reportedly reading more than the destination buffer's capacity"
155        );
156        // Nothing has been consumed and there are unprocessed bytes.
157        if nsrc == 0 && !unprocessed.is_empty() {
158            assert!(unprocessed.len() < this.buf.len());
159            let count = unprocessed.len();
160            // SAFETY: This utilizes `ptr::copy` which per the documentation is
161            // safe to use for overlapping areas. The only invariants we have to
162            // keep track of are:
163            // - `src` is valid data
164            // - `dst` is owned and capable of containing `count` bytes
165            // `src` points to be beginning of the `unprocessed` data and is valid.
166            // `dst` is merely a `*mut` to start of `self.buf`, so it is owned.
167            // `count` must be less than `self.buf.len()` because `unprocessed`
168            // is fully contained within `self.buf`.
169            unsafe {
170                ptr::copy(unprocessed.as_ptr(), this.buf.as_mut().as_mut_ptr(), count);
171            }
172        }
173        *this.pos += nsrc;
174        Poll::Ready(Ok(ndst))
175    }
176}
177
178unsafe fn uninit_buf(size: usize) -> Box<[u8]> {
179    let layout = Layout::array::<u8>(size).unwrap();
180    let ptr = slice_from_raw_parts_mut(alloc(layout), size);
181    Box::from_raw(ptr)
182}