Skip to main content

bytesbuf_io/
read_futures.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::any::type_name;
5use std::fmt::{self, Debug};
6use std::marker::PhantomPinned;
7use std::pin::Pin;
8use std::task::Poll::{Pending, Ready};
9use std::{mem, task};
10
11use bytesbuf::{BytesBuf, BytesView};
12
13use crate::Read;
14
15/// Adapts a [`Read`] implementation to the `futures::Stream` API.
16///
17/// Each item in the `futures::Stream` is a [`BytesView`] containing some bytes read
18/// from the underlying [`Read`].
19///
20/// # Security
21///
22/// **This adapter is insecure if the side producing the bytes is not trusted**. An attacker
23/// may trickle data byte-by-byte, consuming a large amount of resources.
24///
25/// Robust code working with untrusted sources should take precautions such as only processing
26/// read data when either a time or length threshold is reached and reusing buffers that
27/// have remaining capacity, appending additional data to existing buffers using
28/// [`read_more_into()`][crate::Read::read_more_into] instead of reserving new buffers
29/// for each read operation.
30pub struct ReadAsFuturesStream<S>
31where
32    S: Read + Debug,
33{
34    // References `inner`. Must be defined above `inner` to ensure it gets dropped first.`
35    #[expect(clippy::type_complexity, reason = "never needs to be named, good enough")]
36    active_read: Option<Pin<Box<dyn Future<Output = Result<BytesBuf, S::Error>>>>>,
37
38    // Safety invariant: we can only touch this field if `active_read` is `None`.
39    inner: S,
40
41    // This struct must remain pinned because `active_read` references `inner`.
42    _require_pin: PhantomPinned,
43}
44
45impl<S> ReadAsFuturesStream<S>
46where
47    S: Read + Debug,
48{
49    pub(crate) fn new(inner: S) -> Pin<Box<Self>> {
50        Box::pin(Self {
51            active_read: None,
52            inner,
53            _require_pin: PhantomPinned,
54        })
55    }
56
57    /// Abandons any ongoing read operation and returns the source.
58    #[must_use]
59    pub fn into_inner(self: Pin<Box<Self>>) -> S {
60        // SAFETY: We are going to unpin `self` by first dropping `active_read`, which is the thing
61        // that references `inner` and requires pinning. Once `active_read` has been dropped,
62        // no more pinning requirements exist.
63        let mut unpinned = unsafe { Pin::into_inner_unchecked(self) };
64
65        unpinned.active_read = None;
66        unpinned.inner
67    }
68}
69
70impl<S> futures_core::Stream for ReadAsFuturesStream<S>
71where
72    S: Read + Debug,
73{
74    type Item = Result<BytesView, S::Error>;
75
76    fn poll_next<'a>(self: Pin<&'a mut Self>, cx: &'a mut task::Context) -> task::Poll<Option<Self::Item>> {
77        // SAFETY: We are not moving `inner`, which is the field that must remain pinned.
78        let this = unsafe { self.get_unchecked_mut() };
79
80        let mut active_read = if let Some(active_read) = this.active_read.take() {
81            active_read
82        } else {
83            let inner = &mut this.inner;
84            let future = async move { inner.read_any().await };
85            let boxed_future = Box::pin(future);
86
87            // SAFETY: We overwrite the lifetime of the future to 'static because in reality we
88            // have a lifetime bounded to the lifetime of the struct itself but this cannot be
89            // meaningfully expressed in Rust, so we have to expand it to 'static. For safety, we
90            // have to ensure that the future does not outlive either the struct instance itself
91            // or `inner`, and that we do not touch `inner` while the future exists.
92            unsafe {
93                mem::transmute::<
94                    Pin<Box<dyn Future<Output = Result<BytesBuf, S::Error>> + 'a>>,
95                    Pin<Box<dyn Future<Output = Result<BytesBuf, S::Error>>>>,
96                >(boxed_future)
97            }
98        };
99
100        let result = active_read.as_mut().poll(cx);
101
102        match result {
103            Ready(Ok(mut buf)) => {
104                let data = buf.consume_all();
105
106                if data.is_empty() {
107                    // We have reached the end of the stream.
108                    return Ready(None);
109                }
110
111                Ready(Some(Ok(data)))
112            }
113            Ready(Err(e)) => Ready(Some(Err(e))),
114            Pending => {
115                this.active_read = Some(active_read);
116                Pending
117            }
118        }
119    }
120}
121
122impl<S> Debug for ReadAsFuturesStream<S>
123where
124    S: Read + Debug,
125{
126    #[cfg_attr(coverage_nightly, coverage(off))] // No API contract to test.
127    #[cfg_attr(test, mutants::skip)] // We have no contract to test.
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        f.debug_struct(type_name::<Self>())
130            .field("inner", &self.inner)
131            .field("active_read.is_some()", &self.active_read.is_some())
132            .finish_non_exhaustive()
133    }
134}
135
136#[cfg(test)]
137#[cfg_attr(coverage_nightly, coverage(off))]
138mod tests {
139    use std::convert::Infallible;
140    use std::pin::pin;
141    use std::task::Waker;
142
143    use bytesbuf::mem::testing::TransparentMemory;
144    use bytesbuf::mem::{GlobalPool, HasMemory, Memory, MemoryShared};
145    use futures::{Stream, StreamExt};
146    use new_zealand::nz;
147    use testing_aids::{YieldFuture, async_test};
148
149    use super::*;
150    use crate::ReadExt;
151    use crate::testing::{FakeRead, Pending};
152
153    #[test]
154    fn smoke_test() {
155        async_test(async || {
156            let memory = GlobalPool::new();
157            let contents = BytesView::copied_from_slice(b"Hello, w", &memory);
158            let inner = FakeRead::builder().contents(contents).max_read_size(nz!(2)).build();
159
160            let mut futures_stream = inner.into_futures_stream();
161
162            // It can be read from.
163            let mut payload1 = futures_stream.next().await.unwrap().unwrap();
164            assert_eq!(payload1.len(), 2);
165            assert_eq!(payload1.get_byte(), b'H');
166            assert_eq!(payload1.get_byte(), b'e');
167
168            let mut payload2 = futures_stream.next().await.unwrap().unwrap();
169            assert_eq!(payload2.len(), 2);
170            assert_eq!(payload2.get_byte(), b'l');
171            assert_eq!(payload2.get_byte(), b'l');
172
173            // We can get back the original.
174            let mut original = futures_stream.into_inner();
175
176            let mut payload3 = original.read_exactly(2).await.unwrap();
177            assert_eq!(payload3.len(), 2);
178            assert_eq!(payload3.get_byte(), b'o');
179            assert_eq!(payload3.get_byte(), b',');
180
181            // Back to the futures::Stream!
182            let mut futures_stream = original.into_futures_stream();
183
184            let mut payload4 = futures_stream.next().await.unwrap().unwrap();
185            assert_eq!(payload4.len(), 2);
186            assert_eq!(payload4.get_byte(), b' ');
187            assert_eq!(payload4.get_byte(), b'w');
188
189            // And once we hit end of stream, it needs to return `None`.
190            assert!(futures_stream.next().await.is_none());
191        });
192    }
193
194    #[test]
195    fn pending_read_cancelled_on_into_inner() {
196        let inner = Pending::new();
197
198        let mut futures_stream = inner.into_futures_stream();
199
200        // We can cancel a pending read. Need to test this directly against the impl,
201        // as the `futures::Stream` extension methods are lazy and will not start the read.
202        // let mut futures_stream = original.into_futures_stream();
203
204        let mut cx = task::Context::from_waker(Waker::noop());
205        assert!(matches!(futures_stream.as_mut().poll_next(&mut cx), task::Poll::Pending));
206
207        // The inner stream is not capable of completing reads, so a bit hard to test.
208        // Well, as long as there is no panic or Miri complaint, we can be satisfied.
209        let mut inner = futures_stream.into_inner();
210
211        let read_future = pin!(inner.read_any());
212        assert!(read_future.poll(&mut cx).is_pending());
213    }
214
215    /// A Read implementation that yields on first poll then returns data.
216    /// This is used to test that `ReadAsFuturesStream` correctly handles `Poll::Pending`.
217    #[derive(Debug)]
218    struct YieldThenRead {
219        inner: FakeRead,
220    }
221
222    impl Memory for YieldThenRead {
223        fn reserve(&self, min_bytes: usize) -> BytesBuf {
224            self.inner.reserve(min_bytes)
225        }
226    }
227
228    impl HasMemory for YieldThenRead {
229        fn memory(&self) -> impl MemoryShared {
230            self.inner.memory()
231        }
232    }
233
234    impl crate::Read for YieldThenRead {
235        type Error = Infallible;
236
237        async fn read_at_most_into(&mut self, len: usize, into: BytesBuf) -> Result<(usize, BytesBuf), Self::Error> {
238            YieldFuture::default().await;
239            self.inner.read_at_most_into(len, into).await
240        }
241
242        async fn read_more_into(&mut self, into: BytesBuf) -> Result<(usize, BytesBuf), Self::Error> {
243            YieldFuture::default().await;
244            self.inner.read_more_into(into).await
245        }
246
247        async fn read_any(&mut self) -> Result<BytesBuf, Self::Error> {
248            YieldFuture::default().await;
249            self.inner.read_any().await
250        }
251    }
252
253    #[test]
254    fn pending_on_first_poll_then_returns_result() {
255        async_test(async || {
256            let memory = GlobalPool::new();
257            let contents = BytesView::copied_from_slice(b"Hello", &memory);
258            let inner = YieldThenRead {
259                inner: FakeRead::builder().contents(contents).build(),
260            };
261
262            let mut futures_stream = ReadAsFuturesStream::new(inner);
263
264            // First poll should be Pending due to YieldFuture
265            let waker = Waker::noop();
266            let mut cx = task::Context::from_waker(waker);
267            let poll_result = futures_stream.as_mut().poll_next(&mut cx);
268            assert!(matches!(poll_result, task::Poll::Pending));
269
270            // Second poll should return the actual data
271            let poll_result = futures_stream.as_mut().poll_next(&mut cx);
272            if let task::Poll::Ready(Some(Ok(mut data))) = poll_result {
273                assert_eq!(data.len(), 5);
274                assert_eq!(data.get_byte(), b'H');
275                assert_eq!(data.get_byte(), b'e');
276                assert_eq!(data.get_byte(), b'l');
277                assert_eq!(data.get_byte(), b'l');
278                assert_eq!(data.get_byte(), b'o');
279            } else {
280                panic!("Expected Ready(Some(Ok(_)))");
281            }
282        });
283    }
284
285    /// A Read implementation that always returns an error.
286    #[derive(Debug)]
287    struct ErroringRead {
288        memory: TransparentMemory,
289    }
290
291    impl Default for ErroringRead {
292        fn default() -> Self {
293            Self {
294                memory: TransparentMemory::new(),
295            }
296        }
297    }
298
299    #[derive(Debug)]
300    struct TestError(String);
301
302    impl fmt::Display for TestError {
303        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304            write!(f, "{}", self.0)
305        }
306    }
307
308    impl std::error::Error for TestError {}
309
310    impl Memory for ErroringRead {
311        fn reserve(&self, min_bytes: usize) -> BytesBuf {
312            self.memory.reserve(min_bytes)
313        }
314    }
315
316    impl HasMemory for ErroringRead {
317        fn memory(&self) -> impl MemoryShared {
318            self.memory.clone()
319        }
320    }
321
322    impl crate::Read for ErroringRead {
323        type Error = TestError;
324
325        async fn read_at_most_into(&mut self, _len: usize, _into: BytesBuf) -> Result<(usize, BytesBuf), Self::Error> {
326            Err(TestError("read_at_most_into error".to_string()))
327        }
328
329        async fn read_more_into(&mut self, _into: BytesBuf) -> Result<(usize, BytesBuf), Self::Error> {
330            Err(TestError("read_more_into error".to_string()))
331        }
332
333        async fn read_any(&mut self) -> Result<BytesBuf, Self::Error> {
334            Err(TestError("read_any error".to_string()))
335        }
336    }
337
338    #[test]
339    fn passes_through_error_from_inner() {
340        async_test(async || {
341            let inner = ErroringRead::default();
342            let mut futures_stream = ReadAsFuturesStream::new(inner);
343
344            let result = futures_stream.next().await;
345
346            match result {
347                Some(Err(TestError(msg))) => {
348                    assert_eq!(msg, "read_any error");
349                }
350                _ => panic!("Expected Some(Err(TestError(_)))"),
351            }
352        });
353    }
354}