Skip to main content

ntex_h2/codec/
mod.rs

1use std::{cell::RefCell, rc::Rc};
2
3use ntex_bytes::{Bytes, BytesMut};
4use ntex_codec::{Decoder, Encoder};
5
6mod error;
7mod length_delimited;
8
9pub use self::error::EncoderError;
10
11use self::length_delimited::LengthDelimitedCodec;
12use crate::{consts, frame, frame::Frame, frame::Kind, hpack};
13
14// Push promise frame kind
15const PUSH_PROMISE: u8 = 5;
16
17#[derive(Clone, Debug)]
18pub struct Codec(Rc<RefCell<CodecInner>>);
19
20/// Partially loaded headers frame
21#[derive(Debug)]
22struct Partial {
23    /// Empty frame
24    frame: frame::Headers,
25    /// Partial header payload
26    buf: Bytes,
27    /// Number of continuations
28    count: usize,
29}
30
31#[derive(Debug)]
32struct CodecInner {
33    // encoder state
34    encoder_hpack: hpack::Encoder,
35    encoder_last_data_frame: Option<frame::Data>,
36    encoder_max_frame_size: frame::FrameSize, // Max frame size, this is specified by the peer
37
38    // decoder state
39    decoder: LengthDelimitedCodec,
40    decoder_hpack: hpack::Decoder,
41    decoder_max_header_list_size: usize,
42    decoder_max_header_continuations: usize,
43    partial: Option<Partial>, // Partially loaded headers frame
44}
45
46impl Default for Codec {
47    #[inline]
48    /// Returns a new `Codec` with the default max frame size
49    fn default() -> Self {
50        // Delimit the frames
51        let decoder = self::length_delimited::Builder::new()
52            .length_field_length(3)
53            .length_adjustment(9)
54            .max_frame_length(frame::DEFAULT_MAX_FRAME_SIZE as usize)
55            .num_skip(0) // Don't skip the header
56            .new_codec();
57
58        Codec(Rc::new(RefCell::new(CodecInner {
59            decoder,
60            decoder_hpack: hpack::Decoder::new(frame::DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
61            decoder_max_header_list_size: consts::DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE as usize,
62            decoder_max_header_continuations: consts::DEFAULT_MAX_COUNTINUATIONS,
63            partial: None,
64
65            encoder_hpack: hpack::Encoder::default(),
66            encoder_last_data_frame: None,
67            encoder_max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
68        })))
69    }
70}
71
72impl Codec {
73    /// Updates the max received frame size.
74    ///
75    /// The change takes effect the next time a frame is decoded. In other
76    /// words, if a frame is currently in process of being decoded with a frame
77    /// size greater than `val` but less than the max frame size in effect
78    /// before calling this function, then the frame will be allowed.
79    ///
80    /// # Panics
81    ///
82    /// Panics if size is greater than `16_777_215`.
83    #[inline]
84    pub fn set_recv_frame_size(&self, val: usize) {
85        assert!(
86            frame::DEFAULT_MAX_FRAME_SIZE as usize <= val
87                && val <= frame::MAX_MAX_FRAME_SIZE as usize
88        );
89        self.0.borrow_mut().decoder.set_max_frame_length(val);
90    }
91
92    /// Local max frame size.
93    pub fn recv_frame_size(&self) -> u32 {
94        self.0.borrow_mut().decoder.max_frame_length() as u32
95    }
96
97    /// Set the max header list size that can be received.
98    ///
99    /// By default value is set to 48kb
100    pub fn set_recv_header_list_size(&self, val: usize) {
101        self.0.borrow_mut().decoder_max_header_list_size = val;
102    }
103
104    /// Set the max header continuation frames.
105    ///
106    /// By default value is set to 5
107    pub fn set_max_header_continuations(&self, val: usize) {
108        self.0.borrow_mut().decoder_max_header_continuations = val;
109    }
110
111    /// Set the peer's max frame size.
112    ///
113    /// # Panics
114    ///
115    /// Panics if size is greater than `16_777_215`.
116    pub fn set_send_frame_size(&self, val: usize) {
117        assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
118        self.0.borrow_mut().encoder_max_frame_size = val as frame::FrameSize;
119    }
120
121    /// Set the peer's header table size size.
122    pub fn set_send_header_table_size(&self, val: usize) {
123        self.0.borrow_mut().encoder_hpack.update_max_size(val);
124    }
125
126    /// Remote max frame size.
127    pub fn send_frame_size(&self) -> u32 {
128        self.0.borrow_mut().encoder_max_frame_size
129    }
130}
131
132impl Decoder for Codec {
133    type Item = Frame;
134    type Error = frame::FrameError;
135
136    #[allow(clippy::too_many_lines)]
137    /// Decodes a frame.
138    ///
139    /// This method is intentionally de-generified and outlined because it is very large.
140    fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>, frame::FrameError> {
141        let mut inner = self.0.borrow_mut();
142        loop {
143            let Some(mut bytes) = inner.decoder.decode(src)? else {
144                return Ok(None);
145            };
146
147            // check push promise, we do not support push
148            if bytes[3] == PUSH_PROMISE {
149                return Err(frame::FrameError::UnexpectedPushPromise);
150            }
151
152            // Parse the head
153            let head = frame::Head::parse(&bytes);
154            let kind = head.kind();
155
156            if inner.partial.is_some() && kind != Kind::Continuation {
157                proto_err!(conn: "expected CONTINUATION, got {:?}", kind);
158                return Err(frame::FrameError::Continuation(
159                    frame::FrameContinuationError::Expected,
160                ));
161            }
162
163            log::trace!("decoding {:?} frame, frame buf len {}", kind, bytes.len());
164
165            let frame = match kind {
166                Kind::Settings => frame::Settings::load(head, &bytes[frame::HEADER_LEN..])
167                    .inspect_err(|e| {
168                        proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
169                    })?
170                    .into(),
171                Kind::Ping => frame::Ping::load(head, &bytes[frame::HEADER_LEN..])
172                    .inspect_err(|e| {
173                        proto_err!(conn: "failed to load PING frame; err={:?}", e);
174                    })?
175                    .into(),
176                Kind::WindowUpdate => frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..])
177                    .inspect_err(|e| {
178                        proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
179                    })?
180                    .into(),
181                Kind::Data => {
182                    bytes.advance_to(frame::HEADER_LEN);
183
184                    frame::Data::load(head, bytes)
185                        // TODO: Should this always be connection level? Probably not...
186                        .inspect_err(|e| {
187                            proto_err!(conn: "failed to load DATA frame; err={:?}", e);
188                        })?
189                        .into()
190                }
191                Kind::Headers => {
192                    // Drop the frame header
193                    bytes.advance_to(frame::HEADER_LEN);
194
195                    // Parse the header frame w/o parsing the payload
196                    let mut frame = match frame::Headers::load(head, &mut bytes) {
197                        Ok(res) => Ok(res),
198                        Err(frame::FrameError::InvalidDependencyId) => {
199                            proto_err!(stream: "invalid HEADERS dependency ID");
200                            // A stream cannot depend on itself. An endpoint MUST
201                            // treat this as a stream error (Section 5.4.2) of type `PROTOCOL_ERROR`.
202                            Err(frame::FrameError::InvalidDependencyId)
203                        }
204                        Err(e) => {
205                            proto_err!(conn: "failed to load frame; err={:?}", e);
206                            Err(e)
207                        }
208                    }?;
209
210                    if frame.is_end_headers() {
211                        // Load the HPACK encoded headers
212                        match frame.load_hpack(&mut bytes, &mut inner.decoder_hpack) {
213                            Ok(()) => {}
214                            Err(frame::FrameError::MalformedMessage) => {
215                                let id = head.stream_id();
216                                proto_err!(stream: "malformed header block; stream={:?}", id);
217                                return Err(frame::FrameError::MalformedMessage);
218                            }
219                            Err(e) => {
220                                proto_err!(conn: "failed HPACK decoding; err={:?}", e);
221                                return Err(e);
222                            }
223                        }
224                        frame.into()
225                    } else {
226                        log::trace!("loaded partial header block");
227                        // Defer returning the frame
228                        inner.partial = Some(Partial {
229                            frame,
230                            buf: bytes,
231                            count: 0,
232                        });
233
234                        continue;
235                    }
236                }
237                Kind::Reset => frame::Reset::load(head, &bytes[frame::HEADER_LEN..])
238                    .inspect_err(|e| {
239                        proto_err!(conn: "failed to load RESET frame; err={:?}", e);
240                    })?
241                    .into(),
242                Kind::GoAway => frame::GoAway::load(&bytes[frame::HEADER_LEN..])
243                    .inspect_err(|e| {
244                        proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
245                    })?
246                    .into(),
247                Kind::Priority => {
248                    if head.stream_id() == 0 {
249                        // Invalid stream identifier
250                        proto_err!(conn: "invalid stream ID 0");
251                        return Err(frame::FrameError::InvalidStreamId);
252                    }
253
254                    match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
255                        Ok(frame) => frame.into(),
256                        Err(frame::FrameError::InvalidDependencyId) => {
257                            // A stream cannot depend on itself. An endpoint MUST
258                            // treat this as a stream error (Section 5.4.2) of type
259                            // `PROTOCOL_ERROR`.
260                            let id = head.stream_id();
261                            proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
262                            return Err(frame::FrameError::InvalidDependencyId);
263                        }
264                        Err(e) => {
265                            proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
266                            return Err(e);
267                        }
268                    }
269                }
270                Kind::Continuation => {
271                    let mut partial = inner.partial.take().ok_or_else(|| {
272                        proto_err!(conn: "received unexpected CONTINUATION frame");
273                        frame::FrameError::Continuation(frame::FrameContinuationError::Unexpected)
274                    })?;
275
276                    // The stream identifiers must match
277                    if partial.frame.stream_id() != head.stream_id() {
278                        proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
279                        return Err(frame::FrameError::Continuation(
280                            frame::FrameContinuationError::UnknownStreamId,
281                        ));
282                    }
283
284                    if inner.decoder_max_header_continuations > 0 {
285                        // Check count of continuation frames
286                        partial.count += 1;
287                        if partial.count > inner.decoder_max_header_continuations {
288                            proto_err!(conn: "received excessive amount of CONTINUATION frames");
289                            return Err(frame::FrameError::Continuation(
290                                frame::FrameContinuationError::MaxContinuations,
291                            ));
292                        }
293                    }
294
295                    // Extend the buf
296                    if partial.buf.is_empty() {
297                        partial.buf = bytes.split_off(frame::HEADER_LEN);
298                    } else {
299                        // If there was left over bytes previously, they may be
300                        // needed to continue decoding, even though we will
301                        // be ignoring this frame. This is done to keep the HPACK
302                        // decoder state up-to-date.
303                        //
304                        // Still, we need to be careful, because if a malicious
305                        // attacker were to try to send a gigantic string, such
306                        // that it fits over multiple header blocks.
307                        //
308                        // Instead, we use a simple heuristic to determine if
309                        // we should continue to ignore decoding, or to tell
310                        // the attacker to go away.
311                        if partial.buf.len() + bytes.len() > inner.decoder_max_header_list_size {
312                            proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
313                            return Err(frame::FrameError::Continuation(
314                                frame::FrameContinuationError::MaxLeftoverSize,
315                            ));
316                        }
317                        let mut buf = BytesMut::with_capacity(
318                            partial.buf.len() + bytes.len() - frame::HEADER_LEN,
319                        );
320                        buf.extend_from_slice(&partial.buf);
321                        buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
322                        partial.buf = buf.into();
323                    }
324
325                    if (head.flag() & 0x4) == 0x4 {
326                        match partial
327                            .frame
328                            .load_hpack(&mut partial.buf, &mut inner.decoder_hpack)
329                        {
330                            Ok(()) => {}
331                            Err(frame::FrameError::MalformedMessage) => {
332                                let id = head.stream_id();
333                                proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
334                                return Err(frame::FrameContinuationError::Malformed.into());
335                            }
336                            Err(e) => {
337                                proto_err!(conn: "failed HPACK decoding; err={:?}", e);
338                                return Err(e);
339                            }
340                        }
341
342                        partial.frame.into()
343                    } else {
344                        inner.partial = Some(partial);
345                        continue;
346                    }
347                }
348                Kind::Unknown => {
349                    // Unknown frames are ignored
350                    continue;
351                }
352            };
353
354            return Ok(Some(frame));
355        }
356    }
357}
358
359impl Encoder for Codec {
360    type Item = Frame;
361    type Error = error::EncoderError;
362
363    fn encode(&self, item: Frame, buf: &mut BytesMut) -> Result<(), error::EncoderError> {
364        // Ensure that we have enough capacity to accept the write.
365        // log::debug!(frame = ?item, "send");
366
367        let mut inner = self.0.borrow_mut();
368
369        match item {
370            Frame::Data(v) => {
371                // Ensure that the payload is not greater than the max frame.
372                let len = v.payload().len();
373                if len > inner.encoder_max_frame_size as usize {
374                    return Err(error::EncoderError::MaxSizeExceeded);
375                }
376                v.encode(buf);
377
378                // Save off the last frame...
379                inner.encoder_last_data_frame = Some(v);
380            }
381            Frame::Headers(v) => {
382                let max_size = inner.encoder_max_frame_size as usize;
383                v.encode(&mut inner.encoder_hpack, buf, max_size);
384            }
385            Frame::Settings(v) => {
386                v.encode(buf);
387            }
388            Frame::GoAway(v) => {
389                v.encode(buf);
390            }
391            Frame::Ping(v) => {
392                v.encode(buf);
393            }
394            Frame::WindowUpdate(v) => {
395                v.encode(buf);
396            }
397
398            Frame::Priority(_) => (),
399            Frame::Reset(v) => {
400                v.encode(buf);
401            }
402        }
403
404        Ok(())
405    }
406}