async_h1/chunked/
decoder.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_lite::io::{self, AsyncRead as Read};
7use futures_lite::ready;
8use http_types::trailers::{Sender, Trailers};
9
10/// Decodes a chunked body according to
11/// https://tools.ietf.org/html/rfc7230#section-4.1
12#[derive(Debug)]
13pub struct ChunkedDecoder<R: Read> {
14    /// The underlying stream
15    inner: R,
16    /// Current state.
17    state: State,
18    /// Current chunk size (increased while parsing size, decreased while reading chunk)
19    chunk_size: u64,
20    /// Trailer channel sender.
21    trailer_sender: Option<Sender>,
22}
23
24impl<R: Read> ChunkedDecoder<R> {
25    pub(crate) fn new(inner: R, trailer_sender: Sender) -> Self {
26        ChunkedDecoder {
27            inner,
28            state: State::ChunkSize,
29            chunk_size: 0,
30            trailer_sender: Some(trailer_sender),
31        }
32    }
33}
34
35/// Decoder state.
36enum State {
37    /// Parsing bytes from a chunk size
38    ChunkSize,
39    /// Expecting the \n at the end of a chunk size
40    ChunkSizeExpectLf,
41    /// Parsing the chunk body
42    ChunkBody,
43    /// Expecting the \r at the end of a chunk body
44    ChunkBodyExpectCr,
45    /// Expecting the \n at the end of a chunk body
46    ChunkBodyExpectLf,
47    /// Parsing trailers.
48    Trailers(usize, Box<[u8; 8192]>),
49    /// Sending trailers over the channel.
50    TrailerSending(Pin<Box<dyn Future<Output = ()> + 'static + Send + Sync>>),
51    /// All is said and done.
52    Done,
53}
54
55impl fmt::Debug for State {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            State::ChunkSize => write!(f, "State::ChunkSize"),
59            State::ChunkSizeExpectLf => write!(f, "State::ChunkSizeExpectLf"),
60            State::ChunkBody => write!(f, "State::ChunkBody"),
61            State::ChunkBodyExpectCr => write!(f, "State::ChunkBodyExpectCr"),
62            State::ChunkBodyExpectLf => write!(f, "State::ChunkBodyExpectLf"),
63            State::Trailers(len, _) => write!(f, "State::Trailers({}, _)", len),
64            State::TrailerSending(_) => write!(f, "State::TrailerSending"),
65            State::Done => write!(f, "State::Done"),
66        }
67    }
68}
69
70impl<R: Read + Unpin> ChunkedDecoder<R> {
71    fn poll_read_byte(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<u8>> {
72        let mut byte = [0u8];
73        if ready!(Pin::new(&mut self.inner).poll_read(cx, &mut byte))? == 1 {
74            Poll::Ready(Ok(byte[0]))
75        } else {
76            eof()
77        }
78    }
79
80    fn expect_byte(
81        &mut self,
82        cx: &mut Context<'_>,
83        expected_byte: u8,
84        expected: &'static str,
85    ) -> Poll<io::Result<()>> {
86        let byte = ready!(self.poll_read_byte(cx))?;
87        if byte == expected_byte {
88            Poll::Ready(Ok(()))
89        } else {
90            unexpected(byte, expected)
91        }
92    }
93
94    fn send_trailers(&mut self, trailers: Trailers) {
95        let sender = self
96            .trailer_sender
97            .take()
98            .expect("invalid chunked state, tried sending multiple trailers");
99        let fut = Box::pin(sender.send(trailers));
100        self.state = State::TrailerSending(fut);
101    }
102}
103
104fn eof<T>() -> Poll<io::Result<T>> {
105    Poll::Ready(Err(io::Error::new(
106        io::ErrorKind::UnexpectedEof,
107        "Unexpected EOF when decoding chunked data",
108    )))
109}
110
111fn unexpected<T>(byte: u8, expected: &'static str) -> Poll<io::Result<T>> {
112    Poll::Ready(Err(io::Error::new(
113        io::ErrorKind::InvalidData,
114        format!("Unexpected byte {}; expected {}", byte, expected),
115    )))
116}
117
118fn overflow() -> io::Error {
119    io::Error::new(io::ErrorKind::InvalidData, "Chunk size overflowed 64 bits")
120}
121
122impl<R: Read + Unpin> Read for ChunkedDecoder<R> {
123    #[allow(missing_doc_code_examples)]
124    fn poll_read(
125        mut self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127        buf: &mut [u8],
128    ) -> Poll<io::Result<usize>> {
129        let this = &mut *self;
130
131        loop {
132            match this.state {
133                State::ChunkSize => {
134                    let byte = ready!(this.poll_read_byte(cx))?;
135                    let digit = match byte {
136                        b'0'..=b'9' => byte - b'0',
137                        b'a'..=b'f' => 10 + byte - b'a',
138                        b'A'..=b'F' => 10 + byte - b'A',
139                        b'\r' => {
140                            this.state = State::ChunkSizeExpectLf;
141                            continue;
142                        }
143                        _ => {
144                            return unexpected(byte, "hex digit or CR");
145                        }
146                    };
147                    this.chunk_size = this
148                        .chunk_size
149                        .checked_mul(16)
150                        .ok_or_else(overflow)?
151                        .checked_add(digit as u64)
152                        .ok_or_else(overflow)?;
153                }
154                State::ChunkSizeExpectLf => {
155                    ready!(this.expect_byte(cx, b'\n', "LF"))?;
156                    if this.chunk_size == 0 {
157                        this.state = State::Trailers(0, Box::new([0u8; 8192]));
158                    } else {
159                        this.state = State::ChunkBody;
160                    }
161                }
162                State::ChunkBody => {
163                    let max_bytes = std::cmp::min(
164                        buf.len(),
165                        std::cmp::min(this.chunk_size, usize::MAX as u64) as usize,
166                    );
167                    let bytes_read =
168                        ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf[..max_bytes]))?;
169                    this.chunk_size -= bytes_read as u64;
170                    if bytes_read == 0 {
171                        return eof();
172                    } else if this.chunk_size == 0 {
173                        this.state = State::ChunkBodyExpectCr;
174                    }
175                    return Poll::Ready(Ok(bytes_read));
176                }
177                State::ChunkBodyExpectCr => {
178                    ready!(this.expect_byte(cx, b'\r', "CR"))?;
179                    this.state = State::ChunkBodyExpectLf;
180                }
181                State::ChunkBodyExpectLf => {
182                    ready!(this.expect_byte(cx, b'\n', "LF"))?;
183                    this.state = State::ChunkSize;
184                }
185                State::Trailers(ref mut len, ref mut buf) => {
186                    let bytes_read =
187                        ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf[*len..]))?;
188                    *len += bytes_read;
189                    let len = *len;
190                    if len == 0 {
191                        this.send_trailers(Trailers::new());
192                        continue;
193                    }
194                    if bytes_read == 0 {
195                        return eof();
196                    }
197                    let mut headers = [httparse::EMPTY_HEADER; 16];
198                    let parse_result = httparse::parse_headers(&buf[..len], &mut headers)
199                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
200                    use httparse::Status;
201                    match parse_result {
202                        Status::Partial => {
203                            if len == buf.len() {
204                                return eof();
205                            } else {
206                                return Poll::Pending;
207                            }
208                        }
209                        Status::Complete((offset, headers)) => {
210                            if offset != len {
211                                return unexpected(buf[offset], "end of trailers");
212                            }
213                            let mut trailers = Trailers::new();
214                            for header in headers {
215                                trailers.insert(
216                                    header.name,
217                                    String::from_utf8_lossy(header.value).as_ref(),
218                                );
219                            }
220                            this.send_trailers(trailers);
221                        }
222                    }
223                }
224                State::TrailerSending(ref mut fut) => {
225                    ready!(Pin::new(fut).poll(cx));
226                    this.state = State::Done;
227                }
228                State::Done => return Poll::Ready(Ok(0)),
229            }
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use async_std::prelude::*;
238
239    #[test]
240    fn test_chunked_wiki() {
241        async_std::task::block_on(async move {
242            let input = async_std::io::Cursor::new(
243                "4\r\n\
244                  Wiki\r\n\
245                  5\r\n\
246                  pedia\r\n\
247                  E\r\n in\r\n\
248                  \r\n\
249                  chunks.\r\n\
250                  0\r\n\
251                  \r\n"
252                    .as_bytes(),
253            );
254
255            let (s, _r) = async_channel::bounded(1);
256            let sender = Sender::new(s);
257            let mut decoder = ChunkedDecoder::new(input, sender);
258
259            let mut output = String::new();
260            decoder.read_to_string(&mut output).await.unwrap();
261            assert_eq!(
262                output,
263                "Wikipedia in\r\n\
264                 \r\n\
265                 chunks."
266            );
267        });
268    }
269
270    #[test]
271    fn test_chunked_big() {
272        async_std::task::block_on(async move {
273            let mut input: Vec<u8> = b"800\r\n".to_vec();
274            input.extend(vec![b'X'; 2048]);
275            input.extend(b"\r\n1800\r\n");
276            input.extend(vec![b'Y'; 6144]);
277            input.extend(b"\r\n800\r\n");
278            input.extend(vec![b'Z'; 2048]);
279            input.extend(b"\r\n0\r\n\r\n");
280
281            let (s, _r) = async_channel::bounded(1);
282            let sender = Sender::new(s);
283            let mut decoder = ChunkedDecoder::new(async_std::io::Cursor::new(input), sender);
284
285            let mut output = String::new();
286            decoder.read_to_string(&mut output).await.unwrap();
287
288            let mut expected = vec![b'X'; 2048];
289            expected.extend(vec![b'Y'; 6144]);
290            expected.extend(vec![b'Z'; 2048]);
291            assert_eq!(output.len(), 10240);
292            assert_eq!(output.as_bytes(), expected.as_slice());
293        });
294    }
295
296    #[test]
297    fn test_chunked_mdn() {
298        async_std::task::block_on(async move {
299            let input = async_std::io::Cursor::new(
300                "7\r\n\
301                 Mozilla\r\n\
302                 9\r\n\
303                 Developer\r\n\
304                 7\r\n\
305                 Network\r\n\
306                 0\r\n\
307                 Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\n\
308                 \r\n"
309                    .as_bytes(),
310            );
311            let (s, r) = async_channel::bounded(1);
312            let sender = Sender::new(s);
313            let mut decoder = ChunkedDecoder::new(input, sender);
314
315            let mut output = String::new();
316            decoder.read_to_string(&mut output).await.unwrap();
317            assert_eq!(output, "MozillaDeveloperNetwork");
318
319            let trailers = r.recv().await.unwrap();
320            assert_eq!(trailers.iter().count(), 1);
321            assert_eq!(trailers["Expires"], "Wed, 21 Oct 2015 07:28:00 GMT");
322        });
323    }
324
325    #[test]
326    fn test_ff7() {
327        async_std::task::block_on(async move {
328            let mut input: Vec<u8> = b"FF7\r\n".to_vec();
329            input.extend(vec![b'X'; 0xFF7]);
330            input.extend(b"\r\n4\r\n");
331            input.extend(vec![b'Y'; 4]);
332            input.extend(b"\r\n0\r\n\r\n");
333
334            let (s, _r) = async_channel::bounded(1);
335            let sender = Sender::new(s);
336            let mut decoder = ChunkedDecoder::new(async_std::io::Cursor::new(input), sender);
337
338            let mut output = String::new();
339            decoder.read_to_string(&mut output).await.unwrap();
340            assert_eq!(
341                output,
342                "X".to_string().repeat(0xFF7) + &"Y".to_string().repeat(4)
343            );
344        });
345    }
346}