Skip to main content

asupersync/io/
stream_adapters.rs

1//! Stream/AsyncRead bridge adapters.
2//!
3//! These adapters cover the common bridge patterns used by middleware and
4//! protocol glue:
5//!
6//! - [`ReaderStream`]: `AsyncRead` -> `Stream<Item = io::Result<Vec<u8>>>`
7//! - [`StreamReader`]: `Stream<Item = io::Result<Vec<u8>>>` -> `AsyncRead`
8
9use super::{AsyncRead, ReadBuf};
10use crate::stream::Stream;
11use std::io;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15const DEFAULT_CHUNK_SIZE: usize = 8 * 1024;
16
17/// Adapts an [`AsyncRead`] into a stream of byte chunks.
18#[derive(Debug)]
19pub struct ReaderStream<R> {
20    reader: R,
21    chunk_size: usize,
22    done: bool,
23    scratch: Vec<u8>,
24}
25
26impl<R> ReaderStream<R> {
27    /// Creates a new `ReaderStream` with the default chunk size (8 KiB).
28    #[must_use]
29    pub fn new(reader: R) -> Self {
30        Self::with_capacity(reader, DEFAULT_CHUNK_SIZE)
31    }
32
33    /// Creates a new `ReaderStream` with a custom chunk size.
34    #[must_use]
35    pub fn with_capacity(reader: R, chunk_size: usize) -> Self {
36        let chunk_size = chunk_size.max(1);
37        Self {
38            reader,
39            chunk_size,
40            done: false,
41            scratch: vec![0; chunk_size],
42        }
43    }
44
45    /// Returns a reference to the inner reader.
46    #[must_use]
47    pub fn get_ref(&self) -> &R {
48        &self.reader
49    }
50
51    /// Returns a mutable reference to the inner reader.
52    pub fn get_mut(&mut self) -> &mut R {
53        &mut self.reader
54    }
55
56    /// Consumes the adapter and returns the inner reader.
57    #[must_use]
58    pub fn into_inner(self) -> R {
59        self.reader
60    }
61}
62
63impl<R: AsyncRead + Unpin> Stream for ReaderStream<R> {
64    type Item = io::Result<Vec<u8>>;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let this = self.get_mut();
68        if this.done {
69            return Poll::Ready(None);
70        }
71
72        if this.scratch.len() != this.chunk_size {
73            this.scratch.resize(this.chunk_size, 0);
74        }
75
76        let mut read_buf = ReadBuf::new(&mut this.scratch);
77        match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
78            Poll::Pending => Poll::Pending,
79            Poll::Ready(Err(err)) => {
80                this.done = true;
81                Poll::Ready(Some(Err(err)))
82            }
83            Poll::Ready(Ok(())) => {
84                let filled = read_buf.filled();
85                if filled.is_empty() {
86                    this.done = true;
87                    Poll::Ready(None)
88                } else {
89                    Poll::Ready(Some(Ok(filled.to_vec())))
90                }
91            }
92        }
93    }
94}
95
96/// Adapts a byte stream into an [`AsyncRead`] implementation.
97#[derive(Debug)]
98pub struct StreamReader<S> {
99    stream: S,
100    current: Vec<u8>,
101    offset: usize,
102    pending_error: Option<io::Error>,
103    done: bool,
104}
105
106impl<S> StreamReader<S> {
107    /// Creates a new `StreamReader`.
108    #[must_use]
109    pub fn new(stream: S) -> Self {
110        Self {
111            stream,
112            current: Vec::new(),
113            offset: 0,
114            pending_error: None,
115            done: false,
116        }
117    }
118
119    /// Returns a reference to the inner stream.
120    #[must_use]
121    pub fn get_ref(&self) -> &S {
122        &self.stream
123    }
124
125    /// Returns a mutable reference to the inner stream.
126    pub fn get_mut(&mut self) -> &mut S {
127        &mut self.stream
128    }
129
130    /// Consumes the adapter and returns the inner stream.
131    #[must_use]
132    pub fn into_inner(self) -> S {
133        self.stream
134    }
135}
136
137impl<S> AsyncRead for StreamReader<S>
138where
139    S: Stream<Item = io::Result<Vec<u8>>> + Unpin,
140{
141    fn poll_read(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut ReadBuf<'_>,
145    ) -> Poll<io::Result<()>> {
146        if buf.remaining() == 0 {
147            return Poll::Ready(Ok(()));
148        }
149
150        let this = self.get_mut();
151        let filled_before = buf.filled().len();
152        let mut steps = 0;
153
154        loop {
155            if steps > 32 {
156                cx.waker().wake_by_ref();
157                if buf.filled().len() == filled_before {
158                    return Poll::Pending;
159                }
160                return Poll::Ready(Ok(()));
161            }
162            steps += 1;
163
164            if this.offset < this.current.len() {
165                if buf.remaining() == 0 {
166                    return Poll::Ready(Ok(()));
167                }
168                let remaining = &this.current[this.offset..];
169                let to_copy = remaining.len().min(buf.remaining());
170                buf.put_slice(&remaining[..to_copy]);
171                this.offset += to_copy;
172                if this.offset == this.current.len() {
173                    this.current.clear();
174                    this.offset = 0;
175                }
176                if buf.remaining() == 0 {
177                    return Poll::Ready(Ok(()));
178                }
179                continue;
180            }
181
182            if let Some(err) = this.pending_error.take() {
183                if buf.filled().len() == filled_before {
184                    this.done = true;
185                    return Poll::Ready(Err(err));
186                }
187                this.pending_error = Some(err);
188                return Poll::Ready(Ok(()));
189            }
190
191            if this.done {
192                return Poll::Ready(Ok(()));
193            }
194
195            match Pin::new(&mut this.stream).poll_next(cx) {
196                Poll::Pending => {
197                    if buf.filled().len() == filled_before {
198                        return Poll::Pending;
199                    }
200                    return Poll::Ready(Ok(()));
201                }
202                Poll::Ready(None) => {
203                    this.done = true;
204                    return Poll::Ready(Ok(()));
205                }
206                Poll::Ready(Some(Ok(chunk))) => {
207                    if chunk.is_empty() {
208                        continue;
209                    }
210                    this.current = chunk;
211                    this.offset = 0;
212                }
213                Poll::Ready(Some(Err(err))) => {
214                    if buf.filled().len() == filled_before {
215                        this.done = true;
216                        return Poll::Ready(Err(err));
217                    }
218                    this.pending_error = Some(err);
219                    return Poll::Ready(Ok(()));
220                }
221            }
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::stream;
230
231    use std::task::Waker;
232
233    fn noop_waker() -> Waker {
234        std::task::Waker::noop().clone()
235    }
236
237    fn init_test(name: &str) {
238        crate::test_utils::init_test_logging();
239        crate::test_phase!(name);
240    }
241
242    fn poll_read<R: AsyncRead + Unpin>(reader: &mut R, out: &mut [u8]) -> Poll<io::Result<usize>> {
243        let waker = noop_waker();
244        let mut cx = Context::from_waker(&waker);
245        let mut read_buf = ReadBuf::new(out);
246        match Pin::new(reader).poll_read(&mut cx, &mut read_buf) {
247            Poll::Pending => Poll::Pending,
248            Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
249            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
250        }
251    }
252
253    #[test]
254    fn reader_stream_yields_chunks() {
255        init_test("reader_stream_yields_chunks");
256        let input: &[u8] = b"abcdef";
257        let mut stream = ReaderStream::with_capacity(input, 2);
258        let waker = noop_waker();
259        let mut cx = Context::from_waker(&waker);
260
261        let first = Pin::new(&mut stream).poll_next(&mut cx);
262        let ok = matches!(first, Poll::Ready(Some(Ok(chunk))) if chunk == b"ab");
263        crate::assert_with_log!(ok, "first chunk", true, ok);
264
265        let second = Pin::new(&mut stream).poll_next(&mut cx);
266        let ok = matches!(second, Poll::Ready(Some(Ok(chunk))) if chunk == b"cd");
267        crate::assert_with_log!(ok, "second chunk", true, ok);
268
269        let third = Pin::new(&mut stream).poll_next(&mut cx);
270        let ok = matches!(third, Poll::Ready(Some(Ok(chunk))) if chunk == b"ef");
271        crate::assert_with_log!(ok, "third chunk", true, ok);
272
273        let done = Pin::new(&mut stream).poll_next(&mut cx);
274        let ok = matches!(done, Poll::Ready(None));
275        crate::assert_with_log!(ok, "terminal none", true, ok);
276        crate::test_complete!("reader_stream_yields_chunks");
277    }
278
279    #[test]
280    fn stream_reader_reads_across_multiple_chunks() {
281        init_test("stream_reader_reads_across_multiple_chunks");
282        let chunks = vec![Ok(vec![1_u8, 2]), Ok(vec![3]), Ok(vec![4, 5])];
283        let stream = stream::iter(chunks);
284        let mut reader = StreamReader::new(stream);
285
286        let mut out = [0_u8; 5];
287        let read = poll_read(&mut reader, &mut out);
288        let ok = matches!(read, Poll::Ready(Ok(5)));
289        crate::assert_with_log!(ok, "read length", true, ok);
290        crate::assert_with_log!(out == [1, 2, 3, 4, 5], "content", [1, 2, 3, 4, 5], out);
291
292        let mut eof = [0_u8; 4];
293        let read = poll_read(&mut reader, &mut eof);
294        let ok = matches!(read, Poll::Ready(Ok(0)));
295        crate::assert_with_log!(ok, "eof", true, ok);
296        crate::test_complete!("stream_reader_reads_across_multiple_chunks");
297    }
298
299    #[test]
300    fn stream_reader_defers_error_until_partial_data_consumed() {
301        init_test("stream_reader_defers_error_until_partial_data_consumed");
302        let chunks = vec![
303            Ok(vec![10_u8, 11]),
304            Err(io::Error::new(io::ErrorKind::BrokenPipe, "stream failed")),
305        ];
306        let stream = stream::iter(chunks);
307        let mut reader = StreamReader::new(stream);
308
309        let mut out = [0_u8; 8];
310        let read = poll_read(&mut reader, &mut out);
311        let ok = matches!(read, Poll::Ready(Ok(2)));
312        crate::assert_with_log!(ok, "partial read before error", true, ok);
313        crate::assert_with_log!(out[..2] == [10, 11], "partial content", [10, 11], &out[..2]);
314
315        let mut second = [0_u8; 8];
316        let read = poll_read(&mut reader, &mut second);
317        let ok = matches!(read, Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::BrokenPipe);
318        crate::assert_with_log!(ok, "error surfaced on next read", true, ok);
319        crate::test_complete!("stream_reader_defers_error_until_partial_data_consumed");
320    }
321
322    struct PendingThenDataStream {
323        state: u8,
324    }
325
326    impl PendingThenDataStream {
327        fn new() -> Self {
328            Self { state: 0 }
329        }
330    }
331
332    impl Stream for PendingThenDataStream {
333        type Item = io::Result<Vec<u8>>;
334
335        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
336            match self.state {
337                0 => {
338                    self.state = 1;
339                    cx.waker().wake_by_ref();
340                    Poll::Pending
341                }
342                1 => {
343                    self.state = 2;
344                    Poll::Ready(Some(Ok(vec![7, 8, 9])))
345                }
346                _ => Poll::Ready(None),
347            }
348        }
349    }
350
351    #[test]
352    fn stream_reader_pending_without_buffered_data() {
353        init_test("stream_reader_pending_without_buffered_data");
354        let mut reader = StreamReader::new(PendingThenDataStream::new());
355
356        let waker = noop_waker();
357        let mut cx = Context::from_waker(&waker);
358        let mut out = [0_u8; 3];
359        let mut read_buf = ReadBuf::new(&mut out);
360        let first = Pin::new(&mut reader).poll_read(&mut cx, &mut read_buf);
361        let ok = first.is_pending();
362        crate::assert_with_log!(ok, "first poll pending", true, ok);
363
364        let mut out = [0_u8; 3];
365        let mut read_buf = ReadBuf::new(&mut out);
366        let second = Pin::new(&mut reader).poll_read(&mut cx, &mut read_buf);
367        let ok = matches!(second, Poll::Ready(Ok(()))) && read_buf.filled() == [7, 8, 9];
368        crate::assert_with_log!(ok, "second poll reads data", true, ok);
369        crate::test_complete!("stream_reader_pending_without_buffered_data");
370    }
371}