#![allow(dead_code)]
use std::fmt;
use aws_runtime::{
auth::PayloadSigningOverride,
content_encoding::{header::X_AMZ_TRAILER_SIGNATURE, header_value::AWS_CHUNKED, AwsChunkedBody, AwsChunkedBodyOptions, DeferredSigner},
};
use aws_smithy_runtime_api::{
box_error::BoxError,
client::{
interceptors::{context::BeforeTransmitInterceptorContextMut, Intercept},
runtime_components::RuntimeComponents,
runtime_plugin::RuntimePlugin,
},
http::Request,
};
use aws_smithy_types::{
body::SdkBody,
config_bag::{ConfigBag, FrozenLayer, Layer, Storable, StoreReplace},
error::operation::BuildError,
};
use http_1x::{header, HeaderValue};
use http_body_1x::Body;
const X_AMZ_DECODED_CONTENT_LENGTH: &str = "x-amz-decoded-content-length";
const TRAILER_SEPARATOR: &[u8] = b":";
const SIGNATURE_VALUE_LENGTH: usize = 64;
const MIN_CHUNK_SIZE_BYTE: usize = 8192;
#[derive(Clone, Copy, Debug)]
pub(crate) enum ChunkSize {
Configured(usize),
DisableChunking,
}
impl Storable for ChunkSize {
type Storer = StoreReplace<Self>;
}
#[derive(Debug)]
pub(crate) struct ChunkSizeRuntimePlugin {
chunk_size: ChunkSize,
}
impl ChunkSizeRuntimePlugin {
pub(crate) fn new(chunk_size: ChunkSize) -> Self {
Self { chunk_size }
}
}
impl RuntimePlugin for ChunkSizeRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
let mut cfg = Layer::new("chunk_size");
cfg.store_put(self.chunk_size);
Some(cfg.freeze())
}
}
#[derive(Debug)]
enum Error {
UnsizedRequestBody,
ChunkSizeTooSmall { min: usize, actual: usize },
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be aws-chunked encoded."),
Self::ChunkSizeTooSmall { min, actual } => write!(f, "Chunk size must be at least {min} bytes, but {actual} was provided."),
}
}
}
impl std::error::Error for Error {}
#[derive(Debug)]
pub(crate) struct AwsChunkedContentEncodingInterceptor;
impl Intercept for AwsChunkedContentEncodingInterceptor {
fn name(&self) -> &'static str {
"AwsChunkedContentEncodingInterceptor"
}
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if must_not_use_chunked_encoding(context.request(), cfg) {
tracing::debug!("short-circuiting modify_before_signing because chunked encoding must not be used");
return Ok(());
}
let original_body_size = if let Some(size) = context
.request()
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|s| s.parse::<u64>().ok())
.or_else(|| context.request().body().size_hint().exact())
{
size
} else {
return Err(BuildError::other(Error::UnsizedRequestBody))?;
};
let sign_during_encoding = context.request().uri().starts_with("http:");
let chunked_body_options = create_chunked_body_options(sign_during_encoding, original_body_size, cfg).map_err(BuildError::other)?;
let request = context.request_mut();
request.headers_mut().insert(
header::HeaderName::from_static(X_AMZ_DECODED_CONTENT_LENGTH),
HeaderValue::from(original_body_size),
);
request
.headers_mut()
.insert(header::CONTENT_LENGTH, HeaderValue::from(chunked_body_options.encoded_length()));
request.headers_mut().remove(header::TRANSFER_ENCODING);
request.headers_mut().append(
header::CONTENT_ENCODING,
HeaderValue::from_str(AWS_CHUNKED)
.map_err(BuildError::other)
.expect("\"aws-chunked\" will always be a valid HeaderValue"),
);
cfg.interceptor_state().store_put(chunked_body_options);
if sign_during_encoding {
let (signer, sender) = DeferredSigner::new();
cfg.interceptor_state().store_put(signer);
cfg.interceptor_state().store_put(sender);
cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingSignedPayloadTrailer);
} else {
cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
}
Ok(())
}
fn modify_before_transmit(
&self,
ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if must_not_use_chunked_encoding(ctx.request(), cfg) {
tracing::debug!("short-circuiting modify_before_transmit because chunked encoding must not be used");
return Ok(());
}
let request = ctx.request_mut();
let mut body = {
let body = std::mem::replace(request.body_mut(), SdkBody::taken());
let opt = cfg
.get_mut_from_interceptor_state::<AwsChunkedBodyOptions>()
.ok_or_else(|| BuildError::other("AwsChunkedBodyOptions missing from config bag"))?;
let aws_chunked_body_options = std::mem::take(opt);
let signer = cfg
.get_mut_from_interceptor_state::<DeferredSigner>()
.map(|s| std::mem::replace(s, DeferredSigner::empty()));
body.map(move |body| {
let body = AwsChunkedBody::new(body, aws_chunked_body_options.clone());
let body = if let Some(signer) = &signer {
body.with_signer(signer.clone())
} else {
body
};
SdkBody::from_body_1_x(body)
})
};
std::mem::swap(request.body_mut(), &mut body);
Ok(())
}
}
fn must_not_use_chunked_encoding(request: &Request, cfg: &ConfigBag) -> bool {
match (request.body().bytes(), cfg.load::<AwsChunkedBodyOptions>()) {
(Some(_), _) => true,
(_, Some(options)) if options.disabled() => true,
_ => false,
}
}
fn create_chunked_body_options(sign_during_encoding: bool, original_body_size: u64, cfg: &mut ConfigBag) -> Result<AwsChunkedBodyOptions, Error> {
let mut chunked_body_options = if let Some(chunked_body_options) = cfg.get_mut_from_interceptor_state::<AwsChunkedBodyOptions>() {
let chunked_body_options = std::mem::take(chunked_body_options);
chunked_body_options.with_stream_length(original_body_size)
} else {
AwsChunkedBodyOptions::default().with_stream_length(original_body_size)
};
if let Some(user_chunk_size) = cfg.load::<ChunkSize>() {
match user_chunk_size {
ChunkSize::Configured(size) => {
chunked_body_options = chunked_body_options.with_chunk_size(*size);
}
ChunkSize::DisableChunking => {
chunked_body_options = chunked_body_options.with_chunk_size(original_body_size as usize);
}
}
}
let chunk_size = chunked_body_options.chunk_size();
if chunk_size < MIN_CHUNK_SIZE_BYTE {
return Err(Error::ChunkSizeTooSmall {
min: MIN_CHUNK_SIZE_BYTE,
actual: chunk_size,
});
}
let chunked_body_options = chunked_body_options.signed_chunked_encoding(sign_during_encoding);
let chunked_body_options = if sign_during_encoding && !chunked_body_options.is_trailer_empty() {
chunked_body_options.with_trailer_len((X_AMZ_TRAILER_SIGNATURE.len() + TRAILER_SEPARATOR.len() + SIGNATURE_VALUE_LENGTH) as u64)
} else {
chunked_body_options
};
Ok(chunked_body_options)
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, Input, InterceptorContext};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::byte_stream::ByteStream;
use bytes::BytesMut;
use http_body_util::BodyExt;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_aws_chunked_body_is_retryable() {
let mut file = NamedTempFile::new().unwrap();
for i in 0..10000 {
let line = format!("This is a large file created for testing purposes {}", i);
file.as_file_mut().write_all(line.as_bytes()).unwrap();
}
let stream_length = file.as_file().metadata().unwrap().len();
let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
assert!(request.body().try_clone().is_some());
let interceptor = AwsChunkedContentEncodingInterceptor;
let mut cfg = ConfigBag::base();
cfg.interceptor_state()
.store_put(AwsChunkedBodyOptions::default().with_stream_length(stream_length));
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.enter_serialization_phase();
let _ = ctx.take_input();
ctx.set_request(request);
ctx.enter_before_transmit_phase();
let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
let mut body = ctx.request().body().try_clone().expect("body is retryable");
let mut body_data = BytesMut::new();
while let Some(Ok(frame)) = body.frame().await {
if frame.is_data() {
let data = frame.into_data().unwrap();
body_data.extend_from_slice(&data);
}
}
let body_str = std::str::from_utf8(&body_data).unwrap();
let expected = "This is a large file created for testing purposes 9999\r\n0\r\n\r\n";
assert!(body_str.ends_with(expected), "expected '{body_str}' to end with '{expected}'");
}
#[tokio::test]
async fn test_deferred_signer_and_payload_override_when_not_over_tls() {
let mut file = NamedTempFile::new().unwrap();
file.as_file_mut().write_all(b"test data").unwrap();
let stream_length = file.as_file().metadata().unwrap().len();
let mut request = HttpRequest::new(streaming_body(&file).await);
*request.uri_mut() = http_1x::Uri::from_static("http://example.com").into();
let interceptor = AwsChunkedContentEncodingInterceptor;
let mut cfg = ConfigBag::base();
cfg.interceptor_state()
.store_put(AwsChunkedBodyOptions::default().with_stream_length(stream_length));
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.enter_serialization_phase();
let _ = ctx.take_input();
ctx.set_request(request);
ctx.enter_before_transmit_phase();
let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
interceptor.modify_before_signing(&mut ctx, &runtime_components, &mut cfg).unwrap();
assert!(cfg.load::<DeferredSigner>().is_some());
assert!(matches!(
cfg.load::<PayloadSigningOverride>(),
Some(&PayloadSigningOverride::StreamingSignedPayloadTrailer)
));
}
#[tokio::test]
async fn test_short_circuit_modify_before_signing() {
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.enter_serialization_phase();
let _ = ctx.take_input();
let request = HttpRequest::new(SdkBody::from("in-memory body, must not use chunked encoding"));
ctx.set_request(request);
ctx.enter_before_transmit_phase();
let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut cfg = ConfigBag::base();
cfg.interceptor_state().store_put(AwsChunkedBodyOptions::default());
let interceptor = AwsChunkedContentEncodingInterceptor;
interceptor.modify_before_signing(&mut ctx, &runtime_components, &mut cfg).unwrap();
let request = ctx.request();
assert!(request.headers().get(header::CONTENT_ENCODING).is_none());
assert!(request
.headers()
.get(header::HeaderName::from_static(X_AMZ_DECODED_CONTENT_LENGTH))
.is_none());
}
#[tokio::test]
async fn test_short_circuit_modify_before_transmit() {
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.enter_serialization_phase();
let _ = ctx.take_input();
let request = HttpRequest::new(SdkBody::from("in-memory body, must not use chunked encoding"));
ctx.set_request(request);
ctx.enter_before_transmit_phase();
let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut cfg = ConfigBag::base();
cfg.interceptor_state().store_put(AwsChunkedBodyOptions::default());
let interceptor = AwsChunkedContentEncodingInterceptor;
interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
let mut body = ctx.request().body().try_clone().expect("body is retryable");
let mut body_data = BytesMut::new();
while let Some(Ok(frame)) = body.frame().await {
if frame.is_data() {
let data = frame.into_data().unwrap();
body_data.extend_from_slice(&data);
}
}
let body_str = std::str::from_utf8(&body_data).unwrap();
assert_eq!("in-memory body, must not use chunked encoding", body_str);
}
#[test]
fn test_must_not_use_chunked_encoding_with_in_memory_body() {
let request = HttpRequest::new(SdkBody::from("test body"));
let cfg = ConfigBag::base();
assert!(must_not_use_chunked_encoding(&request, &cfg));
}
async fn streaming_body(path: impl AsRef<std::path::Path>) -> SdkBody {
let file = path.as_ref();
ByteStream::read_from().path(&file).build().await.unwrap().into_inner()
}
#[tokio::test]
async fn test_must_not_use_chunked_encoding_with_disabled_option() {
let file = NamedTempFile::new().unwrap();
let request = HttpRequest::new(streaming_body(&file).await);
let mut cfg = ConfigBag::base();
cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
assert!(must_not_use_chunked_encoding(&request, &cfg));
}
#[tokio::test]
async fn test_chunked_encoding_is_used() {
let file = NamedTempFile::new().unwrap();
let request = HttpRequest::new(streaming_body(&file).await);
let cfg = ConfigBag::base();
assert!(!must_not_use_chunked_encoding(&request, &cfg));
}
}