Skip to main content

multra/
field.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use bytes::{Bytes, BytesMut};
6use encoding_rs::{Encoding, UTF_8};
7use futures_util::stream::{Stream, TryStreamExt};
8use http::header::HeaderMap;
9#[cfg(feature = "json")]
10use serde::de::DeserializeOwned;
11use spin::mutex::spin::SpinMutex as Mutex;
12
13use crate::content_disposition::ContentDisposition;
14use crate::multipart::{MultipartState, StreamingStage};
15use crate::{Error, helpers};
16
17/// A single field in a multipart stream.
18///
19/// Its content can be accessed via the [`Stream`] API or the methods defined in
20/// this type.
21///
22/// # Lifetime
23///
24/// The lifetime of the stream `'r` corresponds to the lifetime of the
25/// underlying `Stream`. If the underlying stream holds no references directly
26/// or transitively, then the lifetime can be `'static`.
27///
28/// # Examples
29///
30/// ```
31/// use std::convert::Infallible;
32///
33/// use bytes::Bytes;
34/// use futures_util::stream::once;
35/// use multra::Multipart;
36///
37/// # async fn run() {
38/// let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; \
39///     name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
40///
41/// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
42/// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
43///
44/// while let Some(field) = multipart.next_field().await.unwrap() {
45///     let content = field.text().await.unwrap();
46///     assert_eq!(content, "abcd");
47/// }
48/// # }
49/// # tokio::runtime::Runtime::new().unwrap().block_on(run());
50/// ```
51///
52/// [`Multipart`]: crate::Multipart
53#[derive(Debug)]
54pub struct Field<'r> {
55    state: Arc<Mutex<MultipartState<'r>>>,
56    done: bool,
57    headers: HeaderMap,
58    content_disposition: ContentDisposition,
59    content_type: Option<mime::Mime>,
60    idx: usize,
61}
62
63impl<'r> Field<'r> {
64    pub(crate) fn new(
65        state: Arc<Mutex<MultipartState<'r>>>,
66        headers: HeaderMap,
67        idx: usize,
68        content_disposition: ContentDisposition,
69    ) -> Self {
70        let content_type = helpers::parse_content_type(&headers);
71        Field {
72            state,
73            headers,
74            content_disposition,
75            content_type,
76            idx,
77            done: false,
78        }
79    }
80
81    /// The field name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
82    pub fn name(&self) -> Option<&str> {
83        self.content_disposition.field_name.as_deref()
84    }
85
86    /// The file name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
87    pub fn file_name(&self) -> Option<&str> {
88        self.content_disposition.file_name.as_deref()
89    }
90
91    /// Get the content type of the field.
92    pub fn content_type(&self) -> Option<&mime::Mime> {
93        self.content_type.as_ref()
94    }
95
96    /// Get a map of headers as [`HeaderMap`].
97    pub fn headers(&self) -> &HeaderMap {
98        &self.headers
99    }
100
101    /// Get the full data of the field as [`Bytes`].
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use std::convert::Infallible;
107    ///
108    /// use bytes::Bytes;
109    /// use futures_util::stream::once;
110    /// use multra::Multipart;
111    ///
112    /// # async fn run() {
113    /// let data =
114    ///     "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
115    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
116    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
117    ///
118    /// while let Some(field) = multipart.next_field().await.unwrap() {
119    ///     let bytes = field.bytes().await.unwrap();
120    ///     assert_eq!(bytes.len(), 4);
121    /// }
122    /// # }
123    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
124    /// ```
125    pub async fn bytes(self) -> crate::Result<Bytes> {
126        let mut buf = BytesMut::new();
127
128        let mut this = self;
129        while let Some(bytes) = this.chunk().await? {
130            buf.extend_from_slice(&bytes);
131        }
132
133        Ok(buf.freeze())
134    }
135
136    /// Stream a chunk of the field data.
137    ///
138    /// When the field data has been exhausted, this will return [`None`].
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use std::convert::Infallible;
144    ///
145    /// use bytes::Bytes;
146    /// use futures_util::stream::once;
147    /// use multra::Multipart;
148    ///
149    /// # async fn run() {
150    /// let data =
151    ///     "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
152    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
153    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
154    ///
155    /// while let Some(mut field) = multipart.next_field().await.unwrap() {
156    ///     while let Some(chunk) = field.chunk().await.unwrap() {
157    ///         println!("Chunk: {:?}", chunk);
158    ///     }
159    /// }
160    /// # }
161    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
162    /// ```
163    pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
164        self.try_next().await
165    }
166
167    /// Try to deserialize the field data as JSON.
168    ///
169    /// # Optional
170    ///
171    /// This requires the optional `json` feature to be enabled.
172    ///
173    /// # Examples
174    ///
175    /// ```
176    /// use multra::Multipart;
177    /// use bytes::Bytes;
178    /// use std::convert::Infallible;
179    /// use futures_util::stream::once;
180    /// use serde::Deserialize;
181    ///
182    /// // This `derive` requires the `serde` dependency.
183    /// #[derive(Deserialize)]
184    /// struct User {
185    ///     name: String
186    /// }
187    ///
188    /// # async fn run() {
189    /// let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\n{ \"name\": \"Alice\" }\r\n--X-BOUNDARY--\r\n";
190    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
191    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
192    ///
193    /// while let Some(field) = multipart.next_field().await.unwrap() {
194    ///     let user = field.json::<User>().await.unwrap();
195    ///     println!("User Name: {}", user.name);
196    /// }
197    /// # }
198    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
199    /// ```
200    ///
201    /// # Errors
202    ///
203    /// This method fails if the field data is not in JSON format
204    /// or it cannot be properly deserialized to target type `T`. For more
205    /// details please see [`serde_json::from_slice`].
206    #[cfg(feature = "json")]
207    pub async fn json<T: DeserializeOwned>(self) -> crate::Result<T> {
208        serde_json::from_slice(&self.bytes().await?).map_err(Error::DecodeJson)
209    }
210
211    /// Get the full field data as text.
212    ///
213    /// This method decodes the field data with `BOM sniffing` and with
214    /// malformed sequences replaced with the `REPLACEMENT CHARACTER`.
215    /// Encoding is determined from the `charset` parameter of `Content-Type`
216    /// header, and defaults to `utf-8` if not presented.
217    ///
218    /// # Examples
219    ///
220    /// ```
221    /// use std::convert::Infallible;
222    ///
223    /// use bytes::Bytes;
224    /// use futures_util::stream::once;
225    /// use multra::Multipart;
226    ///
227    /// # async fn run() {
228    /// let data =
229    ///     "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
230    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
231    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
232    ///
233    /// while let Some(field) = multipart.next_field().await.unwrap() {
234    ///     let content = field.text().await.unwrap();
235    ///     assert_eq!(content, "abcd");
236    /// }
237    /// # }
238    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
239    /// ```
240    pub async fn text(self) -> crate::Result<String> {
241        self.text_with_charset("utf-8").await
242    }
243
244    /// Get the full field data as text given a specific encoding.
245    ///
246    /// This method decodes the field data with `BOM sniffing` and with
247    /// malformed sequences replaced with the `REPLACEMENT CHARACTER`.
248    /// You can provide a default encoding for decoding the raw message, while
249    /// the `charset` parameter of `Content-Type` header is still prioritized.
250    /// For more information about the possible encoding name, please go to
251    /// [encoding_rs] docs.
252    ///
253    /// # Examples
254    ///
255    /// ```
256    /// use std::convert::Infallible;
257    ///
258    /// use bytes::Bytes;
259    /// use futures_util::stream::once;
260    /// use multra::Multipart;
261    ///
262    /// # async fn run() {
263    /// let data =
264    ///     "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
265    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
266    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
267    ///
268    /// while let Some(field) = multipart.next_field().await.unwrap() {
269    ///     let content = field.text_with_charset("utf-8").await.unwrap();
270    ///     assert_eq!(content, "abcd");
271    /// }
272    /// # }
273    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
274    /// ```
275    pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
276        let encoding_name = self
277            .content_type()
278            .and_then(|mime| mime.get_param(mime::CHARSET))
279            .map(|charset| charset.as_str())
280            .unwrap_or(default_encoding);
281
282        let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
283        let bytes = self.bytes().await?;
284        Ok(encoding.decode(&bytes).0.into_owned())
285    }
286
287    /// Get the index of this field in order they appeared in the stream.
288    ///
289    /// # Examples
290    ///
291    /// ```
292    /// use std::convert::Infallible;
293    ///
294    /// use bytes::Bytes;
295    /// use futures_util::stream::once;
296    /// use multra::Multipart;
297    ///
298    /// # async fn run() {
299    /// let data =
300    ///     "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
301    /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
302    /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
303    ///
304    /// while let Some(field) = multipart.next_field().await.unwrap() {
305    ///     let idx = field.index();
306    ///     println!("Field index: {}", idx);
307    /// }
308    /// # }
309    /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
310    /// ```
311    pub fn index(&self) -> usize {
312        self.idx
313    }
314}
315
316impl Stream for Field<'_> {
317    type Item = Result<Bytes, Error>;
318
319    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320        if self.done {
321            return Poll::Ready(None);
322        }
323
324        debug_assert!(self.state.try_lock().is_some(), "expected exclusive lock");
325        let state = self.state.clone();
326        let mut lock = match state.try_lock() {
327            Some(lock) => lock,
328            None => return Poll::Ready(Some(Err(Error::LockFailure))),
329        };
330
331        let state = &mut *lock;
332        if let Err(err) = state.buffer.poll_stream(cx) {
333            return Poll::Ready(Some(Err(err)));
334        }
335
336        match state.buffer.read_field_data(
337            &state.field_boundary_bytes,
338            state.curr_field_name.as_deref(),
339        ) {
340            Ok(Some((done, bytes))) => {
341                state.curr_field_size_counter += bytes.len() as u64;
342
343                if state.curr_field_size_counter > state.curr_field_size_limit {
344                    return Poll::Ready(Some(Err(Error::FieldSizeExceeded {
345                        limit: state.curr_field_size_limit,
346                        field_name: state.curr_field_name.clone(),
347                    })));
348                }
349
350                if done {
351                    state.stage = StreamingStage::ReadingBoundary;
352                    self.done = true;
353                }
354
355                Poll::Ready(Some(Ok(bytes)))
356            }
357            Ok(None) => Poll::Pending,
358            Err(err) => Poll::Ready(Some(Err(err))),
359        }
360    }
361}