1#![allow(clippy::new_without_default)]
24
25use bytes::BufMut;
26use http::header::{AsHeaderName, HeaderName, HeaderValue};
27use http::request::Builder as ReqBuilder;
28use http::request::Parts as ReqParts;
29use http::response::Builder as RespBuilder;
30use http::response::Parts as RespParts;
31use http::uri::Uri;
32use pingora_error::{ErrorType::*, OrErr, Result};
33use std::ops::Deref;
34
35pub use http::method::Method;
36pub use http::status::StatusCode;
37pub use http::version::Version;
38pub use http::HeaderMap as HMap;
39
40mod case_header_name;
41use case_header_name::CaseHeaderName;
42pub use case_header_name::IntoCaseHeaderName;
43
44pub mod prelude {
45 pub use crate::RequestHeader;
46}
47
48type CaseMap = HMap<CaseHeaderName>;
58
59#[derive(Debug)]
67pub struct RequestHeader {
68 base: ReqParts,
69 header_name_map: Option<CaseMap>,
70 raw_path_fallback: Vec<u8>, send_end_stream: bool,
74}
75
76impl AsRef<ReqParts> for RequestHeader {
77 fn as_ref(&self) -> &ReqParts {
78 &self.base
79 }
80}
81
82impl Deref for RequestHeader {
83 type Target = ReqParts;
84
85 fn deref(&self) -> &Self::Target {
86 &self.base
87 }
88}
89
90impl RequestHeader {
91 fn new_no_case(size_hint: Option<usize>) -> Self {
92 let mut base = ReqBuilder::new().body(()).unwrap().into_parts().0;
93 base.headers.reserve(http_header_map_upper_bound(size_hint));
94 RequestHeader {
95 base,
96 header_name_map: None,
97 raw_path_fallback: vec![],
98 send_end_stream: true,
99 }
100 }
101
102 pub fn build(
106 method: impl TryInto<Method>,
107 path: &[u8],
108 size_hint: Option<usize>,
109 ) -> Result<Self> {
110 let mut req = Self::build_no_case(method, path, size_hint)?;
111 req.header_name_map = Some(CaseMap::with_capacity(http_header_map_upper_bound(
112 size_hint,
113 )));
114 Ok(req)
115 }
116
117 pub fn build_no_case(
123 method: impl TryInto<Method>,
124 path: &[u8],
125 size_hint: Option<usize>,
126 ) -> Result<Self> {
127 let mut req = Self::new_no_case(size_hint);
128 req.base.method = method
129 .try_into()
130 .explain_err(InvalidHTTPHeader, |_| "invalid method")?;
131 if let Ok(p) = std::str::from_utf8(path) {
132 let uri = Uri::builder()
133 .path_and_query(p)
134 .build()
135 .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", p))?;
136 req.base.uri = uri;
137 } else {
139 let lossy_str = String::from_utf8_lossy(path);
141 let uri = Uri::builder()
142 .path_and_query(lossy_str.as_ref())
143 .build()
144 .explain_err(InvalidHTTPHeader, |_| format!("invalid uri {}", lossy_str))?;
145 req.base.uri = uri;
146 req.raw_path_fallback = path.to_vec();
147 }
148
149 Ok(req)
150 }
151
152 pub fn append_header(
157 &mut self,
158 name: impl IntoCaseHeaderName,
159 value: impl TryInto<HeaderValue>,
160 ) -> Result<bool> {
161 let header_value = value
162 .try_into()
163 .explain_err(InvalidHTTPHeader, |_| "invalid value while append")?;
164 append_header_value(
165 self.header_name_map.as_mut(),
166 &mut self.base.headers,
167 name,
168 header_value,
169 )
170 }
171
172 pub fn insert_header(
177 &mut self,
178 name: impl IntoCaseHeaderName,
179 value: impl TryInto<HeaderValue>,
180 ) -> Result<()> {
181 let header_value = value
182 .try_into()
183 .explain_err(InvalidHTTPHeader, |_| "invalid value while insert")?;
184 insert_header_value(
185 self.header_name_map.as_mut(),
186 &mut self.base.headers,
187 name,
188 header_value,
189 )
190 }
191
192 pub fn remove_header<'a, N: ?Sized>(&mut self, name: &'a N) -> Option<HeaderValue>
194 where
195 &'a N: 'a + AsHeaderName,
196 {
197 remove_header(self.header_name_map.as_mut(), &mut self.base.headers, name)
198 }
199
200 pub fn header_to_h1_wire(&self, buf: &mut impl BufMut) {
204 header_to_h1_wire(self.header_name_map.as_ref(), &self.base.headers, buf)
205 }
206
207 pub fn set_method(&mut self, method: Method) {
209 self.base.method = method;
210 }
211
212 pub fn set_uri(&mut self, uri: http::Uri) {
214 self.base.uri = uri;
215 }
216
217 pub fn set_send_end_stream(&mut self, send_end_stream: bool) {
219 self.send_end_stream = send_end_stream;
220 }
221
222 pub fn send_end_stream(&self) -> Option<bool> {
225 if self.base.version != Version::HTTP_2 {
226 return None;
227 }
228 Some(self.send_end_stream)
229 }
230
231 pub fn raw_path(&self) -> &[u8] {
235 if !self.raw_path_fallback.is_empty() {
236 &self.raw_path_fallback
237 } else {
238 self.base
240 .uri
241 .path_and_query()
242 .as_ref()
243 .unwrap()
244 .as_str()
245 .as_bytes()
246 }
247 }
248
249 pub fn uri_file_extension(&self) -> Option<&str> {
251 let (_, ext) = self
253 .uri
254 .path_and_query()
255 .and_then(|pq| pq.path().rsplit_once('.'))?;
256 Some(ext)
257 }
258
259 pub fn set_version(&mut self, version: Version) {
261 self.base.version = version;
262 }
263
264 pub fn as_owned_parts(&self) -> ReqParts {
266 clone_req_parts(&self.base)
267 }
268}
269
270impl Clone for RequestHeader {
271 fn clone(&self) -> Self {
272 Self {
273 base: self.as_owned_parts(),
274 header_name_map: self.header_name_map.clone(),
275 raw_path_fallback: self.raw_path_fallback.clone(),
276 send_end_stream: self.send_end_stream,
277 }
278 }
279}
280
281impl From<ReqParts> for RequestHeader {
283 fn from(parts: ReqParts) -> RequestHeader {
284 Self {
285 base: parts,
286 header_name_map: None,
287 raw_path_fallback: vec![],
289 send_end_stream: true,
290 }
291 }
292}
293
294impl From<RequestHeader> for ReqParts {
295 fn from(resp: RequestHeader) -> ReqParts {
296 resp.base
297 }
298}
299
300#[derive(Debug)]
306pub struct ResponseHeader {
307 base: RespParts,
308 header_name_map: Option<CaseMap>,
310 reason_phrase: Option<String>,
312}
313
314impl AsRef<RespParts> for ResponseHeader {
315 fn as_ref(&self) -> &RespParts {
316 &self.base
317 }
318}
319
320impl Deref for ResponseHeader {
321 type Target = RespParts;
322
323 fn deref(&self) -> &Self::Target {
324 &self.base
325 }
326}
327
328impl Clone for ResponseHeader {
329 fn clone(&self) -> Self {
330 Self {
331 base: self.as_owned_parts(),
332 header_name_map: self.header_name_map.clone(),
333 reason_phrase: self.reason_phrase.clone(),
334 }
335 }
336}
337
338impl From<RespParts> for ResponseHeader {
340 fn from(parts: RespParts) -> ResponseHeader {
341 Self {
342 base: parts,
343 header_name_map: None,
344 reason_phrase: None,
345 }
346 }
347}
348
349impl From<ResponseHeader> for RespParts {
350 fn from(resp: ResponseHeader) -> RespParts {
351 resp.base
352 }
353}
354
355impl From<Box<ResponseHeader>> for Box<RespParts> {
356 fn from(resp: Box<ResponseHeader>) -> Box<RespParts> {
357 Box::new(resp.base)
358 }
359}
360
361impl ResponseHeader {
362 fn new(size_hint: Option<usize>) -> Self {
363 let mut resp_header = Self::new_no_case(size_hint);
364 resp_header.header_name_map = Some(CaseMap::with_capacity(http_header_map_upper_bound(
365 size_hint,
366 )));
367 resp_header
368 }
369
370 fn new_no_case(size_hint: Option<usize>) -> Self {
371 let mut base = RespBuilder::new().body(()).unwrap().into_parts().0;
372 base.headers.reserve(http_header_map_upper_bound(size_hint));
373 ResponseHeader {
374 base,
375 header_name_map: None,
376 reason_phrase: None,
377 }
378 }
379
380 pub fn build(code: impl TryInto<StatusCode>, size_hint: Option<usize>) -> Result<Self> {
382 let mut resp = Self::new(size_hint);
383 resp.base.status = code
384 .try_into()
385 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
386 Ok(resp)
387 }
388
389 pub fn build_no_case(code: impl TryInto<StatusCode>, size_hint: Option<usize>) -> Result<Self> {
395 let mut resp = Self::new_no_case(size_hint);
396 resp.base.status = code
397 .try_into()
398 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
399 Ok(resp)
400 }
401
402 pub fn append_header(
407 &mut self,
408 name: impl IntoCaseHeaderName,
409 value: impl TryInto<HeaderValue>,
410 ) -> Result<bool> {
411 let header_value = value
412 .try_into()
413 .explain_err(InvalidHTTPHeader, |_| "invalid value while append")?;
414 append_header_value(
415 self.header_name_map.as_mut(),
416 &mut self.base.headers,
417 name,
418 header_value,
419 )
420 }
421
422 pub fn insert_header(
427 &mut self,
428 name: impl IntoCaseHeaderName,
429 value: impl TryInto<HeaderValue>,
430 ) -> Result<()> {
431 let header_value = value
432 .try_into()
433 .explain_err(InvalidHTTPHeader, |_| "invalid value while insert")?;
434 insert_header_value(
435 self.header_name_map.as_mut(),
436 &mut self.base.headers,
437 name,
438 header_value,
439 )
440 }
441
442 pub fn remove_header<'a, N: ?Sized>(&mut self, name: &'a N) -> Option<HeaderValue>
444 where
445 &'a N: 'a + AsHeaderName,
446 {
447 remove_header(self.header_name_map.as_mut(), &mut self.base.headers, name)
448 }
449
450 pub fn header_to_h1_wire(&self, buf: &mut impl BufMut) {
454 header_to_h1_wire(self.header_name_map.as_ref(), &self.base.headers, buf)
455 }
456
457 pub fn set_status(&mut self, status: impl TryInto<StatusCode>) -> Result<()> {
459 self.base.status = status
460 .try_into()
461 .explain_err(InvalidHTTPHeader, |_| "invalid status")?;
462 Ok(())
463 }
464
465 pub fn set_version(&mut self, version: Version) {
467 self.base.version = version
468 }
469
470 pub fn set_reason_phrase(&mut self, reason_phrase: Option<&str>) -> Result<()> {
472 if reason_phrase == self.base.status.canonical_reason() {
474 self.reason_phrase = None;
475 return Ok(());
476 }
477
478 self.reason_phrase = reason_phrase.map(str::to_string);
480 Ok(())
481 }
482
483 pub fn get_reason_phrase(&self) -> Option<&str> {
486 self.reason_phrase
487 .as_deref()
488 .or_else(|| self.base.status.canonical_reason())
489 }
490
491 pub fn as_owned_parts(&self) -> RespParts {
493 clone_resp_parts(&self.base)
494 }
495}
496
497fn clone_req_parts(me: &ReqParts) -> ReqParts {
498 let mut parts = ReqBuilder::new()
499 .method(me.method.clone())
500 .uri(me.uri.clone())
501 .version(me.version)
502 .body(())
503 .unwrap()
504 .into_parts()
505 .0;
506 parts.headers = me.headers.clone();
507 parts
508}
509
510fn clone_resp_parts(me: &RespParts) -> RespParts {
511 let mut parts = RespBuilder::new()
512 .status(me.status)
513 .version(me.version)
514 .body(())
515 .unwrap()
516 .into_parts()
517 .0;
518 parts.headers = me.headers.clone();
519 parts
520}
521
522fn http_header_map_upper_bound(size_hint: Option<usize>) -> usize {
527 const PINGORA_MAX_HEADER_COUNT: usize = 4096;
535 const INIT_HEADER_SIZE: usize = 8;
536
537 std::cmp::min(
540 size_hint.unwrap_or(INIT_HEADER_SIZE),
541 PINGORA_MAX_HEADER_COUNT,
542 )
543}
544
545#[inline]
546fn append_header_value<T>(
547 name_map: Option<&mut CaseMap>,
548 value_map: &mut HMap<T>,
549 name: impl IntoCaseHeaderName,
550 value: T,
551) -> Result<bool> {
552 let case_header_name = name.into_case_header_name();
553 let header_name: HeaderName = case_header_name
554 .as_slice()
555 .try_into()
556 .or_err(InvalidHTTPHeader, "invalid header name")?;
557 if let Some(name_map) = name_map {
559 name_map.append(header_name.clone(), case_header_name);
560 }
561
562 Ok(value_map.append(header_name, value))
563}
564
565#[inline]
566fn insert_header_value<T>(
567 name_map: Option<&mut CaseMap>,
568 value_map: &mut HMap<T>,
569 name: impl IntoCaseHeaderName,
570 value: T,
571) -> Result<()> {
572 let case_header_name = name.into_case_header_name();
573 let header_name: HeaderName = case_header_name
574 .as_slice()
575 .try_into()
576 .or_err(InvalidHTTPHeader, "invalid header name")?;
577 if let Some(name_map) = name_map {
578 name_map.insert(header_name.clone(), case_header_name);
580 }
581 value_map.insert(header_name, value);
582 Ok(())
583}
584
585#[inline]
587fn remove_header<'a, T, N: ?Sized>(
588 name_map: Option<&mut CaseMap>,
589 value_map: &mut HMap<T>,
590 name: &'a N,
591) -> Option<T>
592where
593 &'a N: 'a + AsHeaderName,
594{
595 let removed = value_map.remove(name);
596 if removed.is_some() {
597 if let Some(name_map) = name_map {
598 name_map.remove(name);
599 }
600 }
601 removed
602}
603
604#[inline]
605fn header_to_h1_wire(key_map: Option<&CaseMap>, value_map: &HMap, buf: &mut impl BufMut) {
606 const CRLF: &[u8; 2] = b"\r\n";
607 const HEADER_KV_DELIMITER: &[u8; 2] = b": ";
608
609 if let Some(key_map) = key_map {
610 let iter = key_map.iter().zip(value_map.iter());
611 for ((header, case_header), (header2, val)) in iter {
612 if header != header2 {
613 panic!("header iter mismatch {}, {}", header, header2)
615 }
616 buf.put_slice(case_header.as_slice());
617 buf.put_slice(HEADER_KV_DELIMITER);
618 buf.put_slice(val.as_ref());
619 buf.put_slice(CRLF);
620 }
621 } else {
622 for (header, value) in value_map {
623 let titled_header =
624 case_header_name::titled_header_name_str(header).unwrap_or(header.as_str());
625 buf.put_slice(titled_header.as_bytes());
626 buf.put_slice(HEADER_KV_DELIMITER);
627 buf.put_slice(value.as_ref());
628 buf.put_slice(CRLF);
629 }
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636
637 #[test]
638 fn header_map_upper_bound() {
639 assert_eq!(8, http_header_map_upper_bound(None));
640 assert_eq!(16, http_header_map_upper_bound(Some(16)));
641 assert_eq!(4096, http_header_map_upper_bound(Some(7777)));
642 }
643
644 #[test]
645 fn test_single_header() {
646 let mut req = RequestHeader::build("GET", b"\\", None).unwrap();
647 req.insert_header("foo", "bar").unwrap();
648 req.insert_header("FoO", "Bar").unwrap();
649 let mut buf: Vec<u8> = vec![];
650 req.header_to_h1_wire(&mut buf);
651 assert_eq!(buf, b"FoO: Bar\r\n");
652
653 let mut resp = ResponseHeader::new(None);
654 req.insert_header("foo", "bar").unwrap();
655 resp.insert_header("FoO", "Bar").unwrap();
656 let mut buf: Vec<u8> = vec![];
657 resp.header_to_h1_wire(&mut buf);
658 assert_eq!(buf, b"FoO: Bar\r\n");
659 }
660
661 #[test]
662 fn test_single_header_no_case() {
663 let mut req = RequestHeader::new_no_case(None);
664 req.insert_header("foo", "bar").unwrap();
665 req.insert_header("FoO", "Bar").unwrap();
666 let mut buf: Vec<u8> = vec![];
667 req.header_to_h1_wire(&mut buf);
668 assert_eq!(buf, b"foo: Bar\r\n");
669
670 let mut resp = ResponseHeader::new_no_case(None);
671 req.insert_header("foo", "bar").unwrap();
672 resp.insert_header("FoO", "Bar").unwrap();
673 let mut buf: Vec<u8> = vec![];
674 resp.header_to_h1_wire(&mut buf);
675 assert_eq!(buf, b"foo: Bar\r\n");
676 }
677
678 #[test]
679 fn test_multiple_header() {
680 let mut req = RequestHeader::build("GET", b"\\", None).unwrap();
681 req.append_header("FoO", "Bar").unwrap();
682 req.append_header("fOO", "bar").unwrap();
683 req.append_header("BAZ", "baR").unwrap();
684 req.append_header(http::header::CONTENT_LENGTH, "0")
685 .unwrap();
686 req.append_header("a", "b").unwrap();
687 req.remove_header("a");
688 let mut buf: Vec<u8> = vec![];
689 req.header_to_h1_wire(&mut buf);
690 assert_eq!(
691 buf,
692 b"FoO: Bar\r\nfOO: bar\r\nBAZ: baR\r\nContent-Length: 0\r\n"
693 );
694
695 let mut resp = ResponseHeader::new(None);
696 resp.append_header("FoO", "Bar").unwrap();
697 resp.append_header("fOO", "bar").unwrap();
698 resp.append_header("BAZ", "baR").unwrap();
699 resp.append_header(http::header::CONTENT_LENGTH, "0")
700 .unwrap();
701 resp.append_header("a", "b").unwrap();
702 resp.remove_header("a");
703 let mut buf: Vec<u8> = vec![];
704 resp.header_to_h1_wire(&mut buf);
705 assert_eq!(
706 buf,
707 b"FoO: Bar\r\nfOO: bar\r\nBAZ: baR\r\nContent-Length: 0\r\n"
708 );
709 }
710
711 #[cfg(feature = "patched_http1")]
712 #[test]
713 fn test_invalid_path() {
714 let raw_path = b"Hello\xF0\x90\x80World";
715 let req = RequestHeader::build("GET", &raw_path[..], None).unwrap();
716 assert_eq!("Hello�World", req.uri.path_and_query().unwrap());
717 assert_eq!(raw_path, req.raw_path());
718 }
719
720 #[test]
721 fn test_reason_phrase() {
722 let mut resp = ResponseHeader::new(None);
723 let reason = resp.get_reason_phrase().unwrap();
724 assert_eq!(reason, "OK");
725
726 resp.set_reason_phrase(Some("FooBar")).unwrap();
727 let reason = resp.get_reason_phrase().unwrap();
728 assert_eq!(reason, "FooBar");
729
730 resp.set_reason_phrase(Some("OK")).unwrap();
731 let reason = resp.get_reason_phrase().unwrap();
732 assert_eq!(reason, "OK");
733
734 resp.set_reason_phrase(None).unwrap();
735 let reason = resp.get_reason_phrase().unwrap();
736 assert_eq!(reason, "OK");
737 }
738
739 #[test]
740 fn set_test_send_end_stream() {
741 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
742 req.set_send_end_stream(true);
743
744 assert!(req.send_end_stream().is_none());
746
747 let mut req = RequestHeader::build("GET", b"/", None).unwrap();
748 req.set_version(Version::HTTP_2);
749
750 assert!(req.send_end_stream().unwrap());
752
753 req.set_send_end_stream(false);
754 assert!(!req.send_end_stream().unwrap());
756 }
757}