1use crate::io::{AsyncRead, AsyncReadVectored, Chain, ReadBuf, Take};
4use std::future::Future;
5use std::io::{self, ErrorKind, IoSliceMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pub trait AsyncReadExt: AsyncRead {
11 fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self>
13 where
14 Self: Unpin,
15 {
16 ReadExact {
17 reader: self,
18 buf,
19 pos: 0,
20 }
21 }
22
23 fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self>
25 where
26 Self: Unpin,
27 {
28 let start_len = buf.len();
29 ReadToEnd {
30 reader: self,
31 buf,
32 start_len,
33 }
34 }
35
36 fn read_to_string<'a>(&'a mut self, buf: &'a mut String) -> ReadToString<'a, Self>
38 where
39 Self: Unpin,
40 {
41 let start_len = buf.len();
42 ReadToString {
43 reader: self,
44 buf,
45 pending_utf8: Vec::new(),
46 read: 0,
47 start_len,
48 }
49 }
50
51 fn read_u8(&mut self) -> ReadU8<'_, Self>
53 where
54 Self: Unpin,
55 {
56 ReadU8 { reader: self }
57 }
58
59 fn chain<R: AsyncRead>(self, next: R) -> Chain<Self, R>
61 where
62 Self: Sized,
63 {
64 Chain::new(self, next)
65 }
66
67 fn take(self, limit: u64) -> Take<Self>
69 where
70 Self: Sized,
71 {
72 Take::new(self, limit)
73 }
74}
75
76impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
77
78pub trait AsyncReadVectoredExt: AsyncReadVectored {
80 fn read_vectored<'a>(&'a mut self, bufs: &'a mut [IoSliceMut<'a>]) -> ReadVectored<'a, Self>
82 where
83 Self: Unpin,
84 {
85 ReadVectored { reader: self, bufs }
86 }
87}
88
89impl<R: AsyncReadVectored + ?Sized> AsyncReadVectoredExt for R {}
90
91pub struct ReadVectored<'a, R: ?Sized> {
93 reader: &'a mut R,
94 bufs: &'a mut [IoSliceMut<'a>],
95}
96
97impl<R> Future for ReadVectored<'_, R>
98where
99 R: AsyncReadVectored + Unpin + ?Sized,
100{
101 type Output = io::Result<usize>;
102
103 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104 let this = self.get_mut();
105 Pin::new(&mut *this.reader).poll_read_vectored(cx, this.bufs)
106 }
107}
108
109pub struct ReadExact<'a, R: ?Sized> {
111 reader: &'a mut R,
112 buf: &'a mut [u8],
113 pos: usize,
114}
115
116impl<R> Future for ReadExact<'_, R>
117where
118 R: AsyncRead + Unpin + ?Sized,
119{
120 type Output = io::Result<()>;
121
122 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 let this = self.get_mut();
124
125 while this.pos < this.buf.len() {
126 let mut read_buf = ReadBuf::new(&mut this.buf[this.pos..]);
127 match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
128 Poll::Pending => return Poll::Pending,
129 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
130 Poll::Ready(Ok(())) => {
131 let n = read_buf.filled().len();
132 if n == 0 {
133 return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)));
134 }
135 this.pos += n;
136 }
137 }
138 }
139
140 Poll::Ready(Ok(()))
141 }
142}
143
144pub struct ReadToEnd<'a, R: ?Sized> {
146 reader: &'a mut R,
147 buf: &'a mut Vec<u8>,
148 start_len: usize,
149}
150
151impl<R> Future for ReadToEnd<'_, R>
152where
153 R: AsyncRead + Unpin + ?Sized,
154{
155 type Output = io::Result<usize>;
156
157 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 const CHUNK: usize = 1024;
159 let this = self.get_mut();
160
161 loop {
162 let mut local = [0u8; CHUNK];
163 let mut read_buf = ReadBuf::new(&mut local);
164 match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
165 Poll::Pending => return Poll::Pending,
166 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
167 Poll::Ready(Ok(())) => {
168 let n = read_buf.filled().len();
169 if n == 0 {
170 return Poll::Ready(Ok(this.buf.len().saturating_sub(this.start_len)));
171 }
172 this.buf.extend_from_slice(read_buf.filled());
173 }
174 }
175 }
176 }
177}
178
179pub struct ReadToString<'a, R: ?Sized> {
181 reader: &'a mut R,
182 buf: &'a mut String,
183 pending_utf8: Vec<u8>,
184 read: usize,
185 start_len: usize,
186}
187
188impl<R: ?Sized> ReadToString<'_, R> {
189 fn rollback_utf8_error(&mut self) {
190 self.buf.truncate(self.start_len);
191 self.pending_utf8.clear();
192 }
193
194 fn push_valid_prefix(&mut self) -> io::Result<()> {
195 match std::str::from_utf8(&self.pending_utf8) {
196 Ok(s) => {
197 self.buf.push_str(s);
198 self.pending_utf8.clear();
199 Ok(())
200 }
201 Err(err) => {
202 if err.error_len().is_some() {
203 return Err(io::Error::new(ErrorKind::InvalidData, "invalid utf-8"));
204 }
205
206 let valid_up_to = err.valid_up_to();
207 if valid_up_to == 0 {
208 return Ok(());
209 }
210 let valid = &self.pending_utf8[..valid_up_to];
211 let valid_str = std::str::from_utf8(valid)
212 .map_err(|_| io::Error::new(ErrorKind::InvalidData, "invalid utf-8"))?;
213 self.buf.push_str(valid_str);
214 self.pending_utf8 = self.pending_utf8[valid_up_to..].to_vec();
215 Ok(())
216 }
217 }
218 }
219}
220
221impl<R> Future for ReadToString<'_, R>
222where
223 R: AsyncRead + Unpin + ?Sized,
224{
225 type Output = io::Result<usize>;
226
227 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228 const CHUNK: usize = 1024;
229 let this = self.get_mut();
230
231 loop {
232 let mut local = [0u8; CHUNK];
233 let mut read_buf = ReadBuf::new(&mut local);
234 match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
235 Poll::Pending => return Poll::Pending,
236 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
237 Poll::Ready(Ok(())) => {
238 let n = read_buf.filled().len();
239 if n == 0 {
240 if this.pending_utf8.is_empty() {
241 return Poll::Ready(Ok(this.read));
242 }
243 this.rollback_utf8_error();
244 return Poll::Ready(Err(io::Error::new(
245 ErrorKind::InvalidData,
246 "incomplete utf-8 sequence",
247 )));
248 }
249 this.read += n;
250 this.pending_utf8.extend_from_slice(read_buf.filled());
251 if let Err(err) = this.push_valid_prefix() {
252 this.rollback_utf8_error();
253 return Poll::Ready(Err(err));
254 }
255 }
256 }
257 }
258 }
259}
260
261pub struct ReadU8<'a, R: ?Sized> {
263 reader: &'a mut R,
264}
265
266impl<R> Future for ReadU8<'_, R>
267where
268 R: AsyncRead + Unpin + ?Sized,
269{
270 type Output = io::Result<u8>;
271
272 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273 let this = self.get_mut();
274 let mut one = [0u8; 1];
275 let mut read_buf = ReadBuf::new(&mut one);
276 match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
277 Poll::Pending => Poll::Pending,
278 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
279 Poll::Ready(Ok(())) => {
280 if read_buf.filled().is_empty() {
281 Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)))
282 } else {
283 Poll::Ready(Ok(read_buf.filled()[0]))
284 }
285 }
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use std::io::IoSliceMut;
294 use std::pin::Pin;
295 use std::sync::Arc;
296 use std::task::{Context, Wake, Waker};
297
298 fn init_test(name: &str) {
299 crate::test_utils::init_test_logging();
300 crate::test_phase!(name);
301 }
302
303 struct NoopWaker;
304
305 impl Wake for NoopWaker {
306 fn wake(self: Arc<Self>) {}
307 }
308
309 fn noop_waker() -> Waker {
310 Waker::from(Arc::new(NoopWaker))
311 }
312
313 fn poll_ready<F: Future>(fut: &mut Pin<&mut F>) -> Option<F::Output> {
314 let waker = noop_waker();
315 let mut cx = Context::from_waker(&waker);
316 for _ in 0..32 {
317 if let Poll::Ready(output) = fut.as_mut().poll(&mut cx) {
318 return Some(output);
319 }
320 }
321 None
322 }
323
324 #[test]
325 fn read_exact_ok() {
326 init_test("read_exact_ok");
327 let mut reader: &[u8] = b"abcd";
328 let mut buf = [0u8; 4];
329 let mut fut = reader.read_exact(&mut buf);
330 let mut fut = Pin::new(&mut fut);
331 let result = poll_ready(&mut fut).expect("future did not resolve");
332 crate::assert_with_log!(result.is_ok(), "result ok", true, result.is_ok());
333 crate::assert_with_log!(&buf == b"abcd", "buf", b"abcd", buf);
334 crate::test_complete!("read_exact_ok");
335 }
336
337 #[test]
338 fn read_exact_eof() {
339 init_test("read_exact_eof");
340 let mut reader: &[u8] = b"ab";
341 let mut buf = [0u8; 4];
342 let mut fut = reader.read_exact(&mut buf);
343 let mut fut = Pin::new(&mut fut);
344 let err = poll_ready(&mut fut)
345 .expect("future did not resolve")
346 .unwrap_err();
347 let kind = err.kind();
348 crate::assert_with_log!(
349 kind == io::ErrorKind::UnexpectedEof,
350 "error kind",
351 io::ErrorKind::UnexpectedEof,
352 kind
353 );
354 crate::test_complete!("read_exact_eof");
355 }
356
357 #[test]
358 fn read_to_end_reads_all() {
359 init_test("read_to_end_reads_all");
360 let mut reader: &[u8] = b"hello";
361 let mut buf = Vec::new();
362 let mut fut = reader.read_to_end(&mut buf);
363 let mut fut = Pin::new(&mut fut);
364 let n = poll_ready(&mut fut)
365 .expect("future did not resolve")
366 .unwrap();
367 crate::assert_with_log!(n == 5, "bytes read", 5, n);
368 crate::assert_with_log!(buf == b"hello", "buf", b"hello", buf);
369 crate::test_complete!("read_to_end_reads_all");
370 }
371
372 #[test]
373 fn read_to_string_reads_all() {
374 init_test("read_to_string_reads_all");
375 let mut reader: &[u8] = b"hi";
376 let mut buf = String::new();
377 let mut fut = reader.read_to_string(&mut buf);
378 let mut fut = Pin::new(&mut fut);
379 let n = poll_ready(&mut fut)
380 .expect("future did not resolve")
381 .unwrap();
382 crate::assert_with_log!(n == 2, "bytes read", 2, n);
383 crate::assert_with_log!(buf == "hi", "buf", "hi", buf);
384 crate::test_complete!("read_to_string_reads_all");
385 }
386
387 #[test]
388 fn read_to_string_invalid_utf8_errors() {
389 init_test("read_to_string_invalid_utf8_errors");
390 let mut reader: &[u8] = &[0xff, 0xfe];
391 let mut buf = String::new();
392 let mut fut = reader.read_to_string(&mut buf);
393 let mut fut = Pin::new(&mut fut);
394 let err = poll_ready(&mut fut)
395 .expect("future did not resolve")
396 .unwrap_err();
397 let kind = err.kind();
398 crate::assert_with_log!(
399 kind == io::ErrorKind::InvalidData,
400 "error kind",
401 io::ErrorKind::InvalidData,
402 kind
403 );
404 let empty = buf.is_empty();
405 crate::assert_with_log!(empty, "buf empty", true, empty);
406 crate::test_complete!("read_to_string_invalid_utf8_errors");
407 }
408
409 #[test]
410 fn read_to_string_incomplete_utf8_errors() {
411 init_test("read_to_string_incomplete_utf8_errors");
412 let mut reader: &[u8] = &[0xF0, 0x9F, 0x92];
414 let mut buf = String::new();
415 let mut fut = reader.read_to_string(&mut buf);
416 let mut fut = Pin::new(&mut fut);
417 let err = poll_ready(&mut fut)
418 .expect("future did not resolve")
419 .unwrap_err();
420 let kind = err.kind();
421 crate::assert_with_log!(
422 kind == io::ErrorKind::InvalidData,
423 "error kind",
424 io::ErrorKind::InvalidData,
425 kind
426 );
427 let empty = buf.is_empty();
428 crate::assert_with_log!(empty, "buf empty", true, empty);
429 crate::test_complete!("read_to_string_incomplete_utf8_errors");
430 }
431
432 #[test]
433 fn read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix() {
434 init_test("read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix");
435 let mut input = vec![b'a'; 1024];
436 input.push(0xFF);
437 let mut reader: &[u8] = &input;
438 let mut buf = String::from("seed");
439 let mut fut = reader.read_to_string(&mut buf);
440 let mut fut = Pin::new(&mut fut);
441 let err = poll_ready(&mut fut)
442 .expect("future did not resolve")
443 .unwrap_err();
444 let kind = err.kind();
445 crate::assert_with_log!(
446 kind == io::ErrorKind::InvalidData,
447 "error kind",
448 io::ErrorKind::InvalidData,
449 kind
450 );
451 crate::assert_with_log!(buf == "seed", "buf rollback", "seed", buf);
452 crate::test_complete!("read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix");
453 }
454
455 #[test]
456 fn read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix() {
457 init_test("read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix");
458 let mut input = vec![b'a'; 1024];
459 input.extend_from_slice(&[0xF0, 0x9F, 0x92]);
460 let mut reader: &[u8] = &input;
461 let mut buf = String::from("seed");
462 let mut fut = reader.read_to_string(&mut buf);
463 let mut fut = Pin::new(&mut fut);
464 let err = poll_ready(&mut fut)
465 .expect("future did not resolve")
466 .unwrap_err();
467 let kind = err.kind();
468 crate::assert_with_log!(
469 kind == io::ErrorKind::InvalidData,
470 "error kind",
471 io::ErrorKind::InvalidData,
472 kind
473 );
474 crate::assert_with_log!(buf == "seed", "buf rollback", "seed", buf);
475 crate::test_complete!("read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix");
476 }
477
478 #[test]
479 fn read_u8_reads_byte() {
480 init_test("read_u8_reads_byte");
481 let mut reader: &[u8] = b"z";
482 let mut fut = reader.read_u8();
483 let mut fut = Pin::new(&mut fut);
484 let byte = poll_ready(&mut fut)
485 .expect("future did not resolve")
486 .unwrap();
487 crate::assert_with_log!(byte == b'z', "byte", b'z', byte);
488 crate::test_complete!("read_u8_reads_byte");
489 }
490
491 #[test]
492 fn read_vectored_reads_prefix() {
493 init_test("read_vectored_reads_prefix");
494 let mut reader: &[u8] = b"hello";
495 let mut a = [0u8; 2];
496 let mut b = [0u8; 3];
497 let mut bufs = [IoSliceMut::new(&mut a), IoSliceMut::new(&mut b)];
498
499 let mut fut = reader.read_vectored(&mut bufs);
500 let mut fut = Pin::new(&mut fut);
501 let n = poll_ready(&mut fut)
502 .expect("future did not resolve")
503 .expect("read_vectored failed");
504
505 let mut got = Vec::new();
506 let first = n.min(a.len());
507 got.extend_from_slice(&a[..first]);
508 if n > a.len() {
509 got.extend_from_slice(&b[..n - a.len()]);
510 }
511
512 let expected = b"hello";
513 crate::assert_with_log!(got == expected[..n], "vectored prefix", &expected[..n], got);
514 crate::test_complete!("read_vectored_reads_prefix");
515 }
516
517 #[derive(Debug)]
518 struct YieldingReader<'a> {
519 data: &'a [u8],
520 pos: usize,
521 yield_next: bool,
522 }
523
524 impl<'a> YieldingReader<'a> {
525 fn new(data: &'a [u8]) -> Self {
526 Self {
527 data,
528 pos: 0,
529 yield_next: false,
530 }
531 }
532 }
533
534 impl AsyncRead for YieldingReader<'_> {
535 fn poll_read(
536 mut self: Pin<&mut Self>,
537 _cx: &mut Context<'_>,
538 buf: &mut ReadBuf<'_>,
539 ) -> Poll<io::Result<()>> {
540 if self.yield_next {
541 self.yield_next = false;
542 return Poll::Pending;
543 }
544
545 if self.pos >= self.data.len() {
546 return Poll::Ready(Ok(()));
547 }
548
549 if buf.remaining() == 0 {
550 return Poll::Ready(Ok(()));
551 }
552
553 buf.put_slice(&self.data[self.pos..=self.pos]);
554 self.pos += 1;
555 self.yield_next = true;
556
557 Poll::Ready(Ok(()))
558 }
559 }
560
561 #[test]
562 fn cancel_safety_read_exact_is_not_cancel_safe() {
563 init_test("cancel_safety_read_exact_is_not_cancel_safe");
564 let mut reader = YieldingReader::new(b"abc");
565 let mut buf = [0u8; 3];
566 let waker = noop_waker();
567 let mut cx = Context::from_waker(&waker);
568
569 let poll = {
570 let mut fut = reader.read_exact(&mut buf);
571 let mut pinned = Pin::new(&mut fut);
572 pinned.as_mut().poll(&mut cx)
573 };
574 let pending = matches!(poll, Poll::Pending);
575 crate::assert_with_log!(pending, "pending", true, pending);
576 crate::assert_with_log!(buf[0] == b'a', "prefix", b'a', buf[0]);
577 crate::test_complete!("cancel_safety_read_exact_is_not_cancel_safe");
578 }
579
580 #[test]
581 fn cancel_safety_read_to_end_preserves_bytes() {
582 init_test("cancel_safety_read_to_end_preserves_bytes");
583 let mut reader = YieldingReader::new(b"abc");
584 let mut out = Vec::new();
585 let waker = noop_waker();
586 let mut cx = Context::from_waker(&waker);
587
588 let poll = {
589 let mut fut = reader.read_to_end(&mut out);
590 let mut pinned = Pin::new(&mut fut);
591 pinned.as_mut().poll(&mut cx)
592 };
593 let pending = matches!(poll, Poll::Pending);
594 crate::assert_with_log!(pending, "pending", true, pending);
595 crate::assert_with_log!(out == b"a", "out", b"a", out);
596 crate::test_complete!("cancel_safety_read_to_end_preserves_bytes");
597 }
598
599 #[test]
600 fn cancel_safety_read_to_string_preserves_prefix() {
601 init_test("cancel_safety_read_to_string_preserves_prefix");
602 let mut reader = YieldingReader::new(b"abc");
603 let mut out = String::new();
604 let waker = noop_waker();
605 let mut cx = Context::from_waker(&waker);
606
607 let poll = {
608 let mut fut = reader.read_to_string(&mut out);
609 let mut pinned = Pin::new(&mut fut);
610 pinned.as_mut().poll(&mut cx)
611 };
612 let pending = matches!(poll, Poll::Pending);
613 crate::assert_with_log!(pending, "pending", true, pending);
614 crate::assert_with_log!(out == "a", "out", "a", out);
615 crate::test_complete!("cancel_safety_read_to_string_preserves_prefix");
616 }
617}