1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10use smallvec::SmallVec;
11
12use std::fmt;
13use std::io::Cursor;
14
15type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
16#[derive(Eq, PartialEq)]
20pub struct Headers {
21 stream_id: StreamId,
23
24 stream_dep: Option<StreamDependency>,
26
27 header_block: HeaderBlock,
29
30 flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39 stream_id: StreamId,
41
42 promised_id: StreamId,
44
45 header_block: HeaderBlock,
47
48 flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57 stream_id: StreamId,
59
60 header_block: EncodingHeaderBlock,
61}
62
63#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66 pub method: Option<Method>,
68 pub scheme: Option<BytesStr>,
69 pub authority: Option<BytesStr>,
70 pub path: Option<BytesStr>,
71 pub protocol: Option<Protocol>,
72
73 pub status: Option<StatusCode>,
75
76 pub order: PseudoOrder,
78}
79
80define_enum_with_values! {
81 @U8
87 pub enum PseudoId {
88 Method => 0x0001,
89 Scheme => 0x0002,
90 Authority => 0x0003,
91 Path => 0x0004,
92 Protocol => 0x0005,
93 Status => 0x0006,
94 }
95}
96
97#[derive(Clone, Debug, PartialEq, Eq)]
107pub struct PseudoOrder {
108 ids: SmallVec<[PseudoId; PseudoId::DEFAULT_STACK_SIZE]>,
109}
110
111#[derive(Debug)]
118pub struct PseudoOrderBuilder {
119 ids: SmallVec<[PseudoId; PseudoId::DEFAULT_STACK_SIZE]>,
120 mask: u8,
121}
122
123impl PseudoOrder {
126 pub fn builder() -> PseudoOrderBuilder {
127 PseudoOrderBuilder {
128 ids: SmallVec::new(),
129 mask: 0,
130 }
131 }
132}
133
134impl Default for PseudoOrder {
135 fn default() -> Self {
136 PseudoOrder {
137 ids: SmallVec::from(PseudoId::DEFAULT_IDS),
138 }
139 }
140}
141
142impl<'a> IntoIterator for &'a PseudoOrder {
143 type Item = &'a PseudoId;
144 type IntoIter = std::slice::Iter<'a, PseudoId>;
145
146 fn into_iter(self) -> Self::IntoIter {
147 self.ids.iter()
148 }
149}
150
151impl PseudoOrderBuilder {
154 pub fn push(mut self, id: PseudoId) -> Self {
155 let mask_id = id.mask_id();
156 if mask_id != 0 {
157 if self.mask & mask_id == 0 {
158 self.mask |= mask_id;
159 self.ids.push(id);
160 } else {
161 tracing::trace!("duplicate pseudo header: {:?}", id);
162 }
163 }
164 self
165 }
166
167 pub fn extend(mut self, iter: impl IntoIterator<Item = PseudoId>) -> Self {
168 for id in iter {
169 self = self.push(id);
170 }
171 self
172 }
173
174 pub fn build(mut self) -> PseudoOrder {
175 if self.ids.len() != PseudoId::DEFAULT_IDS.len() {
176 self = self.extend(PseudoId::DEFAULT_IDS);
177 }
178 PseudoOrder { ids: self.ids }
179 }
180}
181
182#[derive(Debug)]
183pub struct Iter {
184 pseudo: Option<Pseudo>,
186
187 fields: std::vec::IntoIter<(Option<HeaderName>, HeaderValue)>,
189
190 pseudo_order: PseudoOrder,
192}
193
194#[derive(Debug, PartialEq, Eq)]
195struct HeaderBlock {
196 fields: HeaderMap,
198
199 field_size: usize,
201
202 is_over_size: bool,
204
205 pseudo: Pseudo,
208
209 header_order: Option<Vec<HeaderName>>,
212}
213
214#[derive(Debug)]
215struct EncodingHeaderBlock {
216 hpack: Bytes,
217}
218
219const END_STREAM: u8 = 0x1;
220const END_HEADERS: u8 = 0x4;
221const PADDED: u8 = 0x8;
222const PRIORITY: u8 = 0x20;
223const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
224
225impl Headers {
228 pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
230 Headers {
231 stream_id,
232 stream_dep: None,
233 header_block: HeaderBlock {
234 field_size: calculate_headermap_size(&fields),
235 fields,
236 is_over_size: false,
237 pseudo,
238 header_order: None,
239 },
240 flags: HeadersFlag::default(),
241 }
242 }
243
244 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
245 let mut flags = HeadersFlag::default();
246 flags.set_end_stream();
247
248 Headers {
249 stream_id,
250 stream_dep: None,
251 header_block: HeaderBlock {
252 field_size: calculate_headermap_size(&fields),
253 fields,
254 is_over_size: false,
255 pseudo: Pseudo::default(),
256 header_order: None,
257 },
258 flags,
259 }
260 }
261
262 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
266 let flags = HeadersFlag(head.flag());
267 let mut pad = 0;
268
269 tracing::trace!("loading headers; flags={:?}", flags);
270
271 if head.stream_id().is_zero() {
272 return Err(Error::InvalidStreamId);
273 }
274
275 if flags.is_padded() {
277 if src.is_empty() {
278 return Err(Error::MalformedMessage);
279 }
280 pad = src[0] as usize;
281
282 src.advance(1);
284 }
285
286 let stream_dep = if flags.is_priority() {
288 if src.len() < 5 {
289 return Err(Error::MalformedMessage);
290 }
291 let stream_dep = StreamDependency::load(&src[..5])?;
292
293 if stream_dep.dependency_id() == head.stream_id() {
294 return Err(Error::InvalidDependencyId);
295 }
296
297 src.advance(5);
299
300 Some(stream_dep)
301 } else {
302 None
303 };
304
305 if pad > 0 {
306 if pad > src.len() {
307 return Err(Error::TooMuchPadding);
308 }
309
310 let len = src.len() - pad;
311 src.truncate(len);
312 }
313
314 let headers = Headers {
315 stream_id: head.stream_id(),
316 stream_dep,
317 header_block: HeaderBlock {
318 fields: HeaderMap::new(),
319 field_size: 0,
320 is_over_size: false,
321 pseudo: Pseudo::default(),
322 header_order: None,
323 },
324 flags,
325 };
326
327 Ok((headers, src))
328 }
329
330 pub fn load_hpack(
331 &mut self,
332 src: &mut BytesMut,
333 max_header_list_size: usize,
334 decoder: &mut hpack::Decoder,
335 ) -> Result<(), Error> {
336 self.header_block.load(src, max_header_list_size, decoder)
337 }
338
339 pub fn stream_id(&self) -> StreamId {
340 self.stream_id
341 }
342
343 pub fn is_end_headers(&self) -> bool {
344 self.flags.is_end_headers()
345 }
346
347 pub fn set_end_headers(&mut self) {
348 self.flags.set_end_headers();
349 }
350
351 pub fn is_end_stream(&self) -> bool {
352 self.flags.is_end_stream()
353 }
354
355 pub fn set_end_stream(&mut self) {
356 self.flags.set_end_stream()
357 }
358
359 pub fn set_priority(&mut self, dependency: StreamDependency) {
360 self.flags.set_priority();
361 self.stream_dep = Some(dependency);
362 }
363
364 pub fn set_header_order(&mut self, order: Vec<HeaderName>) {
365 self.header_block.header_order = Some(order);
366 }
367
368 pub fn is_over_size(&self) -> bool {
369 self.header_block.is_over_size
370 }
371
372 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
373 (self.header_block.pseudo, self.header_block.fields)
374 }
375
376 #[cfg(feature = "unstable")]
377 pub fn pseudo_mut(&mut self) -> &mut Pseudo {
378 &mut self.header_block.pseudo
379 }
380
381 pub(crate) fn pseudo(&self) -> &Pseudo {
382 &self.header_block.pseudo
383 }
384
385 pub(crate) fn is_informational(&self) -> bool {
387 self.header_block.pseudo.is_informational()
388 }
389
390 pub fn fields(&self) -> &HeaderMap {
391 &self.header_block.fields
392 }
393
394 pub fn into_fields(self) -> HeaderMap {
395 self.header_block.fields
396 }
397
398 pub fn encode(
399 self,
400 encoder: &mut hpack::Encoder,
401 dst: &mut EncodeBuf<'_>,
402 ) -> Option<Continuation> {
403 debug_assert!(self.flags.is_end_headers());
405
406 let head = self.head();
408
409 let stream_dep = self.stream_dep;
410 self.header_block
411 .into_encoding(encoder)
412 .encode(&head, dst, |dst| {
413 if let Some(ref dep) = stream_dep {
415 dep.encode(dst);
416 }
417 })
418 }
419
420 fn head(&self) -> Head {
421 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
422 }
423}
424
425impl<T> From<Headers> for Frame<T> {
426 fn from(src: Headers) -> Self {
427 Frame::Headers(src)
428 }
429}
430
431impl fmt::Debug for Headers {
432 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
433 let mut builder = f.debug_struct("Headers");
434 builder
435 .field("stream_id", &self.stream_id)
436 .field("flags", &self.flags);
437
438 if let Some(ref protocol) = self.header_block.pseudo.protocol {
439 builder.field("protocol", protocol);
440 }
441
442 if let Some(ref dep) = self.stream_dep {
443 builder.field("stream_dep", dep);
444 }
445
446 builder.finish()
448 }
449}
450
451#[derive(Debug, PartialEq, Eq)]
454pub struct ParseU64Error;
455
456pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
457 if src.len() > 19 {
458 return Err(ParseU64Error);
460 }
461
462 let mut ret = 0;
463
464 for &d in src {
465 if d < b'0' || d > b'9' {
466 return Err(ParseU64Error);
467 }
468
469 ret *= 10;
470 ret += (d - b'0') as u64;
471 }
472
473 Ok(ret)
474}
475
476#[derive(Debug)]
479pub enum PushPromiseHeaderError {
480 InvalidContentLength(Result<u64, ParseU64Error>),
481 NotSafeAndCacheable,
482}
483
484impl PushPromise {
485 pub fn new(
486 stream_id: StreamId,
487 promised_id: StreamId,
488 pseudo: Pseudo,
489 fields: HeaderMap,
490 ) -> Self {
491 PushPromise {
492 flags: PushPromiseFlag::default(),
493 header_block: HeaderBlock {
494 field_size: calculate_headermap_size(&fields),
495 fields,
496 is_over_size: false,
497 pseudo,
498 header_order: None,
499 },
500 promised_id,
501 stream_id,
502 }
503 }
504
505 pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
506 use PushPromiseHeaderError::*;
507 if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
513 let parsed_length = parse_u64(content_length.as_bytes());
514 if parsed_length != Ok(0) {
515 return Err(InvalidContentLength(parsed_length));
516 }
517 }
518 if !Self::safe_and_cacheable(req.method()) {
521 return Err(NotSafeAndCacheable);
522 }
523
524 Ok(())
525 }
526
527 fn safe_and_cacheable(method: &Method) -> bool {
528 method == Method::GET || method == Method::HEAD
531 }
532
533 pub fn fields(&self) -> &HeaderMap {
534 &self.header_block.fields
535 }
536
537 #[cfg(feature = "unstable")]
538 pub fn into_fields(self) -> HeaderMap {
539 self.header_block.fields
540 }
541
542 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
546 let flags = PushPromiseFlag(head.flag());
547 let mut pad = 0;
548
549 if head.stream_id().is_zero() {
550 return Err(Error::InvalidStreamId);
551 }
552
553 if flags.is_padded() {
555 if src.is_empty() {
556 return Err(Error::MalformedMessage);
557 }
558
559 pad = src[0] as usize;
561
562 src.advance(1);
564 }
565
566 if src.len() < 5 {
567 return Err(Error::MalformedMessage);
568 }
569
570 let (promised_id, _) = StreamId::parse(&src[..4]);
571 src.advance(4);
573
574 if pad > 0 {
575 if pad > src.len() {
576 return Err(Error::TooMuchPadding);
577 }
578
579 let len = src.len() - pad;
580 src.truncate(len);
581 }
582
583 let frame = PushPromise {
584 flags,
585 header_block: HeaderBlock {
586 fields: HeaderMap::new(),
587 field_size: 0,
588 is_over_size: false,
589 pseudo: Pseudo::default(),
590 header_order: None,
591 },
592 promised_id,
593 stream_id: head.stream_id(),
594 };
595 Ok((frame, src))
596 }
597
598 pub fn load_hpack(
599 &mut self,
600 src: &mut BytesMut,
601 max_header_list_size: usize,
602 decoder: &mut hpack::Decoder,
603 ) -> Result<(), Error> {
604 self.header_block.load(src, max_header_list_size, decoder)
605 }
606
607 pub fn stream_id(&self) -> StreamId {
608 self.stream_id
609 }
610
611 pub fn promised_id(&self) -> StreamId {
612 self.promised_id
613 }
614
615 pub fn is_end_headers(&self) -> bool {
616 self.flags.is_end_headers()
617 }
618
619 pub fn set_end_headers(&mut self) {
620 self.flags.set_end_headers();
621 }
622
623 pub fn is_over_size(&self) -> bool {
624 self.header_block.is_over_size
625 }
626
627 pub fn encode(
628 self,
629 encoder: &mut hpack::Encoder,
630 dst: &mut EncodeBuf<'_>,
631 ) -> Option<Continuation> {
632 debug_assert!(self.flags.is_end_headers());
634
635 let head = self.head();
636 let promised_id = self.promised_id;
637
638 self.header_block
639 .into_encoding(encoder)
640 .encode(&head, dst, |dst| {
641 dst.put_u32(promised_id.into());
642 })
643 }
644
645 fn head(&self) -> Head {
646 Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
647 }
648
649 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
651 (self.header_block.pseudo, self.header_block.fields)
652 }
653}
654
655impl<T> From<PushPromise> for Frame<T> {
656 fn from(src: PushPromise) -> Self {
657 Frame::PushPromise(src)
658 }
659}
660
661impl fmt::Debug for PushPromise {
662 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
663 f.debug_struct("PushPromise")
664 .field("stream_id", &self.stream_id)
665 .field("promised_id", &self.promised_id)
666 .field("flags", &self.flags)
667 .finish()
669 }
670}
671
672impl Continuation {
675 fn head(&self) -> Head {
676 Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
677 }
678
679 pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
680 let head = self.head();
682
683 self.header_block.encode(&head, dst, |_| {})
684 }
685}
686
687impl Pseudo {
690 pub fn request(
691 method: Method,
692 uri: Uri,
693 protocol: Option<Protocol>,
694 order: Option<PseudoOrder>,
695 ) -> Self {
696 let parts = uri::Parts::from(uri);
697
698 let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
699 (None, None)
700 } else {
701 let path = parts
702 .path_and_query
703 .map(|v| BytesStr::from(v.as_str()))
704 .unwrap_or(BytesStr::from_static(""));
705
706 let path = if !path.is_empty() {
707 path
708 } else if method == Method::OPTIONS {
709 BytesStr::from_static("*")
710 } else {
711 BytesStr::from_static("/")
712 };
713
714 (parts.scheme, Some(path))
715 };
716
717 let mut pseudo = Pseudo {
718 method: Some(method),
719 scheme: None,
720 authority: None,
721 path,
722 protocol,
723 status: None,
724 order: order.unwrap_or_default(),
725 };
726
727 if let Some(scheme) = scheme {
729 pseudo.set_scheme(scheme);
730 }
731
732 if let Some(authority) = parts.authority {
735 pseudo.set_authority(BytesStr::from(authority.as_str()));
736 }
737
738 pseudo
739 }
740
741 pub fn response(status: StatusCode) -> Self {
742 Pseudo {
743 method: None,
744 scheme: None,
745 authority: None,
746 path: None,
747 protocol: None,
748 status: Some(status),
749 order: PseudoOrder::default(),
750 }
751 }
752
753 #[cfg(feature = "unstable")]
754 pub fn set_status(&mut self, value: StatusCode) {
755 self.status = Some(value);
756 }
757
758 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
759 let bytes_str = match scheme.as_str() {
760 "http" => BytesStr::from_static("http"),
761 "https" => BytesStr::from_static("https"),
762 s => BytesStr::from(s),
763 };
764 self.scheme = Some(bytes_str);
765 }
766
767 #[cfg(feature = "unstable")]
768 pub fn set_protocol(&mut self, protocol: Protocol) {
769 self.protocol = Some(protocol);
770 }
771
772 pub fn set_authority(&mut self, authority: BytesStr) {
773 self.authority = Some(authority);
774 }
775
776 pub(crate) fn is_informational(&self) -> bool {
778 self.status.is_some_and(|status| status.is_informational())
779 }
780}
781
782impl EncodingHeaderBlock {
785 fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
786 where
787 F: FnOnce(&mut EncodeBuf<'_>),
788 {
789 let head_pos = dst.get_ref().len();
790
791 head.encode(0, dst);
795
796 let payload_pos = dst.get_ref().len();
797
798 f(dst);
799
800 let continuation = if self.hpack.len() > dst.remaining_mut() {
802 dst.put((&mut self.hpack).take(dst.remaining_mut()));
803
804 Some(Continuation {
805 stream_id: head.stream_id(),
806 header_block: self,
807 })
808 } else {
809 dst.put_slice(&self.hpack);
810
811 None
812 };
813
814 let payload_len = (dst.get_ref().len() - payload_pos) as u64;
816
817 let payload_len_be = payload_len.to_be_bytes();
819 assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
820 (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
821
822 if continuation.is_some() {
823 debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
826
827 dst.get_mut()[head_pos + 4] -= END_HEADERS;
828 }
829
830 continuation
831 }
832}
833
834impl Iterator for Iter {
837 type Item = hpack::Header<Option<HeaderName>>;
838
839 fn next(&mut self) -> Option<Self::Item> {
840 use crate::hpack::Header::*;
841
842 if let Some(ref mut pseudo) = self.pseudo {
843 for id in &self.pseudo_order.ids {
845 match id {
846 PseudoId::Method => {
847 if let Some(method) = pseudo.method.take() {
848 return Some(Method(method));
849 }
850 }
851 PseudoId::Scheme => {
852 if let Some(scheme) = pseudo.scheme.take() {
853 return Some(Scheme(scheme));
854 }
855 }
856 PseudoId::Authority => {
857 if let Some(authority) = pseudo.authority.take() {
858 return Some(Authority(authority));
859 }
860 }
861 PseudoId::Path => {
862 if let Some(path) = pseudo.path.take() {
863 return Some(Path(path));
864 }
865 }
866 PseudoId::Protocol => {
867 if let Some(protocol) = pseudo.protocol.take() {
868 return Some(Protocol(protocol));
869 }
870 }
871 PseudoId::Status => {
872 if let Some(status) = pseudo.status.take() {
873 return Some(Status(status));
874 }
875 }
876 }
877 }
878
879 self.pseudo = None;
881 }
882
883 self.fields
884 .next()
885 .map(|(name, value)| Field { name, value })
886 }
887}
888
889impl HeadersFlag {
892 pub fn empty() -> HeadersFlag {
893 HeadersFlag(0)
894 }
895
896 pub fn load(bits: u8) -> HeadersFlag {
897 HeadersFlag(bits & ALL)
898 }
899
900 pub fn is_end_stream(&self) -> bool {
901 self.0 & END_STREAM == END_STREAM
902 }
903
904 pub fn set_end_stream(&mut self) {
905 self.0 |= END_STREAM;
906 }
907
908 pub fn is_end_headers(&self) -> bool {
909 self.0 & END_HEADERS == END_HEADERS
910 }
911
912 pub fn set_end_headers(&mut self) {
913 self.0 |= END_HEADERS;
914 }
915
916 pub fn is_padded(&self) -> bool {
917 self.0 & PADDED == PADDED
918 }
919
920 pub fn is_priority(&self) -> bool {
921 self.0 & PRIORITY == PRIORITY
922 }
923
924 pub fn set_priority(&mut self) {
925 self.0 |= PRIORITY;
926 }
927}
928
929impl Default for HeadersFlag {
930 fn default() -> Self {
932 HeadersFlag(END_HEADERS)
933 }
934}
935
936impl From<HeadersFlag> for u8 {
937 fn from(src: HeadersFlag) -> u8 {
938 src.0
939 }
940}
941
942impl fmt::Debug for HeadersFlag {
943 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
944 util::debug_flags(fmt, self.0)
945 .flag_if(self.is_end_headers(), "END_HEADERS")
946 .flag_if(self.is_end_stream(), "END_STREAM")
947 .flag_if(self.is_padded(), "PADDED")
948 .flag_if(self.is_priority(), "PRIORITY")
949 .finish()
950 }
951}
952
953impl PushPromiseFlag {
956 pub fn empty() -> PushPromiseFlag {
957 PushPromiseFlag(0)
958 }
959
960 pub fn load(bits: u8) -> PushPromiseFlag {
961 PushPromiseFlag(bits & ALL)
962 }
963
964 pub fn is_end_headers(&self) -> bool {
965 self.0 & END_HEADERS == END_HEADERS
966 }
967
968 pub fn set_end_headers(&mut self) {
969 self.0 |= END_HEADERS;
970 }
971
972 pub fn is_padded(&self) -> bool {
973 self.0 & PADDED == PADDED
974 }
975}
976
977impl Default for PushPromiseFlag {
978 fn default() -> Self {
980 PushPromiseFlag(END_HEADERS)
981 }
982}
983
984impl From<PushPromiseFlag> for u8 {
985 fn from(src: PushPromiseFlag) -> u8 {
986 src.0
987 }
988}
989
990impl fmt::Debug for PushPromiseFlag {
991 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
992 util::debug_flags(fmt, self.0)
993 .flag_if(self.is_end_headers(), "END_HEADERS")
994 .flag_if(self.is_padded(), "PADDED")
995 .finish()
996 }
997}
998
999impl HeaderBlock {
1002 fn load(
1003 &mut self,
1004 src: &mut BytesMut,
1005 max_header_list_size: usize,
1006 decoder: &mut hpack::Decoder,
1007 ) -> Result<(), Error> {
1008 let mut reg = !self.fields.is_empty();
1009 let mut malformed = false;
1010 let mut headers_size = self.calculate_header_list_size();
1011
1012 macro_rules! set_pseudo {
1013 ($field:ident, $val:expr) => {{
1014 if reg {
1015 tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
1016 malformed = true;
1017 } else if self.pseudo.$field.is_some() {
1018 tracing::trace!("load_hpack; header malformed -- repeated pseudo");
1019 malformed = true;
1020 } else {
1021 let __val = $val;
1022 headers_size +=
1023 decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
1024 if headers_size < max_header_list_size {
1025 self.pseudo.$field = Some(__val);
1026 } else if !self.is_over_size {
1027 tracing::trace!("load_hpack; header list size over max");
1028 self.is_over_size = true;
1029 }
1030 }
1031 }};
1032 }
1033
1034 let mut cursor = Cursor::new(src);
1035
1036 let res = decoder.decode(&mut cursor, |header| {
1041 use crate::hpack::Header::*;
1042
1043 match header {
1044 Field { name, value } => {
1045 if name == header::CONNECTION
1049 || name == header::TRANSFER_ENCODING
1050 || name == header::UPGRADE
1051 || name == "keep-alive"
1052 || name == "proxy-connection"
1053 {
1054 tracing::trace!("load_hpack; connection level header");
1055 malformed = true;
1056 } else if name == header::TE && value != "trailers" {
1057 tracing::trace!(
1058 "load_hpack; TE header not set to trailers; val={:?}",
1059 value
1060 );
1061 malformed = true;
1062 } else {
1063 reg = true;
1064
1065 headers_size += decoded_header_size(name.as_str().len(), value.len());
1066 if headers_size < max_header_list_size {
1067 self.field_size +=
1068 decoded_header_size(name.as_str().len(), value.len());
1069 self.fields.append(name, value);
1070 } else if !self.is_over_size {
1071 tracing::trace!("load_hpack; header list size over max");
1072 self.is_over_size = true;
1073 }
1074 }
1075 }
1076 Authority(v) => set_pseudo!(authority, v),
1077 Method(v) => set_pseudo!(method, v),
1078 Scheme(v) => set_pseudo!(scheme, v),
1079 Path(v) => set_pseudo!(path, v),
1080 Protocol(v) => set_pseudo!(protocol, v),
1081 Status(v) => set_pseudo!(status, v),
1082 }
1083 });
1084
1085 if let Err(e) = res {
1086 tracing::trace!("hpack decoding error; err={:?}", e);
1087 return Err(e.into());
1088 }
1089
1090 if malformed {
1091 tracing::trace!("malformed message");
1092 return Err(Error::MalformedMessage);
1093 }
1094
1095 Ok(())
1096 }
1097
1098 fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
1099 let mut hpack = BytesMut::new();
1100 let pseudo_order = self.pseudo.order.clone();
1101
1102 let mut fields_vec: Vec<(Option<HeaderName>, HeaderValue)> =
1104 self.fields.into_iter().collect();
1105
1106 if let Some(ref order) = self.header_order {
1108 let order_map: std::collections::HashMap<&HeaderName, usize> = order
1109 .iter()
1110 .enumerate()
1111 .map(|(i, name)| (name, i))
1112 .collect();
1113 let default_pos = order.len();
1114 fields_vec.sort_by_key(|(name, _)| {
1115 name.as_ref()
1116 .and_then(|n| order_map.get(n).copied())
1117 .unwrap_or(default_pos)
1118 });
1119 }
1120
1121 let headers = Iter {
1122 pseudo: Some(self.pseudo),
1123 fields: fields_vec.into_iter(),
1124 pseudo_order,
1125 };
1126
1127 encoder.encode(headers, &mut hpack);
1128
1129 EncodingHeaderBlock {
1130 hpack: hpack.freeze(),
1131 }
1132 }
1133
1134 fn calculate_header_list_size(&self) -> usize {
1142 macro_rules! pseudo_size {
1143 ($name:ident) => {{
1144 self.pseudo
1145 .$name
1146 .as_ref()
1147 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
1148 .unwrap_or(0)
1149 }};
1150 }
1151
1152 pseudo_size!(method)
1153 + pseudo_size!(scheme)
1154 + pseudo_size!(status)
1155 + pseudo_size!(authority)
1156 + pseudo_size!(path)
1157 + self.field_size
1158 }
1159}
1160
1161fn calculate_headermap_size(map: &HeaderMap) -> usize {
1162 map.iter()
1163 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
1164 .sum::<usize>()
1165}
1166
1167fn decoded_header_size(name: usize, value: usize) -> usize {
1168 name + value + 32
1169}
1170
1171#[cfg(test)]
1172mod test {
1173 use super::*;
1174 use crate::frame;
1175 use crate::hpack::{huffman, Encoder};
1176
1177 #[test]
1178 fn test_nameless_header_at_resume() {
1179 let mut encoder = Encoder::default();
1180 let mut dst = BytesMut::new();
1181
1182 let headers = Headers::new(
1183 StreamId::ZERO,
1184 Default::default(),
1185 HeaderMap::from_iter(vec![
1186 (
1187 HeaderName::from_static("hello"),
1188 HeaderValue::from_static("world"),
1189 ),
1190 (
1191 HeaderName::from_static("hello"),
1192 HeaderValue::from_static("zomg"),
1193 ),
1194 (
1195 HeaderName::from_static("hello"),
1196 HeaderValue::from_static("sup"),
1197 ),
1198 ]),
1199 );
1200
1201 let continuation = headers
1202 .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1203 .unwrap();
1204
1205 assert_eq!(17, dst.len());
1206 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1207 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1208 assert_eq!("hello", huff_decode(&dst[11..15]));
1209 assert_eq!(0x80 | 4, dst[15]);
1210
1211 let mut world = dst[16..17].to_owned();
1212
1213 dst.clear();
1214
1215 assert!(continuation
1216 .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1217 .is_none());
1218
1219 world.extend_from_slice(&dst[9..12]);
1220 assert_eq!("world", huff_decode(&world));
1221
1222 assert_eq!(24, dst.len());
1223 assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1224
1225 assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1227 assert_eq!("zomg", huff_decode(&dst[15..18]));
1228 assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1229 assert_eq!("sup", huff_decode(&dst[21..]));
1230 }
1231
1232 fn huff_decode(src: &[u8]) -> BytesMut {
1233 let mut buf = BytesMut::new();
1234 huffman::decode(src, &mut buf).unwrap()
1235 }
1236
1237 #[test]
1238 fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1239 assert_eq!(
1243 Pseudo::request(
1244 Method::CONNECT,
1245 Uri::from_static("https://example.com:8443"),
1246 None,
1247 None,
1248 ),
1249 Pseudo {
1250 method: Method::CONNECT.into(),
1251 authority: BytesStr::from_static("example.com:8443").into(),
1252 ..Default::default()
1253 }
1254 );
1255
1256 assert_eq!(
1257 Pseudo::request(
1258 Method::CONNECT,
1259 Uri::from_static("https://example.com/test"),
1260 None,
1261 None,
1262 ),
1263 Pseudo {
1264 method: Method::CONNECT.into(),
1265 authority: BytesStr::from_static("example.com").into(),
1266 ..Default::default()
1267 }
1268 );
1269
1270 assert_eq!(
1271 Pseudo::request(
1272 Method::CONNECT,
1273 Uri::from_static("example.com:8443"),
1274 None,
1275 None
1276 ),
1277 Pseudo {
1278 method: Method::CONNECT.into(),
1279 authority: BytesStr::from_static("example.com:8443").into(),
1280 ..Default::default()
1281 }
1282 );
1283 }
1284
1285 #[test]
1286 fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1287 assert_eq!(
1293 Pseudo::request(
1294 Method::CONNECT,
1295 Uri::from_static("https://example.com:8443"),
1296 Protocol::from_static("the-bread-protocol").into(),
1297 None,
1298 ),
1299 Pseudo {
1300 method: Method::CONNECT.into(),
1301 authority: BytesStr::from_static("example.com:8443").into(),
1302 scheme: BytesStr::from_static("https").into(),
1303 path: BytesStr::from_static("/").into(),
1304 protocol: Protocol::from_static("the-bread-protocol").into(),
1305 ..Default::default()
1306 }
1307 );
1308
1309 assert_eq!(
1310 Pseudo::request(
1311 Method::CONNECT,
1312 Uri::from_static("https://example.com:8443/test"),
1313 Protocol::from_static("the-bread-protocol").into(),
1314 None,
1315 ),
1316 Pseudo {
1317 method: Method::CONNECT.into(),
1318 authority: BytesStr::from_static("example.com:8443").into(),
1319 scheme: BytesStr::from_static("https").into(),
1320 path: BytesStr::from_static("/test").into(),
1321 protocol: Protocol::from_static("the-bread-protocol").into(),
1322 ..Default::default()
1323 }
1324 );
1325
1326 assert_eq!(
1327 Pseudo::request(
1328 Method::CONNECT,
1329 Uri::from_static("http://example.com/a/b/c"),
1330 Protocol::from_static("the-bread-protocol").into(),
1331 None,
1332 ),
1333 Pseudo {
1334 method: Method::CONNECT.into(),
1335 authority: BytesStr::from_static("example.com").into(),
1336 scheme: BytesStr::from_static("http").into(),
1337 path: BytesStr::from_static("/a/b/c").into(),
1338 protocol: Protocol::from_static("the-bread-protocol").into(),
1339 ..Default::default()
1340 }
1341 );
1342 }
1343
1344 #[test]
1345 fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1346 assert_eq!(
1350 Pseudo::request(
1351 Method::OPTIONS,
1352 Uri::from_static("example.com:8080"),
1353 None,
1354 None
1355 ),
1356 Pseudo {
1357 method: Method::OPTIONS.into(),
1358 authority: BytesStr::from_static("example.com:8080").into(),
1359 path: BytesStr::from_static("*").into(),
1360 ..Default::default()
1361 }
1362 );
1363 }
1364
1365 #[test]
1366 fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1367 let methods = [
1368 Method::GET,
1369 Method::POST,
1370 Method::PUT,
1371 Method::DELETE,
1372 Method::HEAD,
1373 Method::PATCH,
1374 Method::TRACE,
1375 ];
1376
1377 for method in methods {
1378 assert_eq!(
1379 Pseudo::request(
1380 method.clone(),
1381 Uri::from_static("http://example.com:8080"),
1382 None,
1383 None,
1384 ),
1385 Pseudo {
1386 method: method.clone().into(),
1387 authority: BytesStr::from_static("example.com:8080").into(),
1388 scheme: BytesStr::from_static("http").into(),
1389 path: BytesStr::from_static("/").into(),
1390 ..Default::default()
1391 }
1392 );
1393 assert_eq!(
1394 Pseudo::request(
1395 method.clone(),
1396 Uri::from_static("https://example.com/a/b/c"),
1397 None,
1398 None,
1399 ),
1400 Pseudo {
1401 method: method.into(),
1402 authority: BytesStr::from_static("example.com").into(),
1403 scheme: BytesStr::from_static("https").into(),
1404 path: BytesStr::from_static("/a/b/c").into(),
1405 ..Default::default()
1406 }
1407 );
1408 }
1409 }
1410}