Skip to main content

ntex_h2/frame/
headers.rs

1use std::{cell::RefCell, cmp, fmt, io::Cursor};
2
3use ntex_bytes::{ByteString, Bytes, BytesMut};
4use ntex_http::{HeaderMap, HeaderName, Method, StatusCode, Uri, header, uri};
5
6use crate::hpack;
7
8use super::priority::StreamDependency;
9use super::{Frame, FrameError, Head, Kind, Protocol, StreamId, util};
10
11/// Header frame
12///
13/// This could be either a request or a response.
14#[derive(Clone)]
15pub struct Headers {
16    /// The ID of the stream with which this frame is associated.
17    stream_id: StreamId,
18
19    /// The header block fragment
20    header_block: HeaderBlock,
21
22    /// The associated flags
23    flags: HeadersFlag,
24}
25
26#[derive(Copy, Clone, Eq, PartialEq)]
27pub struct HeadersFlag(u8);
28
29#[derive(Clone, Debug, Default)]
30pub struct PseudoHeaders {
31    // Request
32    pub method: Option<Method>,
33    pub scheme: Option<ByteString>,
34    pub authority: Option<ByteString>,
35    pub path: Option<ByteString>,
36    pub protocol: Option<Protocol>,
37
38    // Response
39    pub status: Option<StatusCode>,
40}
41
42pub(super) struct Iter<'a> {
43    /// Pseudo headers
44    pseudo: Option<PseudoHeaders>,
45
46    /// Header fields
47    fields: header::Iter<'a>,
48}
49
50#[derive(Debug, Clone)]
51struct HeaderBlock {
52    /// The decoded header fields
53    fields: HeaderMap,
54
55    /// Pseudo headers, these are broken out as they must be sent as part of the
56    /// headers frame.
57    pseudo: PseudoHeaders,
58}
59
60const END_STREAM: u8 = 0x1;
61const END_HEADERS: u8 = 0x4;
62const PADDED: u8 = 0x8;
63const PRIORITY: u8 = 0x20;
64const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
65
66// ===== impl Headers =====
67
68impl Headers {
69    /// Create a new HEADERS frame
70    pub fn new(stream_id: StreamId, pseudo: PseudoHeaders, fields: HeaderMap, eof: bool) -> Self {
71        let mut flags = HeadersFlag::default();
72        if eof {
73            flags.set_end_stream();
74        }
75        Headers {
76            flags,
77            stream_id,
78            header_block: HeaderBlock { fields, pseudo },
79        }
80    }
81
82    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
83        let mut flags = HeadersFlag::default();
84        flags.set_end_stream();
85
86        Headers {
87            stream_id,
88            flags,
89            header_block: HeaderBlock {
90                fields,
91                pseudo: PseudoHeaders::default(),
92            },
93        }
94    }
95
96    /// Loads the header frame but doesn't actually do HPACK decoding.
97    ///
98    /// HPACK decoding is done in the `load_hpack` step.
99    pub fn load(head: Head, src: &mut Bytes) -> Result<Self, FrameError> {
100        let flags = HeadersFlag(head.flag());
101
102        if head.stream_id().is_zero() {
103            return Err(FrameError::InvalidStreamId);
104        }
105
106        // Read the padding length
107        let pad = if flags.is_padded() {
108            if src.is_empty() {
109                return Err(FrameError::MalformedMessage);
110            }
111            let pad = src[0] as usize;
112
113            // Drop the padding
114            src.advance_to(1);
115            pad
116        } else {
117            0
118        };
119
120        // Read the stream dependency
121        if flags.is_priority() {
122            if src.len() < 5 {
123                return Err(FrameError::MalformedMessage);
124            }
125            let stream_dep = StreamDependency::load(&src[..5])?;
126
127            if stream_dep.dependency_id() == head.stream_id() {
128                return Err(FrameError::InvalidDependencyId);
129            }
130
131            // Drop the next 5 bytes
132            src.advance_to(5);
133        }
134
135        if pad > 0 {
136            if pad > src.len() {
137                return Err(FrameError::TooMuchPadding);
138            }
139            src.truncate(src.len() - pad);
140        }
141
142        Ok(Headers {
143            flags,
144            stream_id: head.stream_id(),
145            header_block: HeaderBlock {
146                fields: HeaderMap::new(),
147                pseudo: PseudoHeaders::default(),
148            },
149        })
150    }
151
152    pub fn load_hpack(
153        &mut self,
154        src: &mut Bytes,
155        decoder: &mut hpack::Decoder,
156    ) -> Result<(), FrameError> {
157        self.header_block.load(src, decoder)
158    }
159
160    pub fn stream_id(&self) -> StreamId {
161        self.stream_id
162    }
163
164    pub fn is_end_headers(&self) -> bool {
165        self.flags.is_end_headers()
166    }
167
168    pub fn set_end_headers(&mut self) {
169        self.flags.set_end_headers();
170    }
171
172    pub fn is_end_stream(&self) -> bool {
173        self.flags.is_end_stream()
174    }
175
176    pub fn set_end_stream(&mut self) {
177        self.flags.set_end_stream();
178    }
179
180    pub fn into_parts(self) -> (PseudoHeaders, HeaderMap) {
181        (self.header_block.pseudo, self.header_block.fields)
182    }
183
184    pub fn fields(&self) -> &HeaderMap {
185        &self.header_block.fields
186    }
187
188    pub fn pseudo(&self) -> &PseudoHeaders {
189        &self.header_block.pseudo
190    }
191
192    pub fn into_fields(self) -> HeaderMap {
193        self.header_block.fields
194    }
195
196    pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut, max_size: usize) {
197        // At this point, the `is_end_headers` flag should always be set
198        debug_assert!(self.flags.is_end_headers());
199
200        // Get the HEADERS frame head
201        let head = self.head();
202
203        self.header_block.encode(encoder, head, dst, max_size);
204    }
205
206    fn head(&self) -> Head {
207        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
208    }
209}
210
211impl From<Headers> for Frame {
212    fn from(src: Headers) -> Self {
213        Frame::Headers(src)
214    }
215}
216
217impl fmt::Debug for Headers {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        let mut builder = f.debug_struct("Headers");
220        builder
221            .field("stream_id", &self.stream_id)
222            .field("flags", &self.flags)
223            .field("pseudo", &self.header_block.pseudo);
224
225        if let Some(ref protocol) = self.header_block.pseudo.protocol {
226            builder.field("protocol", protocol);
227        }
228
229        // `fields` and `pseudo` purposefully not included
230        builder.finish()
231    }
232}
233
234// ===== impl Pseudo =====
235
236impl PseudoHeaders {
237    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
238        let parts = uri::Parts::from(uri);
239
240        let mut path = parts
241            .path_and_query
242            .map_or(ByteString::from_static(""), |v| {
243                ByteString::from(v.as_str())
244            });
245
246        match method {
247            Method::OPTIONS | Method::CONNECT => {}
248            _ if path.is_empty() => {
249                path = ByteString::from_static("/");
250            }
251            _ => {}
252        }
253
254        let mut pseudo = PseudoHeaders {
255            method: Some(method),
256            scheme: None,
257            authority: None,
258            path: Some(path).filter(|p| !p.is_empty()),
259            protocol,
260            status: None,
261        };
262
263        // If the URI includes a scheme component, add it to the pseudo headers
264        //
265        // TODO: Scheme must be set...
266        if let Some(ref scheme) = parts.scheme {
267            pseudo.set_scheme(scheme);
268        }
269
270        // If the URI includes an authority component, add it to the pseudo
271        // headers
272        if let Some(authority) = parts.authority {
273            pseudo.set_authority(ByteString::from(authority.as_str()));
274        }
275
276        pseudo
277    }
278
279    pub fn response(status: StatusCode) -> Self {
280        PseudoHeaders {
281            method: None,
282            scheme: None,
283            authority: None,
284            path: None,
285            protocol: None,
286            status: Some(status),
287        }
288    }
289
290    pub fn set_status(&mut self, value: StatusCode) {
291        self.status = Some(value);
292    }
293
294    pub fn set_scheme(&mut self, scheme: &uri::Scheme) {
295        self.scheme = Some(match scheme.as_str() {
296            "http" => ByteString::from_static("http"),
297            "https" => ByteString::from_static("https"),
298            s => ByteString::from(s),
299        });
300    }
301
302    pub fn set_protocol(&mut self, protocol: Protocol) {
303        self.protocol = Some(protocol);
304    }
305
306    pub fn set_authority(&mut self, authority: ByteString) {
307        self.authority = Some(authority);
308    }
309}
310
311// ===== impl Iter =====
312
313impl Iterator for Iter<'_> {
314    type Item = hpack::Header<Option<HeaderName>>;
315
316    fn next(&mut self) -> Option<Self::Item> {
317        use crate::hpack::Header;
318
319        if let Some(ref mut pseudo) = self.pseudo {
320            if let Some(method) = pseudo.method.take() {
321                return Some(Header::Method(method));
322            }
323
324            if let Some(scheme) = pseudo.scheme.take() {
325                return Some(Header::Scheme(scheme));
326            }
327
328            if let Some(authority) = pseudo.authority.take() {
329                return Some(Header::Authority(authority));
330            }
331
332            if let Some(path) = pseudo.path.take() {
333                return Some(Header::Path(path));
334            }
335
336            if let Some(protocol) = pseudo.protocol.take() {
337                return Some(Header::Protocol(protocol.into()));
338            }
339
340            if let Some(status) = pseudo.status.take() {
341                return Some(Header::Status(status));
342            }
343        }
344
345        self.pseudo = None;
346
347        self.fields.next().map(|(name, value)| Header::Field {
348            name: Some(name.clone()),
349            value: value.clone(),
350        })
351    }
352}
353
354// ===== impl HeadersFlag =====
355
356impl HeadersFlag {
357    pub fn empty() -> HeadersFlag {
358        HeadersFlag(0)
359    }
360
361    pub fn load(bits: u8) -> HeadersFlag {
362        HeadersFlag(bits & ALL)
363    }
364
365    pub fn is_end_stream(self) -> bool {
366        self.0 & END_STREAM == END_STREAM
367    }
368
369    pub fn set_end_stream(&mut self) {
370        self.0 |= END_STREAM;
371    }
372
373    pub fn is_end_headers(self) -> bool {
374        self.0 & END_HEADERS == END_HEADERS
375    }
376
377    pub fn set_end_headers(&mut self) {
378        self.0 |= END_HEADERS;
379    }
380
381    pub fn is_padded(self) -> bool {
382        self.0 & PADDED == PADDED
383    }
384
385    pub fn is_priority(self) -> bool {
386        self.0 & PRIORITY == PRIORITY
387    }
388}
389
390impl Default for HeadersFlag {
391    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
392    fn default() -> Self {
393        HeadersFlag(END_HEADERS)
394    }
395}
396
397impl From<HeadersFlag> for u8 {
398    fn from(src: HeadersFlag) -> u8 {
399        src.0
400    }
401}
402
403impl fmt::Debug for HeadersFlag {
404    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
405        util::debug_flags(fmt, self.0)
406            .flag_if(self.is_end_headers(), "END_HEADERS")
407            .flag_if(self.is_end_stream(), "END_STREAM")
408            .flag_if(self.is_padded(), "PADDED")
409            .flag_if(self.is_priority(), "PRIORITY")
410            .finish()
411    }
412}
413
414// ===== HeaderBlock =====
415
416thread_local! {
417    static HDRS_BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(1024));
418}
419
420impl HeaderBlock {
421    fn load(&mut self, src: &mut Bytes, decoder: &mut hpack::Decoder) -> Result<(), FrameError> {
422        let mut reg = !self.fields.is_empty();
423        let mut malformed = false;
424
425        macro_rules! set_pseudo {
426            ($field:ident, $val:expr) => {{
427                if reg {
428                    log::trace!("load_hpack; header malformed -- pseudo not at head of block");
429                    malformed = true;
430                } else if self.pseudo.$field.is_some() {
431                    log::trace!("load_hpack; header malformed -- repeated pseudo");
432                    malformed = true;
433                } else {
434                    self.pseudo.$field = Some($val.into());
435                }
436            }};
437        }
438
439        let mut cursor = Cursor::new(src);
440
441        // If the header frame is malformed, we still have to continue decoding
442        // the headers. A malformed header frame is a stream level error, but
443        // the hpack state is connection level. In order to maintain correct
444        // state for other streams, the hpack decoding process must complete.
445        let res = decoder.decode(&mut cursor, |header| {
446            use crate::hpack::Header;
447
448            match header {
449                Header::Field { name, value } => {
450                    // Connection level header fields are not supported and must
451                    // result in a protocol error.
452
453                    if name == header::CONNECTION
454                        || name == header::TRANSFER_ENCODING
455                        || name == header::UPGRADE
456                        || name == "keep-alive"
457                        || name == "proxy-connection"
458                    {
459                        log::trace!("load_hpack; connection level header");
460                        malformed = true;
461                    } else if name == header::TE && value != "trailers" {
462                        log::trace!("load_hpack; TE header not set to trailers; val={value:?}");
463                        malformed = true;
464                    } else {
465                        reg = true;
466                        self.fields.append(name, value);
467                    }
468                }
469                Header::Authority(v) => {
470                    set_pseudo!(authority, v);
471                }
472                Header::Method(v) => {
473                    set_pseudo!(method, v);
474                }
475                Header::Scheme(v) => {
476                    set_pseudo!(scheme, v);
477                }
478                Header::Path(v) => {
479                    set_pseudo!(path, v);
480                }
481                Header::Protocol(v) => {
482                    set_pseudo!(protocol, v);
483                }
484                Header::Status(v) => {
485                    set_pseudo!(status, v);
486                }
487            }
488        });
489
490        if let Err(e) = res {
491            log::trace!("hpack decoding error; err={e:?}");
492            return Err(e.into());
493        }
494
495        if malformed {
496            log::trace!("malformed message");
497            return Err(FrameError::MalformedMessage);
498        }
499
500        Ok(())
501    }
502
503    fn encode(self, encoder: &mut hpack::Encoder, head: Head, dst: &mut BytesMut, max_size: usize) {
504        HDRS_BUF.with(|buf| {
505            let mut b = buf.borrow_mut();
506            let hpack = &mut b;
507            hpack.clear();
508
509            // encode hpack
510            let headers = Iter {
511                pseudo: Some(self.pseudo),
512                fields: self.fields.into_iter(),
513            };
514            encoder.encode(headers, hpack);
515
516            let mut head = head;
517            let mut start = 0;
518            loop {
519                let end = cmp::min(start + max_size, hpack.len());
520
521                // encode the header payload
522                if hpack.len() > end {
523                    Head::new(head.kind(), head.flag() ^ END_HEADERS, head.stream_id())
524                        .encode(max_size, dst);
525                    dst.extend_from_slice(&hpack[start..end]);
526                    head = Head::new(Kind::Continuation, END_HEADERS, head.stream_id());
527                    start = end;
528                } else {
529                    head.encode(end - start, dst);
530                    dst.extend_from_slice(&hpack[start..end]);
531                    break;
532                }
533            }
534        });
535    }
536}
537
538#[cfg(test)]
539mod test {
540    use ntex_http::HeaderValue;
541
542    use super::*;
543    use crate::hpack::{Encoder, huffman};
544
545    #[test]
546    fn test_nameless_header_at_resume() {
547        let mut encoder = Encoder::default();
548        let mut dst = BytesMut::new();
549
550        let mut hdrs = HeaderMap::default();
551        hdrs.append(
552            HeaderName::from_static("hello"),
553            HeaderValue::from_static("world"),
554        );
555        hdrs.append(
556            HeaderName::from_static("hello"),
557            HeaderValue::from_static("zomg"),
558        );
559        hdrs.append(
560            HeaderName::from_static("hello"),
561            HeaderValue::from_static("sup"),
562        );
563
564        let mut headers = Headers::new(StreamId::CON, Default::default(), hdrs, false);
565        headers.set_end_headers();
566        headers.encode(&mut encoder, &mut dst, 8);
567        assert_eq!(48, dst.len());
568        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
569        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
570        assert_eq!("hello", huff_decode(&dst[11..15]));
571        assert_eq!(0x80 | 4, dst[15]);
572
573        let mut world = BytesMut::from(&dst[16..17]);
574        world.extend_from_slice(&dst[26..29]);
575        // assert_eq!("world", huff_decode(&world));
576
577        assert_eq!([0, 0, 8, 9, 0, 0, 0, 0, 0], &dst[17..26]);
578
579        // // Next is not indexed
580        //assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
581        //assert_eq!("zomg", huff_decode(&dst[15..18]));
582        //assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
583        //assert_eq!("sup", huff_decode(&dst[21..]));
584    }
585
586    fn huff_decode(src: &[u8]) -> Bytes {
587        let mut buf = BytesMut::new();
588        huffman::decode(src, &mut buf).unwrap()
589    }
590}