1#![forbid(unsafe_code)]
2
3mod content_disposition;
4mod error;
5mod field;
6mod header;
7
8pub use self::{error::MultipartError, field::Field};
9
10use core::{future::poll_fn, pin::Pin};
11
12use bytes::{Buf, BytesMut};
13use field::FieldDecoder;
14use futures_core::stream::Stream;
15use http::{header::HeaderMap, Method, Request};
16use memchr::memmem;
17use pin_project_lite::pin_project;
18
19use self::{content_disposition::ContentDisposition, error::PayloadError};
20
21pub fn multipart<Ext, B, T, E>(req: &Request<Ext>, body: B) -> Result<Multipart<B>, MultipartError>
59where
60 B: Stream<Item = Result<T, E>>,
61 T: AsRef<[u8]>,
62 E: Into<PayloadError>,
63{
64 multipart_with_config(req, body, Config::default())
65}
66
67pub fn multipart_with_config<Ext, B, T, E>(
69 req: &Request<Ext>,
70 body: B,
71 config: Config,
72) -> Result<Multipart<B>, MultipartError>
73where
74 B: Stream<Item = Result<T, E>>,
75 T: AsRef<[u8]>,
76 E: Into<PayloadError>,
77{
78 if req.method() != Method::POST {
79 return Err(MultipartError::NoPostMethod);
80 }
81
82 let boundary = header::boundary(req.headers())?;
83
84 Ok(Multipart {
85 stream: body,
86 buf: BytesMut::new(),
87 boundary: boundary.into(),
88 headers: HeaderMap::new(),
89 pending_field: false,
90 config,
91 })
92}
93
94#[derive(Debug, Copy, Clone)]
96pub struct Config {
97 pub buf_limit: usize,
101}
102
103impl Default for Config {
104 fn default() -> Self {
105 Self { buf_limit: 1024 * 1024 }
106 }
107}
108
109pin_project! {
110 pub struct Multipart<S> {
111 #[pin]
112 stream: S,
113 buf: BytesMut,
114 boundary: Box<[u8]>,
115 headers: HeaderMap,
116 pending_field: bool,
117 config: Config
118 }
119}
120
121const DOUBLE_HYPHEN: &[u8; 2] = b"--";
122const LF: &[u8; 1] = b"\n";
123const DOUBLE_CR_LF: &[u8; 4] = b"\r\n\r\n";
124
125impl<S, T, E> Multipart<S>
126where
127 S: Stream<Item = Result<T, E>>,
128 T: AsRef<[u8]>,
129 E: Into<PayloadError>,
130{
131 pub async fn try_next<'s>(self: &'s mut Pin<&mut Self>) -> Result<Option<Field<'s, S>>, MultipartError> {
134 let boundary_len = self.boundary.len();
135
136 if self.pending_field {
137 self.as_mut().consume_pending_field().await?;
138 }
139
140 loop {
141 let this = self.as_mut().project();
142 if let Some(idx) = memmem::find(this.buf, LF) {
143 let slice = match idx.checked_sub(1) {
145 Some(idx) => &this.buf[..idx],
146 None => return Err(MultipartError::Boundary),
148 };
149
150 match slice.len() {
151 0 => {
153 this.buf.advance(idx + 1);
155 continue;
156 }
157 len if len < (boundary_len + 2) => {}
159 _ if &slice[..2] != DOUBLE_HYPHEN => return Err(MultipartError::Boundary),
161 _ if this.boundary.as_ref().eq(&slice[2..]) => {
163 this.buf.advance(idx + 1);
165
166 let field = self.as_mut().parse_field().await?;
167 return Ok(Some(field));
168 }
169 len if len == (boundary_len + 4) => {
171 let at = boundary_len + 2;
172 let _ = this.boundary.as_ref().eq(&slice[2..at]) && &slice[at..] == DOUBLE_HYPHEN;
174 return Ok(None);
175 }
176 _ => return Err(MultipartError::Boundary),
178 }
179 }
180
181 if self.buf_overflow() {
182 return Err(MultipartError::BufferOverflow);
183 }
184
185 self.as_mut().try_read_stream_to_buf().await?;
186 }
187 }
188
189 async fn parse_field(mut self: Pin<&mut Self>) -> Result<Field<'_, S>, MultipartError> {
190 loop {
191 let this = self.as_mut().project();
192
193 if let Some(idx) = memmem::find(this.buf, DOUBLE_CR_LF) {
194 let slice = &this.buf[..idx + 4];
195
196 header::parse_headers(this.headers, slice)?;
197 this.buf.advance(slice.len());
198
199 let cp = ContentDisposition::try_from_header(this.headers)?;
200
201 header::check_headers(this.headers)?;
202
203 let length = header::content_length_opt(this.headers)?;
204
205 *this.pending_field = true;
206
207 return Ok(Field::new(length, cp, self));
208 }
209
210 if self.buf_overflow() {
211 return Err(MultipartError::Header(httparse::Error::TooManyHeaders));
212 }
213
214 self.as_mut().try_read_stream_to_buf().await?;
215 }
216 }
217
218 #[cold]
219 #[inline(never)]
220 async fn consume_pending_field(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
221 let mut field_ty = FieldDecoder::default();
222
223 loop {
224 let this = self.as_mut().project();
225 if let Some(idx) = field_ty.try_find_split_idx(this.buf, this.boundary)? {
226 this.buf.advance(idx);
227 }
228 if matches!(field_ty, FieldDecoder::StreamEnd) {
229 *this.pending_field = false;
230 return Ok(());
231 }
232 self.as_mut().try_read_stream_to_buf().await?;
233 }
234 }
235
236 async fn try_read_stream_to_buf(mut self: Pin<&mut Self>) -> Result<(), MultipartError> {
237 let bytes = self.as_mut().try_read_stream().await?;
238 self.project().buf.extend_from_slice(bytes.as_ref());
239 Ok(())
240 }
241
242 async fn try_read_stream(mut self: Pin<&mut Self>) -> Result<T, MultipartError> {
243 match poll_fn(move |cx| self.as_mut().project().stream.poll_next(cx)).await {
244 Some(Ok(bytes)) => Ok(bytes),
245 Some(Err(e)) => Err(MultipartError::Payload(e.into())),
246 None => Err(MultipartError::UnexpectedEof),
247 }
248 }
249
250 pub(crate) fn buf_overflow(&self) -> bool {
251 self.buf.len() > self.config.buf_limit
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use std::{convert::Infallible, pin::pin};
258
259 use bytes::Bytes;
260 use futures_util::FutureExt;
261 use http::header::{HeaderValue, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_TYPE};
262
263 use super::*;
264
265 fn once_body(b: impl Into<Bytes>) -> impl Stream<Item = Result<Bytes, Infallible>> {
266 futures_util::stream::once(async { Ok(b.into()) })
267 }
268
269 #[test]
270 fn method() {
271 let req = Request::new(());
272 let body = once_body(Bytes::new());
273 let err = multipart(&req, body).err();
274 assert!(matches!(err, Some(MultipartError::NoPostMethod)));
275 }
276
277 #[test]
278 fn basic() {
279 let body = b"\
280 --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
281 Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
282 Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\n\
283 test\r\n\
284 --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
285 Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
286 Content-Type: text/plain\r\n\r\n\
287 testdata\r\n\
288 --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
289 Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
290 Content-Type: text/plain\r\n\r\n\
291 testdata\r\n\
292 --abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
293 Content-Disposition: form-data; name=\"file\"; filename=\"bar.txt\"\r\n\
294 Content-Type: text/plain\r\nContent-Length: 9\r\n\r\n\
295 testdata2\r\n\
296 --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n\
297 ";
298
299 let mut req = Request::new(());
300 *req.method_mut() = Method::POST;
301 req.headers_mut().insert(
302 CONTENT_TYPE,
303 HeaderValue::from_static("multipart/mixed; boundary=abbc761f78ff4d7cb7573b5a23f96ef0"),
304 );
305
306 let body = once_body(Bytes::copy_from_slice(body));
307
308 let multipart = multipart(&req, body).unwrap();
309
310 let mut multipart = pin!(multipart);
311
312 {
313 let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
314
315 assert_eq!(
316 field.headers().get(CONTENT_DISPOSITION).unwrap(),
317 HeaderValue::from_static("form-data; name=\"file\"; filename=\"foo.txt\"")
318 );
319 assert_eq!(field.name().unwrap(), "file");
320 assert_eq!(field.file_name().unwrap(), "foo.txt");
321 assert_eq!(
322 field.headers().get(CONTENT_TYPE).unwrap(),
323 HeaderValue::from_static("text/plain; charset=utf-8")
324 );
325 assert_eq!(
326 field.headers().get(CONTENT_LENGTH).unwrap(),
327 HeaderValue::from_static("4")
328 );
329 assert_eq!(
330 field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
331 b"test"
332 );
333 assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
334 }
335
336 {
337 let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
338
339 assert_eq!(
340 field.headers().get(CONTENT_DISPOSITION).unwrap(),
341 HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
342 );
343 assert_eq!(field.name().unwrap(), "file");
344 assert_eq!(field.file_name().unwrap(), "bar.txt");
345 assert_eq!(
346 field.headers().get(CONTENT_TYPE).unwrap(),
347 HeaderValue::from_static("text/plain")
348 );
349 assert!(field.headers().get(CONTENT_LENGTH).is_none());
350 assert_eq!(
351 field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
352 b"testdata"
353 );
354 assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
355 }
356
357 multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
359
360 {
361 let mut field = multipart.try_next().now_or_never().unwrap().unwrap().unwrap();
362
363 assert_eq!(
364 field.headers().get(CONTENT_DISPOSITION).unwrap(),
365 HeaderValue::from_static("form-data; name=\"file\"; filename=\"bar.txt\"")
366 );
367 assert_eq!(field.name().unwrap(), "file");
368 assert_eq!(field.file_name().unwrap(), "bar.txt");
369 assert_eq!(
370 field.headers().get(CONTENT_TYPE).unwrap(),
371 HeaderValue::from_static("text/plain")
372 );
373 assert_eq!(
374 field.headers().get(CONTENT_LENGTH).unwrap(),
375 HeaderValue::from_static("9")
376 );
377 assert_eq!(
378 field.try_next().now_or_never().unwrap().unwrap().unwrap().chunk(),
379 b"testdata2"
380 );
381 assert!(field.try_next().now_or_never().unwrap().unwrap().is_none());
382 }
383
384 assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
385 assert!(multipart.try_next().now_or_never().unwrap().unwrap().is_none());
386 }
387
388 #[test]
389 fn field_header_overflow() {
390 let body = b"\
391 --12345\r\n\
392 Content-Disposition: form-data; name=\"file\"; filename=\"foo.txt\"\r\n\
393 Content-Type: text/plain; charset=utf-8\r\nContent-Length: 4";
394
395 let mut req = Request::new(());
396 *req.method_mut() = Method::POST;
397 req.headers_mut().insert(
398 CONTENT_TYPE,
399 HeaderValue::from_static("multipart/mixed; boundary=12345"),
400 );
401
402 let body = once_body(Bytes::copy_from_slice(body));
403
404 let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
406
407 let mut multipart = pin!(multipart);
408
409 assert!(matches!(
410 multipart.try_next().now_or_never().unwrap().err().unwrap(),
411 MultipartError::Header(httparse::Error::TooManyHeaders)
412 ));
413 }
414
415 #[test]
416 fn boundary_overflow() {
417 let body = b"--123456";
418
419 let mut req = Request::new(());
420 *req.method_mut() = Method::POST;
421 req.headers_mut().insert(
422 CONTENT_TYPE,
423 HeaderValue::from_static("multipart/mixed; boundary=12345"),
424 );
425
426 let body = once_body(Bytes::copy_from_slice(body));
427
428 let multipart = multipart_with_config(&req, body, Config { buf_limit: 7 }).unwrap();
430
431 let mut multipart = pin!(multipart);
432
433 assert!(matches!(
434 multipart.try_next().now_or_never().unwrap().err().unwrap(),
435 MultipartError::BufferOverflow
436 ));
437 }
438}