multipart_stream/
parser.rs

1// Copyright (C) 2021 Scott Lamb <slamb@slamb.org>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4// This implementation is gross (it's hard to read and copies when not
5// necessary), I think due to a combination of the following:
6//
7// 1.  the current state of Rust async: in particular, that there are no coroutines.
8// 2.  my inexperience with Rust async
9// 3.  how quickly I threw this together.
10//
11// Fortunately the badness is hidden behind a decent interface, and there are decent tests
12// of success cases with partial data. In the situations we're using it (small
13// bits of metadata rather than video), the inefficient probably doesn't matter.
14// TODO: add tests of bad inputs.
15
16//! Parses a [Bytes] stream into a [Part] stream.
17
18use crate::Part;
19use bytes::{Buf, Bytes, BytesMut};
20use futures::Stream;
21use http::header::{self, HeaderMap, HeaderName, HeaderValue};
22use httparse;
23use pin_project::pin_project;
24use std::pin::Pin;
25use std::task::{Context, Poll};
26
27#[derive(Debug)]
28pub struct Error(ErrorInt);
29
30#[derive(Debug)]
31enum ErrorInt {
32    ParseError(String),
33    Underlying(Box<dyn std::error::Error + Send + Sync>),
34}
35
36impl std::fmt::Display for Error {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self.0 {
39            ErrorInt::ParseError(ref s) => f.pad(s),
40            ErrorInt::Underlying(ref e) => e.fmt(f),
41        }
42    }
43}
44
45impl std::error::Error for Error {
46    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
47        match &self.0 {
48            ErrorInt::Underlying(e) => Some(&**e),
49            _ => None,
50        }
51    }
52}
53
54/// Creates a parse error with the specified format string and arguments.
55macro_rules! parse_err {
56    ($($arg:tt)*) => {
57        Error(ErrorInt::ParseError(format!($($arg)*)))
58    };
59}
60
61#[pin_project]
62pub struct Parser<S, E>
63where
64    S: Stream<Item = Result<Bytes, E>>,
65    E: Into<Box<dyn std::error::Error + Send + Sync>>,
66{
67    #[pin]
68    input: S,
69
70    /// The boundary with `--` prefix and `\r\n` suffix.
71    boundary: Vec<u8>,
72    buf: BytesMut,
73    state: State,
74    max_header_bytes: usize,
75    max_body_bytes: usize,
76}
77
78enum State {
79    /// Consuming 0 or more `\r\n` pairs, advancing when encountering a byte that doesn't fit that pattern.
80    Newlines,
81
82    /// Waiting for the completion of a boundary.
83    /// `pos` is the current offset within `boundary_buf`.
84    Boundary { pos: usize },
85
86    /// Waiting for a full set of headers.
87    Headers,
88
89    /// Waiting for a full body.
90    Body { headers: HeaderMap, body_len: usize },
91
92    /// The stream is finished (has already returned an error).
93    Done,
94}
95
96impl State {
97    /// Processes the current buffer contents.
98    /// This reverses the order of the return value so it can return error via `?` and `bail!`.
99    /// The caller puts it back into the order expected by `Stream`.
100    fn process(
101        &mut self,
102        boundary: &[u8],
103        buf: &mut BytesMut,
104        max_header_bytes: usize,
105        max_body_bytes: usize,
106    ) -> Result<Poll<Option<Part>>, Error> {
107        'outer: loop {
108            match self {
109                State::Newlines => {
110                    while buf.len() >= 2 {
111                        if &buf[0..2] == b"\r\n" {
112                            buf.advance(2);
113                        } else {
114                            *self = Self::Boundary { pos: 0 };
115                            continue 'outer;
116                        }
117                    }
118                    if buf.len() == 1 && buf[0] != b'\r' {
119                        *self = Self::Boundary { pos: 0 };
120                    } else {
121                        return Ok(Poll::Pending);
122                    }
123                }
124                State::Boundary { ref mut pos } => {
125                    let len = std::cmp::min(boundary.len() - *pos, buf.len());
126                    if buf[0..len] != boundary[*pos..*pos + len] {
127                        return Err(parse_err!("bad boundary"));
128                    }
129                    buf.advance(len);
130                    *pos += len;
131                    if *pos < boundary.len() {
132                        return Ok(Poll::Pending);
133                    }
134                    *self = State::Headers;
135                }
136                State::Headers => {
137                    let mut raw = [httparse::EMPTY_HEADER; 16];
138                    let headers = httparse::parse_headers(&buf, &mut raw)
139                        .map_err(|e| parse_err!("Part headers invalid: {}", e))?;
140                    match headers {
141                        httparse::Status::Complete((body_pos, raw)) => {
142                            let mut headers = HeaderMap::with_capacity(raw.len());
143                            for h in raw {
144                                headers.append(
145                                    HeaderName::from_bytes(h.name.as_bytes())
146                                        .map_err(|_| parse_err!("bad header name"))?,
147                                    HeaderValue::from_bytes(h.value)
148                                        .map_err(|_| parse_err!("bad header value"))?,
149                                );
150                            }
151                            buf.advance(body_pos);
152                            let body_len: usize = headers
153                                .get(header::CONTENT_LENGTH)
154                                .ok_or_else(|| parse_err!("Missing part Content-Length"))?
155                                .to_str()
156                                .map_err(|_| parse_err!("Part Content-Length is not valid string"))?
157                                .parse()
158                                .map_err(|_| {
159                                    parse_err!("Part Content-Length is not valid usize")
160                                })?;
161                            if body_len > max_body_bytes {
162                                return Err(parse_err!(
163                                    "body byte length {} exceeds maximum of {}",
164                                    body_len,
165                                    max_body_bytes
166                                ));
167                            }
168                            *self = State::Body { headers, body_len };
169                        }
170                        httparse::Status::Partial => {
171                            if buf.len() >= max_header_bytes {
172                                return Err(parse_err!(
173                                    "incomplete {}-byte header, vs maximum of {} bytes",
174                                    buf.len(),
175                                    max_header_bytes
176                                ));
177                            }
178                            return Ok(Poll::Pending);
179                        }
180                    }
181                }
182                State::Body { headers, body_len } => {
183                    if buf.len() >= *body_len {
184                        let body = buf.split_to(*body_len).freeze();
185                        let headers = std::mem::replace(headers, HeaderMap::new());
186                        *self = State::Newlines;
187                        return Ok(Poll::Ready(Some(Part { headers, body })));
188                    }
189                    return Ok(Poll::Pending);
190                }
191                State::Done => return Ok(Poll::Ready(None)),
192            }
193        }
194    }
195}
196
197pub struct ParserBuilder {
198    max_header_bytes: usize,
199    max_body_bytes: usize,
200}
201
202impl ParserBuilder {
203    pub fn new() -> Self {
204        ParserBuilder {
205            max_header_bytes: usize::MAX,
206            max_body_bytes: usize::MAX,
207        }
208    }
209
210    /// Causes the parser to return error if the headers exceed this byte length.
211    /// Implementation note: currently this is only checked when about to wait for another chunk.
212    /// If a single chunk contains a complete header, it may be parsed successfully in spite of exceeding this length.
213    pub fn max_header_bytes(self, max_header_bytes: usize) -> Self {
214        ParserBuilder {
215            max_header_bytes,
216            ..self
217        }
218    }
219
220    /// Causes the parser to return error if the body exceeds this byte length.
221    pub fn max_body_bytes(self, max_body_bytes: usize) -> Self {
222        ParserBuilder {
223            max_body_bytes,
224            ..self
225        }
226    }
227
228    /// Parses a [Bytes] stream into a [Part] stream.
229    ///
230    /// `boundary` should be as in the `boundary` parameter of the `Content-Type` header.
231    pub fn parse<S, E>(self, input: S, boundary: &str) -> impl Stream<Item = Result<Part, Error>>
232    where
233        S: Stream<Item = Result<Bytes, E>>,
234        E: Into<Box<dyn std::error::Error + Send + Sync>>,
235    {
236        let boundary = {
237            let mut line = Vec::with_capacity(boundary.len() + 4);
238            line.extend_from_slice(b"--");
239            line.extend_from_slice(boundary.as_bytes());
240            line.extend_from_slice(b"\r\n");
241            line
242        };
243
244        Parser {
245            input,
246            buf: BytesMut::new(),
247            boundary,
248            state: State::Newlines,
249            max_header_bytes: self.max_header_bytes,
250            max_body_bytes: self.max_body_bytes,
251        }
252    }
253}
254
255/// Parses a [Bytes] stream into a [Part] stream.
256///
257/// `boundary` should be as in the `boundary` parameter of the `Content-Type` header.
258///
259/// This doesn't allow customizing the parser; use [ParserBuilder] instead if desired.
260pub fn parse<S, E>(input: S, boundary: &str) -> impl Stream<Item = Result<Part, Error>>
261where
262    S: Stream<Item = Result<Bytes, E>>,
263    E: Into<Box<dyn std::error::Error + Send + Sync>>,
264{
265    ParserBuilder::new().parse(input, boundary)
266}
267
268impl<S, E> Stream for Parser<S, E>
269where
270    S: Stream<Item = Result<Bytes, E>>,
271    E: Into<Box<dyn std::error::Error + Send + Sync>>,
272{
273    type Item = Result<Part, Error>;
274
275    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276        let mut this = self.project();
277        loop {
278            match this.state.process(
279                &this.boundary,
280                this.buf,
281                *this.max_header_bytes,
282                *this.max_body_bytes,
283            ) {
284                Err(e) => {
285                    *this.state = State::Done;
286                    return Poll::Ready(Some(Err(e.into())));
287                }
288                Ok(Poll::Ready(Some(r))) => return Poll::Ready(Some(Ok(r))),
289                Ok(Poll::Ready(None)) => return Poll::Ready(None),
290                Ok(Poll::Pending) => {}
291            }
292            match this.input.as_mut().poll_next(ctx) {
293                Poll::Pending => return Poll::Pending,
294                Poll::Ready(None) => {
295                    if !matches!(*this.state, State::Newlines) {
296                        *this.state = State::Done;
297                        return Poll::Ready(Some(Err(parse_err!("unexpected mid-part EOF"))));
298                    }
299                    return Poll::Ready(None);
300                }
301                Poll::Ready(Some(Err(e))) => {
302                    *this.state = State::Done;
303                    return Poll::Ready(Some(Err(Error(ErrorInt::Underlying(e.into())))));
304                }
305                Poll::Ready(Some(Ok(b))) => {
306                    this.buf.extend_from_slice(&b);
307                }
308            };
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::{Error, ParserBuilder, Part};
316    use bytes::Bytes;
317    use futures::StreamExt;
318
319    /// Tries parsing `input` with a stream that has chunks of different sizes arriving.
320    /// This ensures that the "not enough data for the current state", "enough for the current state
321    /// exactly", "enough for the current state and some for the next", and "enough for the next
322    /// state (and beyond)" cases are exercised.
323    async fn tester<F>(boundary: &str, input: &'static [u8], verify_parts: F)
324    where
325        F: Fn(Vec<Result<Part, Error>>),
326    {
327        for chunk_size in &[1, 2, usize::MAX] {
328            let input: Vec<Result<Bytes, std::convert::Infallible>> = input
329                .chunks(*chunk_size)
330                .map(|c: &[u8]| Ok(Bytes::from(c)))
331                .collect();
332            let input = futures::stream::iter(input);
333            let parts = ParserBuilder::new().parse(input, boundary);
334            let output_stream: Vec<Result<Part, Error>> = parts.collect().await;
335            verify_parts(output_stream);
336        }
337    }
338
339    #[tokio::test]
340    async fn truncated_header() {
341        let input = "--boundary\r\nPartial-Header";
342        let verify_parts = |mut parts: Vec<Result<Part, Error>>| {
343            assert_eq!(parts.len(), 1);
344            parts.pop().unwrap().unwrap_err();
345        };
346        tester("boundary", input.as_bytes(), verify_parts).await;
347    }
348
349    #[tokio::test]
350    async fn truncated_data() {
351        let input = "--boundary\r\nContent-Length: 42\r\n\r\n";
352        let verify_parts = |mut parts: Vec<Result<Part, Error>>| {
353            assert_eq!(parts.len(), 1);
354            parts.pop().unwrap().unwrap_err();
355        };
356        tester("boundary", input.as_bytes(), verify_parts).await;
357    }
358
359    #[tokio::test]
360    async fn hikvision_style() {
361        let input = concat!(
362            "--boundary\r\n",
363            "Content-Type: application/xml; charset=\"UTF-8\"\r\n",
364            "Content-Length: 480\r\n",
365            "\r\n",
366            "<EventNotificationAlert version=\"1.0\" ",
367            "xmlns=\"http://www.hikvision.com/ver10/XMLSchema\">\r\n",
368            "<ipAddress>192.168.5.106</ipAddress>\r\n",
369            "<portNo>80</portNo>\r\n",
370            "<protocol>HTTP</protocol>\r\n",
371            "<macAddress>8c:e7:48:da:94:8f</macAddress>\r\n",
372            "<channelID>1</channelID>\r\n",
373            "<dateTime>2019-02-20T15:22:34-8:00</dateTime>\r\n",
374            "<activePostCount>0</activePostCount>\r\n",
375            "<eventType>videoloss</eventType>\r\n",
376            "<eventState>inactive</eventState>\r\n",
377            "<eventDescription>videoloss alarm</eventDescription>\r\n",
378            "</EventNotificationAlert>\r\n",
379            "--boundary\r\n",
380            "Content-Type: application/xml; charset=\"UTF-8\"\r\n",
381            "Content-Length: 480\r\n",
382            "\r\n",
383            "<EventNotificationAlert version=\"1.0\" ",
384            "xmlns=\"http://www.hikvision.com/ver10/XMLSchema\">\r\n",
385            "<ipAddress>192.168.5.106</ipAddress>\r\n",
386            "<portNo>80</portNo>\r\n",
387            "<protocol>HTTP</protocol>\r\n",
388            "<macAddress>8c:e7:48:da:94:8f</macAddress>\r\n",
389            "<channelID>1</channelID>\r\n",
390            "<dateTime>2019-02-20T15:22:34-8:00</dateTime>\r\n",
391            "<activePostCount>0</activePostCount>\r\n",
392            "<eventType>videoloss</eventType>\r\n",
393            "<eventState>inactive</eventState>\r\n",
394            "<eventDescription>videoloss alarm</eventDescription>\r\n",
395            "</EventNotificationAlert>\r\n"
396        );
397
398        let verify_parts = |parts: Vec<Result<Part, Error>>| {
399            let mut i = 0;
400            for p in parts {
401                let p = p.unwrap();
402                assert_eq!(
403                    p.headers
404                        .get(http::header::CONTENT_TYPE)
405                        .unwrap()
406                        .to_str()
407                        .unwrap(),
408                    "application/xml; charset=\"UTF-8\""
409                );
410                assert!(p.body.starts_with(b"<EventNotificationAlert"));
411                assert!(p.body.ends_with(b"</EventNotificationAlert>\r\n"));
412                i += 1;
413            }
414            assert_eq!(i, 2);
415        };
416        tester("boundary", input.as_bytes(), verify_parts).await;
417    }
418
419    #[tokio::test]
420    async fn dahua_style() {
421        let input = concat!(
422            "--myboundary\r\n",
423            "Content-Type: text/plain\r\n",
424            "Content-Length:135\r\n",
425            "\r\n",
426            "Code=TimeChange;action=Pulse;index=0;data={\n",
427            "   \"BeforeModifyTime\" : \"2019-02-20 13:49:58\",\n",
428            "   \"ModifiedTime\" : \"2019-02-20 13:49:58\"\n",
429            "}\n",
430            "\r\n",
431            "\r\n",
432            "--myboundary\r\n",
433            "Content-Type: text/plain\r\n",
434            "Content-Length:137\r\n",
435            "\r\n",
436            "Code=NTPAdjustTime;action=Pulse;index=0;data={\n",
437            "   \"Address\" : \"192.168.5.254\",\n",
438            "   \"Before\" : \"2019-02-20 13:49:57\",\n",
439            "   \"result\" : true\n",
440            "}\n\r\n"
441        );
442        let verify_parts = |parts: Vec<Result<Part, Error>>| {
443            let mut i = 0;
444            for p in parts {
445                let p = p.unwrap();
446                assert_eq!(
447                    p.headers
448                        .get(http::header::CONTENT_TYPE)
449                        .unwrap()
450                        .to_str()
451                        .unwrap(),
452                    "text/plain"
453                );
454                match i {
455                    0 => assert!(p.body.starts_with(b"Code=TimeChange")),
456                    1 => assert!(p.body.starts_with(b"Code=NTPAdjustTime")),
457                    _ => unreachable!(),
458                }
459                i += 1;
460            }
461            assert_eq!(i, 2);
462        };
463        tester("myboundary", input.as_bytes(), verify_parts).await;
464    }
465
466    #[tokio::test]
467    async fn dahua_heartbeat() {
468        // Dahua event streams have a `heartbeat` parameter which sends messages like the one below.
469        // The newlines are before the part, rather than after the part as in other messages.
470        // The heartbeat is sometimes the first message in the stream. We need to allow initial
471        // newlines to avoid erroring in this case.
472        let input = concat!(
473            "\r\n--myboundary\r\n",
474            "Content-Type: text/plain\r\n",
475            "Content-Length:9\r\n\r\n",
476            "Heartbeat"
477        );
478        let verify_parts = |parts: Vec<Result<Part, Error>>| {
479            let mut i = 0;
480            for p in parts {
481                let p = p.unwrap();
482                assert_eq!(&p.body[..], b"Heartbeat");
483                i += 1;
484            }
485            assert_eq!(i, 1);
486        };
487        tester("myboundary", input.as_bytes(), verify_parts).await;
488    }
489}