Skip to main content

aws_runtime/
recursion_detection.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::box_error::BoxError;
7use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
8use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
9use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
10use aws_smithy_types::config_bag::ConfigBag;
11use aws_types::os_shim_internal::Env;
12use http_1x::HeaderValue;
13use percent_encoding::{percent_encode, CONTROLS};
14use std::borrow::Cow;
15
16const TRACE_ID_HEADER: &str = "x-amzn-trace-id";
17
18mod env {
19    pub(super) const LAMBDA_FUNCTION_NAME: &str = "AWS_LAMBDA_FUNCTION_NAME";
20    pub(super) const TRACE_ID: &str = "_X_AMZN_TRACE_ID";
21}
22
23/// Recursion Detection Interceptor
24///
25/// This interceptor inspects the value of the `AWS_LAMBDA_FUNCTION_NAME` and `_X_AMZN_TRACE_ID` environment
26/// variables to detect if the request is being invoked in a Lambda function. If it is, the `X-Amzn-Trace-Id` header
27/// will be set. This enables downstream services to prevent accidentally infinitely recursive invocations spawned
28/// from Lambda.
29#[non_exhaustive]
30#[derive(Debug, Default)]
31pub struct RecursionDetectionInterceptor {
32    env: Env,
33}
34
35impl RecursionDetectionInterceptor {
36    /// Creates a new `RecursionDetectionInterceptor`
37    pub fn new() -> Self {
38        Self::default()
39    }
40}
41
42#[dyn_dispatch_hint]
43impl Intercept for RecursionDetectionInterceptor {
44    fn name(&self) -> &'static str {
45        "RecursionDetectionInterceptor"
46    }
47
48    fn modify_before_signing(
49        &self,
50        context: &mut BeforeTransmitInterceptorContextMut<'_>,
51        _runtime_components: &RuntimeComponents,
52        _cfg: &mut ConfigBag,
53    ) -> Result<(), BoxError> {
54        let request = context.request_mut();
55        if request.headers().contains_key(TRACE_ID_HEADER) {
56            return Ok(());
57        }
58
59        if let (Ok(_function_name), Ok(trace_id)) = (
60            self.env.get(env::LAMBDA_FUNCTION_NAME),
61            self.env.get(env::TRACE_ID),
62        ) {
63            request
64                .headers_mut()
65                .insert(TRACE_ID_HEADER, encode_header(trace_id.as_bytes()));
66        }
67        Ok(())
68    }
69}
70
71/// Encodes a byte slice as a header.
72///
73/// ASCII control characters are percent encoded which ensures that all byte sequences are valid headers
74fn encode_header(value: &[u8]) -> HeaderValue {
75    let value: Cow<'_, str> = percent_encode(value, CONTROLS).into();
76    HeaderValue::from_bytes(value.as_bytes()).expect("header is encoded, header must be valid")
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use aws_smithy_protocol_test::{assert_ok, validate_headers};
83    use aws_smithy_runtime_api::client::interceptors::context::{Input, InterceptorContext};
84    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
85    use aws_smithy_types::body::SdkBody;
86    use aws_types::os_shim_internal::Env;
87    use http_1x::HeaderValue;
88    use proptest::{prelude::*, proptest};
89    use serde::Deserialize;
90    use std::collections::HashMap;
91
92    proptest! {
93        #[test]
94        fn header_encoding_never_panics(s in any::<Vec<u8>>()) {
95            encode_header(&s);
96        }
97    }
98
99    #[test]
100    fn every_char() {
101        let buff = (0..=255).collect::<Vec<u8>>();
102        assert_eq!(
103            encode_header(&buff),
104            HeaderValue::from_static(
105                r##"%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~%7F%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF"##
106            )
107        );
108    }
109
110    #[test]
111    fn run_tests() {
112        let test_cases: Vec<TestCase> =
113            serde_json::from_str(include_str!("../test-data/recursion-detection.json"))
114                .expect("invalid test case");
115        for test_case in test_cases {
116            check(test_case)
117        }
118    }
119
120    #[derive(Deserialize)]
121    #[serde(rename_all = "camelCase")]
122    struct TestCase {
123        env: HashMap<String, String>,
124        request_headers_before: Vec<String>,
125        request_headers_after: Vec<String>,
126    }
127
128    impl TestCase {
129        fn env(&self) -> Env {
130            Env::from(self.env.clone())
131        }
132
133        /// Headers on the input request
134        fn request_headers_before(&self) -> impl Iterator<Item = (&str, &str)> {
135            Self::split_headers(&self.request_headers_before)
136        }
137
138        /// Headers on the output request
139        fn request_headers_after(&self) -> impl Iterator<Item = (&str, &str)> {
140            Self::split_headers(&self.request_headers_after)
141        }
142
143        /// Split text headers on `: `
144        fn split_headers(headers: &[String]) -> impl Iterator<Item = (&str, &str)> {
145            headers
146                .iter()
147                .map(|header| header.split_once(": ").expect("header must contain :"))
148        }
149    }
150
151    fn check(test_case: TestCase) {
152        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
153        let env = test_case.env();
154        let mut request = http_1x::Request::builder();
155        for (name, value) in test_case.request_headers_before() {
156            request = request.header(name, value);
157        }
158        let request = request
159            .body(SdkBody::empty())
160            .expect("must be valid")
161            .try_into()
162            .unwrap();
163        let mut context = InterceptorContext::new(Input::doesnt_matter());
164        context.enter_serialization_phase();
165        context.set_request(request);
166        let _ = context.take_input();
167        context.enter_before_transmit_phase();
168        let mut config = ConfigBag::base();
169
170        let mut ctx = Into::into(&mut context);
171        RecursionDetectionInterceptor { env }
172            .modify_before_signing(&mut ctx, &rc, &mut config)
173            .expect("interceptor must succeed");
174        let mutated_request = context.request().expect("request is set");
175        for (name, _) in mutated_request.headers() {
176            assert_eq!(
177                mutated_request.headers().get_all(name).count(),
178                1,
179                "No duplicated headers"
180            )
181        }
182        assert_ok(validate_headers(
183            mutated_request.headers(),
184            test_case.request_headers_after(),
185        ))
186    }
187}