1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use std::{
    alloc::{alloc, Layout},
    io,
    pin::Pin,
    ptr,
    ptr::slice_from_raw_parts_mut,
    task::Poll,
};

use futures::{io::BufReader, AsyncRead};
use pin_project::pin_project;

/// Convenience trait to apply the utility functions to types implementing
/// [`futures::AsyncRead`].
pub trait AsyncReadUtil: AsyncRead + Sized {
    /// Observe the bytes being read from `self` using the provided closure.
    /// Refer to [`crate::ObservedReader`] for more info.
    fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F>;
    /// Map the bytes being read from `self` into a new buffer using the
    /// provided closure. Refer to [`crate::MappedReader`] for more
    /// info.
    fn map_read<F>(self, f: F) -> MappedReader<Self, F>
    where
        F: FnMut(&[u8], &mut [u8]) -> (usize, usize);
    fn buffered(self) -> BufReader<Self>;
}

impl<R: AsyncRead + Sized> AsyncReadUtil for R {
    fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F> {
        ObservedReader::new(self, f)
    }

    fn map_read<F>(self, f: F) -> MappedReader<Self, F>
    where
        F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
    {
        MappedReader::new(self, f)
    }

    fn buffered(self) -> BufReader<Self> {
        BufReader::new(self)
    }
}

/// An async reader which allows a closure to observe the bytes being read as
/// they are ready. This has use cases such as hashing the output of a reader
/// without interfering with the actual content.
#[pin_project]
pub struct ObservedReader<R, F> {
    #[pin]
    inner: R,
    f: F,
}

impl<R, F> ObservedReader<R, F>
where
    R: AsyncRead,
    F: FnMut(&[u8]),
{
    pub fn new(inner: R, f: F) -> Self {
        Self { inner, f }
    }
}

impl<R, F> AsyncRead for ObservedReader<R, F>
where
    R: AsyncRead,
    F: FnMut(&[u8]),
{
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.as_mut().project();
        let num_read = futures::ready!(this.inner.poll_read(cx, buf))?;
        (this.f)(&buf[0..num_read]);
        Poll::Ready(Ok(num_read))
    }
}

/// An async reader which allows a closure to map the output of the inner async
/// reader into a new buffer. This allows things like compression/encryption to
/// be layered on top of a normal reader.
///
/// NOTE: The closure must consume at least 1 byte for the reader to continue.
///
/// SAFETY: This currently creates the equivalent of `Box<[MaybeUninit<u8>]>`,
/// but does so through use of accessing the allocator directly. Once new_uninit
/// is available on stable, `Box::new_uninit_slice` will be used. This will
/// still utilize unsafe. A uninitialized buffer is acceptable because it's
/// contents are only ever written to before reading only the written section.
///
/// ref: https://github.com/rust-lang/rust/issues/63291
#[pin_project]
pub struct MappedReader<R, F> {
    #[pin]
    inner: R,
    f: F,
    buf: Box<[u8]>,
    pos: usize,
    cap: usize,
    done: bool,
}

impl<R, F> MappedReader<R, F>
where
    R: AsyncRead,
    F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
{
    pub fn new(inner: R, f: F) -> Self {
        Self::with_capacity(8096, inner, f)
    }

    pub fn with_capacity(capacity: usize, inner: R, f: F) -> Self {
        let buf = unsafe { uninit_buf(capacity) };
        Self {
            inner,
            f,
            buf,
            pos: 0,
            cap: 0,
            done: false,
        }
    }
}

impl<R, F> AsyncRead for MappedReader<R, F>
where
    R: AsyncRead,
    F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
{
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.as_mut().project();
        if *this.pos == *this.cap {
            *this.pos = 0;
            *this.cap = 0;
        }
        if !*this.done && *this.cap < this.buf.len() {
            let nread = futures::ready!(this.inner.poll_read(cx, &mut this.buf[*this.cap..]))?;
            *this.cap += nread;
            if nread == 0 {
                *this.done = true;
            }
        }
        let unprocessed = &this.buf[*this.pos..*this.cap];
        let (nsrc, ndst) = (this.f)(&this.buf[*this.pos..*this.cap], buf);
        assert!(
            ndst <= buf.len(),
            "mapped reader is reportedly reading more than the destination buffer's capacity"
        );
        // Nothing has been consumed and there are unprocessed bytes.
        if nsrc == 0 && !unprocessed.is_empty() {
            assert!(unprocessed.len() < this.buf.len());
            let count = unprocessed.len();
            // SAFETY: This utilizes `ptr::copy` which per the documentation is
            // safe to use for overlapping areas. The only invariants we have to
            // keep track of are:
            // - `src` is valid data
            // - `dst` is owned and capable of containing `count` bytes
            // `src` points to be beginning of the `unprocessed` data and is valid.
            // `dst` is merely a `*mut` to start of `self.buf`, so it is owned.
            // `count` must be less than `self.buf.len()` because `unprocessed`
            // is fully contained within `self.buf`.
            unsafe {
                ptr::copy(unprocessed.as_ptr(), this.buf.as_mut().as_mut_ptr(), count);
            }
        }
        *this.pos += nsrc;
        Poll::Ready(Ok(ndst))
    }
}

unsafe fn uninit_buf(size: usize) -> Box<[u8]> {
    let layout = Layout::array::<u8>(size).unwrap();
    let ptr = slice_from_raw_parts_mut(alloc(layout), size);
    Box::from_raw(ptr)
}