Skip to main content

aws_smithy_runtime/client/
stalled_stream_protection.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::http::body::minimum_throughput::{
7    options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
8    UploadThroughput,
9};
10use aws_smithy_async::rt::sleep::SharedAsyncSleep;
11use aws_smithy_async::time::SharedTimeSource;
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::{
14    BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
15};
16use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
17use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
18use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use std::mem;
22
23/// Adds stalled stream protection when sending requests and/or receiving responses.
24#[derive(Debug, Default)]
25#[non_exhaustive]
26pub struct StalledStreamProtectionInterceptor;
27
28/// Stalled stream protection can be enable for request bodies, response bodies,
29/// or both.
30#[deprecated(
31    since = "1.2.0",
32    note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
33)]
34pub enum StalledStreamProtectionInterceptorKind {
35    /// Enable stalled stream protection for request bodies.
36    RequestBody,
37    /// Enable stalled stream protection for response bodies.
38    ResponseBody,
39    /// Enable stalled stream protection for both request and response bodies.
40    RequestAndResponseBody,
41}
42
43impl StalledStreamProtectionInterceptor {
44    /// Create a new stalled stream protection interceptor.
45    #[deprecated(
46        since = "1.2.0",
47        note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
48    )]
49    #[allow(deprecated)]
50    pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
51        Default::default()
52    }
53}
54
55#[dyn_dispatch_hint]
56impl Intercept for StalledStreamProtectionInterceptor {
57    fn name(&self) -> &'static str {
58        "StalledStreamProtectionInterceptor"
59    }
60
61    fn modify_before_transmit(
62        &self,
63        context: &mut BeforeTransmitInterceptorContextMut<'_>,
64        runtime_components: &RuntimeComponents,
65        cfg: &mut ConfigBag,
66    ) -> Result<(), BoxError> {
67        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
68            if sspcfg.upload_enabled() {
69                if let Some(0) = context.request().body().content_length() {
70                    tracing::trace!(
71                        "skipping stalled stream protection for zero length request body"
72                    );
73                    return Ok(());
74                }
75                let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
76                let now = time_source.now();
77
78                let options: MinimumThroughputBodyOptions = sspcfg.into();
79                let throughput = UploadThroughput::new(options.check_window(), now);
80                cfg.interceptor_state().store_put(throughput.clone());
81
82                tracing::trace!("adding stalled stream protection to request body");
83                let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
84                let it = it.map_preserve_contents(move |body| {
85                    let time_source = time_source.clone();
86                    SdkBody::from_body_1_x(ThroughputReadingBody::new(
87                        time_source,
88                        throughput.clone(),
89                        body,
90                    ))
91                });
92                let _ = mem::replace(context.request_mut().body_mut(), it);
93            }
94        }
95
96        Ok(())
97    }
98
99    fn modify_before_deserialization(
100        &self,
101        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
102        runtime_components: &RuntimeComponents,
103        cfg: &mut ConfigBag,
104    ) -> Result<(), BoxError> {
105        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
106            if sspcfg.download_enabled() {
107                let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
108                tracing::trace!("adding stalled stream protection to response body");
109                let sspcfg = sspcfg.clone();
110                let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
111                let it = it.map_preserve_contents(move |body| {
112                    let sspcfg = sspcfg.clone();
113                    let async_sleep = async_sleep.clone();
114                    let time_source = time_source.clone();
115                    let mtb = MinimumThroughputDownloadBody::new(
116                        time_source,
117                        async_sleep,
118                        body,
119                        sspcfg.into(),
120                    );
121                    SdkBody::from_body_1_x(mtb)
122                });
123                let _ = mem::replace(context.response_mut().body_mut(), it);
124            }
125        }
126        Ok(())
127    }
128}
129
130fn get_runtime_component_deps(
131    runtime_components: &RuntimeComponents,
132) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> {
133    let async_sleep = runtime_components.sleep_impl().ok_or(
134        "An async sleep implementation is required when stalled stream protection is enabled",
135    )?;
136    let time_source = runtime_components
137        .time_source()
138        .ok_or("A time source is required when stalled stream protection is enabled")?;
139    Ok((async_sleep, time_source))
140}