1use mime::Mime;
10
11use std::error::Error;
12use std::io::{self, BufRead, Read};
13use std::{fmt, str};
14
15use std::sync::Arc;
16
17use super::httparse::{self, Error as HttparseError, Header, Status, EMPTY_HEADER};
18
19use self::ReadEntryResult::*;
20
21use super::save::SaveBuilder;
22
23const EMPTY_STR_HEADER: StrHeader<'static> = StrHeader { name: "", val: "" };
24
25macro_rules! invalid_cont_disp {
26 ($reason: expr, $cause: expr) => {
27 return Err(ParseHeaderError::InvalidContDisp(
28 $reason,
29 $cause.to_string(),
30 ));
31 };
32}
33
34#[derive(Copy, Clone, Debug)]
36pub struct StrHeader<'a> {
37 name: &'a str,
38 val: &'a str,
39}
40
41struct DisplayHeaders<'s, 'a: 's>(&'s [StrHeader<'a>]);
42
43impl<'s, 'a: 's> fmt::Display for DisplayHeaders<'s, 'a> {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 for hdr in self.0 {
46 writeln!(f, "{}: {}", hdr.name, hdr.val)?;
47 }
48
49 Ok(())
50 }
51}
52
53fn with_headers<R, F, Ret>(r: &mut R, closure: F) -> Result<Ret, ParseHeaderError>
54where
55 R: BufRead,
56 F: FnOnce(&[StrHeader]) -> Ret,
57{
58 const HEADER_LEN: usize = 4;
59
60 let consume;
61 let ret;
62
63 let mut last_len = 0;
64
65 loop {
66 let buf = r.fill_buf()?;
68
69 if buf.len() == last_len {
71 return Err(ParseHeaderError::TooLarge);
72 }
73
74 let mut raw_headers = [EMPTY_HEADER; HEADER_LEN];
75
76 match httparse::parse_headers(buf, &mut raw_headers)? {
77 Status::Partial => last_len = buf.len(),
79 Status::Complete((consume_, raw_headers)) => {
80 let mut headers = [EMPTY_STR_HEADER; HEADER_LEN];
81 let headers = copy_headers(raw_headers, &mut headers)?;
82 debug!("Parsed headers: {:?}", headers);
83 consume = consume_;
84 ret = closure(headers);
85 break;
86 }
87 }
88 }
89
90 r.consume(consume);
91 Ok(ret)
92}
93
94fn copy_headers<'h, 'b: 'h>(
95 raw: &[Header<'b>],
96 headers: &'h mut [StrHeader<'b>],
97) -> io::Result<&'h [StrHeader<'b>]> {
98 for (raw, header) in raw.iter().zip(&mut *headers) {
99 header.name = raw.name;
100 header.val = io_str_utf8(raw.value)?;
101 }
102
103 Ok(&headers[..raw.len()])
104}
105
106#[derive(Clone, Debug)]
112pub struct FieldHeaders {
113 pub name: Arc<str>,
115
116 pub filename: Option<String>,
119
120 pub content_type: Option<Mime>,
127
128 pub content_range: Option<String>,
130}
131
132impl FieldHeaders {
133 fn read_from<R: BufRead>(r: &mut R) -> Result<Self, ParseHeaderError> {
135 with_headers(r, Self::parse)?
136 }
137
138 fn parse(headers: &[StrHeader]) -> Result<FieldHeaders, ParseHeaderError> {
139 let cont_disp = ContentDisp::parse(headers)?;
140
141 Ok(FieldHeaders {
142 name: cont_disp.field_name.into(),
143 filename: cont_disp.filename,
144 content_type: parse_content_type(headers)?,
145 content_range: parse_content_range(headers),
146 })
147 }
148}
149
150struct ContentDisp {
152 field_name: String,
154 filename: Option<String>,
156}
157
158impl ContentDisp {
159 fn parse(headers: &[StrHeader]) -> Result<ContentDisp, ParseHeaderError> {
160 let header = if let Some(header) = find_header(headers, "Content-Disposition") {
161 header
162 } else {
163 return Ok(Self {
164 field_name: "".into(),
165 filename: None,
166 });
167 };
168
169 let after_disp_type = match split_once(header.val, ';') {
171 Some((disp_type, after_disp_type)) => {
172 if disp_type.trim() != "form-data" {
175 invalid_cont_disp!("unexpected Content-Disposition value", disp_type);
176 }
177 after_disp_type
178 }
179 None => invalid_cont_disp!(
180 "expected additional data after Content-Disposition type",
181 header.val
182 ),
183 };
184
185 let (field_name, filename) = match get_str_after("name=", ';', after_disp_type) {
187 None => invalid_cont_disp!(
188 "expected field name and maybe filename, got",
189 after_disp_type
190 ),
191 Some((field_name, after_field_name)) => {
193 let field_name = trim_quotes(field_name);
194 let filename = get_str_after("filename=", ';', after_field_name)
195 .map(|(filename, _)| trim_quotes(filename).to_owned());
196 (field_name, filename)
197 }
198 };
199
200 Ok(ContentDisp {
201 field_name: field_name.to_owned(),
202 filename,
203 })
204 }
205}
206
207fn parse_content_type(headers: &[StrHeader]) -> Result<Option<Mime>, ParseHeaderError> {
208 if let Some(header) = find_header(headers, "Content-Type") {
209 debug!("Found Content-Type: {:?}", header.val);
211 Ok(Some(header.val.parse::<Mime>().map_err(|_| {
212 ParseHeaderError::MimeError(header.val.into())
213 })?))
214 } else {
215 Ok(None)
216 }
217}
218
219fn parse_content_range(headers: &[StrHeader]) -> Option<String> {
220 if let Some(header) = find_header(headers, "Content-Range") {
221 debug!("Found Content-Range: {:?}", header.val);
223 Some(header.val.to_string())
224 } else {
225 None
226 }
227}
228
229#[derive(Debug)]
231pub struct MultipartField<M: ReadEntry> {
232 pub headers: FieldHeaders,
238
239 pub data: MultipartData<M>,
241}
242
243impl<M: ReadEntry> MultipartField<M> {
244 pub fn is_text(&self) -> bool {
254 self.headers
255 .content_type
256 .as_ref()
257 .map_or(true, |ct| ct.type_() == mime::TEXT)
258 }
259
260 pub fn next_entry(self) -> ReadEntryResult<M> {
262 self.data.into_inner().read_entry()
263 }
264
265 pub fn next_entry_inplace(&mut self) -> io::Result<Option<&mut Self>>
270 where
271 for<'a> &'a mut M: ReadEntry,
272 {
273 let multipart = self.data.take_inner();
274
275 match multipart.read_entry() {
276 Entry(entry) => {
277 *self = entry;
278 Ok(Some(self))
279 }
280 End(multipart) => {
281 self.data.give_inner(multipart);
282 Ok(None)
283 }
284 Error(multipart, err) => {
285 self.data.give_inner(multipart);
286 Err(err)
287 }
288 }
289 }
290}
291
292#[derive(Debug)]
296pub struct MultipartData<M> {
297 inner: Option<M>,
298}
299
300const DATA_INNER_ERR: &str = "MultipartFile::inner taken and not replaced; this is likely \
301 caused by a logic error in `multipart` or by resuming after \
302 a previously caught panic.\nPlease open an issue with the \
303 relevant backtrace and debug logs at \
304 https://github.com/abonander/multipart";
305
306impl<M> MultipartData<M>
307where
308 M: ReadEntry,
309{
310 pub fn save(&mut self) -> SaveBuilder<&mut Self> {
312 SaveBuilder::new(self)
313 }
314
315 pub fn into_inner(self) -> M {
317 self.inner.expect(DATA_INNER_ERR)
318 }
319
320 pub fn set_min_buf_size(&mut self, min_buf_size: usize) {
327 self.inner_mut().set_min_buf_size(min_buf_size)
328 }
329
330 fn inner_mut(&mut self) -> &mut M {
331 self.inner.as_mut().expect(DATA_INNER_ERR)
332 }
333
334 fn take_inner(&mut self) -> M {
335 self.inner.take().expect(DATA_INNER_ERR)
336 }
337
338 fn give_inner(&mut self, inner: M) {
339 self.inner = Some(inner);
340 }
341}
342
343impl<M: ReadEntry> Read for MultipartData<M> {
344 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
345 self.inner_mut().source_mut().read(buf)
346 }
347}
348
349impl<M: ReadEntry> BufRead for MultipartData<M> {
353 fn fill_buf(&mut self) -> io::Result<&[u8]> {
354 self.inner_mut().source_mut().fill_buf()
355 }
356
357 fn consume(&mut self, amt: usize) {
358 self.inner_mut().source_mut().consume(amt)
359 }
360}
361
362fn split_once(s: &str, delim: char) -> Option<(&str, &str)> {
363 s.find(delim).map(|idx| s.split_at(idx))
364}
365
366fn trim_quotes(s: &str) -> &str {
367 s.trim_matches('"')
368}
369
370fn get_str_after<'a>(
372 needle: &str,
373 end_val_delim: char,
374 haystack: &'a str,
375) -> Option<(&'a str, &'a str)> {
376 let val_start_idx = try_opt!(haystack.find(needle)) + needle.len();
377 let val_end_idx = haystack[val_start_idx..]
378 .find(end_val_delim)
379 .map_or(haystack.len(), |end_idx| end_idx + val_start_idx);
380 Some((
381 &haystack[val_start_idx..val_end_idx],
382 &haystack[val_end_idx..],
383 ))
384}
385
386fn io_str_utf8(buf: &[u8]) -> io::Result<&str> {
387 str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388}
389
390fn find_header<'a, 'b>(headers: &'a [StrHeader<'b>], name: &str) -> Option<&'a StrHeader<'b>> {
391 headers
394 .iter()
395 .find(|header| header.name.eq_ignore_ascii_case(name))
396}
397
398pub trait ReadEntry: PrivReadEntry + Sized {
400 fn read_entry(mut self) -> ReadEntryResult<Self> {
402 self.set_min_buf_size(super::boundary::MIN_BUF_SIZE);
403
404 debug!("ReadEntry::read_entry()");
405
406 if !try_read_entry!(self; self.consume_boundary()) {
407 return End(self);
408 }
409
410 let field_headers: FieldHeaders = try_read_entry!(self; self.read_headers());
411
412 if let Some(ct) = field_headers.content_type.as_ref() {
413 if ct.type_() == mime::MULTIPART {
414 info!(
418 "Found nested multipart field: {:?}:\r\n\
419 Please report this client's User-Agent and any other available details \
420 at https://github.com/abonander/multipart/issues/56",
421 field_headers
422 );
423 }
424 }
425
426 Entry(MultipartField {
427 headers: field_headers,
428 data: MultipartData { inner: Some(self) },
429 })
430 }
431
432 fn read_entry_mut(&mut self) -> ReadEntryResult<&mut Self> {
434 ReadEntry::read_entry(self)
435 }
436}
437
438impl<T> ReadEntry for T where T: PrivReadEntry {}
439
440pub trait PrivReadEntry {
442 type Source: BufRead;
443
444 fn source_mut(&mut self) -> &mut Self::Source;
445
446 fn set_min_buf_size(&mut self, min_buf_size: usize);
447
448 fn consume_boundary(&mut self) -> io::Result<bool>;
451
452 fn read_headers(&mut self) -> Result<FieldHeaders, io::Error> {
453 FieldHeaders::read_from(self.source_mut())
454 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
455 }
456
457 fn read_to_string(&mut self) -> io::Result<String> {
458 let mut buf = String::new();
459
460 match self.source_mut().read_to_string(&mut buf) {
461 Ok(_) => Ok(buf),
462 Err(err) => Err(err),
463 }
464 }
465}
466
467impl<'a, M: ReadEntry> PrivReadEntry for &'a mut M {
468 type Source = M::Source;
469
470 fn source_mut(&mut self) -> &mut M::Source {
471 (**self).source_mut()
472 }
473
474 fn set_min_buf_size(&mut self, min_buf_size: usize) {
475 (**self).set_min_buf_size(min_buf_size)
476 }
477
478 fn consume_boundary(&mut self) -> io::Result<bool> {
479 (**self).consume_boundary()
480 }
481}
482
483pub enum ReadEntryResult<M: ReadEntry, Entry = MultipartField<M>> {
486 Entry(Entry),
488 End(M),
490 Error(M, io::Error),
492}
493
494impl<M: ReadEntry, Entry> ReadEntryResult<M, Entry> {
495 pub fn into_result(self) -> io::Result<Option<Entry>> {
501 match self {
502 ReadEntryResult::Entry(entry) => Ok(Some(entry)),
503 ReadEntryResult::End(_) => Ok(None),
504 ReadEntryResult::Error(_, err) => Err(err),
505 }
506 }
507
508 pub fn unwrap(self) -> Entry {
510 self.expect_alt(
511 "`ReadEntryResult::unwrap()` called on `End` value",
512 "`ReadEntryResult::unwrap()` called on `Error` value: {:?}",
513 )
514 }
515
516 pub fn expect(self, msg: &str) -> Entry {
519 self.expect_alt(msg, msg)
520 }
521
522 pub fn expect_alt(self, end_msg: &str, err_msg: &str) -> Entry {
526 match self {
527 Entry(entry) => entry,
528 End(_) => panic!("{}", end_msg),
529 Error(_, err) => panic!("{}: {:?}", err_msg, err),
530 }
531 }
532
533 pub fn unwrap_opt(self) -> Option<Entry> {
535 self.expect_opt("`ReadEntryResult::unwrap_opt()` called on `Error` value")
536 }
537
538 pub fn expect_opt(self, msg: &str) -> Option<Entry> {
541 match self {
542 Entry(entry) => Some(entry),
543 End(_) => None,
544 Error(_, err) => panic!("{}: {:?}", msg, err),
545 }
546 }
547}
548
549const GENERIC_PARSE_ERR: &str = "an error occurred while parsing field headers";
550
551quick_error! {
552 #[derive(Debug)]
553 enum ParseHeaderError {
554 MissingContentDisposition(headers: String) {
556 display(x) -> ("{}:\n{}", x.description(), headers)
557 description("\"Content-Disposition\" header not found in field headers")
558 }
559 InvalidContDisp(reason: &'static str, cause: String) {
560 display(x) -> ("{}: {}: {}", x.description(), reason, cause)
561 description("invalid \"Content-Disposition\" header")
562 }
563 TokenizeError(err: HttparseError) {
565 description(GENERIC_PARSE_ERR)
566 display(x) -> ("{}: {}", x.description(), err)
567 cause(err)
568 from()
569 }
570 MimeError(cont_type: String) {
571 description("Failed to parse Content-Type")
572 display(this) -> ("{}: {}", this.description(), cont_type)
573 }
574 TooLarge {
575 description("field headers section ridiculously long or missing trailing CRLF-CRLF")
576 }
577 Io(err: io::Error) {
579 description("an io error occurred while parsing the headers")
580 display(x) -> ("{}: {}", x.description(), err)
581 cause(err)
582 from()
583 }
584 }
585}
586
587#[test]
588fn test_find_header() {
589 let headers = [
590 StrHeader {
591 name: "Content-Type",
592 val: "text/plain",
593 },
594 StrHeader {
595 name: "Content-disposition",
596 val: "form-data",
597 },
598 StrHeader {
599 name: "content-transfer-encoding",
600 val: "binary",
601 },
602 ];
603
604 assert_eq!(
605 find_header(&headers, "Content-Type").unwrap().val,
606 "text/plain"
607 );
608 assert_eq!(
609 find_header(&headers, "Content-Disposition").unwrap().val,
610 "form-data"
611 );
612 assert_eq!(
613 find_header(&headers, "Content-Transfer-Encoding")
614 .unwrap()
615 .val,
616 "binary"
617 );
618}