Skip to main content

aws_runtime/
request_info.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::service_clock_skew::ServiceClockSkew;
7use aws_smithy_async::time::TimeSource;
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
10use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
11use aws_smithy_runtime_api::client::retries::RequestAttempts;
12use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
13use aws_smithy_types::config_bag::ConfigBag;
14use aws_smithy_types::date_time::Format;
15use aws_smithy_types::retry::RetryConfig;
16use aws_smithy_types::timeout::TimeoutConfig;
17use aws_smithy_types::DateTime;
18use http_1x::{HeaderName, HeaderValue};
19use std::borrow::Cow;
20use std::time::Duration;
21
22#[allow(clippy::declare_interior_mutable_const)] // we will never mutate this
23const AMZ_SDK_REQUEST: HeaderName = HeaderName::from_static("amz-sdk-request");
24
25/// Generates and attaches a request header that communicates request-related metadata.
26/// Examples include:
27///
28/// - When the client will time out this request.
29/// - How many times the request has been retried.
30/// - The maximum number of retries that the client will attempt.
31#[non_exhaustive]
32#[derive(Debug, Default)]
33pub struct RequestInfoInterceptor {}
34
35impl RequestInfoInterceptor {
36    /// Creates a new `RequestInfoInterceptor`
37    pub fn new() -> Self {
38        RequestInfoInterceptor {}
39    }
40}
41
42impl RequestInfoInterceptor {
43    fn build_attempts_pair(
44        &self,
45        cfg: &ConfigBag,
46    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
47        let request_attempts = cfg
48            .load::<RequestAttempts>()
49            .map(|r_a| r_a.attempts())
50            .unwrap_or(0);
51        let request_attempts = request_attempts.to_string();
52        Some((Cow::Borrowed("attempt"), Cow::Owned(request_attempts)))
53    }
54
55    fn build_max_attempts_pair(
56        &self,
57        cfg: &ConfigBag,
58    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
59        if let Some(retry_config) = cfg.load::<RetryConfig>() {
60            let max_attempts = retry_config.max_attempts().to_string();
61            Some((Cow::Borrowed("max"), Cow::Owned(max_attempts)))
62        } else {
63            None
64        }
65    }
66
67    fn build_ttl_pair(
68        &self,
69        cfg: &ConfigBag,
70        timesource: impl TimeSource,
71    ) -> Option<(Cow<'static, str>, Cow<'static, str>)> {
72        let timeout_config = cfg.load::<TimeoutConfig>()?;
73        let socket_read = timeout_config.read_timeout()?;
74        let estimated_skew: Duration = cfg.load::<ServiceClockSkew>().cloned()?.into();
75        let current_time = timesource.now();
76        let ttl = current_time.checked_add(socket_read + estimated_skew)?;
77        let mut timestamp = DateTime::from(ttl);
78        // Set subsec_nanos to 0 so that the formatted `DateTime` won't have fractional seconds.
79        timestamp.set_subsec_nanos(0);
80        let mut formatted_timestamp = timestamp
81            .fmt(Format::DateTime)
82            .expect("the resulting DateTime will always be valid");
83
84        // Remove dashes and colons
85        formatted_timestamp = formatted_timestamp
86            .chars()
87            .filter(|&c| c != '-' && c != ':')
88            .collect();
89
90        Some((Cow::Borrowed("ttl"), Cow::Owned(formatted_timestamp)))
91    }
92}
93
94#[dyn_dispatch_hint]
95impl Intercept for RequestInfoInterceptor {
96    fn name(&self) -> &'static str {
97        "RequestInfoInterceptor"
98    }
99
100    fn modify_before_transmit(
101        &self,
102        context: &mut BeforeTransmitInterceptorContextMut<'_>,
103        runtime_components: &RuntimeComponents,
104        cfg: &mut ConfigBag,
105    ) -> Result<(), BoxError> {
106        let mut pairs = RequestPairs::new();
107        if let Some(pair) = self.build_ttl_pair(
108            cfg,
109            runtime_components
110                .time_source()
111                .ok_or("A timesource must be provided")?,
112        ) {
113            pairs = pairs.with_pair(pair);
114        }
115        if let Some(pair) = self.build_attempts_pair(cfg) {
116            pairs = pairs.with_pair(pair);
117        }
118        if let Some(pair) = self.build_max_attempts_pair(cfg) {
119            pairs = pairs.with_pair(pair);
120        }
121
122        let headers = context.request_mut().headers_mut();
123        headers.insert(AMZ_SDK_REQUEST, pairs.try_into_header_value()?);
124
125        Ok(())
126    }
127}
128
129/// A builder for creating a `RequestPairs` header value. `RequestPairs` is used to generate a
130/// retry information header that is sent with every request. The information conveyed by this
131/// header allows services to anticipate whether a client will time out or retry a request.
132#[derive(Default, Debug)]
133struct RequestPairs {
134    inner: Vec<(Cow<'static, str>, Cow<'static, str>)>,
135}
136
137impl RequestPairs {
138    /// Creates a new `RequestPairs` builder.
139    fn new() -> Self {
140        Default::default()
141    }
142
143    /// Adds a pair to the `RequestPairs` builder.
144    /// Only strings that can be converted to header values are considered valid.
145    fn with_pair(
146        mut self,
147        pair: (impl Into<Cow<'static, str>>, impl Into<Cow<'static, str>>),
148    ) -> Self {
149        let pair = (pair.0.into(), pair.1.into());
150        self.inner.push(pair);
151        self
152    }
153
154    /// Converts the `RequestPairs` builder into a `HeaderValue`.
155    fn try_into_header_value(self) -> Result<HeaderValue, BoxError> {
156        self.try_into()
157    }
158}
159
160impl TryFrom<RequestPairs> for HeaderValue {
161    type Error = BoxError;
162
163    fn try_from(value: RequestPairs) -> Result<Self, BoxError> {
164        let mut pairs = String::new();
165        for (key, value) in value.inner {
166            if !pairs.is_empty() {
167                pairs.push_str("; ");
168            }
169
170            pairs.push_str(&key);
171            pairs.push('=');
172            pairs.push_str(&value);
173            continue;
174        }
175        HeaderValue::from_str(&pairs).map_err(Into::into)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::RequestInfoInterceptor;
182    use crate::request_info::RequestPairs;
183    use aws_smithy_runtime_api::client::interceptors::context::Input;
184    use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
185    use aws_smithy_runtime_api::client::interceptors::Intercept;
186    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
187    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
188    use aws_smithy_types::config_bag::{ConfigBag, Layer};
189    use aws_smithy_types::retry::RetryConfig;
190    use aws_smithy_types::timeout::TimeoutConfig;
191
192    use http_1x::HeaderValue;
193    use std::time::Duration;
194
195    fn expect_header<'a>(context: &'a InterceptorContext, header_name: &str) -> &'a str {
196        context
197            .request()
198            .expect("request is set")
199            .headers()
200            .get(header_name)
201            .unwrap()
202    }
203
204    #[test]
205    fn test_request_pairs_for_initial_attempt() {
206        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
207        let mut context = InterceptorContext::new(Input::doesnt_matter());
208        context.enter_serialization_phase();
209        context.set_request(HttpRequest::empty());
210
211        let mut layer = Layer::new("test");
212        layer.store_put(RetryConfig::standard());
213        layer.store_put(
214            TimeoutConfig::builder()
215                .read_timeout(Duration::from_secs(30))
216                .build(),
217        );
218        let mut config = ConfigBag::of_layers(vec![layer]);
219
220        let _ = context.take_input();
221        context.enter_before_transmit_phase();
222        let interceptor = RequestInfoInterceptor::new();
223        let mut ctx = (&mut context).into();
224        interceptor
225            .modify_before_transmit(&mut ctx, &rc, &mut config)
226            .unwrap();
227
228        assert_eq!(
229            expect_header(&context, "amz-sdk-request"),
230            "attempt=0; max=3"
231        );
232    }
233
234    #[test]
235    fn test_header_value_from_request_pairs_supports_all_valid_characters() {
236        // The list of valid characters is defined by an internal-only spec.
237        let rp = RequestPairs::new()
238            .with_pair(("allowed-symbols", "!#$&'*+-.^_`|~"))
239            .with_pair(("allowed-digits", "01234567890"))
240            .with_pair((
241                "allowed-characters",
242                "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
243            ))
244            .with_pair(("allowed-whitespace", " \t"));
245        let _header_value: HeaderValue = rp
246            .try_into()
247            .expect("request pairs can be converted into valid header value.");
248    }
249}