1use std::future::Future;
2use std::io::{ErrorKind, Result};
3use std::marker::PhantomPinned;
4use std::mem::MaybeUninit;
5use std::pin::Pin;
6use std::slice;
7use std::task::{Context, Poll};
8
9use aliasable::AliasableMut;
10use completion_core::CompletionFuture;
11use completion_io::{AsyncRead, AsyncReadWith, ReadBuf};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14
15use super::extend_lifetime_mut;
16
17pin_project! {
18 pub struct ReadToEnd<'a, T>
20 where
21 T: AsyncRead,
22 T: ?Sized,
23 {
24 #[pin]
26 fut: Option<<T as AsyncReadWith<'a>>::ReadFuture>,
27
28 reader: AliasableMut<'a, T>,
29
30 read_buf: Box<Option<ReadBuf<'a>>>,
33
34 #[pin]
37 _pinned: PhantomPinned,
38
39 buf: &'a mut Vec<u8>,
41
42 initialized_to: usize,
45
46 initial_filled: usize,
48 }
49}
50
51impl<'a, T: AsyncRead + ?Sized + 'a> ReadToEnd<'a, T> {
52 pub(super) fn new(reader: &'a mut T, buf: &'a mut Vec<u8>) -> Self {
53 let buf_len = buf.len();
54 Self {
55 fut: None,
56 reader: AliasableMut::from_unique(reader),
57 read_buf: Box::new(None),
58 _pinned: PhantomPinned,
59 buf,
60 initialized_to: buf_len,
61 initial_filled: buf_len,
62 }
63 }
64}
65
66impl<'a, T: AsyncRead + ?Sized + 'a> CompletionFuture for ReadToEnd<'a, T> {
67 type Output = Result<usize>;
68
69 unsafe fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
70 let mut this = self.project();
71
72 loop {
73 if let Some(fut) = this.fut.as_mut().as_pin_mut() {
74 let res = ready!(fut.poll(cx));
75 this.fut.set(None);
76
77 let read_buf = this.read_buf.take().unwrap();
80
81 match res {
82 Ok(()) => {
83 let filled = read_buf.filled().len();
84 let initialized = read_buf.initialized().len();
85
86 drop(read_buf);
87
88 if filled == 0 {
90 return Poll::Ready(Ok(this.buf.len() - *this.initial_filled));
91 }
92
93 this.buf.set_len(this.buf.len() + filled);
94 *this.initialized_to = this.buf.len() + initialized;
95 }
96 Err(e) if e.kind() == ErrorKind::Interrupted => {}
97 Err(e) => return Poll::Ready(Err(e)),
98 }
99 }
100
101 this.buf.reserve(32);
102
103 **this.read_buf = Some(ReadBuf::uninit(slice::from_raw_parts_mut(
105 this.buf
106 .as_mut_ptr()
107 .add(this.buf.len())
108 .cast::<MaybeUninit<u8>>(),
109 this.buf.capacity() - this.buf.len(),
110 )));
111 let read_buf = (**this.read_buf).as_mut().unwrap();
112 read_buf.assume_init(*this.initialized_to - this.buf.len());
113
114 let reader = extend_lifetime_mut(&mut **this.reader);
116 let read_buf = extend_lifetime_mut(read_buf);
117 this.fut.as_mut().set(Some(reader.read(read_buf.as_mut())));
118 }
119 }
120 unsafe fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
121 let mut this = self.project();
122
123 if let Some(fut) = this.fut.as_mut().as_pin_mut() {
124 ready!(fut.poll_cancel(cx));
125 this.fut.set(None);
126
127 let filled = this.read_buf.take().unwrap().filled().len();
132 this.buf.set_len(this.buf.len() + filled);
133 }
134
135 Poll::Ready(())
136 }
137}
138impl<'a, T: AsyncRead + ?Sized + 'a> Future for ReadToEnd<'a, T>
139where
140 <T as AsyncReadWith<'a>>::ReadFuture: Future<Output = Result<()>>,
141{
142 type Output = Result<usize>;
143
144 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145 unsafe { CompletionFuture::poll(self, cx) }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 use std::io::{Cursor, Error};
154
155 use crate::future::{block_on, CompletionFutureExt};
156
157 use super::super::{
158 test_utils::{poll_once, YieldingReader},
159 AsyncReadExt,
160 };
161
162 #[test]
163 fn no_yield() {
164 let mut v = Vec::new();
165
166 let mut cursor = Cursor::new(&[1, 2, 3, 4, 5]);
167 assert_eq!(block_on(cursor.read_to_end(&mut v)).unwrap(), 5);
168 assert_eq!(v, &[1, 2, 3, 4, 5]);
169
170 let mut cursor = Cursor::new(&[8; 500]);
171 assert_eq!(block_on(cursor.read_to_end(&mut v)).unwrap(), 500);
172 assert_eq!(v.len(), 505);
173 assert!(v.starts_with(&[1, 2, 3, 4, 5]));
174 for &n in &v[5..] {
175 assert_eq!(n, 8);
176 }
177 }
178
179 #[test]
180 fn yielding() {
181 const BYTES: usize = 13;
182
183 let mut v = Vec::new();
184
185 let mut reader = YieldingReader::new((0..BYTES).map(|_| Ok([18_u8])));
186 assert_eq!(block_on(reader.read_to_end(&mut v)).unwrap(), BYTES);
187 assert_eq!(v, [18; BYTES]);
188 }
189
190 #[test]
191 fn partial() {
192 let mut v = Vec::new();
193
194 let mut reader = YieldingReader::new((0..10).map(|_| [10, 11]).map(Ok));
195 let fut = reader.read_to_end(&mut v);
196 futures_lite::pin!(fut);
197 assert!(poll_once(fut.as_mut()).is_none());
198 assert!(poll_once(fut.as_mut()).is_none());
199 assert_eq!(v, [10, 11]);
200 }
201
202 #[test]
203 fn error() {
204 let mut v = vec![1, 2, 3];
205
206 let mut reader = YieldingReader::new(vec![
207 Ok([4, 5]),
208 Ok([6, 7]),
209 Err(Error::new(ErrorKind::Other, "Some error")),
210 Ok([8, 9]),
211 ]);
212 assert_eq!(
213 block_on(reader.read_to_end(&mut v))
214 .unwrap_err()
215 .to_string(),
216 "Some error"
217 );
218 assert_eq!(v, [1, 2, 3, 4, 5, 6, 7]);
219 }
220
221 #[test]
222 fn ignore_interrupted() {
223 let mut v = vec![1, 2, 3];
224
225 let mut reader = YieldingReader::new(vec![
226 Err(Error::from(ErrorKind::Interrupted)),
227 Ok(&[4, 5][..]),
228 Err(Error::from(ErrorKind::Interrupted)),
229 Err(Error::from(ErrorKind::Interrupted)),
230 Ok(&[6]),
231 Err(Error::from(ErrorKind::Interrupted)),
232 Ok(&[7, 8]),
233 ]);
234 assert_eq!(block_on(reader.read_to_end(&mut v)).unwrap(), 5);
235 assert_eq!(v, [1, 2, 3, 4, 5, 6, 7, 8]);
236 }
237
238 #[test]
239 fn cancellation_doesnt_lose_data() {
240 let mut reader = YieldingReader::empty().after_cancellation(vec![&[4, 5, 6][..], &[0, 0]]);
241
242 let mut v = vec![1, 2, 3];
243 assert!(block_on(reader.read_to_end(&mut v).now_or_never()).is_none());
244 assert_eq!(v, vec![1, 2, 3, 4, 5, 6]);
245 }
246}