hreq_h2/codec/
framed_read.rs

1use crate::codec::RecvError;
2use crate::frame::{self, Frame, Kind, Reason};
3use crate::frame::{
4    DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
5};
6
7use crate::hpack;
8
9use futures_core::Stream;
10
11use bytes::BytesMut;
12
13use std::io;
14
15use crate::tokio_codec::FramedRead as InnerFramedRead;
16use crate::tokio_codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
17use futures_io::AsyncRead;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21// 16 MB "sane default" taken from golang http2
22const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24#[derive(Debug)]
25pub struct FramedRead<T> {
26    inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28    // hpack decoder state
29    hpack: hpack::Decoder,
30
31    max_header_list_size: usize,
32
33    partial: Option<Partial>,
34}
35
36/// Partially loaded headers frame
37#[derive(Debug)]
38struct Partial {
39    /// Empty frame
40    frame: Continuable,
41
42    /// Partial header payload
43    buf: BytesMut,
44}
45
46#[derive(Debug)]
47enum Continuable {
48    Headers(frame::Headers),
49    PushPromise(frame::PushPromise),
50}
51
52impl<T> FramedRead<T> {
53    pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
54        FramedRead {
55            inner,
56            hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
57            max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
58            partial: None,
59        }
60    }
61
62    fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError> {
63        use self::RecvError::*;
64
65        log::trace!("decoding frame from {}B", bytes.len());
66
67        // Parse the head
68        let head = frame::Head::parse(&bytes);
69
70        if self.partial.is_some() && head.kind() != Kind::Continuation {
71            proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
72            return Err(Connection(Reason::PROTOCOL_ERROR));
73        }
74
75        let kind = head.kind();
76
77        log::trace!("    -> kind={:?}", kind);
78
79        macro_rules! header_block {
80            ($frame:ident, $head:ident, $bytes:ident) => ({
81                // Drop the frame header
82                // TODO: Change to drain: carllerche/bytes#130
83                let _ = $bytes.split_to(frame::HEADER_LEN);
84
85                // Parse the header frame w/o parsing the payload
86                let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
87                    Ok(res) => res,
88                    Err(frame::Error::InvalidDependencyId) => {
89                        proto_err!(stream: "invalid HEADERS dependency ID");
90                        // A stream cannot depend on itself. An endpoint MUST
91                        // treat this as a stream error (Section 5.4.2) of type
92                        // `PROTOCOL_ERROR`.
93                        return Err(Stream {
94                            id: $head.stream_id(),
95                            reason: Reason::PROTOCOL_ERROR,
96                        });
97                    },
98                    Err(e) => {
99                        proto_err!(conn: "failed to load frame; err={:?}", e);
100                        return Err(Connection(Reason::PROTOCOL_ERROR));
101                    }
102                };
103
104                let is_end_headers = frame.is_end_headers();
105
106                // Load the HPACK encoded headers
107                match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) {
108                    Ok(_) => {},
109                    Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
110                    Err(frame::Error::MalformedMessage) => {
111                        let id = $head.stream_id();
112                        proto_err!(stream: "malformed header block; stream={:?}", id);
113                        return Err(Stream {
114                            id,
115                            reason: Reason::PROTOCOL_ERROR,
116                        });
117                    },
118                    Err(e) => {
119                        proto_err!(conn: "failed HPACK decoding; err={:?}", e);
120                        return Err(Connection(Reason::PROTOCOL_ERROR));
121                    }
122                }
123
124                if is_end_headers {
125                    frame.into()
126                } else {
127                    log::trace!("loaded partial header block");
128                    // Defer returning the frame
129                    self.partial = Some(Partial {
130                        frame: Continuable::$frame(frame),
131                        buf: payload,
132                    });
133
134                    return Ok(None);
135                }
136            });
137        }
138
139        let frame = match kind {
140            Kind::Settings => {
141                let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
142
143                res.map_err(|e| {
144                    proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
145                    Connection(Reason::PROTOCOL_ERROR)
146                })?
147                .into()
148            }
149            Kind::Ping => {
150                let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
151
152                res.map_err(|e| {
153                    proto_err!(conn: "failed to load PING frame; err={:?}", e);
154                    Connection(Reason::PROTOCOL_ERROR)
155                })?
156                .into()
157            }
158            Kind::WindowUpdate => {
159                let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
160
161                res.map_err(|e| {
162                    proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
163                    Connection(Reason::PROTOCOL_ERROR)
164                })?
165                .into()
166            }
167            Kind::Data => {
168                let _ = bytes.split_to(frame::HEADER_LEN);
169                let res = frame::Data::load(head, bytes.freeze());
170
171                // TODO: Should this always be connection level? Probably not...
172                res.map_err(|e| {
173                    proto_err!(conn: "failed to load DATA frame; err={:?}", e);
174                    Connection(Reason::PROTOCOL_ERROR)
175                })?
176                .into()
177            }
178            Kind::Headers => header_block!(Headers, head, bytes),
179            Kind::Reset => {
180                let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
181                res.map_err(|e| {
182                    proto_err!(conn: "failed to load RESET frame; err={:?}", e);
183                    Connection(Reason::PROTOCOL_ERROR)
184                })?
185                .into()
186            }
187            Kind::GoAway => {
188                let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
189                res.map_err(|e| {
190                    proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
191                    Connection(Reason::PROTOCOL_ERROR)
192                })?
193                .into()
194            }
195            Kind::PushPromise => header_block!(PushPromise, head, bytes),
196            Kind::Priority => {
197                if head.stream_id() == 0 {
198                    // Invalid stream identifier
199                    proto_err!(conn: "invalid stream ID 0");
200                    return Err(Connection(Reason::PROTOCOL_ERROR));
201                }
202
203                match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
204                    Ok(frame) => frame.into(),
205                    Err(frame::Error::InvalidDependencyId) => {
206                        // A stream cannot depend on itself. An endpoint MUST
207                        // treat this as a stream error (Section 5.4.2) of type
208                        // `PROTOCOL_ERROR`.
209                        let id = head.stream_id();
210                        proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
211                        return Err(Stream {
212                            id,
213                            reason: Reason::PROTOCOL_ERROR,
214                        });
215                    }
216                    Err(e) => {
217                        proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
218                        return Err(Connection(Reason::PROTOCOL_ERROR));
219                    }
220                }
221            }
222            Kind::Continuation => {
223                let is_end_headers = (head.flag() & 0x4) == 0x4;
224
225                let mut partial = match self.partial.take() {
226                    Some(partial) => partial,
227                    None => {
228                        proto_err!(conn: "received unexpected CONTINUATION frame");
229                        return Err(Connection(Reason::PROTOCOL_ERROR));
230                    }
231                };
232
233                // The stream identifiers must match
234                if partial.frame.stream_id() != head.stream_id() {
235                    proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
236                    return Err(Connection(Reason::PROTOCOL_ERROR));
237                }
238
239                // Extend the buf
240                if partial.buf.is_empty() {
241                    partial.buf = bytes.split_off(frame::HEADER_LEN);
242                } else {
243                    if partial.frame.is_over_size() {
244                        // If there was left over bytes previously, they may be
245                        // needed to continue decoding, even though we will
246                        // be ignoring this frame. This is done to keep the HPACK
247                        // decoder state up-to-date.
248                        //
249                        // Still, we need to be careful, because if a malicious
250                        // attacker were to try to send a gigantic string, such
251                        // that it fits over multiple header blocks, we could
252                        // grow memory uncontrollably again, and that'd be a shame.
253                        //
254                        // Instead, we use a simple heuristic to determine if
255                        // we should continue to ignore decoding, or to tell
256                        // the attacker to go away.
257                        if partial.buf.len() + bytes.len() > self.max_header_list_size {
258                            proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
259                            return Err(Connection(Reason::COMPRESSION_ERROR));
260                        }
261                    }
262                    partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
263                }
264
265                match partial.frame.load_hpack(
266                    &mut partial.buf,
267                    self.max_header_list_size,
268                    &mut self.hpack,
269                ) {
270                    Ok(_) => {}
271                    Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_)))
272                        if !is_end_headers => {}
273                    Err(frame::Error::MalformedMessage) => {
274                        let id = head.stream_id();
275                        proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
276                        return Err(Stream {
277                            id,
278                            reason: Reason::PROTOCOL_ERROR,
279                        });
280                    }
281                    Err(e) => {
282                        proto_err!(conn: "failed HPACK decoding; err={:?}", e);
283                        return Err(Connection(Reason::PROTOCOL_ERROR));
284                    }
285                }
286
287                if is_end_headers {
288                    partial.frame.into()
289                } else {
290                    self.partial = Some(partial);
291                    return Ok(None);
292                }
293            }
294            Kind::Unknown => {
295                // Unknown frames are ignored
296                return Ok(None);
297            }
298        };
299
300        Ok(Some(frame))
301    }
302
303    pub fn get_ref(&self) -> &T {
304        self.inner.get_ref()
305    }
306
307    pub fn get_mut(&mut self) -> &mut T {
308        self.inner.get_mut()
309    }
310
311    /// Returns the current max frame size setting
312    #[cfg(feature = "unstable")]
313    #[inline]
314    pub fn max_frame_size(&self) -> usize {
315        self.inner.decoder().max_frame_length()
316    }
317
318    /// Updates the max frame size setting.
319    ///
320    /// Must be within 16,384 and 16,777,215.
321    #[inline]
322    pub fn set_max_frame_size(&mut self, val: usize) {
323        assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
324        self.inner.decoder_mut().set_max_frame_length(val)
325    }
326
327    /// Update the max header list size setting.
328    #[inline]
329    pub fn set_max_header_list_size(&mut self, val: usize) {
330        self.max_header_list_size = val;
331    }
332}
333
334impl<T> Stream for FramedRead<T>
335where
336    T: AsyncRead + Unpin,
337{
338    type Item = Result<Frame, RecvError>;
339
340    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341        loop {
342            log::trace!("poll");
343            let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
344                Some(Ok(bytes)) => bytes,
345                Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
346                None => return Poll::Ready(None),
347            };
348
349            log::trace!("poll; bytes={}B", bytes.len());
350            if let Some(frame) = self.decode_frame(bytes)? {
351                log::debug!("received; frame={:?}", frame);
352                return Poll::Ready(Some(Ok(frame)));
353            }
354        }
355    }
356}
357
358fn map_err(err: io::Error) -> RecvError {
359    if let io::ErrorKind::InvalidData = err.kind() {
360        if let Some(custom) = err.get_ref() {
361            if custom.is::<LengthDelimitedCodecError>() {
362                return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
363            }
364        }
365    }
366    err.into()
367}
368
369// ===== impl Continuable =====
370
371impl Continuable {
372    fn stream_id(&self) -> frame::StreamId {
373        match *self {
374            Continuable::Headers(ref h) => h.stream_id(),
375            Continuable::PushPromise(ref p) => p.stream_id(),
376        }
377    }
378
379    fn is_over_size(&self) -> bool {
380        match *self {
381            Continuable::Headers(ref h) => h.is_over_size(),
382            Continuable::PushPromise(ref p) => p.is_over_size(),
383        }
384    }
385
386    fn load_hpack(
387        &mut self,
388        src: &mut BytesMut,
389        max_header_list_size: usize,
390        decoder: &mut hpack::Decoder,
391    ) -> Result<(), frame::Error> {
392        match *self {
393            Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
394            Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
395        }
396    }
397}
398
399impl<T> From<Continuable> for Frame<T> {
400    fn from(cont: Continuable) -> Self {
401        match cont {
402            Continuable::Headers(mut headers) => {
403                headers.set_end_headers();
404                headers.into()
405            }
406            Continuable::PushPromise(mut push) => {
407                push.set_end_headers();
408                push.into()
409            }
410        }
411    }
412}