1use std::{
2 alloc::{alloc, Layout},
3 io,
4 pin::Pin,
5 ptr,
6 ptr::slice_from_raw_parts_mut,
7 task::Poll,
8};
9
10use futures::{io::BufReader, AsyncRead};
11use pin_project::pin_project;
12
13pub trait AsyncReadUtil: AsyncRead + Sized {
16 fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F>;
19 fn map_read<F>(self, f: F) -> MappedReader<Self, F>
23 where
24 F: FnMut(&[u8], &mut [u8]) -> (usize, usize);
25 fn buffered(self) -> BufReader<Self>;
26}
27
28impl<R: AsyncRead + Sized> AsyncReadUtil for R {
29 fn observe<F: FnMut(&[u8])>(self, f: F) -> ObservedReader<Self, F> {
30 ObservedReader::new(self, f)
31 }
32
33 fn map_read<F>(self, f: F) -> MappedReader<Self, F>
34 where
35 F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
36 {
37 MappedReader::new(self, f)
38 }
39
40 fn buffered(self) -> BufReader<Self> {
41 BufReader::new(self)
42 }
43}
44
45#[pin_project]
49pub struct ObservedReader<R, F> {
50 #[pin]
51 inner: R,
52 f: F,
53}
54
55impl<R, F> ObservedReader<R, F>
56where
57 R: AsyncRead,
58 F: FnMut(&[u8]),
59{
60 pub fn new(inner: R, f: F) -> Self {
61 Self { inner, f }
62 }
63}
64
65impl<R, F> AsyncRead for ObservedReader<R, F>
66where
67 R: AsyncRead,
68 F: FnMut(&[u8]),
69{
70 fn poll_read(
71 mut self: Pin<&mut Self>,
72 cx: &mut std::task::Context<'_>,
73 buf: &mut [u8],
74 ) -> Poll<io::Result<usize>> {
75 let this = self.as_mut().project();
76 let num_read = futures::ready!(this.inner.poll_read(cx, buf))?;
77 (this.f)(&buf[0..num_read]);
78 Poll::Ready(Ok(num_read))
79 }
80}
81
82#[pin_project]
96pub struct MappedReader<R, F> {
97 #[pin]
98 inner: R,
99 f: F,
100 buf: Box<[u8]>,
101 pos: usize,
102 cap: usize,
103 done: bool,
104}
105
106impl<R, F> MappedReader<R, F>
107where
108 R: AsyncRead,
109 F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
110{
111 pub fn new(inner: R, f: F) -> Self {
112 Self::with_capacity(8096, inner, f)
113 }
114
115 pub fn with_capacity(capacity: usize, inner: R, f: F) -> Self {
116 let buf = unsafe { uninit_buf(capacity) };
117 Self {
118 inner,
119 f,
120 buf,
121 pos: 0,
122 cap: 0,
123 done: false,
124 }
125 }
126}
127
128impl<R, F> AsyncRead for MappedReader<R, F>
129where
130 R: AsyncRead,
131 F: FnMut(&[u8], &mut [u8]) -> (usize, usize),
132{
133 fn poll_read(
134 mut self: Pin<&mut Self>,
135 cx: &mut std::task::Context<'_>,
136 buf: &mut [u8],
137 ) -> Poll<io::Result<usize>> {
138 let this = self.as_mut().project();
139 if *this.pos == *this.cap {
140 *this.pos = 0;
141 *this.cap = 0;
142 }
143 if !*this.done && *this.cap < this.buf.len() {
144 let nread = futures::ready!(this.inner.poll_read(cx, &mut this.buf[*this.cap..]))?;
145 *this.cap += nread;
146 if nread == 0 {
147 *this.done = true;
148 }
149 }
150 let unprocessed = &this.buf[*this.pos..*this.cap];
151 let (nsrc, ndst) = (this.f)(&this.buf[*this.pos..*this.cap], buf);
152 assert!(
153 ndst <= buf.len(),
154 "mapped reader is reportedly reading more than the destination buffer's capacity"
155 );
156 if nsrc == 0 && !unprocessed.is_empty() {
158 assert!(unprocessed.len() < this.buf.len());
159 let count = unprocessed.len();
160 unsafe {
170 ptr::copy(unprocessed.as_ptr(), this.buf.as_mut().as_mut_ptr(), count);
171 }
172 }
173 *this.pos += nsrc;
174 Poll::Ready(Ok(ndst))
175 }
176}
177
178unsafe fn uninit_buf(size: usize) -> Box<[u8]> {
179 let layout = Layout::array::<u8>(size).unwrap();
180 let ptr = slice_from_raw_parts_mut(alloc(layout), size);
181 Box::from_raw(ptr)
182}