aws_smithy_runtime/client/
stalled_stream_protection.rs1use crate::client::http::body::minimum_throughput::{
7 options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
8 UploadThroughput,
9};
10use aws_smithy_async::rt::sleep::SharedAsyncSleep;
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14 BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
15};
16use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use std::mem;
22
23#[derive(Debug, Default)]
25#[non_exhaustive]
26pub struct StalledStreamProtectionInterceptor;
27
28#[deprecated(
31 since = "1.2.0",
32 note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
33)]
34pub enum StalledStreamProtectionInterceptorKind {
35 RequestBody,
37 ResponseBody,
39 RequestAndResponseBody,
41}
42
43impl StalledStreamProtectionInterceptor {
44 #[deprecated(
46 since = "1.2.0",
47 note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
48 )]
49 #[allow(deprecated)]
50 pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
51 Default::default()
52 }
53}
54
55#[dyn_dispatch_hint]
56impl Intercept for StalledStreamProtectionInterceptor {
57 fn name(&self) -> &'static str {
58 "StalledStreamProtectionInterceptor"
59 }
60
61 fn modify_before_transmit(
62 &self,
63 context: &mut BeforeTransmitInterceptorContextMut<'_>,
64 runtime_components: &RuntimeComponents,
65 cfg: &mut ConfigBag,
66 ) -> Result<(), BoxError> {
67 if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
68 if sspcfg.upload_enabled() {
69 if let Some(0) = context.request().body().content_length() {
70 tracing::trace!(
71 "skipping stalled stream protection for zero length request body"
72 );
73 return Ok(());
74 }
75 let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
76 let now = time_source.now();
77
78 let options: MinimumThroughputBodyOptions = sspcfg.into();
79 let throughput = UploadThroughput::new(options.check_window(), now);
80 cfg.interceptor_state().store_put(throughput.clone());
81
82 tracing::trace!("adding stalled stream protection to request body");
83 let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
84 let it = it.map_preserve_contents(move |body| {
85 let time_source = time_source.clone();
86 SdkBody::from_body_1_x(ThroughputReadingBody::new(
87 time_source,
88 throughput.clone(),
89 body,
90 ))
91 });
92 let _ = mem::replace(context.request_mut().body_mut(), it);
93 }
94 }
95
96 Ok(())
97 }
98
99 fn modify_before_deserialization(
100 &self,
101 context: &mut BeforeDeserializationInterceptorContextMut<'_>,
102 runtime_components: &RuntimeComponents,
103 cfg: &mut ConfigBag,
104 ) -> Result<(), BoxError> {
105 if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
106 if sspcfg.download_enabled() {
107 let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
108 tracing::trace!("adding stalled stream protection to response body");
109 let sspcfg = sspcfg.clone();
110 let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
111 let it = it.map_preserve_contents(move |body| {
112 let sspcfg = sspcfg.clone();
113 let async_sleep = async_sleep.clone();
114 let time_source = time_source.clone();
115 let mtb = MinimumThroughputDownloadBody::new(
116 time_source,
117 async_sleep,
118 body,
119 sspcfg.into(),
120 );
121 SdkBody::from_body_1_x(mtb)
122 });
123 let _ = mem::replace(context.response_mut().body_mut(), it);
124 }
125 }
126 Ok(())
127 }
128}
129
130fn get_runtime_component_deps(
131 runtime_components: &RuntimeComponents,
132) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> {
133 let async_sleep = runtime_components.sleep_impl().ok_or(
134 "An async sleep implementation is required when stalled stream protection is enabled",
135 )?;
136 let time_source = runtime_components
137 .time_source()
138 .ok_or("A time source is required when stalled stream protection is enabled")?;
139 Ok((async_sleep, time_source))
140}