1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use crate::client::http::body::minimum_throughput::{
    options::MinimumThroughputBodyOptions, MinimumThroughputDownloadBody, ThroughputReadingBody,
    UploadThroughput,
};
use aws_smithy_async::rt::sleep::SharedAsyncSleep;
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::interceptors::context::{
    BeforeDeserializationInterceptorContextMut, BeforeTransmitInterceptorContextMut,
};
use aws_smithy_runtime_api::client::interceptors::Intercept;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::config_bag::ConfigBag;
use std::mem;

/// Adds stalled stream protection when sending requests and/or receiving responses.
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct StalledStreamProtectionInterceptor;

/// Stalled stream protection can be enable for request bodies, response bodies,
/// or both.
#[deprecated(
    since = "1.2.0",
    note = "This kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag."
)]
pub enum StalledStreamProtectionInterceptorKind {
    /// Enable stalled stream protection for request bodies.
    RequestBody,
    /// Enable stalled stream protection for response bodies.
    ResponseBody,
    /// Enable stalled stream protection for both request and response bodies.
    RequestAndResponseBody,
}

impl StalledStreamProtectionInterceptor {
    /// Create a new stalled stream protection interceptor.
    #[deprecated(
        since = "1.2.0",
        note = "The kind enum is no longer used. Configuration is stored in StalledStreamProtectionConfig in the config bag. Construct the interceptor using Default."
    )]
    #[allow(deprecated)]
    pub fn new(_kind: StalledStreamProtectionInterceptorKind) -> Self {
        Default::default()
    }
}

impl Intercept for StalledStreamProtectionInterceptor {
    fn name(&self) -> &'static str {
        "StalledStreamProtectionInterceptor"
    }

    fn modify_before_transmit(
        &self,
        context: &mut BeforeTransmitInterceptorContextMut<'_>,
        runtime_components: &RuntimeComponents,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>().cloned() {
            if sspcfg.upload_enabled() {
                if let Some(0) = context.request().body().content_length() {
                    tracing::trace!(
                        "skipping stalled stream protection for zero length request body"
                    );
                    return Ok(());
                }
                let (_async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
                let now = time_source.now();

                let options: MinimumThroughputBodyOptions = sspcfg.into();
                let throughput = UploadThroughput::new(options.check_window(), now);
                cfg.interceptor_state().store_put(throughput.clone());

                tracing::trace!("adding stalled stream protection to request body");
                let it = mem::replace(context.request_mut().body_mut(), SdkBody::taken());
                let it = it.map_preserve_contents(move |body| {
                    let time_source = time_source.clone();
                    SdkBody::from_body_0_4(ThroughputReadingBody::new(
                        time_source,
                        throughput.clone(),
                        body,
                    ))
                });
                let _ = mem::replace(context.request_mut().body_mut(), it);
            }
        }

        Ok(())
    }

    fn modify_before_deserialization(
        &self,
        context: &mut BeforeDeserializationInterceptorContextMut<'_>,
        runtime_components: &RuntimeComponents,
        cfg: &mut ConfigBag,
    ) -> Result<(), BoxError> {
        if let Some(sspcfg) = cfg.load::<StalledStreamProtectionConfig>() {
            if sspcfg.download_enabled() {
                let (async_sleep, time_source) = get_runtime_component_deps(runtime_components)?;
                tracing::trace!("adding stalled stream protection to response body");
                let sspcfg = sspcfg.clone();
                let it = mem::replace(context.response_mut().body_mut(), SdkBody::taken());
                let it = it.map_preserve_contents(move |body| {
                    let sspcfg = sspcfg.clone();
                    let async_sleep = async_sleep.clone();
                    let time_source = time_source.clone();
                    let mtb = MinimumThroughputDownloadBody::new(
                        time_source,
                        async_sleep,
                        body,
                        sspcfg.into(),
                    );
                    SdkBody::from_body_0_4(mtb)
                });
                let _ = mem::replace(context.response_mut().body_mut(), it);
            }
        }
        Ok(())
    }
}

fn get_runtime_component_deps(
    runtime_components: &RuntimeComponents,
) -> Result<(SharedAsyncSleep, SharedTimeSource), BoxError> {
    let async_sleep = runtime_components.sleep_impl().ok_or(
        "An async sleep implementation is required when stalled stream protection is enabled",
    )?;
    let time_source = runtime_components
        .time_source()
        .ok_or("A time source is required when stalled stream protection is enabled")?;
    Ok((async_sleep, time_source))
}