rocket_community/data/data_stream.rs
1use std::io::{self, Cursor};
2use std::path::Path;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures::ready;
7use futures::stream::Stream;
8use hyper::body::{Body, Bytes, Incoming as HyperBody};
9use tokio::fs::File;
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf, Take};
11use tokio_util::io::StreamReader;
12
13use crate::data::transform::Transform;
14use crate::data::{Capped, N};
15use crate::util::Chain;
16
17use super::peekable::Peekable;
18use super::transform::TransformBuf;
19
20/// Raw data stream of a request body.
21///
22/// This stream can only be obtained by calling
23/// [`Data::open()`](crate::data::Data::open()) with a data limit. The stream
24/// contains all of the data in the body of the request.
25///
26/// Reading from a `DataStream` is accomplished via the various methods on the
27/// structure. In general, methods exists in two variants: those that _check_
28/// whether the entire stream was read and those that don't. The former either
29/// directly or indirectly (via [`Capped`]) return an [`N`] which allows
30/// checking if the stream was read to completion while the latter do not.
31///
32/// | Read Into | Method | Notes |
33/// |-----------|--------------------------------------|----------------------------------|
34/// | `String` | [`DataStream::into_string()`] | Completeness checked. Preferred. |
35/// | `String` | [`AsyncReadExt::read_to_string()`] | Unchecked w/existing `String`. |
36/// | `Vec<u8>` | [`DataStream::into_bytes()`] | Checked. Preferred. |
37/// | `Vec<u8>` | [`DataStream::stream_to(&mut vec)`] | Checked w/existing `Vec`. |
38/// | `Vec<u8>` | [`DataStream::stream_precise_to()`] | Unchecked w/existing `Vec`. |
39/// | `File` | [`DataStream::into_file()`] | Checked. Preferred. |
40/// | `File` | [`DataStream::stream_to(&mut file)`] | Checked w/ existing `File`. |
41/// | `File` | [`DataStream::stream_precise_to()`] | Unchecked w/ existing `File`. |
42/// | `T` | [`DataStream::stream_to()`] | Checked. Any `T: AsyncWrite`. |
43/// | `T` | [`DataStream::stream_precise_to()`] | Unchecked. Any `T: AsyncWrite`. |
44///
45/// [`DataStream::stream_to(&mut vec)`]: DataStream::stream_to()
46/// [`DataStream::stream_to(&mut file)`]: DataStream::stream_to()
47#[allow(clippy::large_enum_variant)]
48#[non_exhaustive]
49pub enum DataStream<'r> {
50 #[doc(hidden)]
51 Base(BaseReader<'r>),
52 #[doc(hidden)]
53 Transform(TransformReader<'r>),
54}
55
56/// A data stream that has a `transformer` applied to it.
57pub struct TransformReader<'r> {
58 transformer: Pin<Box<dyn Transform + Send + Sync + 'r>>,
59 stream: Pin<Box<DataStream<'r>>>,
60 inner_done: bool,
61}
62
63/// Limited, pre-buffered reader to the underlying data stream.
64pub type BaseReader<'r> = Take<Chain<Cursor<Vec<u8>>, RawReader<'r>>>;
65
66/// Direct reader to the underlying data stream. Not limited in any manner.
67pub type RawReader<'r> = StreamReader<RawStream<'r>, Bytes>;
68
69/// Raw underlying data stream.
70#[allow(clippy::large_enum_variant)]
71pub enum RawStream<'r> {
72 Empty,
73 Body(HyperBody),
74 #[cfg(feature = "http3-preview")]
75 H3Body(crate::listener::Cancellable<crate::listener::quic::QuicRx>),
76 Multipart(multer::Field<'r>),
77}
78
79impl<'r> TransformReader<'r> {
80 /// Returns the underlying `BaseReader`.
81 fn base_mut(&mut self) -> &mut BaseReader<'r> {
82 match self.stream.as_mut().get_mut() {
83 DataStream::Base(base) => base,
84 DataStream::Transform(inner) => inner.base_mut(),
85 }
86 }
87
88 /// Returns the underlying `BaseReader`.
89 fn base(&self) -> &BaseReader<'r> {
90 match self.stream.as_ref().get_ref() {
91 DataStream::Base(base) => base,
92 DataStream::Transform(inner) => inner.base(),
93 }
94 }
95}
96
97impl<'r> DataStream<'r> {
98 pub(crate) fn new(
99 transformers: Vec<Pin<Box<dyn Transform + Send + Sync + 'r>>>,
100 Peekable { buffer, reader, .. }: Peekable<512, RawReader<'r>>,
101 limit: u64,
102 ) -> Self {
103 let mut stream = DataStream::Base(Chain::new(Cursor::new(buffer), reader).take(limit));
104 for transformer in transformers {
105 stream = DataStream::Transform(TransformReader {
106 transformer,
107 stream: Box::pin(stream),
108 inner_done: false,
109 });
110 }
111
112 stream
113 }
114
115 /// Returns the underlying `BaseReader`.
116 fn base_mut(&mut self) -> &mut BaseReader<'r> {
117 match self {
118 DataStream::Base(base) => base,
119 DataStream::Transform(transform) => transform.base_mut(),
120 }
121 }
122
123 /// Returns the underlying `BaseReader`.
124 fn base(&self) -> &BaseReader<'r> {
125 match self {
126 DataStream::Base(base) => base,
127 DataStream::Transform(transform) => transform.base(),
128 }
129 }
130
131 /// Whether a previous read exhausted the set limit _and then some_.
132 async fn limit_exceeded(&mut self) -> io::Result<bool> {
133 let base = self.base_mut();
134
135 #[cold]
136 async fn _limit_exceeded(base: &mut BaseReader<'_>) -> io::Result<bool> {
137 // Read one more byte after reaching limit to see if we cut early.
138 base.set_limit(1);
139 let mut buf = [0u8; 1];
140 let exceeded = base.read(&mut buf).await? != 0;
141 base.set_limit(0);
142 Ok(exceeded)
143 }
144
145 Ok(base.limit() == 0 && _limit_exceeded(base).await?)
146 }
147
148 /// Number of bytes a full read from `self` will _definitely_ read.
149 ///
150 /// # Example
151 ///
152 /// ```rust
153 /// # extern crate rocket_community as rocket;
154 /// use rocket::data::{Data, ToByteUnit};
155 ///
156 /// async fn f(data: Data<'_>) {
157 /// let definitely_have_n_bytes = data.open(1.kibibytes()).hint();
158 /// }
159 /// ```
160 pub fn hint(&self) -> usize {
161 let base = self.base();
162 if let (Some(cursor), _) = base.get_ref().get_ref() {
163 let len = cursor.get_ref().len() as u64;
164 let position = cursor.position().min(len);
165 let remaining = len - position;
166 remaining.min(base.limit()) as usize
167 } else {
168 0
169 }
170 }
171
172 /// A helper method to write the body of the request to any `AsyncWrite`
173 /// type. Returns an [`N`] which indicates how many bytes were written and
174 /// whether the entire stream was read. An additional read from `self` may
175 /// be required to check if all of the stream has been read. If that
176 /// information is not needed, use [`DataStream::stream_precise_to()`].
177 ///
178 /// This method is identical to `tokio::io::copy(&mut self, &mut writer)`
179 /// except in that it returns an `N` to check for completeness.
180 ///
181 /// # Example
182 ///
183 /// ```rust
184 /// # extern crate rocket_community as rocket;
185 /// use std::io;
186 /// use rocket::data::{Data, ToByteUnit};
187 ///
188 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
189 /// // write all of the data to stdout
190 /// let written = data.open(512.kibibytes())
191 /// .stream_to(tokio::io::stdout()).await?;
192 ///
193 /// Ok(format!("Wrote {} bytes.", written))
194 /// }
195 /// ```
196 #[inline(always)]
197 pub async fn stream_to<W>(mut self, mut writer: W) -> io::Result<N>
198 where
199 W: AsyncWrite + Unpin,
200 {
201 let written = tokio::io::copy(&mut self, &mut writer).await?;
202 Ok(N {
203 written,
204 complete: !self.limit_exceeded().await?,
205 })
206 }
207
208 /// Like [`DataStream::stream_to()`] except that no end-of-stream check is
209 /// conducted and thus read/write completeness is unknown.
210 ///
211 /// # Example
212 ///
213 /// ```rust
214 /// # extern crate rocket_community as rocket;
215 /// use std::io;
216 /// use rocket::data::{Data, ToByteUnit};
217 ///
218 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
219 /// // write all of the data to stdout
220 /// let written = data.open(512.kibibytes())
221 /// .stream_precise_to(tokio::io::stdout()).await?;
222 ///
223 /// Ok(format!("Wrote {} bytes.", written))
224 /// }
225 /// ```
226 #[inline(always)]
227 pub async fn stream_precise_to<W>(mut self, mut writer: W) -> io::Result<u64>
228 where
229 W: AsyncWrite + Unpin,
230 {
231 tokio::io::copy(&mut self, &mut writer).await
232 }
233
234 /// A helper method to write the body of the request to a `Vec<u8>`.
235 ///
236 /// # Example
237 ///
238 /// ```rust
239 /// # extern crate rocket_community as rocket;
240 /// use std::io;
241 /// use rocket::data::{Data, ToByteUnit};
242 ///
243 /// async fn data_guard(data: Data<'_>) -> io::Result<Vec<u8>> {
244 /// let bytes = data.open(4.kibibytes()).into_bytes().await?;
245 /// if !bytes.is_complete() {
246 /// println!("there are bytes remaining in the stream");
247 /// }
248 ///
249 /// Ok(bytes.into_inner())
250 /// }
251 /// ```
252 pub async fn into_bytes(self) -> io::Result<Capped<Vec<u8>>> {
253 let mut vec = Vec::with_capacity(self.hint());
254 let n = self.stream_to(&mut vec).await?;
255 Ok(Capped { value: vec, n })
256 }
257
258 /// A helper method to write the body of the request to a `String`.
259 ///
260 /// # Example
261 ///
262 /// ```rust
263 /// # extern crate rocket_community as rocket;
264 /// use std::io;
265 /// use rocket::data::{Data, ToByteUnit};
266 ///
267 /// async fn data_guard(data: Data<'_>) -> io::Result<String> {
268 /// let string = data.open(10.bytes()).into_string().await?;
269 /// if !string.is_complete() {
270 /// println!("there are bytes remaining in the stream");
271 /// }
272 ///
273 /// Ok(string.into_inner())
274 /// }
275 /// ```
276 pub async fn into_string(mut self) -> io::Result<Capped<String>> {
277 let mut string = String::with_capacity(self.hint());
278 let written = self.read_to_string(&mut string).await?;
279 let n = N {
280 written: written as u64,
281 complete: !self.limit_exceeded().await?,
282 };
283 Ok(Capped { value: string, n })
284 }
285
286 /// A helper method to write the body of the request to a file at the path
287 /// determined by `path`. If a file at the path already exists, it is
288 /// overwritten. The opened file is returned.
289 ///
290 /// # Example
291 ///
292 /// ```rust
293 /// # extern crate rocket_community as rocket;
294 /// use std::io;
295 /// use rocket::data::{Data, ToByteUnit};
296 ///
297 /// async fn data_guard(mut data: Data<'_>) -> io::Result<String> {
298 /// let file = data.open(1.megabytes()).into_file("/static/file").await?;
299 /// if !file.is_complete() {
300 /// println!("there are bytes remaining in the stream");
301 /// }
302 ///
303 /// Ok(format!("Wrote {} bytes to /static/file", file.n))
304 /// }
305 /// ```
306 pub async fn into_file<P: AsRef<Path>>(self, path: P) -> io::Result<Capped<File>> {
307 let mut file = File::create(path).await?;
308 let n = self
309 .stream_to(&mut tokio::io::BufWriter::new(&mut file))
310 .await?;
311 Ok(Capped { value: file, n })
312 }
313}
314
315impl AsyncRead for DataStream<'_> {
316 fn poll_read(
317 self: Pin<&mut Self>,
318 cx: &mut Context<'_>,
319 buf: &mut ReadBuf<'_>,
320 ) -> Poll<io::Result<()>> {
321 match self.get_mut() {
322 DataStream::Base(inner) => Pin::new(inner).poll_read(cx, buf),
323 DataStream::Transform(inner) => Pin::new(inner).poll_read(cx, buf),
324 }
325 }
326}
327
328impl AsyncRead for TransformReader<'_> {
329 fn poll_read(
330 mut self: Pin<&mut Self>,
331 cx: &mut Context<'_>,
332 buf: &mut ReadBuf<'_>,
333 ) -> Poll<io::Result<()>> {
334 let init_fill = buf.filled().len();
335 if !self.inner_done {
336 ready!(Pin::new(&mut self.stream).poll_read(cx, buf))?;
337 self.inner_done = init_fill == buf.filled().len();
338 }
339
340 if self.inner_done {
341 return self.transformer.as_mut().poll_finish(cx, buf);
342 }
343
344 let mut tbuf = TransformBuf {
345 buf,
346 cursor: init_fill,
347 };
348 self.transformer.as_mut().transform(&mut tbuf)?;
349 if buf.filled().len() == init_fill {
350 cx.waker().wake_by_ref();
351 return Poll::Pending;
352 }
353
354 Poll::Ready(Ok(()))
355 }
356}
357
358impl Stream for RawStream<'_> {
359 type Item = io::Result<Bytes>;
360
361 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 match self.get_mut() {
363 // TODO: Expose trailer headers, somehow.
364 RawStream::Body(body) => Pin::new(body)
365 .poll_frame(cx)
366 .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new()))
367 .map_err(io::Error::other),
368 #[cfg(feature = "http3-preview")]
369 RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx),
370 RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other),
371 RawStream::Empty => Poll::Ready(None),
372 }
373 }
374
375 fn size_hint(&self) -> (usize, Option<usize>) {
376 match self {
377 RawStream::Body(body) => {
378 let hint = body.size_hint();
379 let (lower, upper) = (hint.lower(), hint.upper());
380 (lower as usize, upper.map(|x| x as usize))
381 }
382 #[cfg(feature = "http3-preview")]
383 RawStream::H3Body(_) => (0, Some(0)),
384 RawStream::Multipart(mp) => mp.size_hint(),
385 RawStream::Empty => (0, Some(0)),
386 }
387 }
388}
389
390impl std::fmt::Display for RawStream<'_> {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 match self {
393 RawStream::Empty => f.write_str("empty stream"),
394 RawStream::Body(_) => f.write_str("request body"),
395 #[cfg(feature = "http3-preview")]
396 RawStream::H3Body(_) => f.write_str("http3 quic stream"),
397 RawStream::Multipart(_) => f.write_str("multipart form field"),
398 }
399 }
400}
401
402impl<'r> From<HyperBody> for RawStream<'r> {
403 fn from(value: HyperBody) -> Self {
404 Self::Body(value)
405 }
406}
407
408#[cfg(feature = "http3-preview")]
409impl<'r> From<crate::listener::Cancellable<crate::listener::quic::QuicRx>> for RawStream<'r> {
410 fn from(value: crate::listener::Cancellable<crate::listener::quic::QuicRx>) -> Self {
411 Self::H3Body(value)
412 }
413}
414
415impl<'r> From<multer::Field<'r>> for RawStream<'r> {
416 fn from(value: multer::Field<'r>) -> Self {
417 Self::Multipart(value)
418 }
419}