rocket_community/data/
data_stream.rs

1use std::io::{self, Cursor};
2use std::path::Path;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::ready;
7use futures::stream::Stream;
8use hyper::body::{Body, Bytes, Incoming as HyperBody};
9use tokio::fs::File;
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf, Take};
11use tokio_util::io::StreamReader;
12
13use crate::data::transform::Transform;
14use crate::data::{Capped, N};
15use crate::util::Chain;
16
17use super::peekable::Peekable;
18use super::transform::TransformBuf;
19
20/// Raw data stream of a request body.
21///
22/// This stream can only be obtained by calling
23/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
24/// contains all of the data in the body of the request.
25///
26/// Reading from a `DataStream` is accomplished via the various methods on the
27/// structure. In general, methods exists in two variants: those that _check_
28/// whether the entire stream was read and those that don't. The former either
29/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
30/// checking if the stream was read to completion while the latter do not.
31///
32/// | Read Into | Method                               | Notes                            |
33/// |-----------|--------------------------------------|----------------------------------|
34/// | `String`  | [`DataStream::into_string()`]        | Completeness checked. Preferred. |
35/// | `String`  | [`AsyncReadExt::read_to_string()`]   | Unchecked w/existing `String`.   |
36/// | `Vec<u8>` | [`DataStream::into_bytes()`]         | Checked. Preferred.              |
37/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`]  | Checked w/existing `Vec`.        |
38/// | `Vec<u8>` | [`DataStream::stream_precise_to()`]  | Unchecked w/existing `Vec`.      |
39/// | `File`    | [`DataStream::into_file()`]          | Checked. Preferred.              |
40/// | `File`    | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`.      |
41/// | `File`    | [`DataStream::stream_precise_to()`]  | Unchecked w/ existing `File`.    |
42/// | `T`       | [`DataStream::stream_to()`]          | Checked. Any `T: AsyncWrite`.    |
43/// | `T`       | [`DataStream::stream_precise_to()`]  | Unchecked. Any `T: AsyncWrite`.  |
44///
45/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
46/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
47#[allow(clippy::large_enum_variant)]
48#[non_exhaustive]
49pub enum DataStream<'r> {
50    #[doc(hidden)]
51    Base(BaseReader<'r>),
52    #[doc(hidden)]
53    Transform(TransformReader<'r>),
54}
55
56/// A data stream that has a `transformer` applied to it.
57pub struct TransformReader<'r> {
58    transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
59    stream: Pin<Box<DataStream<'r>>>,
60    inner_done: bool,
61}
62
63/// Limited, pre-buffered reader to the underlying data stream.
64pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
65
66/// Direct reader to the underlying data stream. Not limited in any manner.
67pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
68
69/// Raw underlying data stream.
70#[allow(clippy::large_enum_variant)]
71pub enum RawStream<'r> {
72    Empty,
73    Body(HyperBody),
74    #[cfg(feature = "http3-preview")]
75    H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
76    Multipart(multer::Field<'r>),
77}
78
79impl<'r> TransformReader<'r> {
80    /// Returns the underlying `BaseReader`.
81    fn base_mut(&mut self) -> &mut BaseReader<'r> {
82        match self.stream.as_mut().get_mut() {
83            DataStream::Base(base) => base,
84            DataStream::Transform(inner) => inner.base_mut(),
85        }
86    }
87
88    /// Returns the underlying `BaseReader`.
89    fn base(&self) -> &BaseReader<'r> {
90        match self.stream.as_ref().get_ref() {
91            DataStream::Base(base) => base,
92            DataStream::Transform(inner) => inner.base(),
93        }
94    }
95}
96
97impl<'r> DataStream<'r> {
98    pub(crate) fn new(
99        transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
100        Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
101        limit: u64,
102    ) -> Self {
103        let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
104        for transformer in transformers {
105            stream = DataStream::Transform(TransformReader {
106                transformer,
107                stream: Box::pin(stream),
108                inner_done: false,
109            });
110        }
111
112        stream
113    }
114
115    /// Returns the underlying `BaseReader`.
116    fn base_mut(&mut self) -> &mut BaseReader<'r> {
117        match self {
118            DataStream::Base(base) => base,
119            DataStream::Transform(transform) => transform.base_mut(),
120        }
121    }
122
123    /// Returns the underlying `BaseReader`.
124    fn base(&self) -> &BaseReader<'r> {
125        match self {
126            DataStream::Base(base) => base,
127            DataStream::Transform(transform) => transform.base(),
128        }
129    }
130
131    /// Whether a previous read exhausted the set limit _and then some_.
132    async fn limit_exceeded(&mut self) -> io::Result<bool> {
133        let base = self.base_mut();
134
135        #[cold]
136        async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
137            // Read one more byte after reaching limit to see if we cut early.
138            base.set_limit(1);
139            let mut buf = [0u8; 1];
140            let exceeded = base.read(&mut buf).await? != 0;
141            base.set_limit(0);
142            Ok(exceeded)
143        }
144
145        Ok(base.limit() == 0 && _limit_exceeded(base).await?)
146    }
147
148    /// Number of bytes a full read from `self` will _definitely_ read.
149    ///
150    /// # Example
151    ///
152    /// ```rust
153    /// # extern crate rocket_community as rocket;
154    /// use rocket::data::{Data, ToByteUnit};
155    ///
156    /// async fn f(data: Data<'_>) {
157    ///     let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
158    /// }
159    /// ```
160    pub fn hint(&self) -> usize {
161        let base = self.base();
162        if let (Some(cursor), _) = base.get_ref().get_ref() {
163            let len = cursor.get_ref().len() as u64;
164            let position = cursor.position().min(len);
165            let remaining = len - position;
166            remaining.min(base.limit()) as usize
167        } else {
168            0
169        }
170    }
171
172    /// A helper method to write the body of the request to any `AsyncWrite`
173    /// type. Returns an [`N`] which indicates how many bytes were written and
174    /// whether the entire stream was read. An additional read from `self` may
175    /// be required to check if all of the stream has been read. If that
176    /// information is not needed, use [`DataStream::stream_precise_to()`].
177    ///
178    /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
179    /// except in that it returns an `N` to check for completeness.
180    ///
181    /// # Example
182    ///
183    /// ```rust
184    /// # extern crate rocket_community as rocket;
185    /// use std::io;
186    /// use rocket::data::{Data, ToByteUnit};
187    ///
188    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
189    ///     // write all of the data to stdout
190    ///     let written = data.open(512.kibibytes())
191    ///         .stream_to(tokio::io::stdout()).await?;
192    ///
193    ///     Ok(format!("Wrote {} bytes.", written))
194    /// }
195    /// ```
196    #[inline(always)]
197    pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
198    where
199        W: AsyncWrite + Unpin,
200    {
201        let written = tokio::io::copy(&mut self, &mut writer).await?;
202        Ok(N {
203            written,
204            complete: !self.limit_exceeded().await?,
205        })
206    }
207
208    /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
209    /// conducted and thus read/write completeness is unknown.
210    ///
211    /// # Example
212    ///
213    /// ```rust
214    /// # extern crate rocket_community as rocket;
215    /// use std::io;
216    /// use rocket::data::{Data, ToByteUnit};
217    ///
218    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
219    ///     // write all of the data to stdout
220    ///     let written = data.open(512.kibibytes())
221    ///         .stream_precise_to(tokio::io::stdout()).await?;
222    ///
223    ///     Ok(format!("Wrote {} bytes.", written))
224    /// }
225    /// ```
226    #[inline(always)]
227    pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
228    where
229        W: AsyncWrite + Unpin,
230    {
231        tokio::io::copy(&mut self, &mut writer).await
232    }
233
234    /// A helper method to write the body of the request to a `Vec<u8>`.
235    ///
236    /// # Example
237    ///
238    /// ```rust
239    /// # extern crate rocket_community as rocket;
240    /// use std::io;
241    /// use rocket::data::{Data, ToByteUnit};
242    ///
243    /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
244    ///     let bytes = data.open(4.kibibytes()).into_bytes().await?;
245    ///     if !bytes.is_complete() {
246    ///         println!("there are bytes remaining in the stream");
247    ///     }
248    ///
249    ///     Ok(bytes.into_inner())
250    /// }
251    /// ```
252    pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
253        let mut vec = Vec::with_capacity(self.hint());
254        let n = self.stream_to(&mut vec).await?;
255        Ok(Capped { value: vec, n })
256    }
257
258    /// A helper method to write the body of the request to a `String`.
259    ///
260    /// # Example
261    ///
262    /// ```rust
263    /// # extern crate rocket_community as rocket;
264    /// use std::io;
265    /// use rocket::data::{Data, ToByteUnit};
266    ///
267    /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
268    ///     let string = data.open(10.bytes()).into_string().await?;
269    ///     if !string.is_complete() {
270    ///         println!("there are bytes remaining in the stream");
271    ///     }
272    ///
273    ///     Ok(string.into_inner())
274    /// }
275    /// ```
276    pub async fn into_string(mut self) -> io::Result<Capped<String>> {
277        let mut string = String::with_capacity(self.hint());
278        let written = self.read_to_string(&mut string).await?;
279        let n = N {
280            written: written as u64,
281            complete: !self.limit_exceeded().await?,
282        };
283        Ok(Capped { value: string, n })
284    }
285
286    /// A helper method to write the body of the request to a file at the path
287    /// determined by `path`. If a file at the path already exists, it is
288    /// overwritten. The opened file is returned.
289    ///
290    /// # Example
291    ///
292    /// ```rust
293    /// # extern crate rocket_community as rocket;
294    /// use std::io;
295    /// use rocket::data::{Data, ToByteUnit};
296    ///
297    /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
298    ///     let file = data.open(1.megabytes()).into_file("/static/file").await?;
299    ///     if !file.is_complete() {
300    ///         println!("there are bytes remaining in the stream");
301    ///     }
302    ///
303    ///     Ok(format!("Wrote {} bytes to /static/file", file.n))
304    /// }
305    /// ```
306    pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
307        let mut file = File::create(path).await?;
308        let n = self
309            .stream_to(&mut tokio::io::BufWriter::new(&mut file))
310            .await?;
311        Ok(Capped { value: file, n })
312    }
313}
314
315impl AsyncRead for DataStream<'_> {
316    fn poll_read(
317        self: Pin<&mut Self>,
318        cx: &mut Context<'_>,
319        buf: &mut ReadBuf<'_>,
320    ) -> Poll<io::Result<()>> {
321        match self.get_mut() {
322            DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
323            DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
324        }
325    }
326}
327
328impl AsyncRead for TransformReader<'_> {
329    fn poll_read(
330        mut self: Pin<&mut Self>,
331        cx: &mut Context<'_>,
332        buf: &mut ReadBuf<'_>,
333    ) -> Poll<io::Result<()>> {
334        let init_fill = buf.filled().len();
335        if !self.inner_done {
336            ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
337            self.inner_done = init_fill == buf.filled().len();
338        }
339
340        if self.inner_done {
341            return self.transformer.as_mut().poll_finish(cx, buf);
342        }
343
344        let mut tbuf = TransformBuf {
345            buf,
346            cursor: init_fill,
347        };
348        self.transformer.as_mut().transform(&mut tbuf)?;
349        if buf.filled().len() == init_fill {
350            cx.waker().wake_by_ref();
351            return Poll::Pending;
352        }
353
354        Poll::Ready(Ok(()))
355    }
356}
357
358impl Stream for RawStream<'_> {
359    type Item = io::Result<Bytes>;
360
361    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        match self.get_mut() {
363            // TODO: Expose trailer headers, somehow.
364            RawStream::Body(body) => Pin::new(body)
365                .poll_frame(cx)
366                .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
367                .map_err(io::Error::other),
368            #[cfg(feature = "http3-preview")]
369            RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
370            RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
371            RawStream::Empty => Poll::Ready(None),
372        }
373    }
374
375    fn size_hint(&self) -> (usize, Option<usize>) {
376        match self {
377            RawStream::Body(body) => {
378                let hint = body.size_hint();
379                let (lower, upper) = (hint.lower(), hint.upper());
380                (lower as usize, upper.map(|x| x as usize))
381            }
382            #[cfg(feature = "http3-preview")]
383            RawStream::H3Body(_) => (0, Some(0)),
384            RawStream::Multipart(mp) => mp.size_hint(),
385            RawStream::Empty => (0, Some(0)),
386        }
387    }
388}
389
390impl std::fmt::Display for RawStream<'_> {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        match self {
393            RawStream::Empty => f.write_str("empty stream"),
394            RawStream::Body(_) => f.write_str("request body"),
395            #[cfg(feature = "http3-preview")]
396            RawStream::H3Body(_) => f.write_str("http3 quic stream"),
397            RawStream::Multipart(_) => f.write_str("multipart form field"),
398        }
399    }
400}
401
402impl<'r> From<HyperBody> for RawStream<'r> {
403    fn from(value: HyperBody) -> Self {
404        Self::Body(value)
405    }
406}
407
408#[cfg(feature = "http3-preview")]
409impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
410    fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
411        Self::H3Body(value)
412    }
413}
414
415impl<'r> From<multer::Field<'r>> for RawStream<'r> {
416    fn from(value: multer::Field<'r>) -> Self {
417        Self::Multipart(value)
418    }
419}