Skip to main content

h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{Buf, BufMut, Bytes, BytesMut};
10use smallvec::SmallVec;
11
12use std::fmt;
13use std::io::Cursor;
14
15type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
16/// Header frame
17///
18/// This could be either a request or a response.
19#[derive(Eq, PartialEq)]
20pub struct Headers {
21    /// The ID of the stream with which this frame is associated.
22    stream_id: StreamId,
23
24    /// The stream dependency information, if any.
25    stream_dep: Option<StreamDependency>,
26
27    /// The header block fragment
28    header_block: HeaderBlock,
29
30    /// The associated flags
31    flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39    /// The ID of the stream with which this frame is associated.
40    stream_id: StreamId,
41
42    /// The ID of the stream being reserved by this PushPromise.
43    promised_id: StreamId,
44
45    /// The header block fragment
46    header_block: HeaderBlock,
47
48    /// The associated flags
49    flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57    /// Stream ID of continuation frame
58    stream_id: StreamId,
59
60    header_block: EncodingHeaderBlock,
61}
62
63// TODO: These fields shouldn't be `pub`
64#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66    // Request
67    pub method: Option<Method>,
68    pub scheme: Option<BytesStr>,
69    pub authority: Option<BytesStr>,
70    pub path: Option<BytesStr>,
71    pub protocol: Option<Protocol>,
72
73    // Response
74    pub status: Option<StatusCode>,
75
76    // Pseudo order
77    pub order: PseudoOrder,
78}
79
80define_enum_with_values! {
81    /// Represents the order of HTTP/2 pseudo-header fields in the header block.
82    ///
83    /// HTTP/2 pseudo-header fields are a set of predefined header fields that start with ':'.
84    /// The order of these fields in a header block is significant. This enum defines the
85    /// possible pseudo-header fields and their standard order according to RFC 7540.
86    @U8
87    pub enum PseudoId {
88        Method => 0x0001,
89        Scheme => 0x0002,
90        Authority => 0x0003,
91        Path => 0x0004,
92        Protocol => 0x0005,
93        Status => 0x0006,
94    }
95}
96
97/// Represents the order of HTTP/2 pseudo-header fields in a header block.
98///
99/// This structure maintains an ordered list of pseudo-header fields (such as `:method`, `:scheme`, etc.)
100/// for use when encoding or decoding HTTP/2 header blocks. The order of pseudo-headers is significant
101/// according to the HTTP/2 specification, and this type ensures that the correct order is preserved
102/// and that no duplicates are present.
103///
104/// Typically, a `PseudoOrder` is constructed using the [`PseudoOrderBuilder`] to enforce uniqueness
105/// and protocol-compliant ordering.
106#[derive(Clone, Debug, PartialEq, Eq)]
107pub struct PseudoOrder {
108    ids: SmallVec<[PseudoId; PseudoId::DEFAULT_STACK_SIZE]>,
109}
110
111/// A builder for constructing a `PseudoOrder`.
112///
113/// This builder allows you to incrementally specify the order of pseudo-header fields for an HTTP/2
114/// header block. It ensures that each pseudo-header is only included once, and provides methods to
115/// push individual pseudo-headers or extend from an iterator. When finished, call `.build()` to
116/// obtain a `PseudoOrder` instance.
117#[derive(Debug)]
118pub struct PseudoOrderBuilder {
119    ids: SmallVec<[PseudoId; PseudoId::DEFAULT_STACK_SIZE]>,
120    mask: u8,
121}
122
123// ===== impl PseudoOrder =====
124
125impl PseudoOrder {
126    pub fn builder() -> PseudoOrderBuilder {
127        PseudoOrderBuilder {
128            ids: SmallVec::new(),
129            mask: 0,
130        }
131    }
132}
133
134impl Default for PseudoOrder {
135    fn default() -> Self {
136        PseudoOrder {
137            ids: SmallVec::from(PseudoId::DEFAULT_IDS),
138        }
139    }
140}
141
142impl<'a> IntoIterator for &'a PseudoOrder {
143    type Item = &'a PseudoId;
144    type IntoIter = std::slice::Iter<'a, PseudoId>;
145
146    fn into_iter(self) -> Self::IntoIter {
147        self.ids.iter()
148    }
149}
150
151// ===== impl PseudoOrderBuilder =====
152
153impl PseudoOrderBuilder {
154    pub fn push(mut self, id: PseudoId) -> Self {
155        let mask_id = id.mask_id();
156        if mask_id != 0 {
157            if self.mask & mask_id == 0 {
158                self.mask |= mask_id;
159                self.ids.push(id);
160            } else {
161                tracing::trace!("duplicate pseudo header: {:?}", id);
162            }
163        }
164        self
165    }
166
167    pub fn extend(mut self, iter: impl IntoIterator<Item = PseudoId>) -> Self {
168        for id in iter {
169            self = self.push(id);
170        }
171        self
172    }
173
174    pub fn build(mut self) -> PseudoOrder {
175        if self.ids.len() != PseudoId::DEFAULT_IDS.len() {
176            self = self.extend(PseudoId::DEFAULT_IDS);
177        }
178        PseudoOrder { ids: self.ids }
179    }
180}
181
182#[derive(Debug)]
183pub struct Iter {
184    /// Pseudo headers
185    pseudo: Option<Pseudo>,
186
187    /// Header fields (sorted by header_order if set)
188    fields: std::vec::IntoIter<(Option<HeaderName>, HeaderValue)>,
189
190    /// Pseudo header order
191    pseudo_order: PseudoOrder,
192}
193
194#[derive(Debug, PartialEq, Eq)]
195struct HeaderBlock {
196    /// The decoded header fields
197    fields: HeaderMap,
198
199    /// Precomputed size of all of our header fields, for perf reasons
200    field_size: usize,
201
202    /// Set to true if decoding went over the max header list size.
203    is_over_size: bool,
204
205    /// Pseudo headers, these are broken out as they must be sent as part of the
206    /// headers frame.
207    pseudo: Pseudo,
208
209    /// Optional ordering for regular headers (for browser fingerprinting).
210    /// When set, headers are encoded in this order instead of hash-based order.
211    header_order: Option<Vec<HeaderName>>,
212}
213
214#[derive(Debug)]
215struct EncodingHeaderBlock {
216    hpack: Bytes,
217}
218
219const END_STREAM: u8 = 0x1;
220const END_HEADERS: u8 = 0x4;
221const PADDED: u8 = 0x8;
222const PRIORITY: u8 = 0x20;
223const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
224
225// ===== impl Headers =====
226
227impl Headers {
228    /// Create a new HEADERS frame
229    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
230        Headers {
231            stream_id,
232            stream_dep: None,
233            header_block: HeaderBlock {
234                field_size: calculate_headermap_size(&fields),
235                fields,
236                is_over_size: false,
237                pseudo,
238                header_order: None,
239            },
240            flags: HeadersFlag::default(),
241        }
242    }
243
244    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
245        let mut flags = HeadersFlag::default();
246        flags.set_end_stream();
247
248        Headers {
249            stream_id,
250            stream_dep: None,
251            header_block: HeaderBlock {
252                field_size: calculate_headermap_size(&fields),
253                fields,
254                is_over_size: false,
255                pseudo: Pseudo::default(),
256                header_order: None,
257            },
258            flags,
259        }
260    }
261
262    /// Loads the header frame but doesn't actually do HPACK decoding.
263    ///
264    /// HPACK decoding is done in the `load_hpack` step.
265    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
266        let flags = HeadersFlag(head.flag());
267        let mut pad = 0;
268
269        tracing::trace!("loading headers; flags={:?}", flags);
270
271        if head.stream_id().is_zero() {
272            return Err(Error::InvalidStreamId);
273        }
274
275        // Read the padding length
276        if flags.is_padded() {
277            if src.is_empty() {
278                return Err(Error::MalformedMessage);
279            }
280            pad = src[0] as usize;
281
282            // Drop the padding
283            src.advance(1);
284        }
285
286        // Read the stream dependency
287        let stream_dep = if flags.is_priority() {
288            if src.len() < 5 {
289                return Err(Error::MalformedMessage);
290            }
291            let stream_dep = StreamDependency::load(&src[..5])?;
292
293            if stream_dep.dependency_id() == head.stream_id() {
294                return Err(Error::InvalidDependencyId);
295            }
296
297            // Drop the next 5 bytes
298            src.advance(5);
299
300            Some(stream_dep)
301        } else {
302            None
303        };
304
305        if pad > 0 {
306            if pad > src.len() {
307                return Err(Error::TooMuchPadding);
308            }
309
310            let len = src.len() - pad;
311            src.truncate(len);
312        }
313
314        let headers = Headers {
315            stream_id: head.stream_id(),
316            stream_dep,
317            header_block: HeaderBlock {
318                fields: HeaderMap::new(),
319                field_size: 0,
320                is_over_size: false,
321                pseudo: Pseudo::default(),
322                header_order: None,
323            },
324            flags,
325        };
326
327        Ok((headers, src))
328    }
329
330    pub fn load_hpack(
331        &mut self,
332        src: &mut BytesMut,
333        max_header_list_size: usize,
334        decoder: &mut hpack::Decoder,
335    ) -> Result<(), Error> {
336        self.header_block.load(src, max_header_list_size, decoder)
337    }
338
339    pub fn stream_id(&self) -> StreamId {
340        self.stream_id
341    }
342
343    pub fn is_end_headers(&self) -> bool {
344        self.flags.is_end_headers()
345    }
346
347    pub fn set_end_headers(&mut self) {
348        self.flags.set_end_headers();
349    }
350
351    pub fn is_end_stream(&self) -> bool {
352        self.flags.is_end_stream()
353    }
354
355    pub fn set_end_stream(&mut self) {
356        self.flags.set_end_stream()
357    }
358
359    pub fn set_priority(&mut self, dependency: StreamDependency) {
360        self.flags.set_priority();
361        self.stream_dep = Some(dependency);
362    }
363
364    pub fn set_header_order(&mut self, order: Vec<HeaderName>) {
365        self.header_block.header_order = Some(order);
366    }
367
368    pub fn is_over_size(&self) -> bool {
369        self.header_block.is_over_size
370    }
371
372    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
373        (self.header_block.pseudo, self.header_block.fields)
374    }
375
376    #[cfg(feature = "unstable")]
377    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
378        &mut self.header_block.pseudo
379    }
380
381    pub(crate) fn pseudo(&self) -> &Pseudo {
382        &self.header_block.pseudo
383    }
384
385    /// Whether it has status 1xx
386    pub(crate) fn is_informational(&self) -> bool {
387        self.header_block.pseudo.is_informational()
388    }
389
390    pub fn fields(&self) -> &HeaderMap {
391        &self.header_block.fields
392    }
393
394    pub fn into_fields(self) -> HeaderMap {
395        self.header_block.fields
396    }
397
398    pub fn encode(
399        self,
400        encoder: &mut hpack::Encoder,
401        dst: &mut EncodeBuf<'_>,
402    ) -> Option<Continuation> {
403        // At this point, the `is_end_headers` flag should always be set
404        debug_assert!(self.flags.is_end_headers());
405
406        // Get the HEADERS frame head
407        let head = self.head();
408
409        let stream_dep = self.stream_dep;
410        self.header_block
411            .into_encoding(encoder)
412            .encode(&head, dst, |dst| {
413                // Write stream dependency if PRIORITY flag is set
414                if let Some(ref dep) = stream_dep {
415                    dep.encode(dst);
416                }
417            })
418    }
419
420    fn head(&self) -> Head {
421        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
422    }
423}
424
425impl<T> From<Headers> for Frame<T> {
426    fn from(src: Headers) -> Self {
427        Frame::Headers(src)
428    }
429}
430
431impl fmt::Debug for Headers {
432    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
433        let mut builder = f.debug_struct("Headers");
434        builder
435            .field("stream_id", &self.stream_id)
436            .field("flags", &self.flags);
437
438        if let Some(ref protocol) = self.header_block.pseudo.protocol {
439            builder.field("protocol", protocol);
440        }
441
442        if let Some(ref dep) = self.stream_dep {
443            builder.field("stream_dep", dep);
444        }
445
446        // `fields` and `pseudo` purposefully not included
447        builder.finish()
448    }
449}
450
451// ===== util =====
452
453#[derive(Debug, PartialEq, Eq)]
454pub struct ParseU64Error;
455
456pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
457    if src.len() > 19 {
458        // At danger for overflow...
459        return Err(ParseU64Error);
460    }
461
462    let mut ret = 0;
463
464    for &d in src {
465        if d < b'0' || d > b'9' {
466            return Err(ParseU64Error);
467        }
468
469        ret *= 10;
470        ret += (d - b'0') as u64;
471    }
472
473    Ok(ret)
474}
475
476// ===== impl PushPromise =====
477
478#[derive(Debug)]
479pub enum PushPromiseHeaderError {
480    InvalidContentLength(Result<u64, ParseU64Error>),
481    NotSafeAndCacheable,
482}
483
484impl PushPromise {
485    pub fn new(
486        stream_id: StreamId,
487        promised_id: StreamId,
488        pseudo: Pseudo,
489        fields: HeaderMap,
490    ) -> Self {
491        PushPromise {
492            flags: PushPromiseFlag::default(),
493            header_block: HeaderBlock {
494                field_size: calculate_headermap_size(&fields),
495                fields,
496                is_over_size: false,
497                pseudo,
498                header_order: None,
499            },
500            promised_id,
501            stream_id,
502        }
503    }
504
505    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
506        use PushPromiseHeaderError::*;
507        // The spec has some requirements for promised request headers
508        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
509
510        // A promised request "that indicates the presence of a request body
511        // MUST reset the promised stream with a stream error"
512        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
513            let parsed_length = parse_u64(content_length.as_bytes());
514            if parsed_length != Ok(0) {
515                return Err(InvalidContentLength(parsed_length));
516            }
517        }
518        // "The server MUST include a method in the :method pseudo-header field
519        // that is safe and cacheable"
520        if !Self::safe_and_cacheable(req.method()) {
521            return Err(NotSafeAndCacheable);
522        }
523
524        Ok(())
525    }
526
527    fn safe_and_cacheable(method: &Method) -> bool {
528        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
529        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
530        method == Method::GET || method == Method::HEAD
531    }
532
533    pub fn fields(&self) -> &HeaderMap {
534        &self.header_block.fields
535    }
536
537    #[cfg(feature = "unstable")]
538    pub fn into_fields(self) -> HeaderMap {
539        self.header_block.fields
540    }
541
542    /// Loads the push promise frame but doesn't actually do HPACK decoding.
543    ///
544    /// HPACK decoding is done in the `load_hpack` step.
545    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
546        let flags = PushPromiseFlag(head.flag());
547        let mut pad = 0;
548
549        if head.stream_id().is_zero() {
550            return Err(Error::InvalidStreamId);
551        }
552
553        // Read the padding length
554        if flags.is_padded() {
555            if src.is_empty() {
556                return Err(Error::MalformedMessage);
557            }
558
559            // TODO: Ensure payload is sized correctly
560            pad = src[0] as usize;
561
562            // Drop the padding
563            src.advance(1);
564        }
565
566        if src.len() < 5 {
567            return Err(Error::MalformedMessage);
568        }
569
570        let (promised_id, _) = StreamId::parse(&src[..4]);
571        // Drop promised_id bytes
572        src.advance(4);
573
574        if pad > 0 {
575            if pad > src.len() {
576                return Err(Error::TooMuchPadding);
577            }
578
579            let len = src.len() - pad;
580            src.truncate(len);
581        }
582
583        let frame = PushPromise {
584            flags,
585            header_block: HeaderBlock {
586                fields: HeaderMap::new(),
587                field_size: 0,
588                is_over_size: false,
589                pseudo: Pseudo::default(),
590                header_order: None,
591            },
592            promised_id,
593            stream_id: head.stream_id(),
594        };
595        Ok((frame, src))
596    }
597
598    pub fn load_hpack(
599        &mut self,
600        src: &mut BytesMut,
601        max_header_list_size: usize,
602        decoder: &mut hpack::Decoder,
603    ) -> Result<(), Error> {
604        self.header_block.load(src, max_header_list_size, decoder)
605    }
606
607    pub fn stream_id(&self) -> StreamId {
608        self.stream_id
609    }
610
611    pub fn promised_id(&self) -> StreamId {
612        self.promised_id
613    }
614
615    pub fn is_end_headers(&self) -> bool {
616        self.flags.is_end_headers()
617    }
618
619    pub fn set_end_headers(&mut self) {
620        self.flags.set_end_headers();
621    }
622
623    pub fn is_over_size(&self) -> bool {
624        self.header_block.is_over_size
625    }
626
627    pub fn encode(
628        self,
629        encoder: &mut hpack::Encoder,
630        dst: &mut EncodeBuf<'_>,
631    ) -> Option<Continuation> {
632        // At this point, the `is_end_headers` flag should always be set
633        debug_assert!(self.flags.is_end_headers());
634
635        let head = self.head();
636        let promised_id = self.promised_id;
637
638        self.header_block
639            .into_encoding(encoder)
640            .encode(&head, dst, |dst| {
641                dst.put_u32(promised_id.into());
642            })
643    }
644
645    fn head(&self) -> Head {
646        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
647    }
648
649    /// Consume `self`, returning the parts of the frame
650    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
651        (self.header_block.pseudo, self.header_block.fields)
652    }
653}
654
655impl<T> From<PushPromise> for Frame<T> {
656    fn from(src: PushPromise) -> Self {
657        Frame::PushPromise(src)
658    }
659}
660
661impl fmt::Debug for PushPromise {
662    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
663        f.debug_struct("PushPromise")
664            .field("stream_id", &self.stream_id)
665            .field("promised_id", &self.promised_id)
666            .field("flags", &self.flags)
667            // `fields` and `pseudo` purposefully not included
668            .finish()
669    }
670}
671
672// ===== impl Continuation =====
673
674impl Continuation {
675    fn head(&self) -> Head {
676        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
677    }
678
679    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
680        // Get the CONTINUATION frame head
681        let head = self.head();
682
683        self.header_block.encode(&head, dst, |_| {})
684    }
685}
686
687// ===== impl Pseudo =====
688
689impl Pseudo {
690    pub fn request(
691        method: Method,
692        uri: Uri,
693        protocol: Option<Protocol>,
694        order: Option<PseudoOrder>,
695    ) -> Self {
696        let parts = uri::Parts::from(uri);
697
698        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
699            (None, None)
700        } else {
701            let path = parts
702                .path_and_query
703                .map(|v| BytesStr::from(v.as_str()))
704                .unwrap_or(BytesStr::from_static(""));
705
706            let path = if !path.is_empty() {
707                path
708            } else if method == Method::OPTIONS {
709                BytesStr::from_static("*")
710            } else {
711                BytesStr::from_static("/")
712            };
713
714            (parts.scheme, Some(path))
715        };
716
717        let mut pseudo = Pseudo {
718            method: Some(method),
719            scheme: None,
720            authority: None,
721            path,
722            protocol,
723            status: None,
724            order: order.unwrap_or_default(),
725        };
726
727        // If the URI includes a scheme component, add it to the pseudo headers
728        if let Some(scheme) = scheme {
729            pseudo.set_scheme(scheme);
730        }
731
732        // If the URI includes an authority component, add it to the pseudo
733        // headers
734        if let Some(authority) = parts.authority {
735            pseudo.set_authority(BytesStr::from(authority.as_str()));
736        }
737
738        pseudo
739    }
740
741    pub fn response(status: StatusCode) -> Self {
742        Pseudo {
743            method: None,
744            scheme: None,
745            authority: None,
746            path: None,
747            protocol: None,
748            status: Some(status),
749            order: PseudoOrder::default(),
750        }
751    }
752
753    #[cfg(feature = "unstable")]
754    pub fn set_status(&mut self, value: StatusCode) {
755        self.status = Some(value);
756    }
757
758    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
759        let bytes_str = match scheme.as_str() {
760            "http" => BytesStr::from_static("http"),
761            "https" => BytesStr::from_static("https"),
762            s => BytesStr::from(s),
763        };
764        self.scheme = Some(bytes_str);
765    }
766
767    #[cfg(feature = "unstable")]
768    pub fn set_protocol(&mut self, protocol: Protocol) {
769        self.protocol = Some(protocol);
770    }
771
772    pub fn set_authority(&mut self, authority: BytesStr) {
773        self.authority = Some(authority);
774    }
775
776    /// Whether it has status 1xx
777    pub(crate) fn is_informational(&self) -> bool {
778        self.status.is_some_and(|status| status.is_informational())
779    }
780}
781
782// ===== impl EncodingHeaderBlock =====
783
784impl EncodingHeaderBlock {
785    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
786    where
787        F: FnOnce(&mut EncodeBuf<'_>),
788    {
789        let head_pos = dst.get_ref().len();
790
791        // At this point, we don't know how big the h2 frame will be.
792        // So, we write the head with length 0, then write the body, and
793        // finally write the length once we know the size.
794        head.encode(0, dst);
795
796        let payload_pos = dst.get_ref().len();
797
798        f(dst);
799
800        // Now, encode the header payload
801        let continuation = if self.hpack.len() > dst.remaining_mut() {
802            dst.put((&mut self.hpack).take(dst.remaining_mut()));
803
804            Some(Continuation {
805                stream_id: head.stream_id(),
806                header_block: self,
807            })
808        } else {
809            dst.put_slice(&self.hpack);
810
811            None
812        };
813
814        // Compute the header block length
815        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
816
817        // Write the frame length
818        let payload_len_be = payload_len.to_be_bytes();
819        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
820        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
821
822        if continuation.is_some() {
823            // There will be continuation frames, so the `is_end_headers` flag
824            // must be unset
825            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
826
827            dst.get_mut()[head_pos + 4] -= END_HEADERS;
828        }
829
830        continuation
831    }
832}
833
834// ===== impl Iter =====
835
836impl Iterator for Iter {
837    type Item = hpack::Header<Option<HeaderName>>;
838
839    fn next(&mut self) -> Option<Self::Item> {
840        use crate::hpack::Header::*;
841
842        if let Some(ref mut pseudo) = self.pseudo {
843            // Iterate through the configured order
844            for id in &self.pseudo_order.ids {
845                match id {
846                    PseudoId::Method => {
847                        if let Some(method) = pseudo.method.take() {
848                            return Some(Method(method));
849                        }
850                    }
851                    PseudoId::Scheme => {
852                        if let Some(scheme) = pseudo.scheme.take() {
853                            return Some(Scheme(scheme));
854                        }
855                    }
856                    PseudoId::Authority => {
857                        if let Some(authority) = pseudo.authority.take() {
858                            return Some(Authority(authority));
859                        }
860                    }
861                    PseudoId::Path => {
862                        if let Some(path) = pseudo.path.take() {
863                            return Some(Path(path));
864                        }
865                    }
866                    PseudoId::Protocol => {
867                        if let Some(protocol) = pseudo.protocol.take() {
868                            return Some(Protocol(protocol));
869                        }
870                    }
871                    PseudoId::Status => {
872                        if let Some(status) = pseudo.status.take() {
873                            return Some(Status(status));
874                        }
875                    }
876                }
877            }
878
879            // All pseudo headers consumed
880            self.pseudo = None;
881        }
882
883        self.fields
884            .next()
885            .map(|(name, value)| Field { name, value })
886    }
887}
888
889// ===== impl HeadersFlag =====
890
891impl HeadersFlag {
892    pub fn empty() -> HeadersFlag {
893        HeadersFlag(0)
894    }
895
896    pub fn load(bits: u8) -> HeadersFlag {
897        HeadersFlag(bits & ALL)
898    }
899
900    pub fn is_end_stream(&self) -> bool {
901        self.0 & END_STREAM == END_STREAM
902    }
903
904    pub fn set_end_stream(&mut self) {
905        self.0 |= END_STREAM;
906    }
907
908    pub fn is_end_headers(&self) -> bool {
909        self.0 & END_HEADERS == END_HEADERS
910    }
911
912    pub fn set_end_headers(&mut self) {
913        self.0 |= END_HEADERS;
914    }
915
916    pub fn is_padded(&self) -> bool {
917        self.0 & PADDED == PADDED
918    }
919
920    pub fn is_priority(&self) -> bool {
921        self.0 & PRIORITY == PRIORITY
922    }
923
924    pub fn set_priority(&mut self) {
925        self.0 |= PRIORITY;
926    }
927}
928
929impl Default for HeadersFlag {
930    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
931    fn default() -> Self {
932        HeadersFlag(END_HEADERS)
933    }
934}
935
936impl From<HeadersFlag> for u8 {
937    fn from(src: HeadersFlag) -> u8 {
938        src.0
939    }
940}
941
942impl fmt::Debug for HeadersFlag {
943    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
944        util::debug_flags(fmt, self.0)
945            .flag_if(self.is_end_headers(), "END_HEADERS")
946            .flag_if(self.is_end_stream(), "END_STREAM")
947            .flag_if(self.is_padded(), "PADDED")
948            .flag_if(self.is_priority(), "PRIORITY")
949            .finish()
950    }
951}
952
953// ===== impl PushPromiseFlag =====
954
955impl PushPromiseFlag {
956    pub fn empty() -> PushPromiseFlag {
957        PushPromiseFlag(0)
958    }
959
960    pub fn load(bits: u8) -> PushPromiseFlag {
961        PushPromiseFlag(bits & ALL)
962    }
963
964    pub fn is_end_headers(&self) -> bool {
965        self.0 & END_HEADERS == END_HEADERS
966    }
967
968    pub fn set_end_headers(&mut self) {
969        self.0 |= END_HEADERS;
970    }
971
972    pub fn is_padded(&self) -> bool {
973        self.0 & PADDED == PADDED
974    }
975}
976
977impl Default for PushPromiseFlag {
978    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
979    fn default() -> Self {
980        PushPromiseFlag(END_HEADERS)
981    }
982}
983
984impl From<PushPromiseFlag> for u8 {
985    fn from(src: PushPromiseFlag) -> u8 {
986        src.0
987    }
988}
989
990impl fmt::Debug for PushPromiseFlag {
991    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
992        util::debug_flags(fmt, self.0)
993            .flag_if(self.is_end_headers(), "END_HEADERS")
994            .flag_if(self.is_padded(), "PADDED")
995            .finish()
996    }
997}
998
999// ===== HeaderBlock =====
1000
1001impl HeaderBlock {
1002    fn load(
1003        &mut self,
1004        src: &mut BytesMut,
1005        max_header_list_size: usize,
1006        decoder: &mut hpack::Decoder,
1007    ) -> Result<(), Error> {
1008        let mut reg = !self.fields.is_empty();
1009        let mut malformed = false;
1010        let mut headers_size = self.calculate_header_list_size();
1011
1012        macro_rules! set_pseudo {
1013            ($field:ident, $val:expr) => {{
1014                if reg {
1015                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
1016                    malformed = true;
1017                } else if self.pseudo.$field.is_some() {
1018                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
1019                    malformed = true;
1020                } else {
1021                    let __val = $val;
1022                    headers_size +=
1023                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
1024                    if headers_size < max_header_list_size {
1025                        self.pseudo.$field = Some(__val);
1026                    } else if !self.is_over_size {
1027                        tracing::trace!("load_hpack; header list size over max");
1028                        self.is_over_size = true;
1029                    }
1030                }
1031            }};
1032        }
1033
1034        let mut cursor = Cursor::new(src);
1035
1036        // If the header frame is malformed, we still have to continue decoding
1037        // the headers. A malformed header frame is a stream level error, but
1038        // the hpack state is connection level. In order to maintain correct
1039        // state for other streams, the hpack decoding process must complete.
1040        let res = decoder.decode(&mut cursor, |header| {
1041            use crate::hpack::Header::*;
1042
1043            match header {
1044                Field { name, value } => {
1045                    // Connection level header fields are not supported and must
1046                    // result in a protocol error.
1047
1048                    if name == header::CONNECTION
1049                        || name == header::TRANSFER_ENCODING
1050                        || name == header::UPGRADE
1051                        || name == "keep-alive"
1052                        || name == "proxy-connection"
1053                    {
1054                        tracing::trace!("load_hpack; connection level header");
1055                        malformed = true;
1056                    } else if name == header::TE && value != "trailers" {
1057                        tracing::trace!(
1058                            "load_hpack; TE header not set to trailers; val={:?}",
1059                            value
1060                        );
1061                        malformed = true;
1062                    } else {
1063                        reg = true;
1064
1065                        headers_size += decoded_header_size(name.as_str().len(), value.len());
1066                        if headers_size < max_header_list_size {
1067                            self.field_size +=
1068                                decoded_header_size(name.as_str().len(), value.len());
1069                            self.fields.append(name, value);
1070                        } else if !self.is_over_size {
1071                            tracing::trace!("load_hpack; header list size over max");
1072                            self.is_over_size = true;
1073                        }
1074                    }
1075                }
1076                Authority(v) => set_pseudo!(authority, v),
1077                Method(v) => set_pseudo!(method, v),
1078                Scheme(v) => set_pseudo!(scheme, v),
1079                Path(v) => set_pseudo!(path, v),
1080                Protocol(v) => set_pseudo!(protocol, v),
1081                Status(v) => set_pseudo!(status, v),
1082            }
1083        });
1084
1085        if let Err(e) = res {
1086            tracing::trace!("hpack decoding error; err={:?}", e);
1087            return Err(e.into());
1088        }
1089
1090        if malformed {
1091            tracing::trace!("malformed message");
1092            return Err(Error::MalformedMessage);
1093        }
1094
1095        Ok(())
1096    }
1097
1098    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
1099        let mut hpack = BytesMut::new();
1100        let pseudo_order = self.pseudo.order.clone();
1101
1102        // Collect headers from HeaderMap into a Vec for ordered iteration
1103        let mut fields_vec: Vec<(Option<HeaderName>, HeaderValue)> =
1104            self.fields.into_iter().collect();
1105
1106        // Sort by header_order if specified (for browser fingerprinting)
1107        if let Some(ref order) = self.header_order {
1108            let order_map: std::collections::HashMap<&HeaderName, usize> = order
1109                .iter()
1110                .enumerate()
1111                .map(|(i, name)| (name, i))
1112                .collect();
1113            let default_pos = order.len();
1114            fields_vec.sort_by_key(|(name, _)| {
1115                name.as_ref()
1116                    .and_then(|n| order_map.get(n).copied())
1117                    .unwrap_or(default_pos)
1118            });
1119        }
1120
1121        let headers = Iter {
1122            pseudo: Some(self.pseudo),
1123            fields: fields_vec.into_iter(),
1124            pseudo_order,
1125        };
1126
1127        encoder.encode(headers, &mut hpack);
1128
1129        EncodingHeaderBlock {
1130            hpack: hpack.freeze(),
1131        }
1132    }
1133
1134    /// Calculates the size of the currently decoded header list.
1135    ///
1136    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
1137    ///
1138    /// > The value is based on the uncompressed size of header fields,
1139    /// > including the length of the name and value in octets plus an
1140    /// > overhead of 32 octets for each header field.
1141    fn calculate_header_list_size(&self) -> usize {
1142        macro_rules! pseudo_size {
1143            ($name:ident) => {{
1144                self.pseudo
1145                    .$name
1146                    .as_ref()
1147                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
1148                    .unwrap_or(0)
1149            }};
1150        }
1151
1152        pseudo_size!(method)
1153            + pseudo_size!(scheme)
1154            + pseudo_size!(status)
1155            + pseudo_size!(authority)
1156            + pseudo_size!(path)
1157            + self.field_size
1158    }
1159}
1160
1161fn calculate_headermap_size(map: &HeaderMap) -> usize {
1162    map.iter()
1163        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
1164        .sum::<usize>()
1165}
1166
1167fn decoded_header_size(name: usize, value: usize) -> usize {
1168    name + value + 32
1169}
1170
1171#[cfg(test)]
1172mod test {
1173    use super::*;
1174    use crate::frame;
1175    use crate::hpack::{huffman, Encoder};
1176
1177    #[test]
1178    fn test_nameless_header_at_resume() {
1179        let mut encoder = Encoder::default();
1180        let mut dst = BytesMut::new();
1181
1182        let headers = Headers::new(
1183            StreamId::ZERO,
1184            Default::default(),
1185            HeaderMap::from_iter(vec![
1186                (
1187                    HeaderName::from_static("hello"),
1188                    HeaderValue::from_static("world"),
1189                ),
1190                (
1191                    HeaderName::from_static("hello"),
1192                    HeaderValue::from_static("zomg"),
1193                ),
1194                (
1195                    HeaderName::from_static("hello"),
1196                    HeaderValue::from_static("sup"),
1197                ),
1198            ]),
1199        );
1200
1201        let continuation = headers
1202            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1203            .unwrap();
1204
1205        assert_eq!(17, dst.len());
1206        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1207        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1208        assert_eq!("hello", huff_decode(&dst[11..15]));
1209        assert_eq!(0x80 | 4, dst[15]);
1210
1211        let mut world = dst[16..17].to_owned();
1212
1213        dst.clear();
1214
1215        assert!(continuation
1216            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1217            .is_none());
1218
1219        world.extend_from_slice(&dst[9..12]);
1220        assert_eq!("world", huff_decode(&world));
1221
1222        assert_eq!(24, dst.len());
1223        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1224
1225        // // Next is not indexed
1226        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1227        assert_eq!("zomg", huff_decode(&dst[15..18]));
1228        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1229        assert_eq!("sup", huff_decode(&dst[21..]));
1230    }
1231
1232    fn huff_decode(src: &[u8]) -> BytesMut {
1233        let mut buf = BytesMut::new();
1234        huffman::decode(src, &mut buf).unwrap()
1235    }
1236
1237    #[test]
1238    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1239        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1240        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1241
1242        assert_eq!(
1243            Pseudo::request(
1244                Method::CONNECT,
1245                Uri::from_static("https://example.com:8443"),
1246                None,
1247                None,
1248            ),
1249            Pseudo {
1250                method: Method::CONNECT.into(),
1251                authority: BytesStr::from_static("example.com:8443").into(),
1252                ..Default::default()
1253            }
1254        );
1255
1256        assert_eq!(
1257            Pseudo::request(
1258                Method::CONNECT,
1259                Uri::from_static("https://example.com/test"),
1260                None,
1261                None,
1262            ),
1263            Pseudo {
1264                method: Method::CONNECT.into(),
1265                authority: BytesStr::from_static("example.com").into(),
1266                ..Default::default()
1267            }
1268        );
1269
1270        assert_eq!(
1271            Pseudo::request(
1272                Method::CONNECT,
1273                Uri::from_static("example.com:8443"),
1274                None,
1275                None
1276            ),
1277            Pseudo {
1278                method: Method::CONNECT.into(),
1279                authority: BytesStr::from_static("example.com:8443").into(),
1280                ..Default::default()
1281            }
1282        );
1283    }
1284
1285    #[test]
1286    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1287        // On requests that contain the :protocol pseudo-header field, the
1288        // :scheme and :path pseudo-header fields of the target URI (see
1289        // Section 5) MUST also be included.
1290        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1291
1292        assert_eq!(
1293            Pseudo::request(
1294                Method::CONNECT,
1295                Uri::from_static("https://example.com:8443"),
1296                Protocol::from_static("the-bread-protocol").into(),
1297                None,
1298            ),
1299            Pseudo {
1300                method: Method::CONNECT.into(),
1301                authority: BytesStr::from_static("example.com:8443").into(),
1302                scheme: BytesStr::from_static("https").into(),
1303                path: BytesStr::from_static("/").into(),
1304                protocol: Protocol::from_static("the-bread-protocol").into(),
1305                ..Default::default()
1306            }
1307        );
1308
1309        assert_eq!(
1310            Pseudo::request(
1311                Method::CONNECT,
1312                Uri::from_static("https://example.com:8443/test"),
1313                Protocol::from_static("the-bread-protocol").into(),
1314                None,
1315            ),
1316            Pseudo {
1317                method: Method::CONNECT.into(),
1318                authority: BytesStr::from_static("example.com:8443").into(),
1319                scheme: BytesStr::from_static("https").into(),
1320                path: BytesStr::from_static("/test").into(),
1321                protocol: Protocol::from_static("the-bread-protocol").into(),
1322                ..Default::default()
1323            }
1324        );
1325
1326        assert_eq!(
1327            Pseudo::request(
1328                Method::CONNECT,
1329                Uri::from_static("http://example.com/a/b/c"),
1330                Protocol::from_static("the-bread-protocol").into(),
1331                None,
1332            ),
1333            Pseudo {
1334                method: Method::CONNECT.into(),
1335                authority: BytesStr::from_static("example.com").into(),
1336                scheme: BytesStr::from_static("http").into(),
1337                path: BytesStr::from_static("/a/b/c").into(),
1338                protocol: Protocol::from_static("the-bread-protocol").into(),
1339                ..Default::default()
1340            }
1341        );
1342    }
1343
1344    #[test]
1345    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1346        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1347        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1348        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1349        assert_eq!(
1350            Pseudo::request(
1351                Method::OPTIONS,
1352                Uri::from_static("example.com:8080"),
1353                None,
1354                None
1355            ),
1356            Pseudo {
1357                method: Method::OPTIONS.into(),
1358                authority: BytesStr::from_static("example.com:8080").into(),
1359                path: BytesStr::from_static("*").into(),
1360                ..Default::default()
1361            }
1362        );
1363    }
1364
1365    #[test]
1366    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1367        let methods = [
1368            Method::GET,
1369            Method::POST,
1370            Method::PUT,
1371            Method::DELETE,
1372            Method::HEAD,
1373            Method::PATCH,
1374            Method::TRACE,
1375        ];
1376
1377        for method in methods {
1378            assert_eq!(
1379                Pseudo::request(
1380                    method.clone(),
1381                    Uri::from_static("http://example.com:8080"),
1382                    None,
1383                    None,
1384                ),
1385                Pseudo {
1386                    method: method.clone().into(),
1387                    authority: BytesStr::from_static("example.com:8080").into(),
1388                    scheme: BytesStr::from_static("http").into(),
1389                    path: BytesStr::from_static("/").into(),
1390                    ..Default::default()
1391                }
1392            );
1393            assert_eq!(
1394                Pseudo::request(
1395                    method.clone(),
1396                    Uri::from_static("https://example.com/a/b/c"),
1397                    None,
1398                    None,
1399                ),
1400                Pseudo {
1401                    method: method.into(),
1402                    authority: BytesStr::from_static("example.com").into(),
1403                    scheme: BytesStr::from_static("https").into(),
1404                    path: BytesStr::from_static("/a/b/c").into(),
1405                    ..Default::default()
1406                }
1407            );
1408        }
1409    }
1410}