multra/field.rs
1use std::pin::Pin;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use bytes::{Bytes, BytesMut};
6use encoding_rs::{Encoding, UTF_8};
7use futures_util::stream::{Stream, TryStreamExt};
8use http::header::HeaderMap;
9#[cfg(feature = "json")]
10use serde::de::DeserializeOwned;
11use spin::mutex::spin::SpinMutex as Mutex;
12
13use crate::content_disposition::ContentDisposition;
14use crate::multipart::{MultipartState, StreamingStage};
15use crate::{Error, helpers};
16
17/// A single field in a multipart stream.
18///
19/// Its content can be accessed via the [`Stream`] API or the methods defined in
20/// this type.
21///
22/// # Lifetime
23///
24/// The lifetime of the stream `'r` corresponds to the lifetime of the
25/// underlying `Stream`. If the underlying stream holds no references directly
26/// or transitively, then the lifetime can be `'static`.
27///
28/// # Examples
29///
30/// ```
31/// use std::convert::Infallible;
32///
33/// use bytes::Bytes;
34/// use futures_util::stream::once;
35/// use multra::Multipart;
36///
37/// # async fn run() {
38/// let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; \
39/// name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
40///
41/// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
42/// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
43///
44/// while let Some(field) = multipart.next_field().await.unwrap() {
45/// let content = field.text().await.unwrap();
46/// assert_eq!(content, "abcd");
47/// }
48/// # }
49/// # tokio::runtime::Runtime::new().unwrap().block_on(run());
50/// ```
51///
52/// [`Multipart`]: crate::Multipart
53#[derive(Debug)]
54pub struct Field<'r> {
55 state: Arc<Mutex<MultipartState<'r>>>,
56 done: bool,
57 headers: HeaderMap,
58 content_disposition: ContentDisposition,
59 content_type: Option<mime::Mime>,
60 idx: usize,
61}
62
63impl<'r> Field<'r> {
64 pub(crate) fn new(
65 state: Arc<Mutex<MultipartState<'r>>>,
66 headers: HeaderMap,
67 idx: usize,
68 content_disposition: ContentDisposition,
69 ) -> Self {
70 let content_type = helpers::parse_content_type(&headers);
71 Field {
72 state,
73 headers,
74 content_disposition,
75 content_type,
76 idx,
77 done: false,
78 }
79 }
80
81 /// The field name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
82 pub fn name(&self) -> Option<&str> {
83 self.content_disposition.field_name.as_deref()
84 }
85
86 /// The file name found in the [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header.
87 pub fn file_name(&self) -> Option<&str> {
88 self.content_disposition.file_name.as_deref()
89 }
90
91 /// Get the content type of the field.
92 pub fn content_type(&self) -> Option<&mime::Mime> {
93 self.content_type.as_ref()
94 }
95
96 /// Get a map of headers as [`HeaderMap`].
97 pub fn headers(&self) -> &HeaderMap {
98 &self.headers
99 }
100
101 /// Get the full data of the field as [`Bytes`].
102 ///
103 /// # Examples
104 ///
105 /// ```
106 /// use std::convert::Infallible;
107 ///
108 /// use bytes::Bytes;
109 /// use futures_util::stream::once;
110 /// use multra::Multipart;
111 ///
112 /// # async fn run() {
113 /// let data =
114 /// "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
115 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
116 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
117 ///
118 /// while let Some(field) = multipart.next_field().await.unwrap() {
119 /// let bytes = field.bytes().await.unwrap();
120 /// assert_eq!(bytes.len(), 4);
121 /// }
122 /// # }
123 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
124 /// ```
125 pub async fn bytes(self) -> crate::Result<Bytes> {
126 let mut buf = BytesMut::new();
127
128 let mut this = self;
129 while let Some(bytes) = this.chunk().await? {
130 buf.extend_from_slice(&bytes);
131 }
132
133 Ok(buf.freeze())
134 }
135
136 /// Stream a chunk of the field data.
137 ///
138 /// When the field data has been exhausted, this will return [`None`].
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// use std::convert::Infallible;
144 ///
145 /// use bytes::Bytes;
146 /// use futures_util::stream::once;
147 /// use multra::Multipart;
148 ///
149 /// # async fn run() {
150 /// let data =
151 /// "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
152 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
153 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
154 ///
155 /// while let Some(mut field) = multipart.next_field().await.unwrap() {
156 /// while let Some(chunk) = field.chunk().await.unwrap() {
157 /// println!("Chunk: {:?}", chunk);
158 /// }
159 /// }
160 /// # }
161 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
162 /// ```
163 pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
164 self.try_next().await
165 }
166
167 /// Try to deserialize the field data as JSON.
168 ///
169 /// # Optional
170 ///
171 /// This requires the optional `json` feature to be enabled.
172 ///
173 /// # Examples
174 ///
175 /// ```
176 /// use multra::Multipart;
177 /// use bytes::Bytes;
178 /// use std::convert::Infallible;
179 /// use futures_util::stream::once;
180 /// use serde::Deserialize;
181 ///
182 /// // This `derive` requires the `serde` dependency.
183 /// #[derive(Deserialize)]
184 /// struct User {
185 /// name: String
186 /// }
187 ///
188 /// # async fn run() {
189 /// let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\n{ \"name\": \"Alice\" }\r\n--X-BOUNDARY--\r\n";
190 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
191 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
192 ///
193 /// while let Some(field) = multipart.next_field().await.unwrap() {
194 /// let user = field.json::<User>().await.unwrap();
195 /// println!("User Name: {}", user.name);
196 /// }
197 /// # }
198 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
199 /// ```
200 ///
201 /// # Errors
202 ///
203 /// This method fails if the field data is not in JSON format
204 /// or it cannot be properly deserialized to target type `T`. For more
205 /// details please see [`serde_json::from_slice`].
206 #[cfg(feature = "json")]
207 pub async fn json<T: DeserializeOwned>(self) -> crate::Result<T> {
208 serde_json::from_slice(&self.bytes().await?).map_err(Error::DecodeJson)
209 }
210
211 /// Get the full field data as text.
212 ///
213 /// This method decodes the field data with `BOM sniffing` and with
214 /// malformed sequences replaced with the `REPLACEMENT CHARACTER`.
215 /// Encoding is determined from the `charset` parameter of `Content-Type`
216 /// header, and defaults to `utf-8` if not presented.
217 ///
218 /// # Examples
219 ///
220 /// ```
221 /// use std::convert::Infallible;
222 ///
223 /// use bytes::Bytes;
224 /// use futures_util::stream::once;
225 /// use multra::Multipart;
226 ///
227 /// # async fn run() {
228 /// let data =
229 /// "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
230 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
231 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
232 ///
233 /// while let Some(field) = multipart.next_field().await.unwrap() {
234 /// let content = field.text().await.unwrap();
235 /// assert_eq!(content, "abcd");
236 /// }
237 /// # }
238 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
239 /// ```
240 pub async fn text(self) -> crate::Result<String> {
241 self.text_with_charset("utf-8").await
242 }
243
244 /// Get the full field data as text given a specific encoding.
245 ///
246 /// This method decodes the field data with `BOM sniffing` and with
247 /// malformed sequences replaced with the `REPLACEMENT CHARACTER`.
248 /// You can provide a default encoding for decoding the raw message, while
249 /// the `charset` parameter of `Content-Type` header is still prioritized.
250 /// For more information about the possible encoding name, please go to
251 /// [encoding_rs] docs.
252 ///
253 /// # Examples
254 ///
255 /// ```
256 /// use std::convert::Infallible;
257 ///
258 /// use bytes::Bytes;
259 /// use futures_util::stream::once;
260 /// use multra::Multipart;
261 ///
262 /// # async fn run() {
263 /// let data =
264 /// "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
265 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
266 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
267 ///
268 /// while let Some(field) = multipart.next_field().await.unwrap() {
269 /// let content = field.text_with_charset("utf-8").await.unwrap();
270 /// assert_eq!(content, "abcd");
271 /// }
272 /// # }
273 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
274 /// ```
275 pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
276 let encoding_name = self
277 .content_type()
278 .and_then(|mime| mime.get_param(mime::CHARSET))
279 .map(|charset| charset.as_str())
280 .unwrap_or(default_encoding);
281
282 let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
283 let bytes = self.bytes().await?;
284 Ok(encoding.decode(&bytes).0.into_owned())
285 }
286
287 /// Get the index of this field in order they appeared in the stream.
288 ///
289 /// # Examples
290 ///
291 /// ```
292 /// use std::convert::Infallible;
293 ///
294 /// use bytes::Bytes;
295 /// use futures_util::stream::once;
296 /// use multra::Multipart;
297 ///
298 /// # async fn run() {
299 /// let data =
300 /// "--X-BOUNDARY\r\nContent-Disposition: form-data; name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--\r\n";
301 /// let stream = once(async move { Result::<Bytes, Infallible>::Ok(Bytes::from(data)) });
302 /// let mut multipart = Multipart::new(stream, "X-BOUNDARY");
303 ///
304 /// while let Some(field) = multipart.next_field().await.unwrap() {
305 /// let idx = field.index();
306 /// println!("Field index: {}", idx);
307 /// }
308 /// # }
309 /// # tokio::runtime::Runtime::new().unwrap().block_on(run());
310 /// ```
311 pub fn index(&self) -> usize {
312 self.idx
313 }
314}
315
316impl Stream for Field<'_> {
317 type Item = Result<Bytes, Error>;
318
319 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320 if self.done {
321 return Poll::Ready(None);
322 }
323
324 debug_assert!(self.state.try_lock().is_some(), "expected exclusive lock");
325 let state = self.state.clone();
326 let mut lock = match state.try_lock() {
327 Some(lock) => lock,
328 None => return Poll::Ready(Some(Err(Error::LockFailure))),
329 };
330
331 let state = &mut *lock;
332 if let Err(err) = state.buffer.poll_stream(cx) {
333 return Poll::Ready(Some(Err(err)));
334 }
335
336 match state.buffer.read_field_data(
337 &state.field_boundary_bytes,
338 state.curr_field_name.as_deref(),
339 ) {
340 Ok(Some((done, bytes))) => {
341 state.curr_field_size_counter += bytes.len() as u64;
342
343 if state.curr_field_size_counter > state.curr_field_size_limit {
344 return Poll::Ready(Some(Err(Error::FieldSizeExceeded {
345 limit: state.curr_field_size_limit,
346 field_name: state.curr_field_name.clone(),
347 })));
348 }
349
350 if done {
351 state.stage = StreamingStage::ReadingBoundary;
352 self.done = true;
353 }
354
355 Poll::Ready(Some(Ok(bytes)))
356 }
357 Ok(None) => Poll::Pending,
358 Err(err) => Poll::Ready(Some(Err(err))),
359 }
360 }
361}