ntex_h2/frame/
headers.rs

1use std::{cell::RefCell, cmp, fmt, io::Cursor};
2
3use ntex_bytes::{ByteString, BytesMut};
4use ntex_http::{header, uri, HeaderMap, HeaderName, Method, StatusCode, Uri};
5
6use crate::hpack;
7
8use super::priority::StreamDependency;
9use super::{util, Frame, FrameError, Head, Kind, Protocol, StreamId};
10
11/// Header frame
12///
13/// This could be either a request or a response.
14#[derive(Clone)]
15pub struct Headers {
16    /// The ID of the stream with which this frame is associated.
17    stream_id: StreamId,
18
19    /// The header block fragment
20    header_block: HeaderBlock,
21
22    /// The associated flags
23    flags: HeadersFlag,
24}
25
26#[derive(Copy, Clone, Eq, PartialEq)]
27pub struct HeadersFlag(u8);
28
29#[derive(Clone, Debug, Default)]
30pub struct PseudoHeaders {
31    // Request
32    pub method: Option<Method>,
33    pub scheme: Option<ByteString>,
34    pub authority: Option<ByteString>,
35    pub path: Option<ByteString>,
36    pub protocol: Option<Protocol>,
37
38    // Response
39    pub status: Option<StatusCode>,
40}
41
42pub struct Iter<'a> {
43    /// Pseudo headers
44    pseudo: Option<PseudoHeaders>,
45
46    /// Header fields
47    fields: header::Iter<'a>,
48}
49
50#[derive(Debug, Clone)]
51struct HeaderBlock {
52    /// The decoded header fields
53    fields: HeaderMap,
54
55    /// Pseudo headers, these are broken out as they must be sent as part of the
56    /// headers frame.
57    pseudo: PseudoHeaders,
58}
59
60const END_STREAM: u8 = 0x1;
61const END_HEADERS: u8 = 0x4;
62const PADDED: u8 = 0x8;
63const PRIORITY: u8 = 0x20;
64const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
65
66// ===== impl Headers =====
67
68impl Headers {
69    /// Create a new HEADERS frame
70    pub fn new(stream_id: StreamId, pseudo: PseudoHeaders, fields: HeaderMap, eof: bool) -> Self {
71        let mut flags = HeadersFlag::default();
72        if eof {
73            flags.set_end_stream();
74        }
75        Headers {
76            flags,
77            stream_id,
78            header_block: HeaderBlock { fields, pseudo },
79        }
80    }
81
82    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
83        let mut flags = HeadersFlag::default();
84        flags.set_end_stream();
85
86        Headers {
87            stream_id,
88            flags,
89            header_block: HeaderBlock {
90                fields,
91                pseudo: PseudoHeaders::default(),
92            },
93        }
94    }
95
96    /// Loads the header frame but doesn't actually do HPACK decoding.
97    ///
98    /// HPACK decoding is done in the `load_hpack` step.
99    pub fn load(head: Head, src: &mut BytesMut) -> Result<Self, FrameError> {
100        let flags = HeadersFlag(head.flag());
101
102        if head.stream_id().is_zero() {
103            return Err(FrameError::InvalidStreamId);
104        }
105
106        // Read the padding length
107        let pad = if flags.is_padded() {
108            if src.is_empty() {
109                return Err(FrameError::MalformedMessage);
110            }
111            let pad = src[0] as usize;
112
113            // Drop the padding
114            let _ = src.split_to(1);
115            pad
116        } else {
117            0
118        };
119
120        // Read the stream dependency
121        if flags.is_priority() {
122            if src.len() < 5 {
123                return Err(FrameError::MalformedMessage);
124            }
125            let stream_dep = StreamDependency::load(&src[..5])?;
126
127            if stream_dep.dependency_id() == head.stream_id() {
128                return Err(FrameError::InvalidDependencyId);
129            }
130
131            // Drop the next 5 bytes
132            let _ = src.split_to(5);
133        }
134
135        if pad > 0 {
136            if pad > src.len() {
137                return Err(FrameError::TooMuchPadding);
138            }
139            src.truncate(src.len() - pad);
140        }
141
142        Ok(Headers {
143            flags,
144            stream_id: head.stream_id(),
145            header_block: HeaderBlock {
146                fields: HeaderMap::new(),
147                pseudo: PseudoHeaders::default(),
148            },
149        })
150    }
151
152    pub fn load_hpack(
153        &mut self,
154        src: &mut BytesMut,
155        decoder: &mut hpack::Decoder,
156    ) -> Result<(), FrameError> {
157        self.header_block.load(src, decoder)
158    }
159
160    pub fn stream_id(&self) -> StreamId {
161        self.stream_id
162    }
163
164    pub fn is_end_headers(&self) -> bool {
165        self.flags.is_end_headers()
166    }
167
168    pub fn set_end_headers(&mut self) {
169        self.flags.set_end_headers();
170    }
171
172    pub fn is_end_stream(&self) -> bool {
173        self.flags.is_end_stream()
174    }
175
176    pub fn set_end_stream(&mut self) {
177        self.flags.set_end_stream()
178    }
179
180    pub fn into_parts(self) -> (PseudoHeaders, HeaderMap) {
181        (self.header_block.pseudo, self.header_block.fields)
182    }
183
184    pub fn fields(&self) -> &HeaderMap {
185        &self.header_block.fields
186    }
187
188    pub fn pseudo(&self) -> &PseudoHeaders {
189        &self.header_block.pseudo
190    }
191
192    pub fn into_fields(self) -> HeaderMap {
193        self.header_block.fields
194    }
195
196    pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut, max_size: usize) {
197        // At this point, the `is_end_headers` flag should always be set
198        debug_assert!(self.flags.is_end_headers());
199
200        // Get the HEADERS frame head
201        let head = self.head();
202
203        self.header_block.encode(encoder, &head, dst, max_size);
204    }
205
206    fn head(&self) -> Head {
207        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
208    }
209}
210
211impl From<Headers> for Frame {
212    fn from(src: Headers) -> Self {
213        Frame::Headers(src)
214    }
215}
216
217impl fmt::Debug for Headers {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        let mut builder = f.debug_struct("Headers");
220        builder
221            .field("stream_id", &self.stream_id)
222            .field("flags", &self.flags)
223            .field("pseudo", &self.header_block.pseudo);
224
225        if let Some(ref protocol) = self.header_block.pseudo.protocol {
226            builder.field("protocol", protocol);
227        }
228
229        // `fields` and `pseudo` purposefully not included
230        builder.finish()
231    }
232}
233
234// ===== impl Pseudo =====
235
236impl PseudoHeaders {
237    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
238        let parts = uri::Parts::from(uri);
239
240        let mut path = parts
241            .path_and_query
242            .map(|v| ByteString::from(v.as_str()))
243            .unwrap_or(ByteString::from_static(""));
244
245        match method {
246            Method::OPTIONS | Method::CONNECT => {}
247            _ if path.is_empty() => {
248                path = ByteString::from_static("/");
249            }
250            _ => {}
251        }
252
253        let mut pseudo = PseudoHeaders {
254            method: Some(method),
255            scheme: None,
256            authority: None,
257            path: Some(path).filter(|p| !p.is_empty()),
258            protocol,
259            status: None,
260        };
261
262        // If the URI includes a scheme component, add it to the pseudo headers
263        //
264        // TODO: Scheme must be set...
265        if let Some(scheme) = parts.scheme {
266            pseudo.set_scheme(scheme);
267        }
268
269        // If the URI includes an authority component, add it to the pseudo
270        // headers
271        if let Some(authority) = parts.authority {
272            pseudo.set_authority(ByteString::from(authority.as_str()));
273        }
274
275        pseudo
276    }
277
278    pub fn response(status: StatusCode) -> Self {
279        PseudoHeaders {
280            method: None,
281            scheme: None,
282            authority: None,
283            path: None,
284            protocol: None,
285            status: Some(status),
286        }
287    }
288
289    pub fn set_status(&mut self, value: StatusCode) {
290        self.status = Some(value);
291    }
292
293    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
294        self.scheme = Some(match scheme.as_str() {
295            "http" => ByteString::from_static("http"),
296            "https" => ByteString::from_static("https"),
297            s => ByteString::from(s),
298        });
299    }
300
301    pub fn set_protocol(&mut self, protocol: Protocol) {
302        self.protocol = Some(protocol);
303    }
304
305    pub fn set_authority(&mut self, authority: ByteString) {
306        self.authority = Some(authority);
307    }
308}
309
310// ===== impl Iter =====
311
312impl Iterator for Iter<'_> {
313    type Item = hpack::Header<Option<HeaderName>>;
314
315    fn next(&mut self) -> Option<Self::Item> {
316        use crate::hpack::Header::*;
317
318        if let Some(ref mut pseudo) = self.pseudo {
319            if let Some(method) = pseudo.method.take() {
320                return Some(Method(method));
321            }
322
323            if let Some(scheme) = pseudo.scheme.take() {
324                return Some(Scheme(scheme));
325            }
326
327            if let Some(authority) = pseudo.authority.take() {
328                return Some(Authority(authority));
329            }
330
331            if let Some(path) = pseudo.path.take() {
332                return Some(Path(path));
333            }
334
335            if let Some(protocol) = pseudo.protocol.take() {
336                return Some(Protocol(protocol.into()));
337            }
338
339            if let Some(status) = pseudo.status.take() {
340                return Some(Status(status));
341            }
342        }
343
344        self.pseudo = None;
345
346        self.fields.next().map(|(name, value)| Field {
347            name: Some(name.clone()),
348            value: value.clone(),
349        })
350    }
351}
352
353// ===== impl HeadersFlag =====
354
355impl HeadersFlag {
356    pub fn empty() -> HeadersFlag {
357        HeadersFlag(0)
358    }
359
360    pub fn load(bits: u8) -> HeadersFlag {
361        HeadersFlag(bits & ALL)
362    }
363
364    pub fn is_end_stream(&self) -> bool {
365        self.0 & END_STREAM == END_STREAM
366    }
367
368    pub fn set_end_stream(&mut self) {
369        self.0 |= END_STREAM;
370    }
371
372    pub fn is_end_headers(&self) -> bool {
373        self.0 & END_HEADERS == END_HEADERS
374    }
375
376    pub fn set_end_headers(&mut self) {
377        self.0 |= END_HEADERS;
378    }
379
380    pub fn is_padded(&self) -> bool {
381        self.0 & PADDED == PADDED
382    }
383
384    pub fn is_priority(&self) -> bool {
385        self.0 & PRIORITY == PRIORITY
386    }
387}
388
389impl Default for HeadersFlag {
390    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
391    fn default() -> Self {
392        HeadersFlag(END_HEADERS)
393    }
394}
395
396impl From<HeadersFlag> for u8 {
397    fn from(src: HeadersFlag) -> u8 {
398        src.0
399    }
400}
401
402impl fmt::Debug for HeadersFlag {
403    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
404        util::debug_flags(fmt, self.0)
405            .flag_if(self.is_end_headers(), "END_HEADERS")
406            .flag_if(self.is_end_stream(), "END_STREAM")
407            .flag_if(self.is_padded(), "PADDED")
408            .flag_if(self.is_priority(), "PRIORITY")
409            .finish()
410    }
411}
412
413// ===== HeaderBlock =====
414
415thread_local! {
416    static HDRS_BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(1024));
417}
418
419impl HeaderBlock {
420    fn load(&mut self, src: &mut BytesMut, decoder: &mut hpack::Decoder) -> Result<(), FrameError> {
421        let mut reg = !self.fields.is_empty();
422        let mut malformed = false;
423
424        macro_rules! set_pseudo {
425            ($field:ident, $val:expr) => {{
426                if reg {
427                    log::trace!("load_hpack; header malformed -- pseudo not at head of block");
428                    malformed = true;
429                } else if self.pseudo.$field.is_some() {
430                    log::trace!("load_hpack; header malformed -- repeated pseudo");
431                    malformed = true;
432                } else {
433                    self.pseudo.$field = Some($val.into());
434                }
435            }};
436        }
437
438        let mut cursor = Cursor::new(src);
439
440        // If the header frame is malformed, we still have to continue decoding
441        // the headers. A malformed header frame is a stream level error, but
442        // the hpack state is connection level. In order to maintain correct
443        // state for other streams, the hpack decoding process must complete.
444        let res = decoder.decode(&mut cursor, |header| {
445            use crate::hpack::Header::*;
446
447            match header {
448                Field { name, value } => {
449                    // Connection level header fields are not supported and must
450                    // result in a protocol error.
451
452                    if name == header::CONNECTION
453                        || name == header::TRANSFER_ENCODING
454                        || name == header::UPGRADE
455                        || name == "keep-alive"
456                        || name == "proxy-connection"
457                    {
458                        log::trace!("load_hpack; connection level header");
459                        malformed = true;
460                    } else if name == header::TE && value != "trailers" {
461                        log::trace!("load_hpack; TE header not set to trailers; val={value:?}");
462                        malformed = true;
463                    } else {
464                        reg = true;
465                        self.fields.append(name, value);
466                    }
467                }
468                Authority(v) => {
469                    set_pseudo!(authority, v)
470                }
471                Method(v) => {
472                    set_pseudo!(method, v)
473                }
474                Scheme(v) => {
475                    set_pseudo!(scheme, v)
476                }
477                Path(v) => {
478                    set_pseudo!(path, v)
479                }
480                Protocol(v) => {
481                    set_pseudo!(protocol, v)
482                }
483                Status(v) => {
484                    set_pseudo!(status, v)
485                }
486            }
487        });
488
489        if let Err(e) = res {
490            log::trace!("hpack decoding error; err={e:?}");
491            return Err(e.into());
492        }
493
494        if malformed {
495            log::trace!("malformed message");
496            return Err(FrameError::MalformedMessage);
497        }
498
499        Ok(())
500    }
501
502    fn encode(
503        self,
504        encoder: &mut hpack::Encoder,
505        head: &Head,
506        dst: &mut BytesMut,
507        max_size: usize,
508    ) {
509        HDRS_BUF.with(|buf| {
510            let mut b = buf.borrow_mut();
511            let hpack = &mut b;
512            hpack.clear();
513
514            // encode hpack
515            let headers = Iter {
516                pseudo: Some(self.pseudo),
517                fields: self.fields.into_iter(),
518            };
519            encoder.encode(headers, hpack);
520
521            let mut head = *head;
522            let mut start = 0;
523            loop {
524                let end = cmp::min(start + max_size, hpack.len());
525
526                // encode the header payload
527                if hpack.len() > end {
528                    Head::new(head.kind(), head.flag() ^ END_HEADERS, head.stream_id())
529                        .encode(max_size, dst);
530                    dst.extend_from_slice(&hpack[start..end]);
531                    head = Head::new(Kind::Continuation, END_HEADERS, head.stream_id());
532                    start = end;
533                } else {
534                    head.encode(end - start, dst);
535                    dst.extend_from_slice(&hpack[start..end]);
536                    break;
537                }
538            }
539        });
540    }
541}
542
543#[cfg(test)]
544mod test {
545    use ntex_http::HeaderValue;
546
547    use super::*;
548    use crate::hpack::{huffman, Encoder};
549
550    #[test]
551    fn test_nameless_header_at_resume() {
552        let mut encoder = Encoder::default();
553        let mut dst = BytesMut::new();
554
555        let mut hdrs = HeaderMap::default();
556        hdrs.append(
557            HeaderName::from_static("hello"),
558            HeaderValue::from_static("world"),
559        );
560        hdrs.append(
561            HeaderName::from_static("hello"),
562            HeaderValue::from_static("zomg"),
563        );
564        hdrs.append(
565            HeaderName::from_static("hello"),
566            HeaderValue::from_static("sup"),
567        );
568
569        let mut headers = Headers::new(StreamId::CON, Default::default(), hdrs, false);
570        headers.set_end_headers();
571        headers.encode(&mut encoder, &mut dst, 8);
572        assert_eq!(48, dst.len());
573        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
574        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
575        assert_eq!("hello", huff_decode(&dst[11..15]));
576        assert_eq!(0x80 | 4, dst[15]);
577
578        let mut world = BytesMut::from(&dst[16..17]);
579        world.extend_from_slice(&dst[26..29]);
580        // assert_eq!("world", huff_decode(&world));
581
582        assert_eq!([0, 0, 8, 9, 0, 0, 0, 0, 0], &dst[17..26]);
583
584        // // Next is not indexed
585        //assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
586        //assert_eq!("zomg", huff_decode(&dst[15..18]));
587        //assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
588        //assert_eq!("sup", huff_decode(&dst[21..]));
589    }
590
591    fn huff_decode(src: &[u8]) -> BytesMut {
592        let mut buf = BytesMut::new();
593        huffman::decode(src, &mut buf).unwrap()
594    }
595}