#![allow(dead_code)]
use crate::presigning::PresigningMarker;
use aws_runtime::content_encoding::AwsChunkedBodyOptions;
use aws_smithy_checksums::body::calculate;
use aws_smithy_checksums::body::ChecksumCache;
use aws_smithy_checksums::http::HttpChecksum;
use aws_smithy_checksums::ChecksumAlgorithm;
use aws_smithy_runtime::client::sdk_feature::SmithySdkFeature;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextMut, BeforeTransmitInterceptorContextMut, Input};
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::http::Request;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::checksum_config::RequestChecksumCalculation;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use http_1x::{HeaderMap, HeaderName};
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::{fmt, mem};
#[derive(Debug)]
pub(crate) enum Error {
ChecksumHeadersAreUnsupportedForStreamingBody,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
f,
"Checksum header insertion is only supported for non-streaming HTTP bodies. \
To checksum validate a streaming body, the checksums must be sent as trailers."
),
}
}
}
impl std::error::Error for Error {}
#[derive(Debug, Default, Clone)]
struct RequestChecksumInterceptorState {
checksum_algorithm: Option<String>,
request_checksum_required: bool,
calculate_checksum: Arc<AtomicBool>,
checksum_cache: ChecksumCache,
}
impl RequestChecksumInterceptorState {
fn checksum_algorithm(&self) -> Option<ChecksumAlgorithm> {
self.checksum_algorithm
.as_ref()
.and_then(|s| ChecksumAlgorithm::from_str(s.as_str()).ok())
}
fn calculate_checksum(&self) -> bool {
self.calculate_checksum.load(Ordering::SeqCst)
}
}
impl Storable for RequestChecksumInterceptorState {
type Storer = StoreReplace<Self>;
}
type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
pub(crate) struct DefaultRequestChecksumOverride {
custom_default: CustomDefaultFn,
}
impl fmt::Debug for DefaultRequestChecksumOverride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DefaultRequestChecksumOverride").finish()
}
}
impl Storable for DefaultRequestChecksumOverride {
type Storer = StoreReplace<Self>;
}
impl DefaultRequestChecksumOverride {
pub(crate) fn new<F>(custom_default: F) -> Self
where
F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
{
Self {
custom_default: Box::new(custom_default),
}
}
pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
(self.custom_default)(original, config_bag)
}
}
pub(crate) struct RequestChecksumInterceptor<AP, CM> {
algorithm_provider: AP,
checksum_mutator: CM,
}
impl<AP, CM> fmt::Debug for RequestChecksumInterceptor<AP, CM> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RequestChecksumInterceptor").finish()
}
}
impl<AP, CM> RequestChecksumInterceptor<AP, CM> {
pub(crate) fn new(algorithm_provider: AP, checksum_mutator: CM) -> Self {
Self {
algorithm_provider,
checksum_mutator,
}
}
}
impl<AP, CM> Intercept for RequestChecksumInterceptor<AP, CM>
where
AP: Fn(&Input) -> (Option<String>, bool) + Send + Sync,
CM: Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
{
fn name(&self) -> &'static str {
"RequestChecksumInterceptor"
}
fn modify_before_serialization(
&self,
context: &mut BeforeSerializationInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let (checksum_algorithm, request_checksum_required) = (self.algorithm_provider)(context.input());
cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
checksum_algorithm,
request_checksum_required,
checksum_cache: ChecksumCache::new(),
calculate_checksum: Arc::new(AtomicBool::new(false)),
});
Ok(())
}
fn modify_before_retry_loop(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let user_set_checksum_value = (self.checksum_mutator)(context.request_mut(), cfg).expect("Checksum header mutation should not fail");
let is_presigned = cfg.load::<PresigningMarker>().is_some();
if user_set_checksum_value || is_presigned {
cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
return Ok(());
}
let state = cfg
.get_mut_from_interceptor_state::<RequestChecksumInterceptorState>()
.expect("set in `read_before_serialization`");
let checksum_algorithm = state
.checksum_algorithm
.clone()
.map(|s| ChecksumAlgorithm::from_str(s.as_str()))
.transpose()?;
let mut state = std::mem::take(state);
if calculate_checksum(cfg, &state) {
state.calculate_checksum.store(true, Ordering::Release);
let checksum_algorithm = incorporate_custom_default(checksum_algorithm, cfg).unwrap_or_default();
state.checksum_algorithm = Some(checksum_algorithm.as_str().to_owned());
track_metric_for_selected_checksum_algorithm(cfg, &checksum_algorithm);
} else {
cfg.interceptor_state().store_put(AwsChunkedBodyOptions::disable_chunked_encoding());
}
cfg.interceptor_state().store_put(state);
Ok(())
}
fn modify_before_signing(
&self,
context: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
if !state.calculate_checksum() {
return Ok(());
}
let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
let mut checksum = checksum_algorithm.into_impl();
match context.request().body().bytes() {
Some(data) => {
tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
checksum.update(data);
for (hdr_name, hdr_value) in get_or_cache_headers(checksum.headers(), &state.checksum_cache).iter() {
context.request_mut().headers_mut().insert(hdr_name.clone(), hdr_value.clone());
}
}
None => {
tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
context
.request_mut()
.headers_mut()
.insert(HeaderName::from_static("x-amz-trailer"), checksum.header_name());
let trailer_len = HttpChecksum::size(checksum.as_ref());
let chunked_body_options = AwsChunkedBodyOptions::default().with_trailer_len(trailer_len);
cfg.interceptor_state().store_put(chunked_body_options);
}
}
Ok(())
}
fn modify_before_transmit(
&self,
ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
_runtime_components: &RuntimeComponents,
cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
if ctx.request().body().bytes().is_some() {
return Ok(());
}
let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
if !state.calculate_checksum() {
return Ok(());
}
let request = ctx.request_mut();
let mut body = {
let body = mem::replace(request.body_mut(), SdkBody::taken());
let checksum_algorithm = state.checksum_algorithm().expect("set in `modify_before_retry_loop`");
let checksum_cache = state.checksum_cache.clone();
body.map(move |body| {
let checksum = checksum_algorithm.into_impl();
let body = calculate::ChecksumBody::new(body, checksum).with_cache(checksum_cache.clone());
SdkBody::from_body_1_x(body)
})
};
mem::swap(request.body_mut(), &mut body);
Ok(())
}
}
fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
match cfg.load::<DefaultRequestChecksumOverride>() {
Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
None => checksum,
}
}
fn get_or_cache_headers(calculated_headers: HeaderMap, checksum_cache: &ChecksumCache) -> HeaderMap {
if let Some(cached_headers) = checksum_cache.get() {
if cached_headers != calculated_headers {
tracing::warn!(cached = ?cached_headers, calculated = ?calculated_headers, "calculated checksum differs from cached checksum!");
}
cached_headers
} else {
checksum_cache.set(calculated_headers.clone());
calculated_headers
}
}
fn calculate_checksum(cfg: &mut ConfigBag, state: &RequestChecksumInterceptorState) -> bool {
let request_checksum_calculation = cfg
.load::<RequestChecksumCalculation>()
.unwrap_or(&RequestChecksumCalculation::WhenSupported);
match request_checksum_calculation {
RequestChecksumCalculation::WhenRequired => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenRequired);
state.request_checksum_required
}
RequestChecksumCalculation::WhenSupported => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqWhenSupported);
true
}
unsupported => {
tracing::warn!(
more_info = "Unsupported value of RequestChecksumCalculation when setting user-agent metrics",
unsupported = ?unsupported
);
true
}
}
}
fn track_metric_for_selected_checksum_algorithm(cfg: &mut ConfigBag, checksum_algorithm: &ChecksumAlgorithm) {
match checksum_algorithm {
ChecksumAlgorithm::Crc32 => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32);
}
ChecksumAlgorithm::Crc32c => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc32c);
}
ChecksumAlgorithm::Crc64Nvme => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqCrc64);
}
#[allow(deprecated)]
ChecksumAlgorithm::Md5 => {
tracing::warn!(more_info = "Unsupported ChecksumAlgorithm MD5 set");
}
ChecksumAlgorithm::Sha1 => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha1);
}
ChecksumAlgorithm::Sha256 => {
cfg.interceptor_state().store_append(SmithySdkFeature::FlexibleChecksumsReqSha256);
}
unsupported => tracing::warn!(
more_info = "Unsupported value of ChecksumAlgorithm detected when setting user-agent metrics",
unsupported = ?unsupported),
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_checksums::ChecksumAlgorithm;
use aws_smithy_runtime_api::client::interceptors::context::{BeforeTransmitInterceptorContextMut, InterceptorContext};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_types::base64;
use aws_smithy_types::byte_stream::ByteStream;
use bytes::BytesMut;
use http_body_util::BodyExt;
use tempfile::NamedTempFile;
fn create_test_interceptor() -> RequestChecksumInterceptor<
impl Fn(&Input) -> (Option<String>, bool) + Send + Sync,
impl Fn(&mut Request, &ConfigBag) -> Result<bool, BoxError> + Send + Sync,
> {
fn algo(_: &Input) -> (Option<String>, bool) {
(Some("crc32".to_string()), false)
}
fn mutator(_: &mut Request, _: &ConfigBag) -> Result<bool, BoxError> {
Ok(false)
}
RequestChecksumInterceptor::new(algo, mutator)
}
#[tokio::test]
async fn test_checksum_body_is_retryable() {
use std::io::Write;
let mut file = NamedTempFile::new().unwrap();
let algorithm_str = "crc32c";
let checksum_algorithm: ChecksumAlgorithm = algorithm_str.parse().unwrap();
let mut crc32c_checksum = checksum_algorithm.into_impl();
for i in 0..10000 {
let line = format!("This is a large file created for testing purposes {}", i);
file.as_file_mut().write_all(line.as_bytes()).unwrap();
crc32c_checksum.update(line.as_bytes());
}
let crc32c_checksum = crc32c_checksum.finalize();
let request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
assert!(request.body().try_clone().is_some());
let interceptor = create_test_interceptor();
let mut cfg = ConfigBag::base();
cfg.interceptor_state().store_put(RequestChecksumInterceptorState {
checksum_algorithm: Some(algorithm_str.to_string()),
calculate_checksum: Arc::new(AtomicBool::new(true)),
..Default::default()
});
let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.enter_serialization_phase();
let _ = ctx.take_input();
ctx.set_request(request);
ctx.enter_before_transmit_phase();
let mut ctx: BeforeTransmitInterceptorContextMut<'_> = (&mut ctx).into();
interceptor.modify_before_transmit(&mut ctx, &runtime_components, &mut cfg).unwrap();
let mut body = ctx.request().body().try_clone().expect("body is retryable");
let mut body_data = BytesMut::new();
let mut header_value = None;
while let Some(Ok(frame)) = body.frame().await {
if frame.is_data() {
let data = frame.into_data().unwrap();
body_data.extend_from_slice(&data);
} else {
let trailers = frame.into_trailers().unwrap();
if let Some(hv) = trailers.get("x-amz-checksum-crc32c") {
header_value = Some(hv.to_str().unwrap().to_owned());
}
}
}
let body_str = std::str::from_utf8(&body_data).unwrap();
let expected = format!("This is a large file created for testing purposes 9999");
assert!(body_str.ends_with(&expected), "expected '{body_str}' to end with '{expected}'");
let expected_checksum = base64::encode(&crc32c_checksum);
assert_eq!(
header_value.as_ref(),
Some(&expected_checksum),
"expected checksum '{header_value:?}' to match '{expected_checksum}'"
);
let collected_body = body.collect().await.unwrap();
while let Some(trailer) = collected_body.trailers() {
if let Some(header_value) = trailer.get("x-amz-checksum-crc32c") {
let header_value = header_value.to_str().unwrap();
assert_eq!(
header_value, expected_checksum,
"expected checksum '{header_value}' to match '{expected_checksum}'"
);
}
}
}
}