1use std::collections::HashSet;
2
3use crate::{
4 _abnf::{FIELD_NAME, FIELD_VALUE},
5 _events::Request,
6 _util::ProtocolError,
7};
8use lazy_static::lazy_static;
9use regex::bytes::Regex;
10
11lazy_static! {
12 static ref CONTENT_LENGTH_RE: Regex = Regex::new(r"^[0-9]+$").unwrap();
13 static ref FIELD_NAME_RE: Regex = Regex::new(&format!(r"^{}$", FIELD_NAME)).unwrap();
14 static ref FIELD_VALUE_RE: Regex = Regex::new(&format!(r"^{}$", *FIELD_VALUE)).unwrap();
15}
16
17#[derive(Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)]
18pub struct Headers(Vec<(Vec<u8>, Vec<u8>, Vec<u8>)>);
19
20impl std::fmt::Debug for Headers {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 let mut debug_struct = f.debug_struct("Headers");
23 self.0.iter().for_each(|(raw_name, _, value)| {
24 debug_struct.field(
25 std::str::from_utf8(raw_name).unwrap(),
26 &std::str::from_utf8(value).unwrap(),
27 );
28 });
29 debug_struct.finish()
30 }
31}
32
33impl Headers {
34 pub fn iter(&self) -> impl Iterator<Item = (Vec<u8>, Vec<u8>)> + '_ {
35 self.0
36 .iter()
37 .map(|(_, name, value)| ((*name).clone(), (*value).clone()))
38 }
39
40 pub fn raw_items(&self) -> Vec<&(Vec<u8>, Vec<u8>, Vec<u8>)> {
41 self.0.iter().collect()
42 }
43
44 pub fn len(&self) -> usize {
45 self.0.len()
46 }
47}
48
49impl From<Vec<(Vec<u8>, Vec<u8>)>> for Headers {
50 fn from(value: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
51 normalize_and_validate(value, false).unwrap()
52 }
53}
54
55pub fn normalize_and_validate(
56 headers: Vec<(Vec<u8>, Vec<u8>)>,
57 _parsed: bool,
58) -> Result<Headers, ProtocolError> {
59 let mut new_headers = vec![];
60 let mut seen_content_length = None;
61 let mut saw_transfer_encoding = false;
62 for (name, value) in headers {
63 if !_parsed {
64 if !FIELD_NAME_RE.is_match(&name) {
65 return Err(ProtocolError::LocalProtocolError(
66 format!("Illegal header name {:?}", &name).into(),
67 ));
68 }
69 if !FIELD_VALUE_RE.is_match(&value) {
70 return Err(ProtocolError::LocalProtocolError(
71 format!("Illegal header value {:?}", &value).into(),
72 ));
73 }
74 }
75 let raw_name = name.clone();
76 let name = name.to_ascii_lowercase();
77 if name == b"content-length" {
78 let lengths: HashSet<Vec<u8>> = value
79 .split(|&b| b == b',')
80 .map(|length| {
81 std::str::from_utf8(length)
82 .unwrap()
83 .trim()
84 .as_bytes()
85 .to_vec()
86 })
87 .collect();
88 if lengths.len() != 1 {
89 return Err(ProtocolError::LocalProtocolError(
90 "conflicting Content-Length headers".into(),
91 ));
92 }
93 let value = lengths.iter().next().unwrap();
94 if !CONTENT_LENGTH_RE.is_match(value) {
95 return Err(ProtocolError::LocalProtocolError(
96 "bad Content-Length".into(),
97 ));
98 }
99 if seen_content_length.is_none() {
100 seen_content_length = Some(value.clone());
101 new_headers.push((raw_name, name, value.clone()));
102 } else if seen_content_length != Some(value.clone()) {
103 return Err(ProtocolError::LocalProtocolError(
104 "conflicting Content-Length headers".into(),
105 ));
106 }
107 } else if name == b"transfer-encoding" {
108 if saw_transfer_encoding {
113 return Err(ProtocolError::LocalProtocolError(
114 ("multiple Transfer-Encoding headers", 501).into(),
115 ));
116 }
117 let value = value.to_ascii_lowercase();
120 if value != b"chunked" {
121 return Err(ProtocolError::LocalProtocolError(
122 ("Only Transfer-Encoding: chunked is supported", 501).into(),
123 ));
124 }
125 saw_transfer_encoding = true;
126 new_headers.push((raw_name, name, value));
127 } else {
128 new_headers.push((raw_name, name, value.to_vec()));
129 }
130 }
131
132 Ok(Headers(new_headers))
133}
134
135pub fn get_comma_header(headers: &Headers, name: &[u8]) -> Vec<Vec<u8>> {
136 let mut out: Vec<Vec<u8>> = vec![];
137 let name = name.to_ascii_lowercase();
138 for (found_name, found_value) in headers.iter() {
139 if found_name == name {
140 for found_split_value in found_value.to_ascii_lowercase().split(|&b| b == b',') {
141 let found_split_value = std::str::from_utf8(found_split_value).unwrap().trim();
142 if !found_split_value.is_empty() {
143 out.push(found_split_value.as_bytes().to_vec());
144 }
145 }
146 }
147 }
148 out
149}
150
151pub fn set_comma_header(
152 headers: &Headers,
153 name: &[u8],
154 new_values: Vec<Vec<u8>>,
155) -> Result<Headers, ProtocolError> {
156 let mut new_headers = vec![];
157 for (found_name, found_value) in headers.iter() {
158 if found_name != name {
159 new_headers.push((found_name, found_value));
160 }
161 }
162 for new_value in new_values {
163 new_headers.push((name.to_vec(), new_value));
164 }
165 normalize_and_validate(new_headers, false)
166}
167
168pub fn has_expect_100_continue(request: &Request) -> bool {
169 if request.http_version < b"1.1".to_vec() {
173 return false;
174 }
175 let expect = get_comma_header(&request.headers, b"expect");
176 expect.contains(&b"100-continue".to_vec())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_normalize_and_validate() {
185 assert_eq!(
186 normalize_and_validate(vec![(b"foo".to_vec(), b"bar".to_vec())], false).unwrap(),
187 Headers(vec![(b"foo".to_vec(), b"foo".to_vec(), b"bar".to_vec())])
188 );
189
190 assert_eq!(
192 normalize_and_validate(vec![(b"foo ".to_vec(), b"bar".to_vec())], false)
193 .expect_err("Expect ProtocolError::LocalProtocolError"),
194 ProtocolError::LocalProtocolError(
195 ("Illegal header name [102, 111, 111, 32]".to_string(), 400).into()
196 )
197 );
198 assert_eq!(
199 normalize_and_validate(vec![(b" foo".to_vec(), b"bar".to_vec())], false)
200 .expect_err("Expect ProtocolError::LocalProtocolError"),
201 ProtocolError::LocalProtocolError(
202 ("Illegal header name [32, 102, 111, 111]".to_string(), 400).into()
203 )
204 );
205
206 assert_eq!(
208 normalize_and_validate(vec![(b"foo bar".to_vec(), b"baz".to_vec())], false)
209 .expect_err("Expect ProtocolError::LocalProtocolError"),
210 ProtocolError::LocalProtocolError(
211 (
212 "Illegal header name [102, 111, 111, 32, 98, 97, 114]".to_string(),
213 400
214 )
215 .into()
216 )
217 );
218 assert_eq!(
219 normalize_and_validate(vec![(b"foo\x00bar".to_vec(), b"baz".to_vec())], false)
220 .expect_err("Expect ProtocolError::LocalProtocolError"),
221 ProtocolError::LocalProtocolError(
222 (
223 "Illegal header name [102, 111, 111, 0, 98, 97, 114]".to_string(),
224 400
225 )
226 .into()
227 )
228 );
229 assert_eq!(
231 normalize_and_validate(vec![(b"foo\xffbar".to_vec(), b"baz".to_vec())], false)
232 .expect_err("Expect ProtocolError::LocalProtocolError"),
233 ProtocolError::LocalProtocolError(
234 (
235 "Illegal header name [102, 111, 111, 255, 98, 97, 114]".to_string(),
236 400
237 )
238 .into()
239 )
240 );
241 assert_eq!(
243 normalize_and_validate(vec![(b"foo\x01bar".to_vec(), b"baz".to_vec())], false)
244 .expect_err("Expect ProtocolError::LocalProtocolError"),
245 ProtocolError::LocalProtocolError(
246 (
247 "Illegal header name [102, 111, 111, 1, 98, 97, 114]".to_string(),
248 400
249 )
250 .into()
251 )
252 );
253
254 assert_eq!(
256 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\rbaz".to_vec())], false)
257 .expect_err("Expect ProtocolError::LocalProtocolError"),
258 ProtocolError::LocalProtocolError(
259 (
260 "Illegal header value [98, 97, 114, 13, 98, 97, 122]".to_string(),
261 400
262 )
263 .into()
264 )
265 );
266 assert_eq!(
267 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\nbaz".to_vec())], false)
268 .expect_err("Expect ProtocolError::LocalProtocolError"),
269 ProtocolError::LocalProtocolError(
270 (
271 "Illegal header value [98, 97, 114, 10, 98, 97, 122]".to_string(),
272 400
273 )
274 .into()
275 )
276 );
277 assert_eq!(
278 normalize_and_validate(vec![(b"foo".to_vec(), b"bar\x00baz".to_vec())], false)
279 .expect_err("Expect ProtocolError::LocalProtocolError"),
280 ProtocolError::LocalProtocolError(
281 (
282 "Illegal header value [98, 97, 114, 0, 98, 97, 122]".to_string(),
283 400
284 )
285 .into()
286 )
287 );
288 assert_eq!(
290 normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz ".to_vec())], false)
291 .expect_err("Expect ProtocolError::LocalProtocolError"),
292 ProtocolError::LocalProtocolError(
293 (
294 "Illegal header value [98, 97, 114, 98, 97, 122, 32, 32]".to_string(),
295 400
296 )
297 .into()
298 )
299 );
300 assert_eq!(
301 normalize_and_validate(vec![(b"foo".to_vec(), b" barbaz".to_vec())], false)
302 .expect_err("Expect ProtocolError::LocalProtocolError"),
303 ProtocolError::LocalProtocolError(
304 (
305 "Illegal header value [32, 32, 98, 97, 114, 98, 97, 122]".to_string(),
306 400
307 )
308 .into()
309 )
310 );
311 assert_eq!(
312 normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz\t".to_vec())], false)
313 .expect_err("Expect ProtocolError::LocalProtocolError"),
314 ProtocolError::LocalProtocolError(
315 (
316 "Illegal header value [98, 97, 114, 98, 97, 122, 9]".to_string(),
317 400
318 )
319 .into()
320 )
321 );
322 assert_eq!(
323 normalize_and_validate(vec![(b"foo".to_vec(), b"\tbarbaz".to_vec())], false)
324 .expect_err("Expect ProtocolError::LocalProtocolError"),
325 ProtocolError::LocalProtocolError(
326 (
327 "Illegal header value [9, 98, 97, 114, 98, 97, 122]".to_string(),
328 400
329 )
330 .into()
331 )
332 );
333
334 assert_eq!(
336 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1".to_vec())], false)
337 .unwrap(),
338 Headers(vec![(
339 b"Content-Length".to_vec(),
340 b"content-length".to_vec(),
341 b"1".to_vec()
342 )])
343 );
344 assert_eq!(
345 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"asdf".to_vec())], false)
346 .expect_err("Expect ProtocolError::LocalProtocolError"),
347 ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
348 );
349 assert_eq!(
350 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1x".to_vec())], false)
351 .expect_err("Expect ProtocolError::LocalProtocolError"),
352 ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
353 );
354 assert_eq!(
355 normalize_and_validate(
356 vec![
357 (b"Content-Length".to_vec(), b"1".to_vec()),
358 (b"Content-Length".to_vec(), b"2".to_vec())
359 ],
360 false
361 )
362 .expect_err("Expect ProtocolError::LocalProtocolError"),
363 ProtocolError::LocalProtocolError(
364 ("conflicting Content-Length headers".to_string(), 400).into()
365 )
366 );
367 assert_eq!(
368 normalize_and_validate(
369 vec![
370 (b"Content-Length".to_vec(), b"0".to_vec()),
371 (b"Content-Length".to_vec(), b"0".to_vec())
372 ],
373 false
374 )
375 .unwrap(),
376 Headers(vec![(
377 b"Content-Length".to_vec(),
378 b"content-length".to_vec(),
379 b"0".to_vec()
380 )])
381 );
382 assert_eq!(
383 normalize_and_validate(vec![(b"Content-Length".to_vec(), b"0 , 0".to_vec())], false)
384 .unwrap(),
385 Headers(vec![(
386 b"Content-Length".to_vec(),
387 b"content-length".to_vec(),
388 b"0".to_vec()
389 )])
390 );
391 assert_eq!(
392 normalize_and_validate(
393 vec![
394 (b"Content-Length".to_vec(), b"1".to_vec()),
395 (b"Content-Length".to_vec(), b"1".to_vec()),
396 (b"Content-Length".to_vec(), b"2".to_vec())
397 ],
398 false
399 )
400 .expect_err("Expect ProtocolError::LocalProtocolError"),
401 ProtocolError::LocalProtocolError(
402 ("conflicting Content-Length headers".to_string(), 400).into()
403 )
404 );
405 assert_eq!(
406 normalize_and_validate(
407 vec![(b"Content-Length".to_vec(), b"1 , 1,2".to_vec())],
408 false
409 )
410 .expect_err("Expect ProtocolError::LocalProtocolError"),
411 ProtocolError::LocalProtocolError(
412 ("conflicting Content-Length headers".to_string(), 400).into()
413 )
414 );
415
416 assert_eq!(
418 normalize_and_validate(
419 vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())],
420 false
421 )
422 .unwrap(),
423 Headers(vec![(
424 b"Transfer-Encoding".to_vec(),
425 b"transfer-encoding".to_vec(),
426 b"chunked".to_vec()
427 )])
428 );
429 assert_eq!(
430 normalize_and_validate(
431 vec![(b"Transfer-Encoding".to_vec(), b"cHuNkEd".to_vec())],
432 false
433 )
434 .unwrap(),
435 Headers(vec![(
436 b"Transfer-Encoding".to_vec(),
437 b"transfer-encoding".to_vec(),
438 b"chunked".to_vec()
439 )])
440 );
441 assert_eq!(
442 normalize_and_validate(
443 vec![(b"Transfer-Encoding".to_vec(), b"gzip".to_vec())],
444 false
445 )
446 .expect_err("Expect ProtocolError::LocalProtocolError"),
447 ProtocolError::LocalProtocolError(
448 (
449 "Only Transfer-Encoding: chunked is supported".to_string(),
450 501
451 )
452 .into()
453 )
454 );
455 assert_eq!(
456 normalize_and_validate(
457 vec![
458 (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()),
459 (b"Transfer-Encoding".to_vec(), b"gzip".to_vec())
460 ],
461 false
462 )
463 .expect_err("Expect ProtocolError::LocalProtocolError"),
464 ProtocolError::LocalProtocolError(
465 ("multiple Transfer-Encoding headers".to_string(), 501).into()
466 )
467 );
468 }
469
470 #[test]
471 fn test_get_set_comma_header() {
472 let headers = normalize_and_validate(
473 vec![
474 (b"Connection".to_vec(), b"close".to_vec()),
475 (b"whatever".to_vec(), b"something".to_vec()),
476 (b"connectiON".to_vec(), b"fOo,, , BAR".to_vec()),
477 ],
478 false,
479 )
480 .unwrap();
481
482 assert_eq!(
483 get_comma_header(&headers, b"connection"),
484 vec![b"close".to_vec(), b"foo".to_vec(), b"bar".to_vec()]
485 );
486
487 let headers =
488 set_comma_header(&headers, b"newthing", vec![b"a".to_vec(), b"b".to_vec()]).unwrap();
489
490 assert_eq!(
491 headers,
492 Headers(vec![
493 (
494 b"connection".to_vec(),
495 b"connection".to_vec(),
496 b"close".to_vec()
497 ),
498 (
499 b"whatever".to_vec(),
500 b"whatever".to_vec(),
501 b"something".to_vec()
502 ),
503 (
504 b"connection".to_vec(),
505 b"connection".to_vec(),
506 b"fOo,, , BAR".to_vec()
507 ),
508 (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
509 (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
510 ])
511 );
512
513 let headers =
514 set_comma_header(&headers, b"whatever", vec![b"different thing".to_vec()]).unwrap();
515
516 assert_eq!(
517 headers,
518 Headers(vec![
519 (
520 b"connection".to_vec(),
521 b"connection".to_vec(),
522 b"close".to_vec()
523 ),
524 (
525 b"connection".to_vec(),
526 b"connection".to_vec(),
527 b"fOo,, , BAR".to_vec()
528 ),
529 (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
530 (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
531 (
532 b"whatever".to_vec(),
533 b"whatever".to_vec(),
534 b"different thing".to_vec()
535 ),
536 ])
537 );
538 }
539
540 #[test]
541 fn test_has_100_continue() {
542 assert!(has_expect_100_continue(&Request {
543 method: b"GET".to_vec(),
544 target: b"/".to_vec(),
545 headers: normalize_and_validate(
546 vec![
547 (b"Host".to_vec(), b"example.com".to_vec()),
548 (b"Expect".to_vec(), b"100-continue".to_vec())
549 ],
550 false
551 )
552 .unwrap(),
553 http_version: b"1.1".to_vec(),
554 }));
555 assert!(!has_expect_100_continue(&Request {
556 method: b"GET".to_vec(),
557 target: b"/".to_vec(),
558 headers: normalize_and_validate(
559 vec![(b"Host".to_vec(), b"example.com".to_vec())],
560 false
561 )
562 .unwrap(),
563 http_version: b"1.1".to_vec(),
564 }));
565 assert!(has_expect_100_continue(&Request {
567 method: b"GET".to_vec(),
568 target: b"/".to_vec(),
569 headers: normalize_and_validate(
570 vec![
571 (b"Host".to_vec(), b"example.com".to_vec()),
572 (b"Expect".to_vec(), b"100-Continue".to_vec())
573 ],
574 false
575 )
576 .unwrap(),
577 http_version: b"1.1".to_vec(),
578 }));
579 assert!(!has_expect_100_continue(&Request {
581 method: b"GET".to_vec(),
582 target: b"/".to_vec(),
583 headers: normalize_and_validate(
584 vec![
585 (b"Host".to_vec(), b"example.com".to_vec()),
586 (b"Expect".to_vec(), b"100-continue".to_vec())
587 ],
588 false
589 )
590 .unwrap(),
591 http_version: b"1.0".to_vec(),
592 }));
593 }
594}