1#![allow(clippy::new_without_default)]
24
25use bytes::BufMut;
26use http::header::{AsHeaderName, HeaderName, HeaderValue};
27use http::request::Builder as ReqBuilder;
28use http::request::Parts as ReqParts;
29use http::response::Builder as RespBuilder;
30use http::response::Parts as RespParts;
31use http::uri::Uri;
32use pingora_error::{ErrorType::*, OrErr, Result};
33use std::ops::Deref;
34
35pub use http::method::Method;
36pub use http::status::StatusCode;
37pub use http::version::Version;
38pub use http::HeaderMap as HMap;
39
40mod case_header_name;
41use case_header_name::CaseHeaderName;
42pub use case_header_name::IntoCaseHeaderName;
43
44pub mod prelude {
45 pub use crate::RequestHeader;
46}
47
48type CaseMap = HMap<CaseHeaderName>;
58
59#[derive(Debug)]
67pub struct RequestHeader {
68 base: ReqParts,
69 header_name_map: Option<CaseMap>,
70 raw_path_fallback: Vec<u8>, send_end_stream: bool,
74}
75
76impl AsRef<ReqParts> for RequestHeader {
77 fn as_ref(&self) -> &ReqParts {
78 &self.base
79 }
80}
81
82impl Deref for RequestHeader {
83 type Target = ReqParts;
84
85 fn deref(&self) -> &Self::Target {
86 &self.base
87 }
88}
89
90impl RequestHeader {
91 fn new_no_case(size_hint: Option<usize>) -> Self {
92 let mut base = ReqBuilder::new().body(()).unwrap().into_parts().0;
93 base.headers.reserve(http_header_map_upper_bound(size_hint));
94 RequestHeader {
95 base,
96 header_name_map: None,
97 raw_path_fallback: vec![],
98 send_end_stream: true,
99 }
100 }
101
102 pub fn build(
106 method: impl TryInto<Method>,
107 path: &[u8],
108 size_hint: Option<usize>,
109 ) -> Result<Self> {
110 let mut req = Self::build_no_case(method, path, size_hint)?;
111 req.header_name_map = Some(CaseMap::with_capacity(http_header_map_upper_bound(
112 size_hint,
113 )));
114 Ok(req)
115 }
116
117 pub fn build_no_case(
123 method: impl TryInto<Method>,
124 path: &[u8],
125 size_hint: Option<usize>,
126 ) -> Result<Self> {
127 let mut req = Self::new_no_case(size_hint);
128 req.base.method = method
129 .try_into()
130 .explain_err(InvalidHTTPHeader, |_| "invalid method")?;
131 req.set_raw_path(path)?;
132 Ok(req)
133 }
134
135 pub fn append_header(
140 &mut self,
141 name: impl IntoCaseHeaderName,
142 value: impl TryInto<HeaderValue>,
143 ) -> Result<bool> {
144 let header_value = value
145 .try_into()
146 .explain_err(InvalidHTTPHeader, |_| "invalid value while append")?;
147 append_header_value(
148 self.header_name_map.as_mut(),
149 &mut self.base.headers,
150 name,
151 header_value,
152 )
153 }
154
155 pub fn insert_header(
160 &mut self,
161 name: impl IntoCaseHeaderName,
162 value: impl TryInto<HeaderValue>,
163 ) -> Result<()> {
164 let header_value = value
165 .try_into()
166 .explain_err(InvalidHTTPHeader, |_| "invalid value while insert")?;
167 insert_header_value(
168 self.header_name_map.as_mut(),
169 &mut self.base.headers,
170 name,
171 header_value,
172 )
173 }
174
175 pub fn remove_header<'a, N: ?Sized>(&mut self, name: &'a N) -> Option<HeaderValue>
177 where
178 &'a N: 'a + AsHeaderName,
179 {
180 remove_header(self.header_name_map.as_mut(), &mut self.base.headers, name)
181 }
182
183 pub fn header_to_h1_wire(&self, buf: &mut impl BufMut) {
187 header_to_h1_wire(self.header_name_map.as_ref(), &self.base.headers, buf)
188 }
189
190 pub fn set_method(&mut self, method: Method) {
192 self.base.method = method;
193 }
194
195 pub fn set_uri(&mut self, uri: http::Uri) {
197 self.base.uri = uri;
198 self.raw_path_fallback = vec![];
200 }
201
202 pub fn set_raw_path(&mut self, path: &[u8]) -> Result<()> {
208 if let Ok(p) = std::str::from_utf8(path) {
209 let uri = Uri::builder()
210 .path_and_query(p)
211 .build()
212 .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p))?;
213 self.base.uri = uri;
214 } else {
216 let lossy_str = String::from_utf8_lossy(path);
218 let uri = Uri::builder()
219 .path_and_query(lossy_str.as_ref())
220 .build()
221 .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", lossy_str))?;
222 self.base.uri = uri;
223 self.raw_path_fallback = path.to_vec();
224 }
225 Ok(())
226 }
227
228 pub fn set_send_end_stream(&mut self, send_end_stream: bool) {
230 self.send_end_stream = send_end_stream;
231 }
232
233 pub fn send_end_stream(&self) -> Option<bool> {
236 if self.base.version != Version::HTTP_2 {
237 return None;
238 }
239 Some(self.send_end_stream)
240 }
241
242 pub fn raw_path(&self) -> &[u8] {
246 if !self.raw_path_fallback.is_empty() {
247 &self.raw_path_fallback
248 } else {
249 self.base
251 .uri
252 .path_and_query()
253 .as_ref()
254 .unwrap()
255 .as_str()
256 .as_bytes()
257 }
258 }
259
260 pub fn uri_file_extension(&self) -> Option<&str> {
262 let (_, ext) = self
264 .uri
265 .path_and_query()
266 .and_then(|pq| pq.path().rsplit_once('.'))?;
267 Some(ext)
268 }
269
270 pub fn set_version(&mut self, version: Version) {
272 self.base.version = version;
273 }
274
275 pub fn as_owned_parts(&self) -> ReqParts {
277 clone_req_parts(&self.base)
278 }
279}
280
281impl Clone for RequestHeader {
282 fn clone(&self) -> Self {
283 Self {
284 base: self.as_owned_parts(),
285 header_name_map: self.header_name_map.clone(),
286 raw_path_fallback: self.raw_path_fallback.clone(),
287 send_end_stream: self.send_end_stream,
288 }
289 }
290}
291
292impl From<ReqParts> for RequestHeader {
294 fn from(parts: ReqParts) -> RequestHeader {
295 Self {
296 base: parts,
297 header_name_map: None,
298 raw_path_fallback: vec![],
300 send_end_stream: true,
301 }
302 }
303}
304
305impl From<RequestHeader> for ReqParts {
306 fn from(resp: RequestHeader) -> ReqParts {
307 resp.base
308 }
309}
310
311#[derive(Debug)]
317pub struct ResponseHeader {
318 base: RespParts,
319 header_name_map: Option<CaseMap>,
321 reason_phrase: Option<String>,
323}
324
325impl AsRef<RespParts> for ResponseHeader {
326 fn as_ref(&self) -> &RespParts {
327 &self.base
328 }
329}
330
331impl Deref for ResponseHeader {
332 type Target = RespParts;
333
334 fn deref(&self) -> &Self::Target {
335 &self.base
336 }
337}
338
339impl Clone for ResponseHeader {
340 fn clone(&self) -> Self {
341 Self {
342 base: self.as_owned_parts(),
343 header_name_map: self.header_name_map.clone(),
344 reason_phrase: self.reason_phrase.clone(),
345 }
346 }
347}
348
349impl From<RespParts> for ResponseHeader {
351 fn from(parts: RespParts) -> ResponseHeader {
352 Self {
353 base: parts,
354 header_name_map: None,
355 reason_phrase: None,
356 }
357 }
358}
359
360impl From<ResponseHeader> for RespParts {
361 fn from(resp: ResponseHeader) -> RespParts {
362 resp.base
363 }
364}
365
366impl From<Box<ResponseHeader>> for Box<RespParts> {
367 fn from(resp: Box<ResponseHeader>) -> Box<RespParts> {
368 Box::new(resp.base)
369 }
370}
371
372impl ResponseHeader {
373 fn new(size_hint: Option<usize>) -> Self {
374 let mut resp_header = Self::new_no_case(size_hint);
375 resp_header.header_name_map = Some(CaseMap::with_capacity(http_header_map_upper_bound(
376 size_hint,
377 )));
378 resp_header
379 }
380
381 fn new_no_case(size_hint: Option<usize>) -> Self {
382 let mut base = RespBuilder::new().body(()).unwrap().into_parts().0;
383 base.headers.reserve(http_header_map_upper_bound(size_hint));
384 ResponseHeader {
385 base,
386 header_name_map: None,
387 reason_phrase: None,
388 }
389 }
390
391 pub fn build(code: impl TryInto<StatusCode>, size_hint: Option<usize>) -> Result<Self> {
393 let mut resp = Self::new(size_hint);
394 resp.base.status = code
395 .try_into()
396 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
397 Ok(resp)
398 }
399
400 pub fn build_no_case(code: impl TryInto<StatusCode>, size_hint: Option<usize>) -> Result<Self> {
406 let mut resp = Self::new_no_case(size_hint);
407 resp.base.status = code
408 .try_into()
409 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
410 Ok(resp)
411 }
412
413 pub fn append_header(
418 &mut self,
419 name: impl IntoCaseHeaderName,
420 value: impl TryInto<HeaderValue>,
421 ) -> Result<bool> {
422 let header_value = value
423 .try_into()
424 .explain_err(InvalidHTTPHeader, |_| "invalid value while append")?;
425 append_header_value(
426 self.header_name_map.as_mut(),
427 &mut self.base.headers,
428 name,
429 header_value,
430 )
431 }
432
433 pub fn insert_header(
438 &mut self,
439 name: impl IntoCaseHeaderName,
440 value: impl TryInto<HeaderValue>,
441 ) -> Result<()> {
442 let header_value = value
443 .try_into()
444 .explain_err(InvalidHTTPHeader, |_| "invalid value while insert")?;
445 insert_header_value(
446 self.header_name_map.as_mut(),
447 &mut self.base.headers,
448 name,
449 header_value,
450 )
451 }
452
453 pub fn remove_header<'a, N: ?Sized>(&mut self, name: &'a N) -> Option<HeaderValue>
455 where
456 &'a N: 'a + AsHeaderName,
457 {
458 remove_header(self.header_name_map.as_mut(), &mut self.base.headers, name)
459 }
460
461 pub fn header_to_h1_wire(&self, buf: &mut impl BufMut) {
465 header_to_h1_wire(self.header_name_map.as_ref(), &self.base.headers, buf)
466 }
467
468 pub fn set_status(&mut self, status: impl TryInto<StatusCode>) -> Result<()> {
470 self.base.status = status
471 .try_into()
472 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
473 Ok(())
474 }
475
476 pub fn set_version(&mut self, version: Version) {
478 self.base.version = version
479 }
480
481 pub fn set_reason_phrase(&mut self, reason_phrase: Option<&str>) -> Result<()> {
483 if reason_phrase == self.base.status.canonical_reason() {
485 self.reason_phrase = None;
486 return Ok(());
487 }
488
489 self.reason_phrase = reason_phrase.map(str::to_string);
491 Ok(())
492 }
493
494 pub fn get_reason_phrase(&self) -> Option<&str> {
497 self.reason_phrase
498 .as_deref()
499 .or_else(|| self.base.status.canonical_reason())
500 }
501
502 pub fn as_owned_parts(&self) -> RespParts {
504 clone_resp_parts(&self.base)
505 }
506
507 pub fn set_content_length(&mut self, len: usize) -> Result<()> {
509 self.insert_header(http::header::CONTENT_LENGTH, len)
510 }
511}
512
513fn clone_req_parts(me: &ReqParts) -> ReqParts {
514 let mut parts = ReqBuilder::new()
515 .method(me.method.clone())
516 .uri(me.uri.clone())
517 .version(me.version)
518 .body(())
519 .unwrap()
520 .into_parts()
521 .0;
522 parts.headers = me.headers.clone();
523 parts
524}
525
526fn clone_resp_parts(me: &RespParts) -> RespParts {
527 let mut parts = RespBuilder::new()
528 .status(me.status)
529 .version(me.version)
530 .body(())
531 .unwrap()
532 .into_parts()
533 .0;
534 parts.headers = me.headers.clone();
535 parts
536}
537
538fn http_header_map_upper_bound(size_hint: Option<usize>) -> usize {
543 const PINGORA_MAX_HEADER_COUNT: usize = 4096;
551 const INIT_HEADER_SIZE: usize = 8;
552
553 std::cmp::min(
556 size_hint.unwrap_or(INIT_HEADER_SIZE),
557 PINGORA_MAX_HEADER_COUNT,
558 )
559}
560
561#[inline]
562fn append_header_value<T>(
563 name_map: Option<&mut CaseMap>,
564 value_map: &mut HMap<T>,
565 name: impl IntoCaseHeaderName,
566 value: T,
567) -> Result<bool> {
568 let case_header_name = name.into_case_header_name();
569 let header_name: HeaderName = case_header_name
570 .as_slice()
571 .try_into()
572 .or_err(InvalidHTTPHeader, "invalid header name")?;
573 if let Some(name_map) = name_map {
575 name_map.append(header_name.clone(), case_header_name);
576 }
577
578 Ok(value_map.append(header_name, value))
579}
580
581#[inline]
582fn insert_header_value<T>(
583 name_map: Option<&mut CaseMap>,
584 value_map: &mut HMap<T>,
585 name: impl IntoCaseHeaderName,
586 value: T,
587) -> Result<()> {
588 let case_header_name = name.into_case_header_name();
589 let header_name: HeaderName = case_header_name
590 .as_slice()
591 .try_into()
592 .or_err(InvalidHTTPHeader, "invalid header name")?;
593 if let Some(name_map) = name_map {
594 name_map.insert(header_name.clone(), case_header_name);
596 }
597 value_map.insert(header_name, value);
598 Ok(())
599}
600
601#[inline]
603fn remove_header<'a, T, N: ?Sized>(
604 name_map: Option<&mut CaseMap>,
605 value_map: &mut HMap<T>,
606 name: &'a N,
607) -> Option<T>
608where
609 &'a N: 'a + AsHeaderName,
610{
611 let removed = value_map.remove(name);
612 if removed.is_some() {
613 if let Some(name_map) = name_map {
614 name_map.remove(name);
615 }
616 }
617 removed
618}
619
620#[inline]
621fn header_to_h1_wire(key_map: Option<&CaseMap>, value_map: &HMap, buf: &mut impl BufMut) {
622 const CRLF: &[u8; 2] = b"\r\n";
623 const HEADER_KV_DELIMITER: &[u8; 2] = b": ";
624
625 if let Some(key_map) = key_map {
626 let iter = key_map.iter().zip(value_map.iter());
627 for ((header, case_header), (header2, val)) in iter {
628 if header != header2 {
629 panic!("header iter mismatch {}, {}", header, header2)
631 }
632 buf.put_slice(case_header.as_slice());
633 buf.put_slice(HEADER_KV_DELIMITER);
634 buf.put_slice(val.as_ref());
635 buf.put_slice(CRLF);
636 }
637 } else {
638 for (header, value) in value_map {
639 let titled_header =
640 case_header_name::titled_header_name_str(header).unwrap_or(header.as_str());
641 buf.put_slice(titled_header.as_bytes());
642 buf.put_slice(HEADER_KV_DELIMITER);
643 buf.put_slice(value.as_ref());
644 buf.put_slice(CRLF);
645 }
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn header_map_upper_bound() {
655 assert_eq!(8, http_header_map_upper_bound(None));
656 assert_eq!(16, http_header_map_upper_bound(Some(16)));
657 assert_eq!(4096, http_header_map_upper_bound(Some(7777)));
658 }
659
660 #[test]
661 fn test_single_header() {
662 let mut req = RequestHeader::build("GET", b"\\", None).unwrap();
663 req.insert_header("foo", "bar").unwrap();
664 req.insert_header("FoO", "Bar").unwrap();
665 let mut buf: Vec<u8> = vec![];
666 req.header_to_h1_wire(&mut buf);
667 assert_eq!(buf, b"FoO: Bar\r\n");
668
669 let mut resp = ResponseHeader::new(None);
670 resp.insert_header("foo", "bar").unwrap();
671 resp.insert_header("FoO", "Bar").unwrap();
672 let mut buf: Vec<u8> = vec![];
673 resp.header_to_h1_wire(&mut buf);
674 assert_eq!(buf, b"FoO: Bar\r\n");
675 }
676
677 #[test]
678 fn test_single_header_no_case() {
679 let mut req = RequestHeader::new_no_case(None);
680 req.insert_header("foo", "bar").unwrap();
681 req.insert_header("FoO", "Bar").unwrap();
682 let mut buf: Vec<u8> = vec![];
683 req.header_to_h1_wire(&mut buf);
684 assert_eq!(buf, b"foo: Bar\r\n");
685
686 let mut resp = ResponseHeader::new_no_case(None);
687 resp.insert_header("foo", "bar").unwrap();
688 resp.insert_header("FoO", "Bar").unwrap();
689 let mut buf: Vec<u8> = vec![];
690 resp.header_to_h1_wire(&mut buf);
691 assert_eq!(buf, b"foo: Bar\r\n");
692 }
693
694 #[test]
695 fn test_multiple_header() {
696 let mut req = RequestHeader::build("GET", b"\\", None).unwrap();
697 req.append_header("FoO", "Bar").unwrap();
698 req.append_header("fOO", "bar").unwrap();
699 req.append_header("BAZ", "baR").unwrap();
700 req.append_header(http::header::CONTENT_LENGTH, "0")
701 .unwrap();
702 req.append_header("a", "b").unwrap();
703 req.remove_header("a");
704 let mut buf: Vec<u8> = vec![];
705 req.header_to_h1_wire(&mut buf);
706 assert_eq!(
707 buf,
708 b"FoO: Bar\r\nfOO: bar\r\nBAZ: baR\r\nContent-Length: 0\r\n"
709 );
710
711 let mut resp = ResponseHeader::new(None);
712 resp.append_header("FoO", "Bar").unwrap();
713 resp.append_header("fOO", "bar").unwrap();
714 resp.append_header("BAZ", "baR").unwrap();
715 resp.append_header(http::header::CONTENT_LENGTH, "0")
716 .unwrap();
717 resp.append_header("a", "b").unwrap();
718 resp.remove_header("a");
719 let mut buf: Vec<u8> = vec![];
720 resp.header_to_h1_wire(&mut buf);
721 assert_eq!(
722 buf,
723 b"FoO: Bar\r\nfOO: bar\r\nBAZ: baR\r\nContent-Length: 0\r\n"
724 );
725 }
726
727 #[cfg(feature = "patched_http1")]
728 #[test]
729 fn test_invalid_path() {
730 let raw_path = b"Hello\xF0\x90\x80World";
731 let req = RequestHeader::build("GET", &raw_path[..], None).unwrap();
732 assert_eq!("Hello�World", req.uri.path_and_query().unwrap());
733 assert_eq!(raw_path, req.raw_path());
734 }
735
736 #[cfg(feature = "patched_http1")]
737 #[test]
738 fn test_override_invalid_path() {
739 let raw_path = b"Hello\xF0\x90\x80World";
740 let mut req = RequestHeader::build("GET", &raw_path[..], None).unwrap();
741 assert_eq!("Hello�World", req.uri.path_and_query().unwrap());
742 assert_eq!(raw_path, req.raw_path());
743
744 let new_path = "/HelloWorld";
745 req.set_uri(Uri::builder().path_and_query(new_path).build().unwrap());
746 assert_eq!(new_path, req.uri.path_and_query().unwrap());
747 assert_eq!(new_path.as_bytes(), req.raw_path());
748 }
749
750 #[test]
751 fn test_reason_phrase() {
752 let mut resp = ResponseHeader::new(None);
753 let reason = resp.get_reason_phrase().unwrap();
754 assert_eq!(reason, "OK");
755
756 resp.set_reason_phrase(Some("FooBar")).unwrap();
757 let reason = resp.get_reason_phrase().unwrap();
758 assert_eq!(reason, "FooBar");
759
760 resp.set_reason_phrase(Some("OK")).unwrap();
761 let reason = resp.get_reason_phrase().unwrap();
762 assert_eq!(reason, "OK");
763
764 resp.set_reason_phrase(None).unwrap();
765 let reason = resp.get_reason_phrase().unwrap();
766 assert_eq!(reason, "OK");
767 }
768
769 #[test]
770 fn set_test_send_end_stream() {
771 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
772 req.set_send_end_stream(true);
773
774 assert!(req.send_end_stream().is_none());
776
777 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
778 req.set_version(Version::HTTP_2);
779
780 assert!(req.send_end_stream().unwrap());
782
783 req.set_send_end_stream(false);
784 assert!(!req.send_end_stream().unwrap());
786 }
787
788 #[test]
789 fn set_test_set_content_length() {
790 let mut resp = ResponseHeader::new(None);
791 resp.set_content_length(10).unwrap();
792
793 assert_eq!(
794 b"10",
795 resp.headers
796 .get(http::header::CONTENT_LENGTH)
797 .map(|d| d.as_bytes())
798 .unwrap()
799 );
800 }
801}