compio_io/read/
ext.rs

1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3use std::{io, io::ErrorKind};
4
5use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, Uninit, t_alloc};
6
7use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
8
9/// Shared code for read a scalar value from the underlying reader.
10macro_rules! read_scalar {
11    ($t:ty, $be:ident, $le:ident) => {
12        ::paste::paste! {
13            #[doc = concat!("Read a big endian `", stringify!($t), "` from the underlying reader.")]
14            async fn [< read_ $t >](&mut self) -> IoResult<$t> {
15                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
16
17                const LEN: usize = ::std::mem::size_of::<$t>();
18                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
19                res?;
20                // Safety: We just checked that the buffer is the correct size
21                Ok($t::$be(unsafe { buf.into_inner_unchecked() }))
22            }
23
24            #[doc = concat!("Read a little endian `", stringify!($t), "` from the underlying reader.")]
25            async fn [< read_ $t _le >](&mut self) -> IoResult<$t> {
26                use ::compio_buf::{arrayvec::ArrayVec, BufResult};
27
28                const LEN: usize = ::std::mem::size_of::<$t>();
29                let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
30                res?;
31                // Safety: We just checked that the buffer is the correct size
32                Ok($t::$le(unsafe { buf.into_inner_unchecked() }))
33            }
34        }
35    };
36}
37
38/// Shared code for loop reading until reaching a certain length.
39macro_rules! loop_read_exact {
40    ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
41        let mut $tracker = 0;
42        let len = $len;
43
44        while $tracker < len {
45            match $read_expr.await.into_inner() {
46                BufResult(Ok(0), buf) => {
47                    return BufResult(
48                        Err(::std::io::Error::new(
49                            ::std::io::ErrorKind::UnexpectedEof,
50                            "failed to fill whole buffer",
51                        )),
52                        buf,
53                    );
54                }
55                BufResult(Ok(n), buf) => {
56                    $tracker += n;
57                    $buf = buf;
58                }
59                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
60                    $buf = buf;
61                }
62                BufResult(Err(e), buf) => return BufResult(Err(e), buf),
63            }
64        }
65        return BufResult(Ok(()), $buf)
66    };
67}
68
69macro_rules! loop_read_vectored {
70    ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
71        use ::compio_buf::OwnedIterator;
72
73        let mut $iter = match $buf.owned_iter() {
74            Ok(buf) => buf,
75            Err(buf) => return BufResult(Ok(()), buf),
76        };
77        let mut $tracker: $tracker_ty = 0;
78
79        loop {
80            let len = $iter.buf_capacity();
81            if len > 0 {
82                match $read_expr.await {
83                    BufResult(Ok(()), ret) => {
84                        $iter = ret;
85                        $tracker += len as $tracker_ty;
86                    }
87                    BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
88                };
89            }
90
91            match $iter.next() {
92                Ok(next) => $iter = next,
93                Err(buf) => return BufResult(Ok(()), buf),
94            }
95        }
96    }};
97    ($buf:ident, $iter:ident, $read_expr:expr) => {{
98        use ::compio_buf::OwnedIterator;
99
100        let mut $iter = match $buf.owned_iter() {
101            Ok(buf) => buf,
102            Err(buf) => return BufResult(Ok(0), buf),
103        };
104
105        loop {
106            let len = $iter.buf_capacity();
107            if len > 0 {
108                return $read_expr.await.into_inner();
109            }
110
111            match $iter.next() {
112                Ok(next) => $iter = next,
113                Err(buf) => return BufResult(Ok(0), buf),
114            }
115        }
116    }};
117}
118
119macro_rules! loop_read_to_end {
120    ($buf:ident, $tracker:ident : $tracker_ty:ty,loop $read_expr:expr) => {{
121        let mut $tracker: $tracker_ty = 0;
122        loop {
123            if $buf.len() == $buf.capacity() {
124                $buf.reserve(32);
125            }
126            match $read_expr.await.into_inner() {
127                BufResult(Ok(0), buf) => {
128                    $buf = buf;
129                    break;
130                }
131                BufResult(Ok(read), buf) => {
132                    $tracker += read as $tracker_ty;
133                    $buf = buf;
134                }
135                BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
136                    $buf = buf
137                }
138                res => return res,
139            }
140        }
141        BufResult(Ok($tracker as usize), $buf)
142    }};
143}
144
145#[inline]
146fn after_read_to_string(res: io::Result<usize>, buf: Vec<u8>) -> BufResult<usize, String> {
147    match res {
148        Err(err) => {
149            // we have to clear the read bytes if it is not valid utf8 bytes
150            let buf = String::from_utf8(buf).unwrap_or_else(|err| {
151                let mut buf = err.into_bytes();
152                buf.clear();
153
154                // Safety: the buffer is empty
155                unsafe { String::from_utf8_unchecked(buf) }
156            });
157
158            BufResult(Err(err), buf)
159        }
160        Ok(n) => match String::from_utf8(buf) {
161            Err(err) => BufResult(
162                Err(std::io::Error::new(ErrorKind::InvalidData, err)),
163                String::new(),
164            ),
165            Ok(data) => BufResult(Ok(n), data),
166        },
167    }
168}
169
170/// Implemented as an extension trait, adding utility methods to all
171/// [`AsyncRead`] types. Callers will tend to import this trait instead of
172/// [`AsyncRead`].
173pub trait AsyncReadExt: AsyncRead {
174    /// Creates a "by reference" adaptor for this instance of [`AsyncRead`].
175    ///
176    /// The returned adapter also implements [`AsyncRead`] and will simply
177    /// borrow this current reader.
178    fn by_ref(&mut self) -> &mut Self
179    where
180        Self: Sized,
181    {
182        self
183    }
184
185    /// Same as [`AsyncRead::read`], but it appends data to the end of the
186    /// buffer; in other words, it read to the beginning of the uninitialized
187    /// area.
188    async fn append<T: IoBufMut>(&mut self, buf: T) -> BufResult<usize, T> {
189        self.read(buf.uninit()).await.map_buffer(Uninit::into_inner)
190    }
191
192    /// Read the exact number of bytes required to fill the buf.
193    async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
194        loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
195    }
196
197    /// Read all bytes as [`String`] until underlying reader reaches `EOF`.
198    async fn read_to_string(&mut self, buf: String) -> BufResult<usize, String> {
199        let BufResult(res, buf) = self.read_to_end(buf.into_bytes()).await;
200        after_read_to_string(res, buf)
201    }
202
203    /// Read all bytes until underlying reader reaches `EOF`.
204    async fn read_to_end<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
205        &mut self,
206        mut buf: t_alloc!(Vec, u8, A),
207    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
208        loop_read_to_end!(buf, total: usize, loop self.read(buf.slice(total..)))
209    }
210
211    /// Read the exact number of bytes required to fill the vectored buf.
212    async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
213        loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
214    }
215
216    /// Creates an adaptor which reads at most `limit` bytes from it.
217    ///
218    /// This function returns a new instance of `AsyncRead` which will read
219    /// at most `limit` bytes, after which it will always return EOF
220    /// (`Ok(0)`). Any read errors will not count towards the number of
221    /// bytes read and future calls to [`read()`] may succeed.
222    ///
223    /// [`read()`]: AsyncRead::read
224    fn take(self, limit: u64) -> Take<Self>
225    where
226        Self: Sized,
227    {
228        Take::new(self, limit)
229    }
230
231    read_scalar!(u8, from_be_bytes, from_le_bytes);
232    read_scalar!(u16, from_be_bytes, from_le_bytes);
233    read_scalar!(u32, from_be_bytes, from_le_bytes);
234    read_scalar!(u64, from_be_bytes, from_le_bytes);
235    read_scalar!(u128, from_be_bytes, from_le_bytes);
236    read_scalar!(i8, from_be_bytes, from_le_bytes);
237    read_scalar!(i16, from_be_bytes, from_le_bytes);
238    read_scalar!(i32, from_be_bytes, from_le_bytes);
239    read_scalar!(i64, from_be_bytes, from_le_bytes);
240    read_scalar!(i128, from_be_bytes, from_le_bytes);
241    read_scalar!(f32, from_be_bytes, from_le_bytes);
242    read_scalar!(f64, from_be_bytes, from_le_bytes);
243}
244
245impl<A: AsyncRead + ?Sized> AsyncReadExt for A {}
246
247/// Implemented as an extension trait, adding utility methods to all
248/// [`AsyncReadAt`] types. Callers will tend to import this trait instead of
249/// [`AsyncReadAt`].
250pub trait AsyncReadAtExt: AsyncReadAt {
251    /// Read the exact number of bytes required to fill `buffer`.
252    ///
253    /// This function reads as many bytes as necessary to completely fill the
254    /// uninitialized space of specified `buffer`.
255    ///
256    /// # Errors
257    ///
258    /// If this function encounters an "end of file" before completely filling
259    /// the buffer, it returns an error of the kind
260    /// [`ErrorKind::UnexpectedEof`]. The contents of `buffer` are unspecified
261    /// in this case.
262    ///
263    /// If any other read error is encountered then this function immediately
264    /// returns. The contents of `buffer` are unspecified in this case.
265    ///
266    /// If this function returns an error, it is unspecified how many bytes it
267    /// has read, but it will never read more than would be necessary to
268    /// completely fill the buffer.
269    ///
270    /// [`ErrorKind::UnexpectedEof`]: std::io::ErrorKind::UnexpectedEof
271    async fn read_exact_at<T: IoBufMut>(&self, mut buf: T, pos: u64) -> BufResult<(), T> {
272        loop_read_exact!(
273            buf,
274            buf.buf_capacity(),
275            read,
276            loop self.read_at(buf.slice(read..), pos + read as u64)
277        );
278    }
279
280    /// Read all bytes as [`String`] until EOF in this source, placing them into
281    /// `buffer`.
282    async fn read_to_string_at(&mut self, buf: String, pos: u64) -> BufResult<usize, String> {
283        let BufResult(res, buf) = self.read_to_end_at(buf.into_bytes(), pos).await;
284        after_read_to_string(res, buf)
285    }
286
287    /// Read all bytes until EOF in this source, placing them into `buffer`.
288    ///
289    /// All bytes read from this source will be appended to the specified buffer
290    /// `buffer`. This function will continuously call [`read_at()`] to append
291    /// more data to `buffer` until [`read_at()`] returns [`Ok(0)`].
292    ///
293    /// If successful, this function will return the total number of bytes read.
294    ///
295    /// [`read_at()`]: AsyncReadAt::read_at
296    async fn read_to_end_at<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
297        &self,
298        mut buffer: t_alloc!(Vec, u8, A),
299        pos: u64,
300    ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
301        loop_read_to_end!(buffer, total: u64, loop self.read_at(buffer.slice(total as usize..), pos + total))
302    }
303
304    /// Like [`AsyncReadExt::read_vectored_exact`], expect that it reads at a
305    /// specified position.
306    async fn read_vectored_exact_at<T: IoVectoredBufMut>(
307        &self,
308        buf: T,
309        pos: u64,
310    ) -> BufResult<(), T> {
311        loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
312    }
313}
314
315impl<A: AsyncReadAt + ?Sized> AsyncReadAtExt for A {}