Skip to main content

aws_runtime/content_encoding/body/
http_body_0_x.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use bytes::Bytes;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use crate::content_encoding::body::{AwsChunkedBody, AwsChunkedBodyError, AwsChunkedBodyState};
11use crate::content_encoding::{CHUNK_TERMINATOR, CRLF, TRAILER_SEPARATOR};
12
13impl<Inner> http_body_04x::Body for AwsChunkedBody<Inner>
14where
15    Inner: http_body_04x::Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
16{
17    type Data = Bytes;
18    type Error = aws_smithy_types::body::Error;
19
20    fn poll_data(
21        self: Pin<&mut Self>,
22        cx: &mut Context<'_>,
23    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
24        tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
25        let mut this = self.project();
26
27        use AwsChunkedBodyState::*;
28        match *this.state {
29            WritingChunk => {
30                if this.options.stream_length == 0 {
31                    // If the stream is empty, we skip to writing trailers after writing the CHUNK_TERMINATOR.
32                    *this.state = WritingTrailers;
33                    tracing::trace!("stream is empty, writing chunk terminator");
34                    Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
35                } else {
36                    *this.state = WritingChunkData;
37                    // A chunk must be prefixed by chunk size in hexadecimal
38                    let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
39                    tracing::trace!(%chunk_size, "writing chunk size");
40                    let chunk_size = Bytes::from(chunk_size);
41                    Poll::Ready(Some(Ok(chunk_size)))
42                }
43            }
44            WritingChunkData => match this.inner.poll_data(cx) {
45                Poll::Ready(Some(Ok(data))) => {
46                    tracing::trace!(len = data.len(), "writing chunk data");
47                    *this.inner_body_bytes_read_so_far += data.len();
48                    Poll::Ready(Some(Ok(data)))
49                }
50                Poll::Ready(None) => {
51                    let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
52                    let expected_stream_length = this.options.stream_length;
53                    if actual_stream_length != expected_stream_length {
54                        let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
55                            actual: actual_stream_length,
56                            expected: expected_stream_length,
57                        });
58                        return Poll::Ready(Some(Err(err)));
59                    };
60
61                    tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
62                    *this.state = WritingTrailers;
63                    // Since we wrote chunk data, we end it with a CRLF and since we only write
64                    // a single chunk, we write the CHUNK_TERMINATOR immediately after
65                    Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
66                }
67                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
68                Poll::Pending => Poll::Pending,
69            },
70            WritingTrailers => {
71                return match this.inner.poll_trailers(cx) {
72                    Poll::Ready(Ok(trailers)) => {
73                        *this.state = Closed;
74                        let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
75                        let actual_length = this.options.total_trailer_length();
76
77                        if expected_length != actual_length {
78                            let err = AwsChunkedBodyError::ReportedTrailerLengthMismatch {
79                                actual: actual_length,
80                                expected: expected_length,
81                            };
82                            return Poll::Ready(Some(Err(err.into())));
83                        }
84
85                        let mut trailers =
86                            trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
87                        // Insert the final CRLF to close the body
88                        trailers.extend_from_slice(CRLF.as_bytes());
89
90                        Poll::Ready(Some(Ok(trailers.into())))
91                    }
92                    Poll::Pending => Poll::Pending,
93                    Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
94                };
95            }
96            Closed => Poll::Ready(None),
97            ref otherwise => {
98                unreachable!(
99                    "invalid state {otherwise:?} for `poll_data` in http-02x; this is a bug"
100                )
101            }
102        }
103    }
104
105    fn poll_trailers(
106        self: Pin<&mut Self>,
107        _cx: &mut Context<'_>,
108    ) -> Poll<Result<Option<http_02x::HeaderMap<http_02x::HeaderValue>>, Self::Error>> {
109        // Trailers were already appended to the body because of the content encoding scheme
110        Poll::Ready(Ok(None))
111    }
112
113    fn is_end_stream(&self) -> bool {
114        self.state == AwsChunkedBodyState::Closed
115    }
116
117    fn size_hint(&self) -> http_body_04x::SizeHint {
118        http_body_04x::SizeHint::with_exact(self.options.encoded_length())
119    }
120}
121
122/// Writes trailers out into a `string` and then converts that `String` to a `Bytes` before
123/// returning.
124///
125/// - Trailer names are separated by a single colon only, no space.
126/// - Trailer names with multiple values will be written out one line per value, with the name
127///   appearing on each line.
128fn trailers_as_aws_chunked_bytes(
129    trailer_map: Option<http_02x::HeaderMap>,
130    estimated_length: u64,
131) -> bytes::BytesMut {
132    if let Some(trailer_map) = trailer_map {
133        let mut current_header_name = None;
134        let mut trailers =
135            bytes::BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
136
137        for (header_name, header_value) in trailer_map.into_iter() {
138            // When a header has multiple values, the name only comes up in iteration the first time
139            // we see it. Therefore, we need to keep track of the last name we saw and fall back to
140            // it when `header_name == None`.
141            current_header_name = header_name.or(current_header_name);
142
143            // In practice, this will always exist, but `if let` is nicer than unwrap
144            if let Some(header_name) = current_header_name.as_ref() {
145                trailers.extend_from_slice(header_name.as_ref());
146                trailers.extend_from_slice(TRAILER_SEPARATOR);
147                trailers.extend_from_slice(header_value.as_bytes());
148                trailers.extend_from_slice(CRLF.as_bytes());
149            }
150        }
151
152        trailers
153    } else {
154        bytes::BytesMut::new()
155    }
156}
157
158/// Given an optional `HeaderMap`, calculate the total number of bytes required to represent the
159/// `HeaderMap`. If no `HeaderMap` is given as input, return 0.
160///
161/// - Trailer names are separated by a single colon only, no space.
162/// - Trailer names with multiple values will be written out one line per value, with the name
163///   appearing on each line.
164fn total_rendered_length_of_trailers(trailer_map: Option<&http_02x::HeaderMap>) -> u64 {
165    match trailer_map {
166        Some(trailer_map) => trailer_map
167            .iter()
168            .map(|(trailer_name, trailer_value)| {
169                trailer_name.as_str().len()
170                    + TRAILER_SEPARATOR.len()
171                    + trailer_value.len()
172                    + CRLF.len()
173            })
174            .sum::<usize>() as u64,
175        None => 0,
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::{total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes};
182    use crate::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF};
183
184    use aws_smithy_types::body::SdkBody;
185    use bytes::{Buf, Bytes};
186    use bytes_utils::SegmentedBuf;
187    use http_02x::{HeaderMap, HeaderValue};
188    use http_body_04x::{Body, SizeHint};
189    use pin_project_lite::pin_project;
190
191    use std::io::Read;
192    use std::pin::Pin;
193    use std::task::{Context, Poll};
194    use std::time::Duration;
195
196    pin_project! {
197        struct SputteringBody {
198            parts: Vec<Option<Bytes>>,
199            cursor: usize,
200            delay_in_millis: u64,
201        }
202    }
203
204    impl SputteringBody {
205        fn len(&self) -> usize {
206            self.parts.iter().flatten().map(|b| b.len()).sum()
207        }
208    }
209
210    impl Body for SputteringBody {
211        type Data = Bytes;
212        type Error = aws_smithy_types::body::Error;
213
214        fn poll_data(
215            self: Pin<&mut Self>,
216            cx: &mut Context<'_>,
217        ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
218            if self.cursor == self.parts.len() {
219                return Poll::Ready(None);
220            }
221
222            let this = self.project();
223            let delay_in_millis = *this.delay_in_millis;
224            let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
225
226            match next_part {
227                None => {
228                    *this.cursor += 1;
229                    let waker = cx.waker().clone();
230                    tokio::spawn(async move {
231                        tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
232                        waker.wake();
233                    });
234                    Poll::Pending
235                }
236                Some(data) => {
237                    *this.cursor += 1;
238                    Poll::Ready(Some(Ok(data)))
239                }
240            }
241        }
242
243        fn poll_trailers(
244            self: Pin<&mut Self>,
245            _cx: &mut Context<'_>,
246        ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
247            Poll::Ready(Ok(None))
248        }
249
250        fn is_end_stream(&self) -> bool {
251            false
252        }
253
254        fn size_hint(&self) -> SizeHint {
255            SizeHint::new()
256        }
257    }
258
259    #[tokio::test]
260    async fn test_aws_chunked_encoding() {
261        let test_fut = async {
262            let input_str = "Hello world";
263            let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
264            let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
265
266            let mut output = SegmentedBuf::new();
267            while let Some(buf) = body.data().await {
268                output.push(buf.unwrap());
269            }
270
271            let mut actual_output = String::new();
272            output
273                .reader()
274                .read_to_string(&mut actual_output)
275                .expect("Doesn't cause IO errors");
276
277            let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
278
279            assert_eq!(expected_output, actual_output);
280            assert!(
281                body.trailers()
282                    .await
283                    .expect("no errors occurred during trailer polling")
284                    .is_none(),
285                "aws-chunked encoded bodies don't have normal HTTP trailers"
286            );
287
288            // You can insert a `tokio::time::sleep` here to verify the timeout works as intended
289        };
290
291        let timeout_duration = Duration::from_secs(3);
292        if tokio::time::timeout(timeout_duration, test_fut)
293            .await
294            .is_err()
295        {
296            panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
297        }
298    }
299
300    #[tokio::test]
301    async fn test_aws_chunked_encoding_sputtering_body() {
302        let test_fut = async {
303            let input = SputteringBody {
304                parts: vec![
305                    Some(Bytes::from_static(b"chunk 1, ")),
306                    None,
307                    Some(Bytes::from_static(b"chunk 2, ")),
308                    Some(Bytes::from_static(b"chunk 3, ")),
309                    None,
310                    None,
311                    Some(Bytes::from_static(b"chunk 4, ")),
312                    Some(Bytes::from_static(b"chunk 5, ")),
313                    Some(Bytes::from_static(b"chunk 6")),
314                ],
315                cursor: 0,
316                delay_in_millis: 500,
317            };
318            let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
319            let mut body = AwsChunkedBody::new(input, opts);
320
321            let mut output = SegmentedBuf::new();
322            while let Some(buf) = body.data().await {
323                output.push(buf.unwrap());
324            }
325
326            let mut actual_output = String::new();
327            output
328                .reader()
329                .read_to_string(&mut actual_output)
330                .expect("Doesn't cause IO errors");
331
332            let expected_output =
333                "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
334
335            assert_eq!(expected_output, actual_output);
336            assert!(
337                body.trailers()
338                    .await
339                    .expect("no errors occurred during trailer polling")
340                    .is_none(),
341                "aws-chunked encoded bodies don't have normal HTTP trailers"
342            );
343        };
344
345        let timeout_duration = Duration::from_secs(3);
346        if tokio::time::timeout(timeout_duration, test_fut)
347            .await
348            .is_err()
349        {
350            panic!(
351                "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
352            );
353        }
354    }
355
356    #[tokio::test]
357    #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
358    async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
359        let input_str = "Hello world";
360        // Test body has no trailers, so this length is incorrect and will trigger an assert panic
361        // When the panic occurs, it will actually expect a length of 44. This is because, when using
362        // aws-chunked encoding, each trailer will end with a CRLF which is 2 bytes long.
363        let wrong_trailer_len = 42;
364        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
365        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
366
367        // We don't care about the body contents but we have to read it all before checking for trailers
368        while let Some(buf) = body.data().await {
369            drop(buf.unwrap());
370        }
371
372        assert!(
373            body.trailers()
374                .await
375                .expect("no errors occurred during trailer polling")
376                .is_none(),
377            "aws-chunked encoded bodies don't have normal HTTP trailers"
378        );
379    }
380
381    #[tokio::test]
382    async fn test_aws_chunked_encoding_empty_body() {
383        let input_str = "";
384        let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
385        let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
386
387        let mut output = SegmentedBuf::new();
388        while let Some(buf) = body.data().await {
389            output.push(buf.unwrap());
390        }
391
392        let mut actual_output = String::new();
393        output
394            .reader()
395            .read_to_string(&mut actual_output)
396            .expect("Doesn't cause IO errors");
397
398        let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
399
400        assert_eq!(expected_output, actual_output);
401        assert!(
402            body.trailers()
403                .await
404                .expect("no errors occurred during trailer polling")
405                .is_none(),
406            "aws-chunked encoded bodies don't have normal HTTP trailers"
407        );
408    }
409
410    #[tokio::test]
411    async fn test_total_rendered_length_of_trailers() {
412        let mut trailers = HeaderMap::new();
413
414        trailers.insert("empty_value", HeaderValue::from_static(""));
415
416        trailers.insert("single_value", HeaderValue::from_static("value 1"));
417
418        trailers.insert("two_values", HeaderValue::from_static("value 1"));
419        trailers.append("two_values", HeaderValue::from_static("value 2"));
420
421        trailers.insert("three_values", HeaderValue::from_static("value 1"));
422        trailers.append("three_values", HeaderValue::from_static("value 2"));
423        trailers.append("three_values", HeaderValue::from_static("value 3"));
424
425        let trailers = Some(trailers);
426        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
427        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
428
429        assert_eq!(expected_length, actual_length);
430    }
431
432    #[tokio::test]
433    async fn test_total_rendered_length_of_empty_trailers() {
434        let trailers = Some(HeaderMap::new());
435        let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
436        let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
437
438        assert_eq!(expected_length, actual_length);
439    }
440}