Skip to main content

asupersync/io/ext/
read_ext.rs

1//! AsyncRead extension methods.
2
3use crate::io::{AsyncRead, AsyncReadVectored, Chain, ReadBuf, Take};
4use std::future::Future;
5use std::io::{self, ErrorKind, IoSliceMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9/// Extension trait for `AsyncRead`.
10pub trait AsyncReadExt: AsyncRead {
11    /// Read the exact number of bytes to fill `buf`.
12    fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self>
13    where
14        Self: Unpin,
15    {
16        ReadExact {
17            reader: self,
18            buf,
19            pos: 0,
20        }
21    }
22
23    /// Read the entire reader into `buf`.
24    fn read_to_end<'a>(&'a mut self, buf: &'a mut Vec<u8>) -> ReadToEnd<'a, Self>
25    where
26        Self: Unpin,
27    {
28        let start_len = buf.len();
29        ReadToEnd {
30            reader: self,
31            buf,
32            start_len,
33        }
34    }
35
36    /// Read the entire reader into `buf` as UTF-8.
37    fn read_to_string<'a>(&'a mut self, buf: &'a mut String) -> ReadToString<'a, Self>
38    where
39        Self: Unpin,
40    {
41        let start_len = buf.len();
42        ReadToString {
43            reader: self,
44            buf,
45            pending_utf8: Vec::new(),
46            read: 0,
47            start_len,
48        }
49    }
50
51    /// Read a single byte.
52    fn read_u8(&mut self) -> ReadU8<'_, Self>
53    where
54        Self: Unpin,
55    {
56        ReadU8 { reader: self }
57    }
58
59    /// Chain this reader with another.
60    fn chain<R: AsyncRead>(self, next: R) -> Chain<Self, R>
61    where
62        Self: Sized,
63    {
64        Chain::new(self, next)
65    }
66
67    /// Take at most `limit` bytes from this reader.
68    fn take(self, limit: u64) -> Take<Self>
69    where
70        Self: Sized,
71    {
72        Take::new(self, limit)
73    }
74}
75
76impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
77
78/// Extension trait for `AsyncReadVectored`.
79pub trait AsyncReadVectoredExt: AsyncReadVectored {
80    /// Read into multiple buffers (vectored I/O).
81    fn read_vectored<'a>(&'a mut self, bufs: &'a mut [IoSliceMut<'a>]) -> ReadVectored<'a, Self>
82    where
83        Self: Unpin,
84    {
85        ReadVectored { reader: self, bufs }
86    }
87}
88
89impl<R: AsyncReadVectored + ?Sized> AsyncReadVectoredExt for R {}
90
91/// Future for read_vectored.
92pub struct ReadVectored<'a, R: ?Sized> {
93    reader: &'a mut R,
94    bufs: &'a mut [IoSliceMut<'a>],
95}
96
97impl<R> Future for ReadVectored<'_, R>
98where
99    R: AsyncReadVectored + Unpin + ?Sized,
100{
101    type Output = io::Result<usize>;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104        let this = self.get_mut();
105        Pin::new(&mut *this.reader).poll_read_vectored(cx, this.bufs)
106    }
107}
108
109/// Future for read_exact.
110pub struct ReadExact<'a, R: ?Sized> {
111    reader: &'a mut R,
112    buf: &'a mut [u8],
113    pos: usize,
114}
115
116impl<R> Future for ReadExact<'_, R>
117where
118    R: AsyncRead + Unpin + ?Sized,
119{
120    type Output = io::Result<()>;
121
122    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        let this = self.get_mut();
124
125        while this.pos < this.buf.len() {
126            let mut read_buf = ReadBuf::new(&mut this.buf[this.pos..]);
127            match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
128                Poll::Pending => return Poll::Pending,
129                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
130                Poll::Ready(Ok(())) => {
131                    let n = read_buf.filled().len();
132                    if n == 0 {
133                        return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)));
134                    }
135                    this.pos += n;
136                }
137            }
138        }
139
140        Poll::Ready(Ok(()))
141    }
142}
143
144/// Future for read_to_end.
145pub struct ReadToEnd<'a, R: ?Sized> {
146    reader: &'a mut R,
147    buf: &'a mut Vec<u8>,
148    start_len: usize,
149}
150
151impl<R> Future for ReadToEnd<'_, R>
152where
153    R: AsyncRead + Unpin + ?Sized,
154{
155    type Output = io::Result<usize>;
156
157    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        const CHUNK: usize = 1024;
159        let this = self.get_mut();
160
161        loop {
162            let mut local = [0u8; CHUNK];
163            let mut read_buf = ReadBuf::new(&mut local);
164            match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
165                Poll::Pending => return Poll::Pending,
166                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
167                Poll::Ready(Ok(())) => {
168                    let n = read_buf.filled().len();
169                    if n == 0 {
170                        return Poll::Ready(Ok(this.buf.len().saturating_sub(this.start_len)));
171                    }
172                    this.buf.extend_from_slice(read_buf.filled());
173                }
174            }
175        }
176    }
177}
178
179/// Future for read_to_string.
180pub struct ReadToString<'a, R: ?Sized> {
181    reader: &'a mut R,
182    buf: &'a mut String,
183    pending_utf8: Vec<u8>,
184    read: usize,
185    start_len: usize,
186}
187
188impl<R: ?Sized> ReadToString<'_, R> {
189    fn rollback_utf8_error(&mut self) {
190        self.buf.truncate(self.start_len);
191        self.pending_utf8.clear();
192    }
193
194    fn push_valid_prefix(&mut self) -> io::Result<()> {
195        match std::str::from_utf8(&self.pending_utf8) {
196            Ok(s) => {
197                self.buf.push_str(s);
198                self.pending_utf8.clear();
199                Ok(())
200            }
201            Err(err) => {
202                if err.error_len().is_some() {
203                    return Err(io::Error::new(ErrorKind::InvalidData, "invalid utf-8"));
204                }
205
206                let valid_up_to = err.valid_up_to();
207                if valid_up_to == 0 {
208                    return Ok(());
209                }
210                let valid = &self.pending_utf8[..valid_up_to];
211                let valid_str = std::str::from_utf8(valid)
212                    .map_err(|_| io::Error::new(ErrorKind::InvalidData, "invalid utf-8"))?;
213                self.buf.push_str(valid_str);
214                self.pending_utf8 = self.pending_utf8[valid_up_to..].to_vec();
215                Ok(())
216            }
217        }
218    }
219}
220
221impl<R> Future for ReadToString<'_, R>
222where
223    R: AsyncRead + Unpin + ?Sized,
224{
225    type Output = io::Result<usize>;
226
227    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228        const CHUNK: usize = 1024;
229        let this = self.get_mut();
230
231        loop {
232            let mut local = [0u8; CHUNK];
233            let mut read_buf = ReadBuf::new(&mut local);
234            match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
235                Poll::Pending => return Poll::Pending,
236                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
237                Poll::Ready(Ok(())) => {
238                    let n = read_buf.filled().len();
239                    if n == 0 {
240                        if this.pending_utf8.is_empty() {
241                            return Poll::Ready(Ok(this.read));
242                        }
243                        this.rollback_utf8_error();
244                        return Poll::Ready(Err(io::Error::new(
245                            ErrorKind::InvalidData,
246                            "incomplete utf-8 sequence",
247                        )));
248                    }
249                    this.read += n;
250                    this.pending_utf8.extend_from_slice(read_buf.filled());
251                    if let Err(err) = this.push_valid_prefix() {
252                        this.rollback_utf8_error();
253                        return Poll::Ready(Err(err));
254                    }
255                }
256            }
257        }
258    }
259}
260
261/// Future for reading a single byte.
262pub struct ReadU8<'a, R: ?Sized> {
263    reader: &'a mut R,
264}
265
266impl<R> Future for ReadU8<'_, R>
267where
268    R: AsyncRead + Unpin + ?Sized,
269{
270    type Output = io::Result<u8>;
271
272    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273        let this = self.get_mut();
274        let mut one = [0u8; 1];
275        let mut read_buf = ReadBuf::new(&mut one);
276        match Pin::new(&mut *this.reader).poll_read(cx, &mut read_buf) {
277            Poll::Pending => Poll::Pending,
278            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
279            Poll::Ready(Ok(())) => {
280                if read_buf.filled().is_empty() {
281                    Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof)))
282                } else {
283                    Poll::Ready(Ok(read_buf.filled()[0]))
284                }
285            }
286        }
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use std::io::IoSliceMut;
294    use std::pin::Pin;
295    use std::sync::Arc;
296    use std::task::{Context, Wake, Waker};
297
298    fn init_test(name: &str) {
299        crate::test_utils::init_test_logging();
300        crate::test_phase!(name);
301    }
302
303    struct NoopWaker;
304
305    impl Wake for NoopWaker {
306        fn wake(self: Arc<Self>) {}
307    }
308
309    fn noop_waker() -> Waker {
310        Waker::from(Arc::new(NoopWaker))
311    }
312
313    fn poll_ready<F: Future>(fut: &mut Pin<&mut F>) -> Option<F::Output> {
314        let waker = noop_waker();
315        let mut cx = Context::from_waker(&waker);
316        for _ in 0..32 {
317            if let Poll::Ready(output) = fut.as_mut().poll(&mut cx) {
318                return Some(output);
319            }
320        }
321        None
322    }
323
324    #[test]
325    fn read_exact_ok() {
326        init_test("read_exact_ok");
327        let mut reader: &[u8] = b"abcd";
328        let mut buf = [0u8; 4];
329        let mut fut = reader.read_exact(&mut buf);
330        let mut fut = Pin::new(&mut fut);
331        let result = poll_ready(&mut fut).expect("future did not resolve");
332        crate::assert_with_log!(result.is_ok(), "result ok", true, result.is_ok());
333        crate::assert_with_log!(&buf == b"abcd", "buf", b"abcd", buf);
334        crate::test_complete!("read_exact_ok");
335    }
336
337    #[test]
338    fn read_exact_eof() {
339        init_test("read_exact_eof");
340        let mut reader: &[u8] = b"ab";
341        let mut buf = [0u8; 4];
342        let mut fut = reader.read_exact(&mut buf);
343        let mut fut = Pin::new(&mut fut);
344        let err = poll_ready(&mut fut)
345            .expect("future did not resolve")
346            .unwrap_err();
347        let kind = err.kind();
348        crate::assert_with_log!(
349            kind == io::ErrorKind::UnexpectedEof,
350            "error kind",
351            io::ErrorKind::UnexpectedEof,
352            kind
353        );
354        crate::test_complete!("read_exact_eof");
355    }
356
357    #[test]
358    fn read_to_end_reads_all() {
359        init_test("read_to_end_reads_all");
360        let mut reader: &[u8] = b"hello";
361        let mut buf = Vec::new();
362        let mut fut = reader.read_to_end(&mut buf);
363        let mut fut = Pin::new(&mut fut);
364        let n = poll_ready(&mut fut)
365            .expect("future did not resolve")
366            .unwrap();
367        crate::assert_with_log!(n == 5, "bytes read", 5, n);
368        crate::assert_with_log!(buf == b"hello", "buf", b"hello", buf);
369        crate::test_complete!("read_to_end_reads_all");
370    }
371
372    #[test]
373    fn read_to_string_reads_all() {
374        init_test("read_to_string_reads_all");
375        let mut reader: &[u8] = b"hi";
376        let mut buf = String::new();
377        let mut fut = reader.read_to_string(&mut buf);
378        let mut fut = Pin::new(&mut fut);
379        let n = poll_ready(&mut fut)
380            .expect("future did not resolve")
381            .unwrap();
382        crate::assert_with_log!(n == 2, "bytes read", 2, n);
383        crate::assert_with_log!(buf == "hi", "buf", "hi", buf);
384        crate::test_complete!("read_to_string_reads_all");
385    }
386
387    #[test]
388    fn read_to_string_invalid_utf8_errors() {
389        init_test("read_to_string_invalid_utf8_errors");
390        let mut reader: &[u8] = &[0xff, 0xfe];
391        let mut buf = String::new();
392        let mut fut = reader.read_to_string(&mut buf);
393        let mut fut = Pin::new(&mut fut);
394        let err = poll_ready(&mut fut)
395            .expect("future did not resolve")
396            .unwrap_err();
397        let kind = err.kind();
398        crate::assert_with_log!(
399            kind == io::ErrorKind::InvalidData,
400            "error kind",
401            io::ErrorKind::InvalidData,
402            kind
403        );
404        let empty = buf.is_empty();
405        crate::assert_with_log!(empty, "buf empty", true, empty);
406        crate::test_complete!("read_to_string_invalid_utf8_errors");
407    }
408
409    #[test]
410    fn read_to_string_incomplete_utf8_errors() {
411        init_test("read_to_string_incomplete_utf8_errors");
412        // 4-byte UTF-8 sequence, missing the final byte.
413        let mut reader: &[u8] = &[0xF0, 0x9F, 0x92];
414        let mut buf = String::new();
415        let mut fut = reader.read_to_string(&mut buf);
416        let mut fut = Pin::new(&mut fut);
417        let err = poll_ready(&mut fut)
418            .expect("future did not resolve")
419            .unwrap_err();
420        let kind = err.kind();
421        crate::assert_with_log!(
422            kind == io::ErrorKind::InvalidData,
423            "error kind",
424            io::ErrorKind::InvalidData,
425            kind
426        );
427        let empty = buf.is_empty();
428        crate::assert_with_log!(empty, "buf empty", true, empty);
429        crate::test_complete!("read_to_string_incomplete_utf8_errors");
430    }
431
432    #[test]
433    fn read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix() {
434        init_test("read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix");
435        let mut input = vec![b'a'; 1024];
436        input.push(0xFF);
437        let mut reader: &[u8] = &input;
438        let mut buf = String::from("seed");
439        let mut fut = reader.read_to_string(&mut buf);
440        let mut fut = Pin::new(&mut fut);
441        let err = poll_ready(&mut fut)
442            .expect("future did not resolve")
443            .unwrap_err();
444        let kind = err.kind();
445        crate::assert_with_log!(
446            kind == io::ErrorKind::InvalidData,
447            "error kind",
448            io::ErrorKind::InvalidData,
449            kind
450        );
451        crate::assert_with_log!(buf == "seed", "buf rollback", "seed", buf);
452        crate::test_complete!("read_to_string_invalid_utf8_rolls_back_after_long_valid_prefix");
453    }
454
455    #[test]
456    fn read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix() {
457        init_test("read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix");
458        let mut input = vec![b'a'; 1024];
459        input.extend_from_slice(&[0xF0, 0x9F, 0x92]);
460        let mut reader: &[u8] = &input;
461        let mut buf = String::from("seed");
462        let mut fut = reader.read_to_string(&mut buf);
463        let mut fut = Pin::new(&mut fut);
464        let err = poll_ready(&mut fut)
465            .expect("future did not resolve")
466            .unwrap_err();
467        let kind = err.kind();
468        crate::assert_with_log!(
469            kind == io::ErrorKind::InvalidData,
470            "error kind",
471            io::ErrorKind::InvalidData,
472            kind
473        );
474        crate::assert_with_log!(buf == "seed", "buf rollback", "seed", buf);
475        crate::test_complete!("read_to_string_incomplete_utf8_rolls_back_after_long_valid_prefix");
476    }
477
478    #[test]
479    fn read_u8_reads_byte() {
480        init_test("read_u8_reads_byte");
481        let mut reader: &[u8] = b"z";
482        let mut fut = reader.read_u8();
483        let mut fut = Pin::new(&mut fut);
484        let byte = poll_ready(&mut fut)
485            .expect("future did not resolve")
486            .unwrap();
487        crate::assert_with_log!(byte == b'z', "byte", b'z', byte);
488        crate::test_complete!("read_u8_reads_byte");
489    }
490
491    #[test]
492    fn read_vectored_reads_prefix() {
493        init_test("read_vectored_reads_prefix");
494        let mut reader: &[u8] = b"hello";
495        let mut a = [0u8; 2];
496        let mut b = [0u8; 3];
497        let mut bufs = [IoSliceMut::new(&mut a), IoSliceMut::new(&mut b)];
498
499        let mut fut = reader.read_vectored(&mut bufs);
500        let mut fut = Pin::new(&mut fut);
501        let n = poll_ready(&mut fut)
502            .expect("future did not resolve")
503            .expect("read_vectored failed");
504
505        let mut got = Vec::new();
506        let first = n.min(a.len());
507        got.extend_from_slice(&a[..first]);
508        if n > a.len() {
509            got.extend_from_slice(&b[..n - a.len()]);
510        }
511
512        let expected = b"hello";
513        crate::assert_with_log!(got == expected[..n], "vectored prefix", &expected[..n], got);
514        crate::test_complete!("read_vectored_reads_prefix");
515    }
516
517    #[derive(Debug)]
518    struct YieldingReader<'a> {
519        data: &'a [u8],
520        pos: usize,
521        yield_next: bool,
522    }
523
524    impl<'a> YieldingReader<'a> {
525        fn new(data: &'a [u8]) -> Self {
526            Self {
527                data,
528                pos: 0,
529                yield_next: false,
530            }
531        }
532    }
533
534    impl AsyncRead for YieldingReader<'_> {
535        fn poll_read(
536            mut self: Pin<&mut Self>,
537            _cx: &mut Context<'_>,
538            buf: &mut ReadBuf<'_>,
539        ) -> Poll<io::Result<()>> {
540            if self.yield_next {
541                self.yield_next = false;
542                return Poll::Pending;
543            }
544
545            if self.pos >= self.data.len() {
546                return Poll::Ready(Ok(()));
547            }
548
549            if buf.remaining() == 0 {
550                return Poll::Ready(Ok(()));
551            }
552
553            buf.put_slice(&self.data[self.pos..=self.pos]);
554            self.pos += 1;
555            self.yield_next = true;
556
557            Poll::Ready(Ok(()))
558        }
559    }
560
561    #[test]
562    fn cancel_safety_read_exact_is_not_cancel_safe() {
563        init_test("cancel_safety_read_exact_is_not_cancel_safe");
564        let mut reader = YieldingReader::new(b"abc");
565        let mut buf = [0u8; 3];
566        let waker = noop_waker();
567        let mut cx = Context::from_waker(&waker);
568
569        let poll = {
570            let mut fut = reader.read_exact(&mut buf);
571            let mut pinned = Pin::new(&mut fut);
572            pinned.as_mut().poll(&mut cx)
573        };
574        let pending = matches!(poll, Poll::Pending);
575        crate::assert_with_log!(pending, "pending", true, pending);
576        crate::assert_with_log!(buf[0] == b'a', "prefix", b'a', buf[0]);
577        crate::test_complete!("cancel_safety_read_exact_is_not_cancel_safe");
578    }
579
580    #[test]
581    fn cancel_safety_read_to_end_preserves_bytes() {
582        init_test("cancel_safety_read_to_end_preserves_bytes");
583        let mut reader = YieldingReader::new(b"abc");
584        let mut out = Vec::new();
585        let waker = noop_waker();
586        let mut cx = Context::from_waker(&waker);
587
588        let poll = {
589            let mut fut = reader.read_to_end(&mut out);
590            let mut pinned = Pin::new(&mut fut);
591            pinned.as_mut().poll(&mut cx)
592        };
593        let pending = matches!(poll, Poll::Pending);
594        crate::assert_with_log!(pending, "pending", true, pending);
595        crate::assert_with_log!(out == b"a", "out", b"a", out);
596        crate::test_complete!("cancel_safety_read_to_end_preserves_bytes");
597    }
598
599    #[test]
600    fn cancel_safety_read_to_string_preserves_prefix() {
601        init_test("cancel_safety_read_to_string_preserves_prefix");
602        let mut reader = YieldingReader::new(b"abc");
603        let mut out = String::new();
604        let waker = noop_waker();
605        let mut cx = Context::from_waker(&waker);
606
607        let poll = {
608            let mut fut = reader.read_to_string(&mut out);
609            let mut pinned = Pin::new(&mut fut);
610            pinned.as_mut().poll(&mut cx)
611        };
612        let pending = matches!(poll, Poll::Pending);
613        crate::assert_with_log!(pending, "pending", true, pending);
614        crate::assert_with_log!(out == "a", "out", "a", out);
615        crate::test_complete!("cancel_safety_read_to_string_preserves_prefix");
616    }
617}