1use crate::Part;
19use bytes::{Buf, Bytes, BytesMut};
20use futures::Stream;
21use http::header::{self, HeaderMap, HeaderName, HeaderValue};
22use httparse;
23use pin_project::pin_project;
24use std::pin::Pin;
25use std::task::{Context, Poll};
26
27#[derive(Debug)]
28pub struct Error(ErrorInt);
29
30#[derive(Debug)]
31enum ErrorInt {
32 ParseError(String),
33 Underlying(Box<dyn std::error::Error + Send + Sync>),
34}
35
36impl std::fmt::Display for Error {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self.0 {
39 ErrorInt::ParseError(ref s) => f.pad(s),
40 ErrorInt::Underlying(ref e) => e.fmt(f),
41 }
42 }
43}
44
45impl std::error::Error for Error {
46 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
47 match &self.0 {
48 ErrorInt::Underlying(e) => Some(&**e),
49 _ => None,
50 }
51 }
52}
53
54macro_rules! parse_err {
56 ($($arg:tt)*) => {
57 Error(ErrorInt::ParseError(format!($($arg)*)))
58 };
59}
60
61#[pin_project]
62pub struct Parser<S, E>
63where
64 S: Stream<Item = Result<Bytes, E>>,
65 E: Into<Box<dyn std::error::Error + Send + Sync>>,
66{
67 #[pin]
68 input: S,
69
70 boundary: Vec<u8>,
72 buf: BytesMut,
73 state: State,
74 max_header_bytes: usize,
75 max_body_bytes: usize,
76}
77
78enum State {
79 Newlines,
81
82 Boundary { pos: usize },
85
86 Headers,
88
89 Body { headers: HeaderMap, body_len: usize },
91
92 Done,
94}
95
96impl State {
97 fn process(
101 &mut self,
102 boundary: &[u8],
103 buf: &mut BytesMut,
104 max_header_bytes: usize,
105 max_body_bytes: usize,
106 ) -> Result<Poll<Option<Part>>, Error> {
107 'outer: loop {
108 match self {
109 State::Newlines => {
110 while buf.len() >= 2 {
111 if &buf[0..2] == b"\r\n" {
112 buf.advance(2);
113 } else {
114 *self = Self::Boundary { pos: 0 };
115 continue 'outer;
116 }
117 }
118 if buf.len() == 1 && buf[0] != b'\r' {
119 *self = Self::Boundary { pos: 0 };
120 } else {
121 return Ok(Poll::Pending);
122 }
123 }
124 State::Boundary { ref mut pos } => {
125 let len = std::cmp::min(boundary.len() - *pos, buf.len());
126 if buf[0..len] != boundary[*pos..*pos + len] {
127 return Err(parse_err!("bad boundary"));
128 }
129 buf.advance(len);
130 *pos += len;
131 if *pos < boundary.len() {
132 return Ok(Poll::Pending);
133 }
134 *self = State::Headers;
135 }
136 State::Headers => {
137 let mut raw = [httparse::EMPTY_HEADER; 16];
138 let headers = httparse::parse_headers(&buf, &mut raw)
139 .map_err(|e| parse_err!("Part headers invalid: {}", e))?;
140 match headers {
141 httparse::Status::Complete((body_pos, raw)) => {
142 let mut headers = HeaderMap::with_capacity(raw.len());
143 for h in raw {
144 headers.append(
145 HeaderName::from_bytes(h.name.as_bytes())
146 .map_err(|_| parse_err!("bad header name"))?,
147 HeaderValue::from_bytes(h.value)
148 .map_err(|_| parse_err!("bad header value"))?,
149 );
150 }
151 buf.advance(body_pos);
152 let body_len: usize = headers
153 .get(header::CONTENT_LENGTH)
154 .ok_or_else(|| parse_err!("Missing part Content-Length"))?
155 .to_str()
156 .map_err(|_| parse_err!("Part Content-Length is not valid string"))?
157 .parse()
158 .map_err(|_| {
159 parse_err!("Part Content-Length is not valid usize")
160 })?;
161 if body_len > max_body_bytes {
162 return Err(parse_err!(
163 "body byte length {} exceeds maximum of {}",
164 body_len,
165 max_body_bytes
166 ));
167 }
168 *self = State::Body { headers, body_len };
169 }
170 httparse::Status::Partial => {
171 if buf.len() >= max_header_bytes {
172 return Err(parse_err!(
173 "incomplete {}-byte header, vs maximum of {} bytes",
174 buf.len(),
175 max_header_bytes
176 ));
177 }
178 return Ok(Poll::Pending);
179 }
180 }
181 }
182 State::Body { headers, body_len } => {
183 if buf.len() >= *body_len {
184 let body = buf.split_to(*body_len).freeze();
185 let headers = std::mem::replace(headers, HeaderMap::new());
186 *self = State::Newlines;
187 return Ok(Poll::Ready(Some(Part { headers, body })));
188 }
189 return Ok(Poll::Pending);
190 }
191 State::Done => return Ok(Poll::Ready(None)),
192 }
193 }
194 }
195}
196
197pub struct ParserBuilder {
198 max_header_bytes: usize,
199 max_body_bytes: usize,
200}
201
202impl ParserBuilder {
203 pub fn new() -> Self {
204 ParserBuilder {
205 max_header_bytes: usize::MAX,
206 max_body_bytes: usize::MAX,
207 }
208 }
209
210 pub fn max_header_bytes(self, max_header_bytes: usize) -> Self {
214 ParserBuilder {
215 max_header_bytes,
216 ..self
217 }
218 }
219
220 pub fn max_body_bytes(self, max_body_bytes: usize) -> Self {
222 ParserBuilder {
223 max_body_bytes,
224 ..self
225 }
226 }
227
228 pub fn parse<S, E>(self, input: S, boundary: &str) -> impl Stream<Item = Result<Part, Error>>
232 where
233 S: Stream<Item = Result<Bytes, E>>,
234 E: Into<Box<dyn std::error::Error + Send + Sync>>,
235 {
236 let boundary = {
237 let mut line = Vec::with_capacity(boundary.len() + 4);
238 line.extend_from_slice(b"--");
239 line.extend_from_slice(boundary.as_bytes());
240 line.extend_from_slice(b"\r\n");
241 line
242 };
243
244 Parser {
245 input,
246 buf: BytesMut::new(),
247 boundary,
248 state: State::Newlines,
249 max_header_bytes: self.max_header_bytes,
250 max_body_bytes: self.max_body_bytes,
251 }
252 }
253}
254
255pub fn parse<S, E>(input: S, boundary: &str) -> impl Stream<Item = Result<Part, Error>>
261where
262 S: Stream<Item = Result<Bytes, E>>,
263 E: Into<Box<dyn std::error::Error + Send + Sync>>,
264{
265 ParserBuilder::new().parse(input, boundary)
266}
267
268impl<S, E> Stream for Parser<S, E>
269where
270 S: Stream<Item = Result<Bytes, E>>,
271 E: Into<Box<dyn std::error::Error + Send + Sync>>,
272{
273 type Item = Result<Part, Error>;
274
275 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276 let mut this = self.project();
277 loop {
278 match this.state.process(
279 &this.boundary,
280 this.buf,
281 *this.max_header_bytes,
282 *this.max_body_bytes,
283 ) {
284 Err(e) => {
285 *this.state = State::Done;
286 return Poll::Ready(Some(Err(e.into())));
287 }
288 Ok(Poll::Ready(Some(r))) => return Poll::Ready(Some(Ok(r))),
289 Ok(Poll::Ready(None)) => return Poll::Ready(None),
290 Ok(Poll::Pending) => {}
291 }
292 match this.input.as_mut().poll_next(ctx) {
293 Poll::Pending => return Poll::Pending,
294 Poll::Ready(None) => {
295 if !matches!(*this.state, State::Newlines) {
296 *this.state = State::Done;
297 return Poll::Ready(Some(Err(parse_err!("unexpected mid-part EOF"))));
298 }
299 return Poll::Ready(None);
300 }
301 Poll::Ready(Some(Err(e))) => {
302 *this.state = State::Done;
303 return Poll::Ready(Some(Err(Error(ErrorInt::Underlying(e.into())))));
304 }
305 Poll::Ready(Some(Ok(b))) => {
306 this.buf.extend_from_slice(&b);
307 }
308 };
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::{Error, ParserBuilder, Part};
316 use bytes::Bytes;
317 use futures::StreamExt;
318
319 async fn tester<F>(boundary: &str, input: &'static [u8], verify_parts: F)
324 where
325 F: Fn(Vec<Result<Part, Error>>),
326 {
327 for chunk_size in &[1, 2, usize::MAX] {
328 let input: Vec<Result<Bytes, std::convert::Infallible>> = input
329 .chunks(*chunk_size)
330 .map(|c: &[u8]| Ok(Bytes::from(c)))
331 .collect();
332 let input = futures::stream::iter(input);
333 let parts = ParserBuilder::new().parse(input, boundary);
334 let output_stream: Vec<Result<Part, Error>> = parts.collect().await;
335 verify_parts(output_stream);
336 }
337 }
338
339 #[tokio::test]
340 async fn truncated_header() {
341 let input = "--boundary\r\nPartial-Header";
342 let verify_parts = |mut parts: Vec<Result<Part, Error>>| {
343 assert_eq!(parts.len(), 1);
344 parts.pop().unwrap().unwrap_err();
345 };
346 tester("boundary", input.as_bytes(), verify_parts).await;
347 }
348
349 #[tokio::test]
350 async fn truncated_data() {
351 let input = "--boundary\r\nContent-Length: 42\r\n\r\n";
352 let verify_parts = |mut parts: Vec<Result<Part, Error>>| {
353 assert_eq!(parts.len(), 1);
354 parts.pop().unwrap().unwrap_err();
355 };
356 tester("boundary", input.as_bytes(), verify_parts).await;
357 }
358
359 #[tokio::test]
360 async fn hikvision_style() {
361 let input = concat!(
362 "--boundary\r\n",
363 "Content-Type: application/xml; charset=\"UTF-8\"\r\n",
364 "Content-Length: 480\r\n",
365 "\r\n",
366 "<EventNotificationAlert version=\"1.0\" ",
367 "xmlns=\"http://www.hikvision.com/ver10/XMLSchema\">\r\n",
368 "<ipAddress>192.168.5.106</ipAddress>\r\n",
369 "<portNo>80</portNo>\r\n",
370 "<protocol>HTTP</protocol>\r\n",
371 "<macAddress>8c:e7:48:da:94:8f</macAddress>\r\n",
372 "<channelID>1</channelID>\r\n",
373 "<dateTime>2019-02-20T15:22:34-8:00</dateTime>\r\n",
374 "<activePostCount>0</activePostCount>\r\n",
375 "<eventType>videoloss</eventType>\r\n",
376 "<eventState>inactive</eventState>\r\n",
377 "<eventDescription>videoloss alarm</eventDescription>\r\n",
378 "</EventNotificationAlert>\r\n",
379 "--boundary\r\n",
380 "Content-Type: application/xml; charset=\"UTF-8\"\r\n",
381 "Content-Length: 480\r\n",
382 "\r\n",
383 "<EventNotificationAlert version=\"1.0\" ",
384 "xmlns=\"http://www.hikvision.com/ver10/XMLSchema\">\r\n",
385 "<ipAddress>192.168.5.106</ipAddress>\r\n",
386 "<portNo>80</portNo>\r\n",
387 "<protocol>HTTP</protocol>\r\n",
388 "<macAddress>8c:e7:48:da:94:8f</macAddress>\r\n",
389 "<channelID>1</channelID>\r\n",
390 "<dateTime>2019-02-20T15:22:34-8:00</dateTime>\r\n",
391 "<activePostCount>0</activePostCount>\r\n",
392 "<eventType>videoloss</eventType>\r\n",
393 "<eventState>inactive</eventState>\r\n",
394 "<eventDescription>videoloss alarm</eventDescription>\r\n",
395 "</EventNotificationAlert>\r\n"
396 );
397
398 let verify_parts = |parts: Vec<Result<Part, Error>>| {
399 let mut i = 0;
400 for p in parts {
401 let p = p.unwrap();
402 assert_eq!(
403 p.headers
404 .get(http::header::CONTENT_TYPE)
405 .unwrap()
406 .to_str()
407 .unwrap(),
408 "application/xml; charset=\"UTF-8\""
409 );
410 assert!(p.body.starts_with(b"<EventNotificationAlert"));
411 assert!(p.body.ends_with(b"</EventNotificationAlert>\r\n"));
412 i += 1;
413 }
414 assert_eq!(i, 2);
415 };
416 tester("boundary", input.as_bytes(), verify_parts).await;
417 }
418
419 #[tokio::test]
420 async fn dahua_style() {
421 let input = concat!(
422 "--myboundary\r\n",
423 "Content-Type: text/plain\r\n",
424 "Content-Length:135\r\n",
425 "\r\n",
426 "Code=TimeChange;action=Pulse;index=0;data={\n",
427 " \"BeforeModifyTime\" : \"2019-02-20 13:49:58\",\n",
428 " \"ModifiedTime\" : \"2019-02-20 13:49:58\"\n",
429 "}\n",
430 "\r\n",
431 "\r\n",
432 "--myboundary\r\n",
433 "Content-Type: text/plain\r\n",
434 "Content-Length:137\r\n",
435 "\r\n",
436 "Code=NTPAdjustTime;action=Pulse;index=0;data={\n",
437 " \"Address\" : \"192.168.5.254\",\n",
438 " \"Before\" : \"2019-02-20 13:49:57\",\n",
439 " \"result\" : true\n",
440 "}\n\r\n"
441 );
442 let verify_parts = |parts: Vec<Result<Part, Error>>| {
443 let mut i = 0;
444 for p in parts {
445 let p = p.unwrap();
446 assert_eq!(
447 p.headers
448 .get(http::header::CONTENT_TYPE)
449 .unwrap()
450 .to_str()
451 .unwrap(),
452 "text/plain"
453 );
454 match i {
455 0 => assert!(p.body.starts_with(b"Code=TimeChange")),
456 1 => assert!(p.body.starts_with(b"Code=NTPAdjustTime")),
457 _ => unreachable!(),
458 }
459 i += 1;
460 }
461 assert_eq!(i, 2);
462 };
463 tester("myboundary", input.as_bytes(), verify_parts).await;
464 }
465
466 #[tokio::test]
467 async fn dahua_heartbeat() {
468 let input = concat!(
473 "\r\n--myboundary\r\n",
474 "Content-Type: text/plain\r\n",
475 "Content-Length:9\r\n\r\n",
476 "Heartbeat"
477 );
478 let verify_parts = |parts: Vec<Result<Part, Error>>| {
479 let mut i = 0;
480 for p in parts {
481 let p = p.unwrap();
482 assert_eq!(&p.body[..], b"Heartbeat");
483 i += 1;
484 }
485 assert_eq!(i, 1);
486 };
487 tester("myboundary", input.as_bytes(), verify_parts).await;
488 }
489}