Skip to main content

min_http11_parser/
parser.rs

1use memchr::memchr_iter;
2use min_http11_core::error::{Error, Result};
3use min_http11_core::hash::hash;
4use min_http11_core::method::Method;
5
6#[cfg(not(feature = "_minimal"))]
7use min_http11_core::request::{
8    ACCEPT, ACCEPT_ENCODING, ACCEPT_ENCODING_HASH, ACCEPT_HASH, ACCEPT_LANGUAGE,
9    ACCEPT_LANGUAGE_HASH, ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_HEADERS_HASH,
10    ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_METHOD_HASH, AUTHORIZATION,
11    AUTHORIZATION_HASH, CONNECTION, CONNECTION_HASH, CONTENT_ENCODING, CONTENT_ENCODING_HASH,
12    CONTENT_TYPE, CONTENT_TYPE_HASH, COOKIE, COOKIE_HASH, EXPECT, EXPECT_HASH, IF_MODIFIED_SINCE,
13    IF_MODIFIED_SINCE_HASH, IF_RANGE, IF_RANGE_HASH, IF_UNMODIFIED_SINCE, IF_UNMODIFIED_SINCE_HASH,
14    ORIGIN, ORIGIN_HASH, RANGE, RANGE_HASH, USER_AGENT, USER_AGENT_HASH, X_CSRF_TOKEN,
15    X_CSRF_TOKEN_HASH, X_FORWARDED_FOR, X_FORWARDED_FOR_HASH, X_FORWARDED_HOST,
16    X_FORWARDED_HOST_HASH, X_REAL_IP, X_REAL_IP_HASH,
17};
18use min_http11_core::request::{
19    CONTENT_LENGTH, CONTENT_LENGTH_HASH, HOST, HOST_HASH, HeaderName, IF_MATCH, IF_MATCH_HASH,
20    IF_NONE_MATCH, IF_NONE_MATCH_HASH, KnownHeaders, TRANSFER_ENCODING, TRANSFER_ENCODING_HASH,
21    X_HUB_SIGNATURE_256, X_HUB_SIGNATURE_256_HASH,
22};
23use min_http11_core::util::AsciiLowercaseTestExt;
24use min_http11_core::version::Version;
25use read_until_slice::AsyncBufReadUntilSliceExt;
26use std::collections::BTreeMap;
27use std::time::Duration;
28use tokio::io::{AsyncBufRead, AsyncReadExt};
29use tokio::time::timeout;
30use tracing::debug;
31
32pub struct Request<'a> {
33    pub method: Method,
34    pub path: &'a [u8],
35    pub known_headers: KnownHeaders<'a>,
36    pub custom_headers: Option<BTreeMap<&'static [u8], &'a [u8]>>,
37    pub body: Option<&'a [u8]>,
38}
39
40pub struct Parser {
41    request_line_read_timeout: Duration,
42    headers_read_timeout: Duration,
43    request_line_max_size: u64,
44    headers_max_size: u64,
45    other_headers: Option<BTreeMap<u32, &'static [u8]>>,
46}
47
48impl Default for Parser {
49    fn default() -> Self {
50        Self {
51            request_line_read_timeout: Duration::from_millis(200_u64),
52            headers_read_timeout: Duration::from_millis(200_u64),
53            request_line_max_size: 4_096_u64,
54            headers_max_size: 16_384_u64,
55            other_headers: None,
56        }
57    }
58}
59
60impl Parser {
61    pub fn with_header(self, header_name: &'static [u8]) -> Result<Self> {
62        let header_name = HeaderName::try_from_static(header_name)?;
63        Ok(match header_name {
64            HeaderName::Other(key) => {
65                let mut custom_headers = self.other_headers.unwrap_or_default();
66                custom_headers.insert(hash(key), key);
67                Parser {
68                    other_headers: Some(custom_headers),
69                    ..self
70                }
71            }
72            _ => self,
73        })
74    }
75    pub fn with_request_line_read_timeout(self, timeout: Duration) -> Self {
76        Parser {
77            request_line_read_timeout: timeout,
78            ..self
79        }
80    }
81    pub fn with_headers_read_timeout(self, timeout: Duration) -> Self {
82        Parser {
83            headers_read_timeout: timeout,
84            ..self
85        }
86    }
87    pub fn with_request_line_max_size(self, size: u16) -> Self {
88        Parser {
89            request_line_max_size: size as u64,
90            ..self
91        }
92    }
93    pub fn with_headers_max_size(self, size: u16) -> Self {
94        Parser {
95            headers_max_size: size as u64,
96            ..self
97        }
98    }
99}
100
101const SPACE: u8 = b' ';
102const COLON: u8 = b':';
103const CR: u8 = b'\r';
104const LF: u8 = b'\n';
105const CRLF: &[u8] = b"\r\n";
106const CRLF_CRLF: &[u8] = b"\r\n\r\n";
107
108impl Parser {
109    pub async fn parse_request_line<'a, 'b, 'c: 'a>(
110        &'a self,
111        reader: &'b mut (impl AsyncBufRead + Unpin),
112        buffer: &'c mut Vec<u8>,
113    ) -> Result<(Method, &'c [u8])> {
114        let request_line = timeout(self.request_line_read_timeout, async {
115            let mut reader = reader.take(self.request_line_max_size);
116            let n = reader.read_until_slice(CRLF, buffer).await?;
117            if n == 0 {
118                return Err(Error::UnexpectedEndOfFile);
119            }
120            let buffer = &buffer[buffer.len() - n..];
121            if !buffer.ends_with(CRLF) {
122                return Err(Error::UnexpectedEndOfFile);
123            }
124            let request_line = &buffer[..buffer.len() - 2];
125            debug!("{}", request_line.escape_ascii());
126            let mut iter = memchr_iter(SPACE, request_line);
127            let first = iter.next().ok_or(Error::BadRequest)?;
128            let second = iter.next().ok_or(Error::BadRequest)?;
129            if iter.next().is_some() {
130                return Err(Error::BadRequest);
131            }
132            let method = Method::try_from(&request_line[0..first])?;
133            let path = &request_line[first + 1..second];
134            let _ = Version::try_from(&request_line[second + 1..])?;
135            Ok((method, path))
136        })
137        .await;
138        let request_line = match request_line {
139            Err(_) => Err(Error::ReadTimeout)?,
140            Ok(result) => result?,
141        };
142        Ok(request_line)
143    }
144
145    pub async fn parse_headers<'a, 'b, 'c: 'a>(
146        &'a self,
147        reader: &'b mut (impl AsyncBufRead + Unpin),
148        buffer: &'c mut Vec<u8>,
149    ) -> Result<(KnownHeaders<'a>, Option<BTreeMap<&'static [u8], &'a [u8]>>)> {
150        let headers = timeout(self.headers_read_timeout, async {
151            let mut reader = reader.take(self.headers_max_size);
152            let n = reader.read_until_slice(CRLF_CRLF, buffer).await?;
153            if n == 0 {
154                return Err(Error::UnexpectedEndOfFile);
155            }
156            let buffer = &buffer[buffer.len() - n..];
157            if !buffer.ends_with(CRLF_CRLF) {
158                return Err(Error::UnexpectedEndOfFile);
159            }
160            let buffer = &buffer[..buffer.len() - 2];
161            let mut known_headers = KnownHeaders::default();
162            let mut custom_headers = None;
163            if !buffer.is_empty() {
164                let mut iter = memchr_iter(CR, buffer);
165                let mut i = 0;
166                while i < buffer.len() {
167                    let j = loop {
168                        if let Some(i) = iter.next() {
169                            if i + 1 < buffer.len() && buffer[i + 1] == LF {
170                                break i;
171                            }
172                        } else {
173                            return Err(Error::UnexpectedEndOfFile);
174                        }
175                    };
176                    let line = &buffer[i..j];
177                    i = j + 2;
178                    let mut iter = memchr_iter(SPACE, line);
179                    let first = iter.next().ok_or(Error::BadRequest)?;
180                    if first == 0 || line[first - 1] != COLON {
181                        return Err(Error::BadRequest);
182                    }
183                    let key = &line[..first - 1];
184                    let value = &line[first + 1..];
185                    debug!("{}: {}", key.escape_ascii(), value.escape_ascii());
186                    if key.is_ascii_lowercase() {
187                        let h = hash(key);
188                        let res = _with_known_header(known_headers, h, key, value)?;
189                        known_headers = res.0;
190                        if res.1 {
191                            custom_headers = _with_other_header(
192                                custom_headers,
193                                &self.other_headers,
194                                h,
195                                key,
196                                value,
197                            );
198                        }
199                    } else {
200                        let key = key.to_ascii_lowercase();
201                        let h = hash(&key);
202                        let res = _with_known_header(known_headers, h, &key, value)?;
203                        known_headers = res.0;
204                        if res.1 {
205                            custom_headers = _with_other_header(
206                                custom_headers,
207                                &self.other_headers,
208                                h,
209                                &key,
210                                value,
211                            );
212                        }
213                    };
214                }
215            }
216            Ok((known_headers, custom_headers))
217        })
218        .await;
219        let request_line = match headers {
220            Err(_) => Err(Error::ReadTimeout)?,
221            Ok(result) => result?,
222        };
223        Ok(request_line)
224    }
225}
226
227fn _with_known_header<'a>(
228    mut known_headers: KnownHeaders<'a>,
229    hash: u32,
230    lowercase_key: &[u8],
231    value: &'a [u8],
232) -> Result<(KnownHeaders<'a>, bool)> {
233    macro_rules! set_once {
234        ($field:ident) => {
235            if known_headers.$field.is_some() {
236                return Err(Error::BadRequest);
237            }
238            known_headers.$field = Some(value);
239        };
240    }
241    #[cfg(feature = "_minimal")]
242    match hash {
243        CONTENT_LENGTH_HASH if lowercase_key == CONTENT_LENGTH => {
244            if known_headers.transfer_encoding.is_some() {
245                return Err(Error::BadRequest);
246            }
247            set_once!(content_length);
248        }
249        HOST_HASH if lowercase_key == HOST => {
250            set_once!(host);
251        }
252        TRANSFER_ENCODING_HASH if lowercase_key == TRANSFER_ENCODING => {
253            if known_headers.content_length.is_some() {
254                return Err(Error::BadRequest);
255            }
256            set_once!(host);
257        }
258        IF_MATCH_HASH if lowercase_key == IF_MATCH => {
259            set_once!(if_match);
260        }
261        IF_NONE_MATCH_HASH if lowercase_key == IF_NONE_MATCH => {
262            set_once!(if_none_match);
263        }
264        X_HUB_SIGNATURE_256_HASH if lowercase_key == X_HUB_SIGNATURE_256 => {
265            set_once!(x_hub_signature_256_hash);
266        }
267        _ => return Ok((known_headers, true)),
268    }
269    #[cfg(not(feature = "_minimal"))]
270    match hash {
271        ACCEPT_HASH if lowercase_key == ACCEPT => {
272            set_once!(accept);
273        }
274        ACCEPT_ENCODING_HASH if lowercase_key == ACCEPT_ENCODING => {
275            set_once!(accept_encoding);
276        }
277        ACCEPT_LANGUAGE_HASH if lowercase_key == ACCEPT_LANGUAGE => {
278            set_once!(accept_language);
279        }
280        ACCESS_CONTROL_REQUEST_HEADERS_HASH if lowercase_key == ACCESS_CONTROL_REQUEST_HEADERS => {
281            set_once!(access_control_request_headers);
282        }
283        ACCESS_CONTROL_REQUEST_METHOD_HASH if lowercase_key == ACCESS_CONTROL_REQUEST_METHOD => {
284            set_once!(access_control_request_method);
285        }
286        AUTHORIZATION_HASH if lowercase_key == AUTHORIZATION => {
287            set_once!(authorization);
288        }
289        CONNECTION_HASH if lowercase_key == CONNECTION => {
290            set_once!(connection);
291        }
292        CONTENT_ENCODING_HASH if lowercase_key == CONTENT_ENCODING => {
293            set_once!(content_encoding);
294        }
295        CONTENT_LENGTH_HASH if lowercase_key == CONTENT_LENGTH => {
296            if known_headers.transfer_encoding.is_some() {
297                return Err(Error::BadRequest);
298            }
299            set_once!(content_length);
300        }
301        CONTENT_TYPE_HASH if lowercase_key == CONTENT_TYPE => {
302            set_once!(content_type);
303        }
304        COOKIE_HASH if lowercase_key == COOKIE => {
305            set_once!(cookie);
306        }
307        EXPECT_HASH if lowercase_key == EXPECT => {
308            set_once!(expect);
309        }
310        HOST_HASH if lowercase_key == HOST => {
311            set_once!(host);
312        }
313        IF_MATCH_HASH if lowercase_key == IF_MATCH => {
314            set_once!(if_match);
315        }
316        IF_MODIFIED_SINCE_HASH if lowercase_key == IF_MODIFIED_SINCE => {
317            set_once!(if_modified_since);
318        }
319        IF_NONE_MATCH_HASH if lowercase_key == IF_NONE_MATCH => {
320            set_once!(if_none_match);
321        }
322        IF_RANGE_HASH if lowercase_key == IF_RANGE => {
323            set_once!(if_range);
324        }
325        IF_UNMODIFIED_SINCE_HASH if lowercase_key == IF_UNMODIFIED_SINCE => {
326            set_once!(if_unmodified_since);
327        }
328        ORIGIN_HASH if lowercase_key == ORIGIN => {
329            set_once!(origin);
330        }
331        RANGE_HASH if lowercase_key == RANGE => {
332            set_once!(range);
333        }
334        TRANSFER_ENCODING_HASH if lowercase_key == TRANSFER_ENCODING => {
335            if known_headers.content_length.is_some() {
336                return Err(Error::BadRequest);
337            }
338            set_once!(transfer_encoding);
339        }
340        USER_AGENT_HASH if lowercase_key == USER_AGENT => {
341            set_once!(user_agent);
342        }
343        X_CSRF_TOKEN_HASH if lowercase_key == X_CSRF_TOKEN => {
344            set_once!(x_csrf_token);
345        }
346        X_FORWARDED_FOR_HASH if lowercase_key == X_FORWARDED_FOR => {
347            set_once!(x_forwarded_for);
348        }
349        X_FORWARDED_HOST_HASH if lowercase_key == X_FORWARDED_HOST => {
350            set_once!(x_forwarded_host);
351        }
352        X_REAL_IP_HASH if lowercase_key == X_REAL_IP => {
353            if known_headers.x_real_ip.is_some() {
354                return Err(Error::BadRequest);
355            }
356            known_headers.x_real_ip = Some(value);
357        }
358        X_HUB_SIGNATURE_256_HASH if lowercase_key == X_HUB_SIGNATURE_256 => {
359            if known_headers.x_hub_signature_256_hash.is_some() {
360                return Err(Error::BadRequest);
361            }
362            known_headers.x_hub_signature_256_hash = Some(value);
363        }
364        _ => return Ok((known_headers, true)),
365    }
366    Ok((known_headers, false))
367}
368
369fn _with_other_header<'a>(
370    custom_headers: Option<BTreeMap<&'static [u8], &'a [u8]>>,
371    other_headers: &Option<BTreeMap<u32, &'static [u8]>>,
372    hash: u32,
373    lowercase_key: &[u8],
374    value: &'a [u8],
375) -> Option<BTreeMap<&'static [u8], &'a [u8]>> {
376    if let Some(other_headers) = other_headers
377        && let Some(&found) = other_headers.get(&hash)
378        && found == lowercase_key
379    {
380        let mut custom_headers = custom_headers.unwrap_or_default();
381        custom_headers.insert(found, value);
382        return Some(custom_headers);
383    }
384    None
385}
386
387#[cfg(test)]
388mod test {
389    use super::*;
390    use std::io::Cursor;
391    use tokio::io::BufReader;
392    use tracing::Level;
393
394    #[tokio::test(flavor = "current_thread")]
395    async fn parse_request_line_and_headers() {
396        tracing_subscriber::fmt()
397            .with_max_level(Level::DEBUG)
398            .with_ansi(true)
399            .compact()
400            .init();
401        let parser = Parser::default();
402        let bytes = b"\
403            GET /test HTTP/1.1\r\n\
404            \r\n\
405        ";
406        let cursor = Cursor::new(bytes);
407        let mut reader = BufReader::new(cursor);
408        let mut buffer = vec![];
409        let (method, path) = parser
410            .parse_request_line(&mut reader, &mut buffer)
411            .await
412            .unwrap();
413        assert_eq!(method, Method::Get);
414        assert_eq!(&path, b"/test");
415        let bytes = b"\
416            HEAD / HTTP/1.1\r\n\
417            Host: example.org\r\n\
418            \r\n\
419        ";
420        let cursor = Cursor::new(bytes);
421        let mut reader = BufReader::new(cursor);
422        let (method, path) = parser
423            .parse_request_line(&mut reader, &mut buffer)
424            .await
425            .unwrap();
426        assert_eq!(method, Method::Head);
427        assert_eq!(&path, b"/");
428        let (known_headers, _) = parser
429            .parse_headers(&mut reader, &mut buffer)
430            .await
431            .unwrap();
432        assert_eq!(known_headers.host, Some(b"example.org".as_slice()));
433        let bytes = b"\
434            POST /post HTTP/1.1\r\n\
435            Host: example.org\r\n\
436            content-type: application/json\r\n\
437            content-length: 0\r\n\
438            \r\n\
439        ";
440        let cursor = Cursor::new(bytes);
441        let mut reader = BufReader::new(cursor);
442        let mut buffer = vec![];
443        let (method, path) = parser
444            .parse_request_line(&mut reader, &mut buffer)
445            .await
446            .unwrap();
447        assert_eq!(method, Method::Post);
448        assert_eq!(&path, b"/post");
449        let (known_headers, _) = parser
450            .parse_headers(&mut reader, &mut buffer)
451            .await
452            .unwrap();
453        assert_eq!(known_headers.host, Some(b"example.org".as_slice()));
454        #[cfg(not(feature = "_minimal"))]
455        assert_eq!(
456            known_headers.content_type,
457            Some(b"application/json".as_slice())
458        );
459        assert_eq!(known_headers.content_length, Some(b"0".as_slice()));
460        let bytes = b"\
461            GET /test HTTP/1.1\r\n\
462            Host: example.org\r\n\
463            Host: fake.xyz\r\n\
464            \r\n\
465        ";
466        let cursor = Cursor::new(bytes);
467        let mut reader = BufReader::new(cursor);
468        let mut buffer = vec![];
469        let (method, path) = parser
470            .parse_request_line(&mut reader, &mut buffer)
471            .await
472            .unwrap();
473        assert_eq!(method, Method::Get);
474        assert_eq!(&path, b"/test");
475        assert!(
476            parser
477                .parse_headers(&mut reader, &mut buffer)
478                .await
479                .is_err()
480        );
481        let bytes = b"\
482            GET /test HTTP/1.1\r\n\
483            Content-Length: 100\r\n\
484            Transfer-Encoding: chunked\r\n\
485            \r\n\
486        ";
487        let cursor = Cursor::new(bytes);
488        let mut reader = BufReader::new(cursor);
489        let mut buffer = vec![];
490        let (method, path) = parser
491            .parse_request_line(&mut reader, &mut buffer)
492            .await
493            .unwrap();
494        assert_eq!(method, Method::Get);
495        assert_eq!(&path, b"/test");
496        assert!(
497            parser
498                .parse_headers(&mut reader, &mut buffer)
499                .await
500                .is_err()
501        );
502        let bytes = b"\
503            GET /test HTTP/1.1\r\n\
504            Host: example.org\r\n\
505            x-test: 1\r\n\
506            \r\n\
507        ";
508        let cursor = Cursor::new(bytes);
509        let mut reader = BufReader::new(cursor);
510        let mut buffer = vec![];
511        let parser = parser.with_header(b"x-test").unwrap();
512        let (method, path) = parser
513            .parse_request_line(&mut reader, &mut buffer)
514            .await
515            .unwrap();
516        assert_eq!(method, Method::Get);
517        assert_eq!(&path, b"/test");
518        let (known_headers, custom_headers) = parser
519            .parse_headers(&mut reader, &mut buffer)
520            .await
521            .unwrap();
522        assert_eq!(known_headers.host, Some(b"example.org".as_slice()));
523        assert!(custom_headers.is_some());
524        let custom_headers = custom_headers.unwrap();
525        assert_eq!(
526            custom_headers.get(b"x-test".as_slice()),
527            Some(&b"1".as_slice())
528        );
529    }
530}