1use std::collections::HashSet;
2
3use crate::{
4 _abnf::{FIELD_NAME, FIELD_VALUE},
5 _events::Request,
6 _util::ProtocolError,
7};
8use lazy_static::lazy_static;
9use regex::bytes::Regex;
10
11lazy_static! {
12 static ref CONTENT_LENGTH_RE: Regex = Regex::new(r"^[0-9]+$").unwrap();
13 static ref FIELD_NAME_RE: Regex = Regex::new(&format!(r"^{}$", FIELD_NAME)).unwrap();
14 static ref FIELD_VALUE_RE: Regex = Regex::new(&format!(r"^{}$", *FIELD_VALUE)).unwrap();
15}
16
17fn trim_ascii_whitespace(value: &[u8]) -> &[u8] {
18 let start = value
19 .iter()
20 .position(|byte| !byte.is_ascii_whitespace())
21 .unwrap_or(value.len());
22 let end = value
23 .iter()
24 .rposition(|byte| !byte.is_ascii_whitespace())
25 .map(|idx| idx + 1)
26 .unwrap_or(start);
27 &value[start..end]
28}
29
30#[derive(Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)]
35pub struct Headers(Vec<(Vec<u8>, Vec<u8>, Vec<u8>)>);
36
37impl std::fmt::Debug for Headers {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 let mut debug_struct = f.debug_struct("Headers");
40 self.0.iter().for_each(|(raw_name, _, value)| {
41 debug_struct.field(
42 &String::from_utf8_lossy(raw_name),
43 &String::from_utf8_lossy(value),
44 );
45 });
46 debug_struct.finish()
47 }
48}
49
50impl Headers {
51 pub fn iter(&self) -> impl Iterator<Item = (Vec<u8>, Vec<u8>)> + '_ {
55 self.0
56 .iter()
57 .map(|(_, name, value)| ((*name).clone(), (*value).clone()))
58 }
59
60 pub fn raw_items(&self) -> Vec<&(Vec<u8>, Vec<u8>, Vec<u8>)> {
62 self.0.iter().collect()
63 }
64
65 pub fn len(&self) -> usize {
67 self.0.len()
68 }
69
70 pub fn is_empty(&self) -> bool {
72 self.0.is_empty()
73 }
74
75 pub fn new<I, N, V>(headers: I) -> Result<Self, ProtocolError>
81 where
82 I: IntoIterator<Item = (N, V)>,
83 N: AsRef<[u8]>,
84 V: AsRef<[u8]>,
85 {
86 normalize_and_validate(
87 headers
88 .into_iter()
89 .map(|(name, value)| (name.as_ref().to_vec(), value.as_ref().to_vec()))
90 .collect(),
91 false,
92 )
93 }
94}
95
96impl From<Vec<(Vec<u8>, Vec<u8>)>> for Headers {
97 fn from(value: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
102 Headers::new(value)
103 .expect("invalid HTTP header list; use Headers::new for fallible construction")
104 }
105}
106
107pub fn normalize_and_validate(
112 headers: Vec<(Vec<u8>, Vec<u8>)>,
113 _parsed: bool,
114) -> Result<Headers, ProtocolError> {
115 let mut new_headers = vec![];
116 let mut seen_content_length = None;
117 let mut saw_transfer_encoding = false;
118 for (name, value) in headers {
119 if !_parsed {
120 if !FIELD_NAME_RE.is_match(&name) {
121 return Err(ProtocolError::LocalProtocolError(
122 format!("Illegal header name {:?}", &name).into(),
123 ));
124 }
125 if !FIELD_VALUE_RE.is_match(&value) {
126 return Err(ProtocolError::LocalProtocolError(
127 format!("Illegal header value {:?}", &value).into(),
128 ));
129 }
130 }
131 let raw_name = name.clone();
132 let name = name.to_ascii_lowercase();
133 if name == b"content-length" {
134 let lengths: HashSet<Vec<u8>> = value
135 .split(|&b| b == b',')
136 .map(|length| trim_ascii_whitespace(length).to_vec())
137 .collect();
138 if lengths.len() != 1 {
139 return Err(ProtocolError::LocalProtocolError(
140 "conflicting Content-Length headers".into(),
141 ));
142 }
143 let value = lengths.iter().next().unwrap();
144 if !CONTENT_LENGTH_RE.is_match(value) {
145 return Err(ProtocolError::LocalProtocolError(
146 "bad Content-Length".into(),
147 ));
148 }
149 if seen_content_length.is_none() {
150 seen_content_length = Some(value.clone());
151 new_headers.push((raw_name, name, value.clone()));
152 } else if seen_content_length != Some(value.clone()) {
153 return Err(ProtocolError::LocalProtocolError(
154 "conflicting Content-Length headers".into(),
155 ));
156 }
157 } else if name == b"transfer-encoding" {
158 if saw_transfer_encoding {
163 return Err(ProtocolError::LocalProtocolError(
164 ("multiple Transfer-Encoding headers", 501).into(),
165 ));
166 }
167 let value = value.to_ascii_lowercase();
170 if value != b"chunked" {
171 return Err(ProtocolError::LocalProtocolError(
172 ("Only Transfer-Encoding: chunked is supported", 501).into(),
173 ));
174 }
175 saw_transfer_encoding = true;
176 new_headers.push((raw_name, name, value));
177 } else {
178 new_headers.push((raw_name, name, value.to_vec()));
179 }
180 }
181
182 Ok(Headers(new_headers))
183}
184
185pub fn get_comma_header(headers: &Headers, name: &[u8]) -> Vec<Vec<u8>> {
187 let mut out: Vec<Vec<u8>> = vec![];
188 let name = name.to_ascii_lowercase();
189 for (found_name, found_value) in headers.iter() {
190 if found_name == name {
191 for found_split_value in found_value.to_ascii_lowercase().split(|&b| b == b',') {
192 let found_split_value = trim_ascii_whitespace(found_split_value);
193 if !found_split_value.is_empty() {
194 out.push(found_split_value.to_vec());
195 }
196 }
197 }
198 }
199 out
200}
201
202pub fn set_comma_header(
204 headers: &Headers,
205 name: &[u8],
206 new_values: Vec<Vec<u8>>,
207) -> Result<Headers, ProtocolError> {
208 let mut new_headers = vec![];
209 for (found_name, found_value) in headers.iter() {
210 if found_name != name {
211 new_headers.push((found_name, found_value));
212 }
213 }
214 for new_value in new_values {
215 new_headers.push((name.to_vec(), new_value));
216 }
217 normalize_and_validate(new_headers, false)
218}
219
220pub fn has_expect_100_continue(request: &Request) -> bool {
222 if request.http_version < b"1.1".to_vec() {
226 return false;
227 }
228 let expect = get_comma_header(&request.headers, b"expect");
229 expect.contains(&b"100-continue".to_vec())
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_headers_new_rejects_invalid_input() {
238 assert!(Headers::new(vec![(b"bad header".to_vec(), b"value".to_vec())]).is_err());
239 }
240
241 #[test]
242 fn test_non_utf8_comma_headers_do_not_panic() {
243 assert_eq!(
244 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"\xff".to_vec())], true)
245 .unwrap_err(),
246 ProtocolError::LocalProtocolError("bad Content-Length".into())
247 );
248
249 let headers = normalize_and_validate(
250 vec![(b"Connection".to_vec(), b"close, \xff".to_vec())],
251 true,
252 )
253 .unwrap();
254 assert_eq!(
255 get_comma_header(&headers, b"connection"),
256 vec![b"close".to_vec(), b"\xff".to_vec()]
257 );
258 }
259
260 #[test]
261 fn test_headers_new_accepts_borrowed_inputs() {
262 assert_eq!(
263 Headers::new([("Host", "example.com"), ("Accept", "*/*")]).unwrap(),
264 Headers(vec![
265 (b"Host".to_vec(), b"host".to_vec(), b"example.com".to_vec()),
266 (b"Accept".to_vec(), b"accept".to_vec(), b"*/*".to_vec()),
267 ])
268 );
269 assert_eq!(
270 Headers::new([(b"Host".as_slice(), b"example.com".as_slice())]).unwrap(),
271 Headers(vec![(
272 b"Host".to_vec(),
273 b"host".to_vec(),
274 b"example.com".to_vec()
275 )])
276 );
277 }
278
279 #[test]
280 fn test_normalize_and_validate() {
281 assert_eq!(
282 normalize_and_validate(vec![(b"foo".to_vec(), b"bar".to_vec())], false).unwrap(),
283 Headers(vec![(b"foo".to_vec(), b"foo".to_vec(), b"bar".to_vec())])
284 );
285
286 assert_eq!(
288 normalize_and_validate(vec![(b"foo ".to_vec(), b"bar".to_vec())], false)
289 .expect_err("Expect ProtocolError::LocalProtocolError"),
290 ProtocolError::LocalProtocolError(
291 ("Illegal header name [102, 111, 111, 32]".to_string(), 400).into()
292 )
293 );
294 assert_eq!(
295 normalize_and_validate(vec![(b" foo".to_vec(), b"bar".to_vec())], false)
296 .expect_err("Expect ProtocolError::LocalProtocolError"),
297 ProtocolError::LocalProtocolError(
298 ("Illegal header name [32, 102, 111, 111]".to_string(), 400).into()
299 )
300 );
301
302 assert_eq!(
304 normalize_and_validate(vec![(b"foo bar".to_vec(), b"baz".to_vec())], false)
305 .expect_err("Expect ProtocolError::LocalProtocolError"),
306 ProtocolError::LocalProtocolError(
307 (
308 "Illegal header name [102, 111, 111, 32, 98, 97, 114]".to_string(),
309 400
310 )
311 .into()
312 )
313 );
314 assert_eq!(
315 normalize_and_validate(vec![(b"foo\x00bar".to_vec(), b"baz".to_vec())], false)
316 .expect_err("Expect ProtocolError::LocalProtocolError"),
317 ProtocolError::LocalProtocolError(
318 (
319 "Illegal header name [102, 111, 111, 0, 98, 97, 114]".to_string(),
320 400
321 )
322 .into()
323 )
324 );
325 assert_eq!(
327 normalize_and_validate(vec![(b"foo\xffbar".to_vec(), b"baz".to_vec())], false)
328 .expect_err("Expect ProtocolError::LocalProtocolError"),
329 ProtocolError::LocalProtocolError(
330 (
331 "Illegal header name [102, 111, 111, 255, 98, 97, 114]".to_string(),
332 400
333 )
334 .into()
335 )
336 );
337 assert_eq!(
339 normalize_and_validate(vec![(b"foo\x01bar".to_vec(), b"baz".to_vec())], false)
340 .expect_err("Expect ProtocolError::LocalProtocolError"),
341 ProtocolError::LocalProtocolError(
342 (
343 "Illegal header name [102, 111, 111, 1, 98, 97, 114]".to_string(),
344 400
345 )
346 .into()
347 )
348 );
349
350 assert_eq!(
352 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\rbaz".to_vec())], false)
353 .expect_err("Expect ProtocolError::LocalProtocolError"),
354 ProtocolError::LocalProtocolError(
355 (
356 "Illegal header value [98, 97, 114, 13, 98, 97, 122]".to_string(),
357 400
358 )
359 .into()
360 )
361 );
362 assert_eq!(
363 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\nbaz".to_vec())], false)
364 .expect_err("Expect ProtocolError::LocalProtocolError"),
365 ProtocolError::LocalProtocolError(
366 (
367 "Illegal header value [98, 97, 114, 10, 98, 97, 122]".to_string(),
368 400
369 )
370 .into()
371 )
372 );
373 assert_eq!(
374 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\x00baz".to_vec())], false)
375 .expect_err("Expect ProtocolError::LocalProtocolError"),
376 ProtocolError::LocalProtocolError(
377 (
378 "Illegal header value [98, 97, 114, 0, 98, 97, 122]".to_string(),
379 400
380 )
381 .into()
382 )
383 );
384 assert_eq!(
386 normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz ".to_vec())], false)
387 .expect_err("Expect ProtocolError::LocalProtocolError"),
388 ProtocolError::LocalProtocolError(
389 (
390 "Illegal header value [98, 97, 114, 98, 97, 122, 32, 32]".to_string(),
391 400
392 )
393 .into()
394 )
395 );
396 assert_eq!(
397 normalize_and_validate(vec![(b"foo".to_vec(), b" barbaz".to_vec())], false)
398 .expect_err("Expect ProtocolError::LocalProtocolError"),
399 ProtocolError::LocalProtocolError(
400 (
401 "Illegal header value [32, 32, 98, 97, 114, 98, 97, 122]".to_string(),
402 400
403 )
404 .into()
405 )
406 );
407 assert_eq!(
408 normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz\t".to_vec())], false)
409 .expect_err("Expect ProtocolError::LocalProtocolError"),
410 ProtocolError::LocalProtocolError(
411 (
412 "Illegal header value [98, 97, 114, 98, 97, 122, 9]".to_string(),
413 400
414 )
415 .into()
416 )
417 );
418 assert_eq!(
419 normalize_and_validate(vec![(b"foo".to_vec(), b"\tbarbaz".to_vec())], false)
420 .expect_err("Expect ProtocolError::LocalProtocolError"),
421 ProtocolError::LocalProtocolError(
422 (
423 "Illegal header value [9, 98, 97, 114, 98, 97, 122]".to_string(),
424 400
425 )
426 .into()
427 )
428 );
429
430 assert_eq!(
432 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1".to_vec())], false)
433 .unwrap(),
434 Headers(vec![(
435 b"Content-Length".to_vec(),
436 b"content-length".to_vec(),
437 b"1".to_vec()
438 )])
439 );
440 assert_eq!(
441 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"asdf".to_vec())], false)
442 .expect_err("Expect ProtocolError::LocalProtocolError"),
443 ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
444 );
445 assert_eq!(
446 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1x".to_vec())], false)
447 .expect_err("Expect ProtocolError::LocalProtocolError"),
448 ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
449 );
450 assert_eq!(
451 normalize_and_validate(
452 vec![
453 (b"Content-Length".to_vec(), b"1".to_vec()),
454 (b"Content-Length".to_vec(), b"2".to_vec())
455 ],
456 false
457 )
458 .expect_err("Expect ProtocolError::LocalProtocolError"),
459 ProtocolError::LocalProtocolError(
460 ("conflicting Content-Length headers".to_string(), 400).into()
461 )
462 );
463 assert_eq!(
464 normalize_and_validate(
465 vec![
466 (b"Content-Length".to_vec(), b"0".to_vec()),
467 (b"Content-Length".to_vec(), b"0".to_vec())
468 ],
469 false
470 )
471 .unwrap(),
472 Headers(vec![(
473 b"Content-Length".to_vec(),
474 b"content-length".to_vec(),
475 b"0".to_vec()
476 )])
477 );
478 assert_eq!(
479 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"0 , 0".to_vec())], false)
480 .unwrap(),
481 Headers(vec![(
482 b"Content-Length".to_vec(),
483 b"content-length".to_vec(),
484 b"0".to_vec()
485 )])
486 );
487 assert_eq!(
488 normalize_and_validate(
489 vec![
490 (b"Content-Length".to_vec(), b"1".to_vec()),
491 (b"Content-Length".to_vec(), b"1".to_vec()),
492 (b"Content-Length".to_vec(), b"2".to_vec())
493 ],
494 false
495 )
496 .expect_err("Expect ProtocolError::LocalProtocolError"),
497 ProtocolError::LocalProtocolError(
498 ("conflicting Content-Length headers".to_string(), 400).into()
499 )
500 );
501 assert_eq!(
502 normalize_and_validate(
503 vec![(b"Content-Length".to_vec(), b"1 , 1,2".to_vec())],
504 false
505 )
506 .expect_err("Expect ProtocolError::LocalProtocolError"),
507 ProtocolError::LocalProtocolError(
508 ("conflicting Content-Length headers".to_string(), 400).into()
509 )
510 );
511
512 assert_eq!(
514 normalize_and_validate(
515 vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())],
516 false
517 )
518 .unwrap(),
519 Headers(vec![(
520 b"Transfer-Encoding".to_vec(),
521 b"transfer-encoding".to_vec(),
522 b"chunked".to_vec()
523 )])
524 );
525 assert_eq!(
526 normalize_and_validate(
527 vec![(b"Transfer-Encoding".to_vec(), b"cHuNkEd".to_vec())],
528 false
529 )
530 .unwrap(),
531 Headers(vec![(
532 b"Transfer-Encoding".to_vec(),
533 b"transfer-encoding".to_vec(),
534 b"chunked".to_vec()
535 )])
536 );
537 assert_eq!(
538 normalize_and_validate(
539 vec![(b"Transfer-Encoding".to_vec(), b"gzip".to_vec())],
540 false
541 )
542 .expect_err("Expect ProtocolError::LocalProtocolError"),
543 ProtocolError::LocalProtocolError(
544 (
545 "Only Transfer-Encoding: chunked is supported".to_string(),
546 501
547 )
548 .into()
549 )
550 );
551 assert_eq!(
552 normalize_and_validate(
553 vec![
554 (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()),
555 (b"Transfer-Encoding".to_vec(), b"gzip".to_vec())
556 ],
557 false
558 )
559 .expect_err("Expect ProtocolError::LocalProtocolError"),
560 ProtocolError::LocalProtocolError(
561 ("multiple Transfer-Encoding headers".to_string(), 501).into()
562 )
563 );
564 }
565
566 #[test]
567 fn test_get_set_comma_header() {
568 let headers = normalize_and_validate(
569 vec![
570 (b"Connection".to_vec(), b"close".to_vec()),
571 (b"whatever".to_vec(), b"something".to_vec()),
572 (b"connectiON".to_vec(), b"fOo,, , BAR".to_vec()),
573 ],
574 false,
575 )
576 .unwrap();
577
578 assert_eq!(
579 get_comma_header(&headers, b"connection"),
580 vec![b"close".to_vec(), b"foo".to_vec(), b"bar".to_vec()]
581 );
582
583 let headers =
584 set_comma_header(&headers, b"newthing", vec![b"a".to_vec(), b"b".to_vec()]).unwrap();
585
586 assert_eq!(
587 headers,
588 Headers(vec![
589 (
590 b"connection".to_vec(),
591 b"connection".to_vec(),
592 b"close".to_vec()
593 ),
594 (
595 b"whatever".to_vec(),
596 b"whatever".to_vec(),
597 b"something".to_vec()
598 ),
599 (
600 b"connection".to_vec(),
601 b"connection".to_vec(),
602 b"fOo,, , BAR".to_vec()
603 ),
604 (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
605 (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
606 ])
607 );
608
609 let headers =
610 set_comma_header(&headers, b"whatever", vec![b"different thing".to_vec()]).unwrap();
611
612 assert_eq!(
613 headers,
614 Headers(vec![
615 (
616 b"connection".to_vec(),
617 b"connection".to_vec(),
618 b"close".to_vec()
619 ),
620 (
621 b"connection".to_vec(),
622 b"connection".to_vec(),
623 b"fOo,, , BAR".to_vec()
624 ),
625 (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
626 (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
627 (
628 b"whatever".to_vec(),
629 b"whatever".to_vec(),
630 b"different thing".to_vec()
631 ),
632 ])
633 );
634 }
635
636 #[test]
637 fn test_has_100_continue() {
638 assert!(has_expect_100_continue(&Request {
639 method: b"GET".to_vec(),
640 target: b"/".to_vec(),
641 headers: normalize_and_validate(
642 vec![
643 (b"Host".to_vec(), b"example.com".to_vec()),
644 (b"Expect".to_vec(), b"100-continue".to_vec())
645 ],
646 false
647 )
648 .unwrap(),
649 http_version: b"1.1".to_vec(),
650 }));
651 assert!(!has_expect_100_continue(&Request {
652 method: b"GET".to_vec(),
653 target: b"/".to_vec(),
654 headers: normalize_and_validate(
655 vec![(b"Host".to_vec(), b"example.com".to_vec())],
656 false
657 )
658 .unwrap(),
659 http_version: b"1.1".to_vec(),
660 }));
661 assert!(has_expect_100_continue(&Request {
663 method: b"GET".to_vec(),
664 target: b"/".to_vec(),
665 headers: normalize_and_validate(
666 vec![
667 (b"Host".to_vec(), b"example.com".to_vec()),
668 (b"Expect".to_vec(), b"100-Continue".to_vec())
669 ],
670 false
671 )
672 .unwrap(),
673 http_version: b"1.1".to_vec(),
674 }));
675 assert!(!has_expect_100_continue(&Request {
677 method: b"GET".to_vec(),
678 target: b"/".to_vec(),
679 headers: normalize_and_validate(
680 vec![
681 (b"Host".to_vec(), b"example.com".to_vec()),
682 (b"Expect".to_vec(), b"100-continue".to_vec())
683 ],
684 false
685 )
686 .unwrap(),
687 http_version: b"1.0".to_vec(),
688 }));
689 }
690}