aws_runtime/content_encoding/body/
http_body_0_x.rs1use bytes::Bytes;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use crate::content_encoding::body::{AwsChunkedBody, AwsChunkedBodyError, AwsChunkedBodyState};
11use crate::content_encoding::{CHUNK_TERMINATOR, CRLF, TRAILER_SEPARATOR};
12
13impl<Inner> http_body_04x::Body for AwsChunkedBody<Inner>
14where
15 Inner: http_body_04x::Body<Data = Bytes, Error = aws_smithy_types::body::Error>,
16{
17 type Data = Bytes;
18 type Error = aws_smithy_types::body::Error;
19
20 fn poll_data(
21 self: Pin<&mut Self>,
22 cx: &mut Context<'_>,
23 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
24 tracing::trace!(state = ?self.state, "polling AwsChunkedBody");
25 let mut this = self.project();
26
27 use AwsChunkedBodyState::*;
28 match *this.state {
29 WritingChunk => {
30 if this.options.stream_length == 0 {
31 *this.state = WritingTrailers;
33 tracing::trace!("stream is empty, writing chunk terminator");
34 Poll::Ready(Some(Ok(Bytes::from([CHUNK_TERMINATOR].concat()))))
35 } else {
36 *this.state = WritingChunkData;
37 let chunk_size = format!("{:X?}{CRLF}", this.options.stream_length);
39 tracing::trace!(%chunk_size, "writing chunk size");
40 let chunk_size = Bytes::from(chunk_size);
41 Poll::Ready(Some(Ok(chunk_size)))
42 }
43 }
44 WritingChunkData => match this.inner.poll_data(cx) {
45 Poll::Ready(Some(Ok(data))) => {
46 tracing::trace!(len = data.len(), "writing chunk data");
47 *this.inner_body_bytes_read_so_far += data.len();
48 Poll::Ready(Some(Ok(data)))
49 }
50 Poll::Ready(None) => {
51 let actual_stream_length = *this.inner_body_bytes_read_so_far as u64;
52 let expected_stream_length = this.options.stream_length;
53 if actual_stream_length != expected_stream_length {
54 let err = Box::new(AwsChunkedBodyError::StreamLengthMismatch {
55 actual: actual_stream_length,
56 expected: expected_stream_length,
57 });
58 return Poll::Ready(Some(Err(err)));
59 };
60
61 tracing::trace!("no more chunk data, writing CRLF and chunk terminator");
62 *this.state = WritingTrailers;
63 Poll::Ready(Some(Ok(Bytes::from([CRLF, CHUNK_TERMINATOR].concat()))))
66 }
67 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
68 Poll::Pending => Poll::Pending,
69 },
70 WritingTrailers => {
71 return match this.inner.poll_trailers(cx) {
72 Poll::Ready(Ok(trailers)) => {
73 *this.state = Closed;
74 let expected_length = total_rendered_length_of_trailers(trailers.as_ref());
75 let actual_length = this.options.total_trailer_length();
76
77 if expected_length != actual_length {
78 let err = AwsChunkedBodyError::ReportedTrailerLengthMismatch {
79 actual: actual_length,
80 expected: expected_length,
81 };
82 return Poll::Ready(Some(Err(err.into())));
83 }
84
85 let mut trailers =
86 trailers_as_aws_chunked_bytes(trailers, actual_length + 1);
87 trailers.extend_from_slice(CRLF.as_bytes());
89
90 Poll::Ready(Some(Ok(trailers.into())))
91 }
92 Poll::Pending => Poll::Pending,
93 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
94 };
95 }
96 Closed => Poll::Ready(None),
97 ref otherwise => {
98 unreachable!(
99 "invalid state {otherwise:?} for `poll_data` in http-02x; this is a bug"
100 )
101 }
102 }
103 }
104
105 fn poll_trailers(
106 self: Pin<&mut Self>,
107 _cx: &mut Context<'_>,
108 ) -> Poll<Result<Option<http_02x::HeaderMap<http_02x::HeaderValue>>, Self::Error>> {
109 Poll::Ready(Ok(None))
111 }
112
113 fn is_end_stream(&self) -> bool {
114 self.state == AwsChunkedBodyState::Closed
115 }
116
117 fn size_hint(&self) -> http_body_04x::SizeHint {
118 http_body_04x::SizeHint::with_exact(self.options.encoded_length())
119 }
120}
121
122fn trailers_as_aws_chunked_bytes(
129 trailer_map: Option<http_02x::HeaderMap>,
130 estimated_length: u64,
131) -> bytes::BytesMut {
132 if let Some(trailer_map) = trailer_map {
133 let mut current_header_name = None;
134 let mut trailers =
135 bytes::BytesMut::with_capacity(estimated_length.try_into().unwrap_or_default());
136
137 for (header_name, header_value) in trailer_map.into_iter() {
138 current_header_name = header_name.or(current_header_name);
142
143 if let Some(header_name) = current_header_name.as_ref() {
145 trailers.extend_from_slice(header_name.as_ref());
146 trailers.extend_from_slice(TRAILER_SEPARATOR);
147 trailers.extend_from_slice(header_value.as_bytes());
148 trailers.extend_from_slice(CRLF.as_bytes());
149 }
150 }
151
152 trailers
153 } else {
154 bytes::BytesMut::new()
155 }
156}
157
158fn total_rendered_length_of_trailers(trailer_map: Option<&http_02x::HeaderMap>) -> u64 {
165 match trailer_map {
166 Some(trailer_map) => trailer_map
167 .iter()
168 .map(|(trailer_name, trailer_value)| {
169 trailer_name.as_str().len()
170 + TRAILER_SEPARATOR.len()
171 + trailer_value.len()
172 + CRLF.len()
173 })
174 .sum::<usize>() as u64,
175 None => 0,
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::{total_rendered_length_of_trailers, trailers_as_aws_chunked_bytes};
182 use crate::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions, CHUNK_TERMINATOR, CRLF};
183
184 use aws_smithy_types::body::SdkBody;
185 use bytes::{Buf, Bytes};
186 use bytes_utils::SegmentedBuf;
187 use http_02x::{HeaderMap, HeaderValue};
188 use http_body_04x::{Body, SizeHint};
189 use pin_project_lite::pin_project;
190
191 use std::io::Read;
192 use std::pin::Pin;
193 use std::task::{Context, Poll};
194 use std::time::Duration;
195
196 pin_project! {
197 struct SputteringBody {
198 parts: Vec<Option<Bytes>>,
199 cursor: usize,
200 delay_in_millis: u64,
201 }
202 }
203
204 impl SputteringBody {
205 fn len(&self) -> usize {
206 self.parts.iter().flatten().map(|b| b.len()).sum()
207 }
208 }
209
210 impl Body for SputteringBody {
211 type Data = Bytes;
212 type Error = aws_smithy_types::body::Error;
213
214 fn poll_data(
215 self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
218 if self.cursor == self.parts.len() {
219 return Poll::Ready(None);
220 }
221
222 let this = self.project();
223 let delay_in_millis = *this.delay_in_millis;
224 let next_part = this.parts.get_mut(*this.cursor).unwrap().take();
225
226 match next_part {
227 None => {
228 *this.cursor += 1;
229 let waker = cx.waker().clone();
230 tokio::spawn(async move {
231 tokio::time::sleep(Duration::from_millis(delay_in_millis)).await;
232 waker.wake();
233 });
234 Poll::Pending
235 }
236 Some(data) => {
237 *this.cursor += 1;
238 Poll::Ready(Some(Ok(data)))
239 }
240 }
241 }
242
243 fn poll_trailers(
244 self: Pin<&mut Self>,
245 _cx: &mut Context<'_>,
246 ) -> Poll<Result<Option<HeaderMap<HeaderValue>>, Self::Error>> {
247 Poll::Ready(Ok(None))
248 }
249
250 fn is_end_stream(&self) -> bool {
251 false
252 }
253
254 fn size_hint(&self) -> SizeHint {
255 SizeHint::new()
256 }
257 }
258
259 #[tokio::test]
260 async fn test_aws_chunked_encoding() {
261 let test_fut = async {
262 let input_str = "Hello world";
263 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
264 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
265
266 let mut output = SegmentedBuf::new();
267 while let Some(buf) = body.data().await {
268 output.push(buf.unwrap());
269 }
270
271 let mut actual_output = String::new();
272 output
273 .reader()
274 .read_to_string(&mut actual_output)
275 .expect("Doesn't cause IO errors");
276
277 let expected_output = "B\r\nHello world\r\n0\r\n\r\n";
278
279 assert_eq!(expected_output, actual_output);
280 assert!(
281 body.trailers()
282 .await
283 .expect("no errors occurred during trailer polling")
284 .is_none(),
285 "aws-chunked encoded bodies don't have normal HTTP trailers"
286 );
287
288 };
290
291 let timeout_duration = Duration::from_secs(3);
292 if tokio::time::timeout(timeout_duration, test_fut)
293 .await
294 .is_err()
295 {
296 panic!("test_aws_chunked_encoding timed out after {timeout_duration:?}");
297 }
298 }
299
300 #[tokio::test]
301 async fn test_aws_chunked_encoding_sputtering_body() {
302 let test_fut = async {
303 let input = SputteringBody {
304 parts: vec![
305 Some(Bytes::from_static(b"chunk 1, ")),
306 None,
307 Some(Bytes::from_static(b"chunk 2, ")),
308 Some(Bytes::from_static(b"chunk 3, ")),
309 None,
310 None,
311 Some(Bytes::from_static(b"chunk 4, ")),
312 Some(Bytes::from_static(b"chunk 5, ")),
313 Some(Bytes::from_static(b"chunk 6")),
314 ],
315 cursor: 0,
316 delay_in_millis: 500,
317 };
318 let opts = AwsChunkedBodyOptions::new(input.len() as u64, Vec::new());
319 let mut body = AwsChunkedBody::new(input, opts);
320
321 let mut output = SegmentedBuf::new();
322 while let Some(buf) = body.data().await {
323 output.push(buf.unwrap());
324 }
325
326 let mut actual_output = String::new();
327 output
328 .reader()
329 .read_to_string(&mut actual_output)
330 .expect("Doesn't cause IO errors");
331
332 let expected_output =
333 "34\r\nchunk 1, chunk 2, chunk 3, chunk 4, chunk 5, chunk 6\r\n0\r\n\r\n";
334
335 assert_eq!(expected_output, actual_output);
336 assert!(
337 body.trailers()
338 .await
339 .expect("no errors occurred during trailer polling")
340 .is_none(),
341 "aws-chunked encoded bodies don't have normal HTTP trailers"
342 );
343 };
344
345 let timeout_duration = Duration::from_secs(3);
346 if tokio::time::timeout(timeout_duration, test_fut)
347 .await
348 .is_err()
349 {
350 panic!(
351 "test_aws_chunked_encoding_sputtering_body timed out after {timeout_duration:?}"
352 );
353 }
354 }
355
356 #[tokio::test]
357 #[should_panic = "called `Result::unwrap()` on an `Err` value: ReportedTrailerLengthMismatch { actual: 44, expected: 0 }"]
358 async fn test_aws_chunked_encoding_incorrect_trailer_length_panic() {
359 let input_str = "Hello world";
360 let wrong_trailer_len = 42;
364 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, vec![wrong_trailer_len]);
365 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
366
367 while let Some(buf) = body.data().await {
369 drop(buf.unwrap());
370 }
371
372 assert!(
373 body.trailers()
374 .await
375 .expect("no errors occurred during trailer polling")
376 .is_none(),
377 "aws-chunked encoded bodies don't have normal HTTP trailers"
378 );
379 }
380
381 #[tokio::test]
382 async fn test_aws_chunked_encoding_empty_body() {
383 let input_str = "";
384 let opts = AwsChunkedBodyOptions::new(input_str.len() as u64, Vec::new());
385 let mut body = AwsChunkedBody::new(SdkBody::from(input_str), opts);
386
387 let mut output = SegmentedBuf::new();
388 while let Some(buf) = body.data().await {
389 output.push(buf.unwrap());
390 }
391
392 let mut actual_output = String::new();
393 output
394 .reader()
395 .read_to_string(&mut actual_output)
396 .expect("Doesn't cause IO errors");
397
398 let expected_output = [CHUNK_TERMINATOR, CRLF].concat();
399
400 assert_eq!(expected_output, actual_output);
401 assert!(
402 body.trailers()
403 .await
404 .expect("no errors occurred during trailer polling")
405 .is_none(),
406 "aws-chunked encoded bodies don't have normal HTTP trailers"
407 );
408 }
409
410 #[tokio::test]
411 async fn test_total_rendered_length_of_trailers() {
412 let mut trailers = HeaderMap::new();
413
414 trailers.insert("empty_value", HeaderValue::from_static(""));
415
416 trailers.insert("single_value", HeaderValue::from_static("value 1"));
417
418 trailers.insert("two_values", HeaderValue::from_static("value 1"));
419 trailers.append("two_values", HeaderValue::from_static("value 2"));
420
421 trailers.insert("three_values", HeaderValue::from_static("value 1"));
422 trailers.append("three_values", HeaderValue::from_static("value 2"));
423 trailers.append("three_values", HeaderValue::from_static("value 3"));
424
425 let trailers = Some(trailers);
426 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
427 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
428
429 assert_eq!(expected_length, actual_length);
430 }
431
432 #[tokio::test]
433 async fn test_total_rendered_length_of_empty_trailers() {
434 let trailers = Some(HeaderMap::new());
435 let actual_length = total_rendered_length_of_trailers(trailers.as_ref());
436 let expected_length = (trailers_as_aws_chunked_bytes(trailers, actual_length).len()) as u64;
437
438 assert_eq!(expected_length, actual_length);
439 }
440}