1#[cfg(feature = "allocator_api")]
2use std::alloc::Allocator;
3use std::{io, io::ErrorKind};
4
5use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut, Uninit, t_alloc};
6
7use crate::{AsyncRead, AsyncReadAt, IoResult, util::Take};
8
9macro_rules! read_scalar {
11 ($t:ty, $be:ident, $le:ident) => {
12 ::paste::paste! {
13 #[doc = concat!("Read a big endian `", stringify!($t), "` from the underlying reader.")]
14 async fn [< read_ $t >](&mut self) -> IoResult<$t> {
15 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
16
17 const LEN: usize = ::std::mem::size_of::<$t>();
18 let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
19 res?;
20 Ok($t::$be(unsafe { buf.into_inner_unchecked() }))
22 }
23
24 #[doc = concat!("Read a little endian `", stringify!($t), "` from the underlying reader.")]
25 async fn [< read_ $t _le >](&mut self) -> IoResult<$t> {
26 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
27
28 const LEN: usize = ::std::mem::size_of::<$t>();
29 let BufResult(res, buf) = self.read_exact(ArrayVec::<u8, LEN>::new()).await;
30 res?;
31 Ok($t::$le(unsafe { buf.into_inner_unchecked() }))
33 }
34 }
35 };
36}
37
38macro_rules! loop_read_exact {
40 ($buf:ident, $len:expr, $tracker:ident,loop $read_expr:expr) => {
41 let mut $tracker = 0;
42 let len = $len;
43
44 while $tracker < len {
45 match $read_expr.await.into_inner() {
46 BufResult(Ok(0), buf) => {
47 return BufResult(
48 Err(::std::io::Error::new(
49 ::std::io::ErrorKind::UnexpectedEof,
50 "failed to fill whole buffer",
51 )),
52 buf,
53 );
54 }
55 BufResult(Ok(n), buf) => {
56 $tracker += n;
57 $buf = buf;
58 }
59 BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
60 $buf = buf;
61 }
62 BufResult(Err(e), buf) => return BufResult(Err(e), buf),
63 }
64 }
65 return BufResult(Ok(()), $buf)
66 };
67}
68
69macro_rules! loop_read_vectored {
70 ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
71 use ::compio_buf::OwnedIterator;
72
73 let mut $iter = match $buf.owned_iter() {
74 Ok(buf) => buf,
75 Err(buf) => return BufResult(Ok(()), buf),
76 };
77 let mut $tracker: $tracker_ty = 0;
78
79 loop {
80 let len = $iter.buf_capacity();
81 if len > 0 {
82 match $read_expr.await {
83 BufResult(Ok(()), ret) => {
84 $iter = ret;
85 $tracker += len as $tracker_ty;
86 }
87 BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
88 };
89 }
90
91 match $iter.next() {
92 Ok(next) => $iter = next,
93 Err(buf) => return BufResult(Ok(()), buf),
94 }
95 }
96 }};
97 ($buf:ident, $iter:ident, $read_expr:expr) => {{
98 use ::compio_buf::OwnedIterator;
99
100 let mut $iter = match $buf.owned_iter() {
101 Ok(buf) => buf,
102 Err(buf) => return BufResult(Ok(0), buf),
103 };
104
105 loop {
106 let len = $iter.buf_capacity();
107 if len > 0 {
108 return $read_expr.await.into_inner();
109 }
110
111 match $iter.next() {
112 Ok(next) => $iter = next,
113 Err(buf) => return BufResult(Ok(0), buf),
114 }
115 }
116 }};
117}
118
119macro_rules! loop_read_to_end {
120 ($buf:ident, $tracker:ident : $tracker_ty:ty,loop $read_expr:expr) => {{
121 let mut $tracker: $tracker_ty = 0;
122 loop {
123 if $buf.len() == $buf.capacity() {
124 $buf.reserve(32);
125 }
126 match $read_expr.await.into_inner() {
127 BufResult(Ok(0), buf) => {
128 $buf = buf;
129 break;
130 }
131 BufResult(Ok(read), buf) => {
132 $tracker += read as $tracker_ty;
133 $buf = buf;
134 }
135 BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
136 $buf = buf
137 }
138 res => return res,
139 }
140 }
141 BufResult(Ok($tracker as usize), $buf)
142 }};
143}
144
145#[inline]
146fn after_read_to_string(res: io::Result<usize>, buf: Vec<u8>) -> BufResult<usize, String> {
147 match res {
148 Err(err) => {
149 let buf = String::from_utf8(buf).unwrap_or_else(|err| {
151 let mut buf = err.into_bytes();
152 buf.clear();
153
154 unsafe { String::from_utf8_unchecked(buf) }
156 });
157
158 BufResult(Err(err), buf)
159 }
160 Ok(n) => match String::from_utf8(buf) {
161 Err(err) => BufResult(
162 Err(std::io::Error::new(ErrorKind::InvalidData, err)),
163 String::new(),
164 ),
165 Ok(data) => BufResult(Ok(n), data),
166 },
167 }
168}
169
170pub trait AsyncReadExt: AsyncRead {
174 fn by_ref(&mut self) -> &mut Self
179 where
180 Self: Sized,
181 {
182 self
183 }
184
185 async fn append<T: IoBufMut>(&mut self, buf: T) -> BufResult<usize, T> {
189 self.read(buf.uninit()).await.map_buffer(Uninit::into_inner)
190 }
191
192 async fn read_exact<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<(), T> {
194 loop_read_exact!(buf, buf.buf_capacity(), read, loop self.read(buf.slice(read..)));
195 }
196
197 async fn read_to_string(&mut self, buf: String) -> BufResult<usize, String> {
199 let BufResult(res, buf) = self.read_to_end(buf.into_bytes()).await;
200 after_read_to_string(res, buf)
201 }
202
203 async fn read_to_end<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
205 &mut self,
206 mut buf: t_alloc!(Vec, u8, A),
207 ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
208 loop_read_to_end!(buf, total: usize, loop self.read(buf.slice(total..)))
209 }
210
211 async fn read_vectored_exact<T: IoVectoredBufMut>(&mut self, buf: T) -> BufResult<(), T> {
213 loop_read_vectored!(buf, _total: usize, iter, loop self.read_exact(iter))
214 }
215
216 fn take(self, limit: u64) -> Take<Self>
225 where
226 Self: Sized,
227 {
228 Take::new(self, limit)
229 }
230
231 read_scalar!(u8, from_be_bytes, from_le_bytes);
232 read_scalar!(u16, from_be_bytes, from_le_bytes);
233 read_scalar!(u32, from_be_bytes, from_le_bytes);
234 read_scalar!(u64, from_be_bytes, from_le_bytes);
235 read_scalar!(u128, from_be_bytes, from_le_bytes);
236 read_scalar!(i8, from_be_bytes, from_le_bytes);
237 read_scalar!(i16, from_be_bytes, from_le_bytes);
238 read_scalar!(i32, from_be_bytes, from_le_bytes);
239 read_scalar!(i64, from_be_bytes, from_le_bytes);
240 read_scalar!(i128, from_be_bytes, from_le_bytes);
241 read_scalar!(f32, from_be_bytes, from_le_bytes);
242 read_scalar!(f64, from_be_bytes, from_le_bytes);
243}
244
245impl<A: AsyncRead + ?Sized> AsyncReadExt for A {}
246
247pub trait AsyncReadAtExt: AsyncReadAt {
251 async fn read_exact_at<T: IoBufMut>(&self, mut buf: T, pos: u64) -> BufResult<(), T> {
272 loop_read_exact!(
273 buf,
274 buf.buf_capacity(),
275 read,
276 loop self.read_at(buf.slice(read..), pos + read as u64)
277 );
278 }
279
280 async fn read_to_string_at(&mut self, buf: String, pos: u64) -> BufResult<usize, String> {
283 let BufResult(res, buf) = self.read_to_end_at(buf.into_bytes(), pos).await;
284 after_read_to_string(res, buf)
285 }
286
287 async fn read_to_end_at<#[cfg(feature = "allocator_api")] A: Allocator + 'static>(
297 &self,
298 mut buffer: t_alloc!(Vec, u8, A),
299 pos: u64,
300 ) -> BufResult<usize, t_alloc!(Vec, u8, A)> {
301 loop_read_to_end!(buffer, total: u64, loop self.read_at(buffer.slice(total as usize..), pos + total))
302 }
303
304 async fn read_vectored_exact_at<T: IoVectoredBufMut>(
307 &self,
308 buf: T,
309 pos: u64,
310 ) -> BufResult<(), T> {
311 loop_read_vectored!(buf, total: u64, iter, loop self.read_exact_at(iter, pos + total))
312 }
313}
314
315impl<A: AsyncReadAt + ?Sized> AsyncReadAtExt for A {}