1use super::{AsyncRead, ReadBuf};
10use crate::stream::Stream;
11use std::io;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15const DEFAULT_CHUNK_SIZE: usize = 8 * 1024;
16
17#[derive(Debug)]
19pub struct ReaderStream<R> {
20 reader: R,
21 chunk_size: usize,
22 done: bool,
23 scratch: Vec<u8>,
24}
25
26impl<R> ReaderStream<R> {
27 #[must_use]
29 pub fn new(reader: R) -> Self {
30 Self::with_capacity(reader, DEFAULT_CHUNK_SIZE)
31 }
32
33 #[must_use]
35 pub fn with_capacity(reader: R, chunk_size: usize) -> Self {
36 let chunk_size = chunk_size.max(1);
37 Self {
38 reader,
39 chunk_size,
40 done: false,
41 scratch: vec![0; chunk_size],
42 }
43 }
44
45 #[must_use]
47 pub fn get_ref(&self) -> &R {
48 &self.reader
49 }
50
51 pub fn get_mut(&mut self) -> &mut R {
53 &mut self.reader
54 }
55
56 #[must_use]
58 pub fn into_inner(self) -> R {
59 self.reader
60 }
61}
62
63impl<R: AsyncRead + Unpin> Stream for ReaderStream<R> {
64 type Item = io::Result<Vec<u8>>;
65
66 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67 let this = self.get_mut();
68 if this.done {
69 return Poll::Ready(None);
70 }
71
72 if this.scratch.len() != this.chunk_size {
73 this.scratch.resize(this.chunk_size, 0);
74 }
75
76 let mut read_buf = ReadBuf::new(&mut this.scratch);
77 match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
78 Poll::Pending => Poll::Pending,
79 Poll::Ready(Err(err)) => {
80 this.done = true;
81 Poll::Ready(Some(Err(err)))
82 }
83 Poll::Ready(Ok(())) => {
84 let filled = read_buf.filled();
85 if filled.is_empty() {
86 this.done = true;
87 Poll::Ready(None)
88 } else {
89 Poll::Ready(Some(Ok(filled.to_vec())))
90 }
91 }
92 }
93 }
94}
95
96#[derive(Debug)]
98pub struct StreamReader<S> {
99 stream: S,
100 current: Vec<u8>,
101 offset: usize,
102 pending_error: Option<io::Error>,
103 done: bool,
104}
105
106impl<S> StreamReader<S> {
107 #[must_use]
109 pub fn new(stream: S) -> Self {
110 Self {
111 stream,
112 current: Vec::new(),
113 offset: 0,
114 pending_error: None,
115 done: false,
116 }
117 }
118
119 #[must_use]
121 pub fn get_ref(&self) -> &S {
122 &self.stream
123 }
124
125 pub fn get_mut(&mut self) -> &mut S {
127 &mut self.stream
128 }
129
130 #[must_use]
132 pub fn into_inner(self) -> S {
133 self.stream
134 }
135}
136
137impl<S> AsyncRead for StreamReader<S>
138where
139 S: Stream<Item = io::Result<Vec<u8>>> + Unpin,
140{
141 fn poll_read(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut ReadBuf<'_>,
145 ) -> Poll<io::Result<()>> {
146 if buf.remaining() == 0 {
147 return Poll::Ready(Ok(()));
148 }
149
150 let this = self.get_mut();
151 let filled_before = buf.filled().len();
152 let mut steps = 0;
153
154 loop {
155 if steps > 32 {
156 cx.waker().wake_by_ref();
157 if buf.filled().len() == filled_before {
158 return Poll::Pending;
159 }
160 return Poll::Ready(Ok(()));
161 }
162 steps += 1;
163
164 if this.offset < this.current.len() {
165 if buf.remaining() == 0 {
166 return Poll::Ready(Ok(()));
167 }
168 let remaining = &this.current[this.offset..];
169 let to_copy = remaining.len().min(buf.remaining());
170 buf.put_slice(&remaining[..to_copy]);
171 this.offset += to_copy;
172 if this.offset == this.current.len() {
173 this.current.clear();
174 this.offset = 0;
175 }
176 if buf.remaining() == 0 {
177 return Poll::Ready(Ok(()));
178 }
179 continue;
180 }
181
182 if let Some(err) = this.pending_error.take() {
183 if buf.filled().len() == filled_before {
184 this.done = true;
185 return Poll::Ready(Err(err));
186 }
187 this.pending_error = Some(err);
188 return Poll::Ready(Ok(()));
189 }
190
191 if this.done {
192 return Poll::Ready(Ok(()));
193 }
194
195 match Pin::new(&mut this.stream).poll_next(cx) {
196 Poll::Pending => {
197 if buf.filled().len() == filled_before {
198 return Poll::Pending;
199 }
200 return Poll::Ready(Ok(()));
201 }
202 Poll::Ready(None) => {
203 this.done = true;
204 return Poll::Ready(Ok(()));
205 }
206 Poll::Ready(Some(Ok(chunk))) => {
207 if chunk.is_empty() {
208 continue;
209 }
210 this.current = chunk;
211 this.offset = 0;
212 }
213 Poll::Ready(Some(Err(err))) => {
214 if buf.filled().len() == filled_before {
215 this.done = true;
216 return Poll::Ready(Err(err));
217 }
218 this.pending_error = Some(err);
219 return Poll::Ready(Ok(()));
220 }
221 }
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use crate::stream;
230
231 use std::task::Waker;
232
233 fn noop_waker() -> Waker {
234 std::task::Waker::noop().clone()
235 }
236
237 fn init_test(name: &str) {
238 crate::test_utils::init_test_logging();
239 crate::test_phase!(name);
240 }
241
242 fn poll_read<R: AsyncRead + Unpin>(reader: &mut R, out: &mut [u8]) -> Poll<io::Result<usize>> {
243 let waker = noop_waker();
244 let mut cx = Context::from_waker(&waker);
245 let mut read_buf = ReadBuf::new(out);
246 match Pin::new(reader).poll_read(&mut cx, &mut read_buf) {
247 Poll::Pending => Poll::Pending,
248 Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
249 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
250 }
251 }
252
253 #[test]
254 fn reader_stream_yields_chunks() {
255 init_test("reader_stream_yields_chunks");
256 let input: &[u8] = b"abcdef";
257 let mut stream = ReaderStream::with_capacity(input, 2);
258 let waker = noop_waker();
259 let mut cx = Context::from_waker(&waker);
260
261 let first = Pin::new(&mut stream).poll_next(&mut cx);
262 let ok = matches!(first, Poll::Ready(Some(Ok(chunk))) if chunk == b"ab");
263 crate::assert_with_log!(ok, "first chunk", true, ok);
264
265 let second = Pin::new(&mut stream).poll_next(&mut cx);
266 let ok = matches!(second, Poll::Ready(Some(Ok(chunk))) if chunk == b"cd");
267 crate::assert_with_log!(ok, "second chunk", true, ok);
268
269 let third = Pin::new(&mut stream).poll_next(&mut cx);
270 let ok = matches!(third, Poll::Ready(Some(Ok(chunk))) if chunk == b"ef");
271 crate::assert_with_log!(ok, "third chunk", true, ok);
272
273 let done = Pin::new(&mut stream).poll_next(&mut cx);
274 let ok = matches!(done, Poll::Ready(None));
275 crate::assert_with_log!(ok, "terminal none", true, ok);
276 crate::test_complete!("reader_stream_yields_chunks");
277 }
278
279 #[test]
280 fn stream_reader_reads_across_multiple_chunks() {
281 init_test("stream_reader_reads_across_multiple_chunks");
282 let chunks = vec![Ok(vec![1_u8, 2]), Ok(vec![3]), Ok(vec![4, 5])];
283 let stream = stream::iter(chunks);
284 let mut reader = StreamReader::new(stream);
285
286 let mut out = [0_u8; 5];
287 let read = poll_read(&mut reader, &mut out);
288 let ok = matches!(read, Poll::Ready(Ok(5)));
289 crate::assert_with_log!(ok, "read length", true, ok);
290 crate::assert_with_log!(out == [1, 2, 3, 4, 5], "content", [1, 2, 3, 4, 5], out);
291
292 let mut eof = [0_u8; 4];
293 let read = poll_read(&mut reader, &mut eof);
294 let ok = matches!(read, Poll::Ready(Ok(0)));
295 crate::assert_with_log!(ok, "eof", true, ok);
296 crate::test_complete!("stream_reader_reads_across_multiple_chunks");
297 }
298
299 #[test]
300 fn stream_reader_defers_error_until_partial_data_consumed() {
301 init_test("stream_reader_defers_error_until_partial_data_consumed");
302 let chunks = vec![
303 Ok(vec![10_u8, 11]),
304 Err(io::Error::new(io::ErrorKind::BrokenPipe, "stream failed")),
305 ];
306 let stream = stream::iter(chunks);
307 let mut reader = StreamReader::new(stream);
308
309 let mut out = [0_u8; 8];
310 let read = poll_read(&mut reader, &mut out);
311 let ok = matches!(read, Poll::Ready(Ok(2)));
312 crate::assert_with_log!(ok, "partial read before error", true, ok);
313 crate::assert_with_log!(out[..2] == [10, 11], "partial content", [10, 11], &out[..2]);
314
315 let mut second = [0_u8; 8];
316 let read = poll_read(&mut reader, &mut second);
317 let ok = matches!(read, Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::BrokenPipe);
318 crate::assert_with_log!(ok, "error surfaced on next read", true, ok);
319 crate::test_complete!("stream_reader_defers_error_until_partial_data_consumed");
320 }
321
322 struct PendingThenDataStream {
323 state: u8,
324 }
325
326 impl PendingThenDataStream {
327 fn new() -> Self {
328 Self { state: 0 }
329 }
330 }
331
332 impl Stream for PendingThenDataStream {
333 type Item = io::Result<Vec<u8>>;
334
335 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
336 match self.state {
337 0 => {
338 self.state = 1;
339 cx.waker().wake_by_ref();
340 Poll::Pending
341 }
342 1 => {
343 self.state = 2;
344 Poll::Ready(Some(Ok(vec![7, 8, 9])))
345 }
346 _ => Poll::Ready(None),
347 }
348 }
349 }
350
351 #[test]
352 fn stream_reader_pending_without_buffered_data() {
353 init_test("stream_reader_pending_without_buffered_data");
354 let mut reader = StreamReader::new(PendingThenDataStream::new());
355
356 let waker = noop_waker();
357 let mut cx = Context::from_waker(&waker);
358 let mut out = [0_u8; 3];
359 let mut read_buf = ReadBuf::new(&mut out);
360 let first = Pin::new(&mut reader).poll_read(&mut cx, &mut read_buf);
361 let ok = first.is_pending();
362 crate::assert_with_log!(ok, "first poll pending", true, ok);
363
364 let mut out = [0_u8; 3];
365 let mut read_buf = ReadBuf::new(&mut out);
366 let second = Pin::new(&mut reader).poll_read(&mut cx, &mut read_buf);
367 let ok = matches!(second, Poll::Ready(Ok(()))) && read_buf.filled() == [7, 8, 9];
368 crate::assert_with_log!(ok, "second poll reads data", true, ok);
369 crate::test_complete!("stream_reader_pending_without_buffered_data");
370 }
371}