completion/io/read/
read_exact.rs

1use std::future::Future;
2use std::io::{Error, ErrorKind, Result};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use aliasable::AliasableMut;
7use completion_core::CompletionFuture;
8use completion_io::{AsyncRead, AsyncReadWith, ReadBuf, ReadBufMut};
9use futures_core::ready;
10use pin_project_lite::pin_project;
11
12use super::extend_lifetime_mut;
13
14pin_project! {
15    /// Future for [`AsyncReadExt::read_exact`](super::AsyncReadExt::read_exact).
16    pub struct ReadExact<'a, T>
17    where
18        T: AsyncRead,
19        T: ?Sized,
20    {
21        #[pin]
22        fut: Option<<T as AsyncReadWith<'a>>::ReadFuture>,
23        reader: AliasableMut<'a, T>,
24        buf: AliasableMut<'a, ReadBuf<'a>>,
25        // The number of bytes filled at the start of the previous operation.
26        previous_filled: usize,
27    }
28}
29
30impl<'a, T: AsyncRead + ?Sized + 'a> ReadExact<'a, T> {
31    pub(super) fn new(reader: &'a mut T, buf: ReadBufMut<'a>) -> Self {
32        Self {
33            fut: None,
34            reader: AliasableMut::from_unique(reader),
35            buf: AliasableMut::from(unsafe { buf.into_mut() }),
36            previous_filled: 0,
37        }
38    }
39}
40
41impl<'a, T: AsyncRead + ?Sized + 'a> CompletionFuture for ReadExact<'a, T> {
42    type Output = Result<()>;
43
44    unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45        let mut this = self.project();
46
47        loop {
48            if let Some(fut) = this.fut.as_mut().as_pin_mut() {
49                let res = ready!(fut.poll(cx));
50                this.fut.set(None);
51
52                match res {
53                    Ok(()) => {
54                        // There is no future, so we can create a mutable reference to `read_buf`
55                        // without aliasing.
56                        let read_buf = this.buf.as_mut();
57
58                        if read_buf.filled().len() == *this.previous_filled {
59                            return Poll::Ready(Err(Error::new(
60                                ErrorKind::UnexpectedEof,
61                                "failed to fill buffer",
62                            )));
63                        }
64                    }
65                    Err(e) if e.kind() == ErrorKind::Interrupted => {}
66                    Err(e) => return Poll::Ready(Err(e)),
67                }
68            }
69
70            let read_buf = &mut **this.buf;
71
72            if read_buf.remaining() == 0 {
73                return Poll::Ready(Ok(()));
74            }
75
76            *this.previous_filled = read_buf.filled().len();
77
78            let reader = extend_lifetime_mut(&mut **this.reader);
79            let read_buf = extend_lifetime_mut(read_buf);
80            this.fut.set(Some(reader.read(read_buf.as_mut())));
81        }
82    }
83    unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
84        if let Some(fut) = self.project().fut.as_pin_mut() {
85            fut.poll_cancel(cx)
86        } else {
87            Poll::Ready(())
88        }
89    }
90}
91impl<'a, T: AsyncRead + ?Sized + 'a> Future for ReadExact<'a, T>
92where
93    <T as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
94{
95    type Output = Result<()>;
96
97    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
98        unsafe { CompletionFuture::poll(self, cx) }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    use std::io::Cursor;
107    use std::mem::MaybeUninit;
108
109    use crate::future::block_on;
110
111    use super::super::{test_utils::YieldingReader, AsyncReadExt};
112
113    #[test]
114    fn no_yield() {
115        let mut cursor = Cursor::new([8; 50]);
116
117        let mut storage = [MaybeUninit::uninit(); 10];
118        let mut buffer = ReadBuf::uninit(&mut storage);
119        block_on(cursor.read_exact(buffer.as_mut())).unwrap();
120
121        assert_eq!(buffer.into_filled(), [8; 10]);
122        assert_eq!(cursor.position(), 10);
123    }
124
125    #[test]
126    fn yielding() {
127        let mut reader = YieldingReader::new((0..100).map(|_| Ok([5, 6, 7])));
128
129        let mut storage = [MaybeUninit::uninit(); 10];
130        let mut buffer = ReadBuf::uninit(&mut storage);
131        block_on(reader.read_exact(buffer.as_mut())).unwrap();
132
133        assert_eq!(buffer.into_filled(), [5, 6, 7, 5, 6, 7, 5, 6, 7, 5]);
134    }
135
136    #[test]
137    fn eof() {
138        let mut reader = YieldingReader::new(vec![
139            Err(Error::from(ErrorKind::Interrupted)),
140            Ok(&[0, 1, 2][..]),
141            Err(Error::from(ErrorKind::Interrupted)),
142            Ok(&[3, 4]),
143            Err(Error::from(ErrorKind::Interrupted)),
144            Err(Error::from(ErrorKind::Interrupted)),
145        ]);
146
147        let mut storage = [MaybeUninit::uninit(); 10];
148        let mut buffer = ReadBuf::uninit(&mut storage);
149        assert_eq!(
150            block_on(reader.read_exact(buffer.as_mut()))
151                .unwrap_err()
152                .kind(),
153            ErrorKind::UnexpectedEof
154        );
155
156        assert_eq!(buffer.into_filled(), [0, 1, 2, 3, 4]);
157    }
158}