1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_lite::io::{self, AsyncRead as Read};
7use futures_lite::ready;
8use http_types::trailers::{Sender, Trailers};
9
10#[derive(Debug)]
13pub struct ChunkedDecoder<R: Read> {
14 inner: R,
16 state: State,
18 chunk_size: u64,
20 trailer_sender: Option<Sender>,
22}
23
24impl<R: Read> ChunkedDecoder<R> {
25 pub(crate) fn new(inner: R, trailer_sender: Sender) -> Self {
26 ChunkedDecoder {
27 inner,
28 state: State::ChunkSize,
29 chunk_size: 0,
30 trailer_sender: Some(trailer_sender),
31 }
32 }
33}
34
35enum State {
37 ChunkSize,
39 ChunkSizeExpectLf,
41 ChunkBody,
43 ChunkBodyExpectCr,
45 ChunkBodyExpectLf,
47 Trailers(usize, Box<[u8; 8192]>),
49 TrailerSending(Pin<Box<dyn Future<Output = ()> + 'static + Send + Sync>>),
51 Done,
53}
54
55impl fmt::Debug for State {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 State::ChunkSize => write!(f, "State::ChunkSize"),
59 State::ChunkSizeExpectLf => write!(f, "State::ChunkSizeExpectLf"),
60 State::ChunkBody => write!(f, "State::ChunkBody"),
61 State::ChunkBodyExpectCr => write!(f, "State::ChunkBodyExpectCr"),
62 State::ChunkBodyExpectLf => write!(f, "State::ChunkBodyExpectLf"),
63 State::Trailers(len, _) => write!(f, "State::Trailers({}, _)", len),
64 State::TrailerSending(_) => write!(f, "State::TrailerSending"),
65 State::Done => write!(f, "State::Done"),
66 }
67 }
68}
69
70impl<R: Read + Unpin> ChunkedDecoder<R> {
71 fn poll_read_byte(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<u8>> {
72 let mut byte = [0u8];
73 if ready!(Pin::new(&mut self.inner).poll_read(cx, &mut byte))? == 1 {
74 Poll::Ready(Ok(byte[0]))
75 } else {
76 eof()
77 }
78 }
79
80 fn expect_byte(
81 &mut self,
82 cx: &mut Context<'_>,
83 expected_byte: u8,
84 expected: &'static str,
85 ) -> Poll<io::Result<()>> {
86 let byte = ready!(self.poll_read_byte(cx))?;
87 if byte == expected_byte {
88 Poll::Ready(Ok(()))
89 } else {
90 unexpected(byte, expected)
91 }
92 }
93
94 fn send_trailers(&mut self, trailers: Trailers) {
95 let sender = self
96 .trailer_sender
97 .take()
98 .expect("invalid chunked state, tried sending multiple trailers");
99 let fut = Box::pin(sender.send(trailers));
100 self.state = State::TrailerSending(fut);
101 }
102}
103
104fn eof<T>() -> Poll<io::Result<T>> {
105 Poll::Ready(Err(io::Error::new(
106 io::ErrorKind::UnexpectedEof,
107 "Unexpected EOF when decoding chunked data",
108 )))
109}
110
111fn unexpected<T>(byte: u8, expected: &'static str) -> Poll<io::Result<T>> {
112 Poll::Ready(Err(io::Error::new(
113 io::ErrorKind::InvalidData,
114 format!("Unexpected byte {}; expected {}", byte, expected),
115 )))
116}
117
118fn overflow() -> io::Error {
119 io::Error::new(io::ErrorKind::InvalidData, "Chunk size overflowed 64 bits")
120}
121
122impl<R: Read + Unpin> Read for ChunkedDecoder<R> {
123 #[allow(missing_doc_code_examples)]
124 fn poll_read(
125 mut self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 buf: &mut [u8],
128 ) -> Poll<io::Result<usize>> {
129 let this = &mut *self;
130
131 loop {
132 match this.state {
133 State::ChunkSize => {
134 let byte = ready!(this.poll_read_byte(cx))?;
135 let digit = match byte {
136 b'0'..=b'9' => byte - b'0',
137 b'a'..=b'f' => 10 + byte - b'a',
138 b'A'..=b'F' => 10 + byte - b'A',
139 b'\r' => {
140 this.state = State::ChunkSizeExpectLf;
141 continue;
142 }
143 _ => {
144 return unexpected(byte, "hex digit or CR");
145 }
146 };
147 this.chunk_size = this
148 .chunk_size
149 .checked_mul(16)
150 .ok_or_else(overflow)?
151 .checked_add(digit as u64)
152 .ok_or_else(overflow)?;
153 }
154 State::ChunkSizeExpectLf => {
155 ready!(this.expect_byte(cx, b'\n', "LF"))?;
156 if this.chunk_size == 0 {
157 this.state = State::Trailers(0, Box::new([0u8; 8192]));
158 } else {
159 this.state = State::ChunkBody;
160 }
161 }
162 State::ChunkBody => {
163 let max_bytes = std::cmp::min(
164 buf.len(),
165 std::cmp::min(this.chunk_size, usize::MAX as u64) as usize,
166 );
167 let bytes_read =
168 ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf[..max_bytes]))?;
169 this.chunk_size -= bytes_read as u64;
170 if bytes_read == 0 {
171 return eof();
172 } else if this.chunk_size == 0 {
173 this.state = State::ChunkBodyExpectCr;
174 }
175 return Poll::Ready(Ok(bytes_read));
176 }
177 State::ChunkBodyExpectCr => {
178 ready!(this.expect_byte(cx, b'\r', "CR"))?;
179 this.state = State::ChunkBodyExpectLf;
180 }
181 State::ChunkBodyExpectLf => {
182 ready!(this.expect_byte(cx, b'\n', "LF"))?;
183 this.state = State::ChunkSize;
184 }
185 State::Trailers(ref mut len, ref mut buf) => {
186 let bytes_read =
187 ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf[*len..]))?;
188 *len += bytes_read;
189 let len = *len;
190 if len == 0 {
191 this.send_trailers(Trailers::new());
192 continue;
193 }
194 if bytes_read == 0 {
195 return eof();
196 }
197 let mut headers = [httparse::EMPTY_HEADER; 16];
198 let parse_result = httparse::parse_headers(&buf[..len], &mut headers)
199 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
200 use httparse::Status;
201 match parse_result {
202 Status::Partial => {
203 if len == buf.len() {
204 return eof();
205 } else {
206 return Poll::Pending;
207 }
208 }
209 Status::Complete((offset, headers)) => {
210 if offset != len {
211 return unexpected(buf[offset], "end of trailers");
212 }
213 let mut trailers = Trailers::new();
214 for header in headers {
215 trailers.insert(
216 header.name,
217 String::from_utf8_lossy(header.value).as_ref(),
218 );
219 }
220 this.send_trailers(trailers);
221 }
222 }
223 }
224 State::TrailerSending(ref mut fut) => {
225 ready!(Pin::new(fut).poll(cx));
226 this.state = State::Done;
227 }
228 State::Done => return Poll::Ready(Ok(0)),
229 }
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use async_std::prelude::*;
238
239 #[test]
240 fn test_chunked_wiki() {
241 async_std::task::block_on(async move {
242 let input = async_std::io::Cursor::new(
243 "4\r\n\
244 Wiki\r\n\
245 5\r\n\
246 pedia\r\n\
247 E\r\n in\r\n\
248 \r\n\
249 chunks.\r\n\
250 0\r\n\
251 \r\n"
252 .as_bytes(),
253 );
254
255 let (s, _r) = async_channel::bounded(1);
256 let sender = Sender::new(s);
257 let mut decoder = ChunkedDecoder::new(input, sender);
258
259 let mut output = String::new();
260 decoder.read_to_string(&mut output).await.unwrap();
261 assert_eq!(
262 output,
263 "Wikipedia in\r\n\
264 \r\n\
265 chunks."
266 );
267 });
268 }
269
270 #[test]
271 fn test_chunked_big() {
272 async_std::task::block_on(async move {
273 let mut input: Vec<u8> = b"800\r\n".to_vec();
274 input.extend(vec![b'X'; 2048]);
275 input.extend(b"\r\n1800\r\n");
276 input.extend(vec![b'Y'; 6144]);
277 input.extend(b"\r\n800\r\n");
278 input.extend(vec![b'Z'; 2048]);
279 input.extend(b"\r\n0\r\n\r\n");
280
281 let (s, _r) = async_channel::bounded(1);
282 let sender = Sender::new(s);
283 let mut decoder = ChunkedDecoder::new(async_std::io::Cursor::new(input), sender);
284
285 let mut output = String::new();
286 decoder.read_to_string(&mut output).await.unwrap();
287
288 let mut expected = vec![b'X'; 2048];
289 expected.extend(vec![b'Y'; 6144]);
290 expected.extend(vec![b'Z'; 2048]);
291 assert_eq!(output.len(), 10240);
292 assert_eq!(output.as_bytes(), expected.as_slice());
293 });
294 }
295
296 #[test]
297 fn test_chunked_mdn() {
298 async_std::task::block_on(async move {
299 let input = async_std::io::Cursor::new(
300 "7\r\n\
301 Mozilla\r\n\
302 9\r\n\
303 Developer\r\n\
304 7\r\n\
305 Network\r\n\
306 0\r\n\
307 Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\n\
308 \r\n"
309 .as_bytes(),
310 );
311 let (s, r) = async_channel::bounded(1);
312 let sender = Sender::new(s);
313 let mut decoder = ChunkedDecoder::new(input, sender);
314
315 let mut output = String::new();
316 decoder.read_to_string(&mut output).await.unwrap();
317 assert_eq!(output, "MozillaDeveloperNetwork");
318
319 let trailers = r.recv().await.unwrap();
320 assert_eq!(trailers.iter().count(), 1);
321 assert_eq!(trailers["Expires"], "Wed, 21 Oct 2015 07:28:00 GMT");
322 });
323 }
324
325 #[test]
326 fn test_ff7() {
327 async_std::task::block_on(async move {
328 let mut input: Vec<u8> = b"FF7\r\n".to_vec();
329 input.extend(vec![b'X'; 0xFF7]);
330 input.extend(b"\r\n4\r\n");
331 input.extend(vec![b'Y'; 4]);
332 input.extend(b"\r\n0\r\n\r\n");
333
334 let (s, _r) = async_channel::bounded(1);
335 let sender = Sender::new(s);
336 let mut decoder = ChunkedDecoder::new(async_std::io::Cursor::new(input), sender);
337
338 let mut output = String::new();
339 decoder.read_to_string(&mut output).await.unwrap();
340 assert_eq!(
341 output,
342 "X".to_string().repeat(0xFF7) + &"Y".to_string().repeat(4)
343 );
344 });
345 }
346}