h11/
_headers.rs

1use std::collections::HashSet;
2
3use crate::{
4    _abnf::{FIELD_NAME, FIELD_VALUE},
5    _events::Request,
6    _util::ProtocolError,
7};
8use lazy_static::lazy_static;
9use regex::bytes::Regex;
10
11lazy_static! {
12    static ref CONTENT_LENGTH_RE: Regex = Regex::new(r"^[0-9]+$").unwrap();
13    static ref FIELD_NAME_RE: Regex = Regex::new(&format!(r"^{}$", FIELD_NAME)).unwrap();
14    static ref FIELD_VALUE_RE: Regex = Regex::new(&format!(r"^{}$", *FIELD_VALUE)).unwrap();
15}
16
17#[derive(Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)]
18pub struct Headers(Vec<(Vec<u8>, Vec<u8>, Vec<u8>)>);
19
20impl std::fmt::Debug for Headers {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        let mut debug_struct = f.debug_struct("Headers");
23        self.0.iter().for_each(|(raw_name, _, value)| {
24            debug_struct.field(
25                std::str::from_utf8(raw_name).unwrap(),
26                &std::str::from_utf8(value).unwrap(),
27            );
28        });
29        debug_struct.finish()
30    }
31}
32
33impl Headers {
34    pub fn iter(&self) -> impl Iterator<Item = (Vec<u8>, Vec<u8>)> + '_ {
35        self.0
36            .iter()
37            .map(|(_, name, value)| ((*name).clone(), (*value).clone()))
38    }
39
40    pub fn raw_items(&self) -> Vec<&(Vec<u8>, Vec<u8>, Vec<u8>)> {
41        self.0.iter().collect()
42    }
43
44    pub fn len(&self) -> usize {
45        self.0.len()
46    }
47}
48
49impl From<Vec<(Vec<u8>, Vec<u8>)>> for Headers {
50    fn from(value: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
51        normalize_and_validate(value, false).unwrap()
52    }
53}
54
55pub fn normalize_and_validate(
56    headers: Vec<(Vec<u8>, Vec<u8>)>,
57    _parsed: bool,
58) -> Result<Headers, ProtocolError> {
59    let mut new_headers = vec![];
60    let mut seen_content_length = None;
61    let mut saw_transfer_encoding = false;
62    for (name, value) in headers {
63        if !_parsed {
64            if !FIELD_NAME_RE.is_match(&name) {
65                return Err(ProtocolError::LocalProtocolError(
66                    format!("Illegal header name {:?}", &name).into(),
67                ));
68            }
69            if !FIELD_VALUE_RE.is_match(&value) {
70                return Err(ProtocolError::LocalProtocolError(
71                    format!("Illegal header value {:?}", &value).into(),
72                ));
73            }
74        }
75        let raw_name = name.clone();
76        let name = name.to_ascii_lowercase();
77        if name == b"content-length" {
78            let lengths: HashSet<Vec<u8>> = value
79                .split(|&b| b == b',')
80                .map(|length| {
81                    std::str::from_utf8(length)
82                        .unwrap()
83                        .trim()
84                        .as_bytes()
85                        .to_vec()
86                })
87                .collect();
88            if lengths.len() != 1 {
89                return Err(ProtocolError::LocalProtocolError(
90                    "conflicting Content-Length headers".into(),
91                ));
92            }
93            let value = lengths.iter().next().unwrap();
94            if !CONTENT_LENGTH_RE.is_match(value) {
95                return Err(ProtocolError::LocalProtocolError(
96                    "bad Content-Length".into(),
97                ));
98            }
99            if seen_content_length.is_none() {
100                seen_content_length = Some(value.clone());
101                new_headers.push((raw_name, name, value.clone()));
102            } else if seen_content_length != Some(value.clone()) {
103                return Err(ProtocolError::LocalProtocolError(
104                    "conflicting Content-Length headers".into(),
105                ));
106            }
107        } else if name == b"transfer-encoding" {
108            // "A server that receives a request message with a transfer coding
109            // it does not understand SHOULD respond with 501 (Not
110            // Implemented)."
111            // https://tools.ietf.org/html/rfc7230#section-3.3.1
112            if saw_transfer_encoding {
113                return Err(ProtocolError::LocalProtocolError(
114                    ("multiple Transfer-Encoding headers", 501).into(),
115                ));
116            }
117            // "All transfer-coding names are case-insensitive"
118            // -- https://tools.ietf.org/html/rfc7230#section-4
119            let value = value.to_ascii_lowercase();
120            if value != b"chunked" {
121                return Err(ProtocolError::LocalProtocolError(
122                    ("Only Transfer-Encoding: chunked is supported", 501).into(),
123                ));
124            }
125            saw_transfer_encoding = true;
126            new_headers.push((raw_name, name, value));
127        } else {
128            new_headers.push((raw_name, name, value.to_vec()));
129        }
130    }
131
132    Ok(Headers(new_headers))
133}
134
135pub fn get_comma_header(headers: &Headers, name: &[u8]) -> Vec<Vec<u8>> {
136    let mut out: Vec<Vec<u8>> = vec![];
137    let name = name.to_ascii_lowercase();
138    for (found_name, found_value) in headers.iter() {
139        if found_name == name {
140            for found_split_value in found_value.to_ascii_lowercase().split(|&b| b == b',') {
141                let found_split_value = std::str::from_utf8(found_split_value).unwrap().trim();
142                if !found_split_value.is_empty() {
143                    out.push(found_split_value.as_bytes().to_vec());
144                }
145            }
146        }
147    }
148    out
149}
150
151pub fn set_comma_header(
152    headers: &Headers,
153    name: &[u8],
154    new_values: Vec<Vec<u8>>,
155) -> Result<Headers, ProtocolError> {
156    let mut new_headers = vec![];
157    for (found_name, found_value) in headers.iter() {
158        if found_name != name {
159            new_headers.push((found_name, found_value));
160        }
161    }
162    for new_value in new_values {
163        new_headers.push((name.to_vec(), new_value));
164    }
165    normalize_and_validate(new_headers, false)
166}
167
168pub fn has_expect_100_continue(request: &Request) -> bool {
169    // https://tools.ietf.org/html/rfc7231#section-5.1.1
170    // "A server that receives a 100-continue expectation in an HTTP/1.0 request
171    // MUST ignore that expectation."
172    if request.http_version < b"1.1".to_vec() {
173        return false;
174    }
175    let expect = get_comma_header(&request.headers, b"expect");
176    expect.contains(&b"100-continue".to_vec())
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_normalize_and_validate() {
185        assert_eq!(
186            normalize_and_validate(vec![(b"foo".to_vec(), b"bar".to_vec())], false).unwrap(),
187            Headers(vec![(b"foo".to_vec(), b"foo".to_vec(), b"bar".to_vec())])
188        );
189
190        // no leading/trailing whitespace in names
191        assert_eq!(
192            normalize_and_validate(vec![(b"foo ".to_vec(), b"bar".to_vec())], false)
193                .expect_err("Expect ProtocolError::LocalProtocolError"),
194            ProtocolError::LocalProtocolError(
195                ("Illegal header name [102, 111, 111, 32]".to_string(), 400).into()
196            )
197        );
198        assert_eq!(
199            normalize_and_validate(vec![(b" foo".to_vec(), b"bar".to_vec())], false)
200                .expect_err("Expect ProtocolError::LocalProtocolError"),
201            ProtocolError::LocalProtocolError(
202                ("Illegal header name [32, 102, 111, 111]".to_string(), 400).into()
203            )
204        );
205
206        // no weird characters in names
207        assert_eq!(
208            normalize_and_validate(vec![(b"foo bar".to_vec(), b"baz".to_vec())], false)
209                .expect_err("Expect ProtocolError::LocalProtocolError"),
210            ProtocolError::LocalProtocolError(
211                (
212                    "Illegal header name [102, 111, 111, 32, 98, 97, 114]".to_string(),
213                    400
214                )
215                    .into()
216            )
217        );
218        assert_eq!(
219            normalize_and_validate(vec![(b"foo\x00bar".to_vec(), b"baz".to_vec())], false)
220                .expect_err("Expect ProtocolError::LocalProtocolError"),
221            ProtocolError::LocalProtocolError(
222                (
223                    "Illegal header name [102, 111, 111, 0, 98, 97, 114]".to_string(),
224                    400
225                )
226                    .into()
227            )
228        );
229        // Not even 8-bit characters:
230        assert_eq!(
231            normalize_and_validate(vec![(b"foo\xffbar".to_vec(), b"baz".to_vec())], false)
232                .expect_err("Expect ProtocolError::LocalProtocolError"),
233            ProtocolError::LocalProtocolError(
234                (
235                    "Illegal header name [102, 111, 111, 255, 98, 97, 114]".to_string(),
236                    400
237                )
238                    .into()
239            )
240        );
241        // And not even the control characters we allow in values:
242        assert_eq!(
243            normalize_and_validate(vec![(b"foo\x01bar".to_vec(), b"baz".to_vec())], false)
244                .expect_err("Expect ProtocolError::LocalProtocolError"),
245            ProtocolError::LocalProtocolError(
246                (
247                    "Illegal header name [102, 111, 111, 1, 98, 97, 114]".to_string(),
248                    400
249                )
250                    .into()
251            )
252        );
253
254        // no return or NUL characters in values
255        assert_eq!(
256            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\rbaz".to_vec())], false)
257                .expect_err("Expect ProtocolError::LocalProtocolError"),
258            ProtocolError::LocalProtocolError(
259                (
260                    "Illegal header value [98, 97, 114, 13, 98, 97, 122]".to_string(),
261                    400
262                )
263                    .into()
264            )
265        );
266        assert_eq!(
267            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\nbaz".to_vec())], false)
268                .expect_err("Expect ProtocolError::LocalProtocolError"),
269            ProtocolError::LocalProtocolError(
270                (
271                    "Illegal header value [98, 97, 114, 10, 98, 97, 122]".to_string(),
272                    400
273                )
274                    .into()
275            )
276        );
277        assert_eq!(
278            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\x00baz".to_vec())], false)
279                .expect_err("Expect ProtocolError::LocalProtocolError"),
280            ProtocolError::LocalProtocolError(
281                (
282                    "Illegal header value [98, 97, 114, 0, 98, 97, 122]".to_string(),
283                    400
284                )
285                    .into()
286            )
287        );
288        // no leading/trailing whitespace
289        assert_eq!(
290            normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz  ".to_vec())], false)
291                .expect_err("Expect ProtocolError::LocalProtocolError"),
292            ProtocolError::LocalProtocolError(
293                (
294                    "Illegal header value [98, 97, 114, 98, 97, 122, 32, 32]".to_string(),
295                    400
296                )
297                    .into()
298            )
299        );
300        assert_eq!(
301            normalize_and_validate(vec![(b"foo".to_vec(), b"  barbaz".to_vec())], false)
302                .expect_err("Expect ProtocolError::LocalProtocolError"),
303            ProtocolError::LocalProtocolError(
304                (
305                    "Illegal header value [32, 32, 98, 97, 114, 98, 97, 122]".to_string(),
306                    400
307                )
308                    .into()
309            )
310        );
311        assert_eq!(
312            normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz\t".to_vec())], false)
313                .expect_err("Expect ProtocolError::LocalProtocolError"),
314            ProtocolError::LocalProtocolError(
315                (
316                    "Illegal header value [98, 97, 114, 98, 97, 122, 9]".to_string(),
317                    400
318                )
319                    .into()
320            )
321        );
322        assert_eq!(
323            normalize_and_validate(vec![(b"foo".to_vec(), b"\tbarbaz".to_vec())], false)
324                .expect_err("Expect ProtocolError::LocalProtocolError"),
325            ProtocolError::LocalProtocolError(
326                (
327                    "Illegal header value [9, 98, 97, 114, 98, 97, 122]".to_string(),
328                    400
329                )
330                    .into()
331            )
332        );
333
334        // content-length
335        assert_eq!(
336            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1".to_vec())], false)
337                .unwrap(),
338            Headers(vec![(
339                b"Content-Length".to_vec(),
340                b"content-length".to_vec(),
341                b"1".to_vec()
342            )])
343        );
344        assert_eq!(
345            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"asdf".to_vec())], false)
346                .expect_err("Expect ProtocolError::LocalProtocolError"),
347            ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
348        );
349        assert_eq!(
350            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1x".to_vec())], false)
351                .expect_err("Expect ProtocolError::LocalProtocolError"),
352            ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
353        );
354        assert_eq!(
355            normalize_and_validate(
356                vec![
357                    (b"Content-Length".to_vec(), b"1".to_vec()),
358                    (b"Content-Length".to_vec(), b"2".to_vec())
359                ],
360                false
361            )
362            .expect_err("Expect ProtocolError::LocalProtocolError"),
363            ProtocolError::LocalProtocolError(
364                ("conflicting Content-Length headers".to_string(), 400).into()
365            )
366        );
367        assert_eq!(
368            normalize_and_validate(
369                vec![
370                    (b"Content-Length".to_vec(), b"0".to_vec()),
371                    (b"Content-Length".to_vec(), b"0".to_vec())
372                ],
373                false
374            )
375            .unwrap(),
376            Headers(vec![(
377                b"Content-Length".to_vec(),
378                b"content-length".to_vec(),
379                b"0".to_vec()
380            )])
381        );
382        assert_eq!(
383            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"0 , 0".to_vec())], false)
384                .unwrap(),
385            Headers(vec![(
386                b"Content-Length".to_vec(),
387                b"content-length".to_vec(),
388                b"0".to_vec()
389            )])
390        );
391        assert_eq!(
392            normalize_and_validate(
393                vec![
394                    (b"Content-Length".to_vec(), b"1".to_vec()),
395                    (b"Content-Length".to_vec(), b"1".to_vec()),
396                    (b"Content-Length".to_vec(), b"2".to_vec())
397                ],
398                false
399            )
400            .expect_err("Expect ProtocolError::LocalProtocolError"),
401            ProtocolError::LocalProtocolError(
402                ("conflicting Content-Length headers".to_string(), 400).into()
403            )
404        );
405        assert_eq!(
406            normalize_and_validate(
407                vec![(b"Content-Length".to_vec(), b"1 , 1,2".to_vec())],
408                false
409            )
410            .expect_err("Expect ProtocolError::LocalProtocolError"),
411            ProtocolError::LocalProtocolError(
412                ("conflicting Content-Length headers".to_string(), 400).into()
413            )
414        );
415
416        // transfer-encoding
417        assert_eq!(
418            normalize_and_validate(
419                vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())],
420                false
421            )
422            .unwrap(),
423            Headers(vec![(
424                b"Transfer-Encoding".to_vec(),
425                b"transfer-encoding".to_vec(),
426                b"chunked".to_vec()
427            )])
428        );
429        assert_eq!(
430            normalize_and_validate(
431                vec![(b"Transfer-Encoding".to_vec(), b"cHuNkEd".to_vec())],
432                false
433            )
434            .unwrap(),
435            Headers(vec![(
436                b"Transfer-Encoding".to_vec(),
437                b"transfer-encoding".to_vec(),
438                b"chunked".to_vec()
439            )])
440        );
441        assert_eq!(
442            normalize_and_validate(
443                vec![(b"Transfer-Encoding".to_vec(), b"gzip".to_vec())],
444                false
445            )
446            .expect_err("Expect ProtocolError::LocalProtocolError"),
447            ProtocolError::LocalProtocolError(
448                (
449                    "Only Transfer-Encoding: chunked is supported".to_string(),
450                    501
451                )
452                    .into()
453            )
454        );
455        assert_eq!(
456            normalize_and_validate(
457                vec![
458                    (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()),
459                    (b"Transfer-Encoding".to_vec(), b"gzip".to_vec())
460                ],
461                false
462            )
463            .expect_err("Expect ProtocolError::LocalProtocolError"),
464            ProtocolError::LocalProtocolError(
465                ("multiple Transfer-Encoding headers".to_string(), 501).into()
466            )
467        );
468    }
469
470    #[test]
471    fn test_get_set_comma_header() {
472        let headers = normalize_and_validate(
473            vec![
474                (b"Connection".to_vec(), b"close".to_vec()),
475                (b"whatever".to_vec(), b"something".to_vec()),
476                (b"connectiON".to_vec(), b"fOo,, , BAR".to_vec()),
477            ],
478            false,
479        )
480        .unwrap();
481
482        assert_eq!(
483            get_comma_header(&headers, b"connection"),
484            vec![b"close".to_vec(), b"foo".to_vec(), b"bar".to_vec()]
485        );
486
487        let headers =
488            set_comma_header(&headers, b"newthing", vec![b"a".to_vec(), b"b".to_vec()]).unwrap();
489
490        assert_eq!(
491            headers,
492            Headers(vec![
493                (
494                    b"connection".to_vec(),
495                    b"connection".to_vec(),
496                    b"close".to_vec()
497                ),
498                (
499                    b"whatever".to_vec(),
500                    b"whatever".to_vec(),
501                    b"something".to_vec()
502                ),
503                (
504                    b"connection".to_vec(),
505                    b"connection".to_vec(),
506                    b"fOo,, , BAR".to_vec()
507                ),
508                (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
509                (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
510            ])
511        );
512
513        let headers =
514            set_comma_header(&headers, b"whatever", vec![b"different thing".to_vec()]).unwrap();
515
516        assert_eq!(
517            headers,
518            Headers(vec![
519                (
520                    b"connection".to_vec(),
521                    b"connection".to_vec(),
522                    b"close".to_vec()
523                ),
524                (
525                    b"connection".to_vec(),
526                    b"connection".to_vec(),
527                    b"fOo,, , BAR".to_vec()
528                ),
529                (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
530                (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
531                (
532                    b"whatever".to_vec(),
533                    b"whatever".to_vec(),
534                    b"different thing".to_vec()
535                ),
536            ])
537        );
538    }
539
540    #[test]
541    fn test_has_100_continue() {
542        assert!(has_expect_100_continue(&Request {
543            method: b"GET".to_vec(),
544            target: b"/".to_vec(),
545            headers: normalize_and_validate(
546                vec![
547                    (b"Host".to_vec(), b"example.com".to_vec()),
548                    (b"Expect".to_vec(), b"100-continue".to_vec())
549                ],
550                false
551            )
552            .unwrap(),
553            http_version: b"1.1".to_vec(),
554        }));
555        assert!(!has_expect_100_continue(&Request {
556            method: b"GET".to_vec(),
557            target: b"/".to_vec(),
558            headers: normalize_and_validate(
559                vec![(b"Host".to_vec(), b"example.com".to_vec())],
560                false
561            )
562            .unwrap(),
563            http_version: b"1.1".to_vec(),
564        }));
565        // Case insensitive
566        assert!(has_expect_100_continue(&Request {
567            method: b"GET".to_vec(),
568            target: b"/".to_vec(),
569            headers: normalize_and_validate(
570                vec![
571                    (b"Host".to_vec(), b"example.com".to_vec()),
572                    (b"Expect".to_vec(), b"100-Continue".to_vec())
573                ],
574                false
575            )
576            .unwrap(),
577            http_version: b"1.1".to_vec(),
578        }));
579        // Doesn't work in HTTP/1.0
580        assert!(!has_expect_100_continue(&Request {
581            method: b"GET".to_vec(),
582            target: b"/".to_vec(),
583            headers: normalize_and_validate(
584                vec![
585                    (b"Host".to_vec(), b"example.com".to_vec()),
586                    (b"Expect".to_vec(), b"100-continue".to_vec())
587                ],
588                false
589            )
590            .unwrap(),
591            http_version: b"1.0".to_vec(),
592        }));
593    }
594}