Skip to main content

axon/ingest/
multipart.rs

1//! Minimal `multipart/form-data` parser that streams each part
2//! into a [`crate::buffer::BufferMut`].
3//!
4//! The parser works on discrete chunks (as they arrive from the
5//! network) so ingest never buffers the entire request body in
6//! RAM; each part transitions from `BufferMut` to `ZeroCopyBuffer`
7//! when its boundary is hit.
8//!
9//! This is a pragmatic RFC 7578 subset. It handles:
10//!
11//! - Boundary detection (leading + closing)
12//! - Header parsing (case-insensitive on names)
13//! - `Content-Disposition` → field name + optional file name
14//! - `Content-Type` → informational, plus a best-effort mapping to
15//!   the [`crate::buffer::BufferKind`] tag
16//! - Streaming payload accumulation into `BufferMut`
17//!
18//! Out of scope for 11.b:
19//!
20//! - `Content-Transfer-Encoding` (base64, quoted-printable) — adopters
21//!   that need these decode at the application layer
22//! - Nested multipart (multipart within multipart) — rejected with
23//!   `MultipartError::Nested`
24//!
25//! The API is stepwise (`feed(bytes) -> Vec<Event>`) so the caller
26//! — typically a Tokio HTTP handler — drives the parser without
27//! owning the full request.
28
29use crate::buffer::{BufferKind, BufferMut, ZeroCopyBuffer};
30
31// ── Errors ───────────────────────────────────────────────────────────
32
33#[derive(Debug)]
34pub enum MultipartError {
35    /// Content-Type boundary missing or malformed.
36    MissingBoundary,
37    /// Header section exceeded our per-part cap.
38    HeaderTooLarge {
39        limit: usize,
40    },
41    /// Single part payload exceeded the configured per-part limit.
42    PartTooLarge {
43        limit: usize,
44    },
45    /// Upstream closed the connection mid-part.
46    UnexpectedEof,
47    /// Nested multipart bodies are not supported.
48    Nested,
49    /// Malformed header line (no `:` separator).
50    MalformedHeader {
51        line: String,
52    },
53}
54
55impl std::fmt::Display for MultipartError {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::MissingBoundary => write!(f, "missing multipart boundary"),
59            Self::HeaderTooLarge { limit } => {
60                write!(f, "header section exceeded {limit} bytes")
61            }
62            Self::PartTooLarge { limit } => {
63                write!(f, "part payload exceeded {limit} bytes")
64            }
65            Self::UnexpectedEof => write!(f, "stream ended mid-part"),
66            Self::Nested => write!(f, "nested multipart is not supported"),
67            Self::MalformedHeader { line } => {
68                write!(f, "malformed header: {line:?}")
69            }
70        }
71    }
72}
73
74impl std::error::Error for MultipartError {}
75
76// ── Events emitted by the parser ─────────────────────────────────────
77
78/// What the caller learns after feeding a chunk. The parser may emit
79/// multiple events per chunk (a boundary can appear mid-chunk).
80#[derive(Debug)]
81pub enum MultipartEvent {
82    /// Start of a new part. Carries metadata extracted from headers.
83    PartStart {
84        field_name: String,
85        file_name: Option<String>,
86        content_type: Option<String>,
87        kind: BufferKind,
88    },
89    /// End of the current part — the accumulated payload is now
90    /// available as an immutable [`ZeroCopyBuffer`].
91    PartEnd {
92        field_name: String,
93        payload: ZeroCopyBuffer,
94    },
95    /// Terminating boundary — no more parts will follow.
96    Complete,
97}
98
99// ── Configuration ────────────────────────────────────────────────────
100
101#[derive(Debug, Clone)]
102pub struct MultipartLimits {
103    pub max_header_bytes: usize,
104    pub max_part_bytes: usize,
105}
106
107impl Default for MultipartLimits {
108    fn default() -> Self {
109        // 16 KiB headers should fit any reasonable Content-Disposition.
110        // 32 MiB per part is a sensible web default; adopters with
111        // larger uploads raise this explicitly.
112        MultipartLimits {
113            max_header_bytes: 16 * 1024,
114            max_part_bytes: 32 * 1024 * 1024,
115        }
116    }
117}
118
119// ── Parser ───────────────────────────────────────────────────────────
120
121#[derive(Debug, PartialEq, Eq)]
122enum State {
123    /// Before the first boundary — we tolerate preamble.
124    Preamble,
125    /// Collecting header lines of the current part.
126    Headers,
127    /// Streaming payload bytes into the current part.
128    Body,
129    /// We saw the terminating `--boundary--` marker.
130    Terminated,
131}
132
133pub struct MultipartParser {
134    boundary: Vec<u8>,
135    limits: MultipartLimits,
136    state: State,
137    /// Accumulates bytes until we've matched a boundary or header end.
138    buf: Vec<u8>,
139    /// Headers for the current part (cleared on PartStart emission).
140    current_headers: Vec<(String, String)>,
141    /// Payload accumulator for the current part.
142    current_body: Option<BufferMut>,
143    current_field_name: Option<String>,
144}
145
146impl MultipartParser {
147    /// Build a parser from the boundary extracted from the request's
148    /// `Content-Type` header (callers pass the value WITHOUT the
149    /// leading `--`; `parse_boundary_from_content_type` helps).
150    pub fn new(boundary: impl Into<String>, limits: MultipartLimits) -> Self {
151        let boundary = boundary.into().into_bytes();
152        MultipartParser {
153            boundary,
154            limits,
155            state: State::Preamble,
156            buf: Vec::with_capacity(4 * 1024),
157            current_headers: Vec::new(),
158            current_body: None,
159            current_field_name: None,
160        }
161    }
162
163    /// Feed a chunk of bytes. May emit 0, 1 or many events; caller
164    /// consumes them in order.
165    pub fn feed(
166        &mut self,
167        chunk: &[u8],
168    ) -> Result<Vec<MultipartEvent>, MultipartError> {
169        self.buf.extend_from_slice(chunk);
170        let mut out = Vec::new();
171        loop {
172            let progressed = self.step(&mut out)?;
173            if !progressed {
174                break;
175            }
176        }
177        Ok(out)
178    }
179
180    /// Called by the driver when the upstream closes. Emits any
181    /// trailing part or an `UnexpectedEof` error if we're mid-part.
182    pub fn finalize(
183        &mut self,
184        out: &mut Vec<MultipartEvent>,
185    ) -> Result<(), MultipartError> {
186        match self.state {
187            State::Terminated => Ok(()),
188            State::Preamble => Err(MultipartError::UnexpectedEof),
189            State::Headers => Err(MultipartError::UnexpectedEof),
190            State::Body => {
191                // We treat an EOF-without-closing-boundary as the
192                // remaining body — most clients produce a closing
193                // boundary, but a resilient parser accepts the tail.
194                if let Some(body) = self.current_body.take() {
195                    let field = self.current_field_name.take().unwrap_or_default();
196                    out.push(MultipartEvent::PartEnd {
197                        field_name: field,
198                        payload: body.freeze(),
199                    });
200                }
201                self.state = State::Terminated;
202                Ok(())
203            }
204        }
205    }
206
207    // ── Inner step ────────────────────────────────────────────────
208
209    fn step(
210        &mut self,
211        out: &mut Vec<MultipartEvent>,
212    ) -> Result<bool, MultipartError> {
213        match self.state {
214            State::Preamble => self.seek_initial_boundary(out),
215            State::Headers => self.parse_headers(out),
216            State::Body => self.stream_body(out),
217            State::Terminated => Ok(false),
218        }
219    }
220
221    fn seek_initial_boundary(
222        &mut self,
223        _out: &mut Vec<MultipartEvent>,
224    ) -> Result<bool, MultipartError> {
225        // The first boundary is `--<boundary>\r\n`.
226        let marker = self.boundary_marker(/*closing=*/ false);
227        match find_subsequence(&self.buf, &marker) {
228            Some(idx) => {
229                let tail = idx + marker.len();
230                self.buf.drain(..tail);
231                self.state = State::Headers;
232                Ok(true)
233            }
234            None => {
235                // Keep the trailing 4*max(2,marker.len()) bytes so a
236                // split boundary across two feeds still matches on
237                // the next step.
238                let keep = marker.len().saturating_add(4);
239                if self.buf.len() > keep {
240                    self.buf.drain(..self.buf.len() - keep);
241                }
242                Ok(false)
243            }
244        }
245    }
246
247    fn parse_headers(
248        &mut self,
249        out: &mut Vec<MultipartEvent>,
250    ) -> Result<bool, MultipartError> {
251        // Header section terminates at the first blank line (`\r\n\r\n`).
252        let terminator = b"\r\n\r\n";
253        let Some(idx) = find_subsequence(&self.buf, terminator) else {
254            if self.buf.len() > self.limits.max_header_bytes {
255                return Err(MultipartError::HeaderTooLarge {
256                    limit: self.limits.max_header_bytes,
257                });
258            }
259            return Ok(false);
260        };
261
262        // §Fase 12.c — also enforce `max_header_bytes` when the
263        // terminator arrives in the same feed as the (oversized)
264        // header. Without this, a caller that fills the buffer with
265        // one big call bypasses the limit because the "buffer larger
266        // than the cap" branch above only fires while the terminator
267        // is still missing.
268        if idx > self.limits.max_header_bytes {
269            return Err(MultipartError::HeaderTooLarge {
270                limit: self.limits.max_header_bytes,
271            });
272        }
273
274        let header_block = self.buf.drain(..idx + terminator.len()).collect::<Vec<u8>>();
275        // Drop the trailing blank-line terminator from the parse set.
276        let header_text = &header_block[..header_block.len() - terminator.len()];
277        let text = std::str::from_utf8(header_text).unwrap_or("");
278        self.current_headers.clear();
279        for raw in text.split("\r\n") {
280            if raw.is_empty() {
281                continue;
282            }
283            let Some((k, v)) = raw.split_once(':') else {
284                return Err(MultipartError::MalformedHeader {
285                    line: raw.to_string(),
286                });
287            };
288            self.current_headers
289                .push((k.trim().to_ascii_lowercase(), v.trim().to_string()));
290        }
291
292        let (field_name, file_name) = disposition_field_and_file(
293            &self.current_headers,
294        );
295        let content_type = self
296            .current_headers
297            .iter()
298            .find(|(k, _)| k == "content-type")
299            .map(|(_, v)| v.clone());
300
301        if content_type
302            .as_deref()
303            .map(|ct| ct.to_ascii_lowercase().contains("multipart/"))
304            .unwrap_or(false)
305        {
306            return Err(MultipartError::Nested);
307        }
308
309        let kind = kind_for_content_type(content_type.as_deref());
310
311        out.push(MultipartEvent::PartStart {
312            field_name: field_name.clone().unwrap_or_default(),
313            file_name,
314            content_type,
315            kind: kind.clone(),
316        });
317
318        self.current_field_name = field_name;
319        self.current_body = Some(BufferMut::with_capacity(4 * 1024, kind));
320        self.state = State::Body;
321        Ok(true)
322    }
323
324    fn stream_body(
325        &mut self,
326        out: &mut Vec<MultipartEvent>,
327    ) -> Result<bool, MultipartError> {
328        let open_marker = self.boundary_marker(false);
329        let close_marker = self.boundary_marker(true);
330
331        // Find the earliest marker in the buffer.
332        let open_idx = find_subsequence(&self.buf, &open_marker);
333        let close_idx = find_subsequence(&self.buf, &close_marker);
334
335        let (boundary_idx, is_closing) = match (open_idx, close_idx) {
336            (None, None) => (None, false),
337            (Some(o), None) => (Some(o), false),
338            (None, Some(c)) => (Some(c), true),
339            (Some(o), Some(c)) => {
340                if c < o {
341                    (Some(c), true)
342                } else {
343                    (Some(o), false)
344                }
345            }
346        };
347
348        let Some(idx) = boundary_idx else {
349            // No boundary in buffer yet. Flush everything EXCEPT the
350            // trailing tail that might still become either:
351            //   · a boundary marker (up to `close_marker.len()` bytes), or
352            //   · the mandatory `\r\n` that RFC 7578 §4.1 requires
353            //     immediately before the boundary (body_end = idx - 2
354            //     trims those two bytes from the body, but only if
355            //     they are still in the buffer when the marker is
356            //     recognised).
357            //
358            // §Fase 12.c fix — the previous heuristic kept only
359            // `close_marker.len()` (or `open_marker.len()` when the
360            // buffer was smaller). In a byte-at-a-time feed that lost
361            // the `\r\n` preceding the boundary, emitting it as part
362            // of the body and leaving the boundary marker
363            // unrecognisable because the parser had already flushed
364            // the first one or two bytes of its prefix.
365            let keep = close_marker.len().max(open_marker.len()) + 2;
366            if self.buf.len() <= keep {
367                return Ok(false);
368            }
369            let take = self.buf.len() - keep;
370            let body = self
371                .current_body
372                .as_mut()
373                .expect("body builder missing in Body state");
374            if body.len() + take > self.limits.max_part_bytes {
375                return Err(MultipartError::PartTooLarge {
376                    limit: self.limits.max_part_bytes,
377                });
378            }
379            body.extend_from_slice(&self.buf[..take]);
380            self.buf.drain(..take);
381            return Ok(false);
382        };
383
384        // The body ends at `idx - 2` to trim the trailing `\r\n` that
385        // precedes every boundary (RFC 7578 §4.1).
386        let body_end = idx.saturating_sub(2);
387        {
388            let body = self
389                .current_body
390                .as_mut()
391                .expect("body builder missing in Body state");
392            if body.len() + body_end > self.limits.max_part_bytes {
393                return Err(MultipartError::PartTooLarge {
394                    limit: self.limits.max_part_bytes,
395                });
396            }
397            body.extend_from_slice(&self.buf[..body_end]);
398        }
399        let finished = self.current_body.take().unwrap();
400        let field = self.current_field_name.take().unwrap_or_default();
401        out.push(MultipartEvent::PartEnd {
402            field_name: field,
403            payload: finished.freeze(),
404        });
405
406        // Drain through the end of the matched marker.
407        let marker_len = if is_closing {
408            close_marker.len()
409        } else {
410            open_marker.len()
411        };
412        self.buf.drain(..idx + marker_len);
413
414        self.state = if is_closing {
415            out.push(MultipartEvent::Complete);
416            State::Terminated
417        } else {
418            State::Headers
419        };
420        Ok(true)
421    }
422
423    // ── Helpers ────────────────────────────────────────────────────
424
425    fn boundary_marker(&self, closing: bool) -> Vec<u8> {
426        let mut v = Vec::with_capacity(self.boundary.len() + 6);
427        v.extend_from_slice(b"--");
428        v.extend_from_slice(&self.boundary);
429        if closing {
430            v.extend_from_slice(b"--");
431        }
432        v.extend_from_slice(b"\r\n");
433        v
434    }
435}
436
437// ── Helpers shared with tests ────────────────────────────────────────
438
439/// Extract the `boundary=...` parameter from a `Content-Type` value.
440/// Returns `None` when absent or malformed.
441pub fn parse_boundary_from_content_type(value: &str) -> Option<String> {
442    for part in value.split(';') {
443        let part = part.trim();
444        if let Some(rest) = part.strip_prefix("boundary=") {
445            // Strip optional quoting.
446            let unquoted = rest.trim_matches('"').to_string();
447            if !unquoted.is_empty() {
448                return Some(unquoted);
449            }
450        }
451    }
452    None
453}
454
455fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
456    if needle.is_empty() || needle.len() > haystack.len() {
457        return None;
458    }
459    haystack
460        .windows(needle.len())
461        .position(|w| w == needle)
462}
463
464fn disposition_field_and_file(
465    headers: &[(String, String)],
466) -> (Option<String>, Option<String>) {
467    let Some((_, disp)) = headers
468        .iter()
469        .find(|(k, _)| k == "content-disposition")
470    else {
471        return (None, None);
472    };
473    let mut field_name: Option<String> = None;
474    let mut file_name: Option<String> = None;
475    for segment in disp.split(';') {
476        let segment = segment.trim();
477        if let Some(rest) = segment.strip_prefix("name=") {
478            field_name = Some(rest.trim_matches('"').to_string());
479        } else if let Some(rest) = segment.strip_prefix("filename=") {
480            file_name = Some(rest.trim_matches('"').to_string());
481        }
482    }
483    (field_name, file_name)
484}
485
486fn kind_for_content_type(ct: Option<&str>) -> BufferKind {
487    let Some(ct) = ct else {
488        return BufferKind::raw();
489    };
490    let ct_low = ct.to_ascii_lowercase();
491    // Cheap prefix + keyword match. Adopters override on the
492    // returned BufferMut if they want a more specific tag.
493    if ct_low.starts_with("image/jpeg") {
494        BufferKind::jpeg()
495    } else if ct_low.starts_with("image/png") {
496        BufferKind::png()
497    } else if ct_low.starts_with("image/webp") {
498        BufferKind::webp()
499    } else if ct_low.starts_with("audio/mpeg") {
500        BufferKind::mp3()
501    } else if ct_low.starts_with("audio/opus") || ct_low.contains("ogg") {
502        BufferKind::opus()
503    } else if ct_low.starts_with("audio/wav") || ct_low.starts_with("audio/x-wav") {
504        BufferKind::wav()
505    } else if ct_low.starts_with("video/mp4") {
506        BufferKind::mp4()
507    } else if ct_low.starts_with("video/webm") {
508        BufferKind::webm()
509    } else if ct_low.starts_with("application/pdf") {
510        BufferKind::pdf()
511    } else if ct_low.contains("json") {
512        BufferKind::json()
513    } else if ct_low.contains("csv") {
514        BufferKind::csv()
515    } else {
516        BufferKind::raw()
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    fn build(body: &[u8]) -> Vec<u8> {
525        body.to_vec()
526    }
527
528    #[test]
529    fn parses_content_type_boundary() {
530        assert_eq!(
531            parse_boundary_from_content_type(
532                "multipart/form-data; boundary=------abc"
533            ),
534            Some("------abc".to_string())
535        );
536        assert_eq!(
537            parse_boundary_from_content_type(
538                "multipart/form-data; boundary=\"quoted-boundary\""
539            ),
540            Some("quoted-boundary".to_string())
541        );
542        assert_eq!(
543            parse_boundary_from_content_type("text/plain"),
544            None
545        );
546    }
547
548    #[test]
549    fn single_text_part_roundtrip() {
550        let body = build(b"\
551            --abc\r\n\
552            Content-Disposition: form-data; name=\"greeting\"\r\n\
553            \r\n\
554            hello world\r\n\
555            --abc--\r\n");
556
557        let mut p = MultipartParser::new("abc", MultipartLimits::default());
558        let events = p.feed(&body).expect("parse");
559        assert_eq!(events.len(), 3);
560        match &events[0] {
561            MultipartEvent::PartStart { field_name, kind, .. } => {
562                assert_eq!(field_name, "greeting");
563                assert_eq!(kind.slug(), "raw");
564            }
565            other => panic!("expected PartStart, got {other:?}"),
566        }
567        match &events[1] {
568            MultipartEvent::PartEnd { field_name, payload } => {
569                assert_eq!(field_name, "greeting");
570                assert_eq!(payload.as_slice(), b"hello world");
571            }
572            other => panic!("expected PartEnd, got {other:?}"),
573        }
574        matches!(events[2], MultipartEvent::Complete);
575    }
576
577    #[test]
578    fn two_parts_with_jpeg_content_type() {
579        let body = build(b"\
580            --bdy\r\n\
581            Content-Disposition: form-data; name=\"field1\"\r\n\
582            \r\n\
583            value1\r\n\
584            --bdy\r\n\
585            Content-Disposition: form-data; name=\"image\"; filename=\"a.jpg\"\r\n\
586            Content-Type: image/jpeg\r\n\
587            \r\n\
588            BINARYDATA\r\n\
589            --bdy--\r\n");
590
591        let mut p = MultipartParser::new("bdy", MultipartLimits::default());
592        let evs = p.feed(&body).unwrap();
593        // Two start/end pairs + complete.
594        let kinds: Vec<_> = evs
595            .iter()
596            .filter_map(|e| match e {
597                MultipartEvent::PartStart { kind, .. } => Some(kind.clone()),
598                _ => None,
599            })
600            .collect();
601        assert_eq!(kinds.len(), 2);
602        assert_eq!(kinds[0].slug(), "raw");
603        assert_eq!(kinds[1].slug(), "jpeg");
604
605        let payloads: Vec<_> = evs
606            .iter()
607            .filter_map(|e| match e {
608                MultipartEvent::PartEnd { payload, .. } => Some(payload.clone()),
609                _ => None,
610            })
611            .collect();
612        assert_eq!(payloads[0].as_slice(), b"value1");
613        assert_eq!(payloads[1].as_slice(), b"BINARYDATA");
614    }
615
616    #[test]
617    fn chunked_feed_works_across_boundary_splits() {
618        let body = b"\
619            --z\r\n\
620            Content-Disposition: form-data; name=\"n\"\r\n\
621            \r\n\
622            hello world\r\n\
623            --z--\r\n";
624
625        let mut p = MultipartParser::new("z", MultipartLimits::default());
626        // Feed one byte at a time — the hardest case for streaming
627        // parsers.
628        let mut events = Vec::new();
629        for byte in body {
630            events.extend(p.feed(&[*byte]).unwrap());
631        }
632        let payloads: Vec<_> = events
633            .iter()
634            .filter_map(|e| match e {
635                MultipartEvent::PartEnd { payload, .. } => Some(payload.clone()),
636                _ => None,
637            })
638            .collect();
639        assert_eq!(payloads.len(), 1);
640        assert_eq!(payloads[0].as_slice(), b"hello world");
641    }
642
643    #[test]
644    fn header_too_large_errors() {
645        let mut limits = MultipartLimits::default();
646        limits.max_header_bytes = 32;
647        let big_header = "Content-Disposition: form-data; name=\"".to_string()
648            + &"x".repeat(200)
649            + "\"";
650        let body = format!(
651            "--z\r\n{big_header}\r\n\r\nbody\r\n--z--\r\n"
652        );
653        let mut p = MultipartParser::new("z", limits);
654        let err = p.feed(body.as_bytes()).unwrap_err();
655        matches!(err, MultipartError::HeaderTooLarge { .. });
656    }
657
658    #[test]
659    fn part_too_large_errors() {
660        let mut limits = MultipartLimits::default();
661        limits.max_part_bytes = 16;
662        let big = "x".repeat(1024);
663        let body = format!(
664            "--z\r\nContent-Disposition: form-data; name=\"n\"\r\n\r\n{big}\r\n--z--\r\n"
665        );
666        let mut p = MultipartParser::new("z", limits);
667        let err = p.feed(body.as_bytes()).unwrap_err();
668        matches!(err, MultipartError::PartTooLarge { .. });
669    }
670
671    #[test]
672    fn nested_multipart_rejected() {
673        let body = b"\
674            --z\r\n\
675            Content-Disposition: form-data; name=\"n\"\r\n\
676            Content-Type: multipart/mixed; boundary=inner\r\n\
677            \r\n\
678            data\r\n\
679            --z--\r\n";
680        let mut p = MultipartParser::new("z", MultipartLimits::default());
681        let err = p.feed(body).unwrap_err();
682        matches!(err, MultipartError::Nested);
683    }
684}