Skip to main content

aws_runtime/
invocation_id.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::fmt::Debug;
7use std::sync::{Arc, Mutex};
8
9use fastrand::Rng;
10use http_1x::{HeaderName, HeaderValue};
11
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
14use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
17#[cfg(feature = "test-util")]
18pub use test_util::{NoInvocationIdGenerator, PredefinedInvocationIdGenerator};
19
20#[allow(clippy::declare_interior_mutable_const)] // we will never mutate this
21const AMZ_SDK_INVOCATION_ID: HeaderName = HeaderName::from_static("amz-sdk-invocation-id");
22
23/// A generator for returning new invocation IDs on demand.
24pub trait InvocationIdGenerator: Debug + Send + Sync {
25    /// Call this function to receive a new [`InvocationId`] or an error explaining why one couldn't
26    /// be provided.
27    fn generate(&self) -> Result<Option<InvocationId>, BoxError>;
28}
29
30/// Dynamic dispatch implementation of [`InvocationIdGenerator`]
31#[derive(Clone, Debug)]
32pub struct SharedInvocationIdGenerator(Arc<dyn InvocationIdGenerator>);
33
34impl SharedInvocationIdGenerator {
35    /// Creates a new [`SharedInvocationIdGenerator`].
36    pub fn new(gen: impl InvocationIdGenerator + 'static) -> Self {
37        Self(Arc::new(gen))
38    }
39}
40
41impl InvocationIdGenerator for SharedInvocationIdGenerator {
42    fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
43        self.0.generate()
44    }
45}
46
47impl Storable for SharedInvocationIdGenerator {
48    type Storer = StoreReplace<Self>;
49}
50
51/// An invocation ID generator that uses random UUIDs for the invocation ID.
52#[derive(Debug, Default)]
53pub struct DefaultInvocationIdGenerator {
54    rng: Mutex<Rng>,
55}
56
57impl DefaultInvocationIdGenerator {
58    /// Creates a new [`DefaultInvocationIdGenerator`].
59    pub fn new() -> Self {
60        Default::default()
61    }
62
63    /// Creates a [`DefaultInvocationIdGenerator`] with the given seed.
64    pub fn with_seed(seed: u64) -> Self {
65        Self {
66            rng: Mutex::new(Rng::with_seed(seed)),
67        }
68    }
69}
70
71impl InvocationIdGenerator for DefaultInvocationIdGenerator {
72    fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
73        let mut rng = self.rng.lock().unwrap();
74        let mut random_bytes = [0u8; 16];
75        rng.fill(&mut random_bytes);
76
77        let id = uuid::Builder::from_random_bytes(random_bytes).into_uuid();
78        Ok(Some(InvocationId::new(id.to_string())))
79    }
80}
81
82/// This interceptor generates a UUID and attaches it to all request attempts made as part of this operation.
83#[non_exhaustive]
84#[derive(Debug, Default)]
85pub struct InvocationIdInterceptor {
86    default: DefaultInvocationIdGenerator,
87}
88
89impl InvocationIdInterceptor {
90    /// Creates a new `InvocationIdInterceptor`
91    pub fn new() -> Self {
92        Self::default()
93    }
94}
95
96#[dyn_dispatch_hint]
97impl Intercept for InvocationIdInterceptor {
98    fn name(&self) -> &'static str {
99        "InvocationIdInterceptor"
100    }
101
102    fn modify_before_retry_loop(
103        &self,
104        _ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
105        _runtime_components: &RuntimeComponents,
106        cfg: &mut ConfigBag,
107    ) -> Result<(), BoxError> {
108        let gen = cfg
109            .load::<SharedInvocationIdGenerator>()
110            .map(|gen| gen as &dyn InvocationIdGenerator)
111            .unwrap_or(&self.default);
112        if let Some(id) = gen.generate()? {
113            cfg.interceptor_state().store_put::<InvocationId>(id);
114        }
115
116        Ok(())
117    }
118
119    fn modify_before_transmit(
120        &self,
121        ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
122        _runtime_components: &RuntimeComponents,
123        cfg: &mut ConfigBag,
124    ) -> Result<(), BoxError> {
125        let headers = ctx.request_mut().headers_mut();
126        if let Some(id) = cfg.load::<InvocationId>() {
127            headers.append(AMZ_SDK_INVOCATION_ID, id.0.clone());
128        }
129        Ok(())
130    }
131}
132
133/// InvocationId provides a consistent ID across retries
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct InvocationId(HeaderValue);
136
137impl InvocationId {
138    /// Create an invocation ID with the given value.
139    ///
140    /// # Panics
141    /// This constructor will panic if the given invocation ID is not a valid HTTP header value.
142    pub fn new(invocation_id: String) -> Self {
143        Self(
144            HeaderValue::try_from(invocation_id)
145                .expect("invocation ID must be a valid HTTP header value"),
146        )
147    }
148}
149
150impl Storable for InvocationId {
151    type Storer = StoreReplace<Self>;
152}
153
154#[cfg(feature = "test-util")]
155mod test_util {
156    use super::*;
157
158    impl InvocationId {
159        /// Create a new invocation ID from a `&'static str`.
160        pub fn new_from_str(uuid: &'static str) -> Self {
161            InvocationId(HeaderValue::from_static(uuid))
162        }
163    }
164
165    /// A "generator" that returns [`InvocationId`]s from a predefined list.
166    #[derive(Debug)]
167    pub struct PredefinedInvocationIdGenerator {
168        pre_generated_ids: Arc<Mutex<Vec<InvocationId>>>,
169    }
170
171    impl PredefinedInvocationIdGenerator {
172        /// Given a `Vec<InvocationId>`, create a new [`PredefinedInvocationIdGenerator`].
173        pub fn new(mut invocation_ids: Vec<InvocationId>) -> Self {
174            // We're going to pop ids off of the end of the list, so we need to reverse the list or else
175            // we'll be popping the ids in reverse order, confusing the poor test writer.
176            invocation_ids.reverse();
177
178            Self {
179                pre_generated_ids: Arc::new(Mutex::new(invocation_ids)),
180            }
181        }
182    }
183
184    impl InvocationIdGenerator for PredefinedInvocationIdGenerator {
185        fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
186            Ok(Some(
187                self.pre_generated_ids
188                    .lock()
189                    .expect("this will never be under contention")
190                    .pop()
191                    .expect("testers will provide enough invocation IDs"),
192            ))
193        }
194    }
195
196    /// A "generator" that always returns `None`.
197    #[derive(Debug, Default)]
198    pub struct NoInvocationIdGenerator;
199
200    impl NoInvocationIdGenerator {
201        /// Create a new [`NoInvocationIdGenerator`].
202        pub fn new() -> Self {
203            Self
204        }
205    }
206
207    impl InvocationIdGenerator for NoInvocationIdGenerator {
208        fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
209            Ok(None)
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use aws_smithy_runtime_api::client::interceptors::context::{
217        BeforeTransmitInterceptorContextMut, Input, InterceptorContext,
218    };
219    use aws_smithy_runtime_api::client::interceptors::Intercept;
220    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
221    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
222    use aws_smithy_types::config_bag::ConfigBag;
223
224    use super::*;
225
226    fn expect_header<'a>(
227        context: &'a BeforeTransmitInterceptorContextMut<'_>,
228        header_name: &str,
229    ) -> &'a str {
230        context.request().headers().get(header_name).unwrap()
231    }
232
233    #[test]
234    fn default_id_generator() {
235        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
236        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
237        ctx.enter_serialization_phase();
238        ctx.set_request(HttpRequest::empty());
239        let _ = ctx.take_input();
240        ctx.enter_before_transmit_phase();
241
242        let mut cfg = ConfigBag::base();
243        let interceptor = InvocationIdInterceptor::new();
244        let mut ctx = Into::into(&mut ctx);
245        interceptor
246            .modify_before_retry_loop(&mut ctx, &rc, &mut cfg)
247            .unwrap();
248        interceptor
249            .modify_before_transmit(&mut ctx, &rc, &mut cfg)
250            .unwrap();
251
252        let expected = cfg.load::<InvocationId>().expect("invocation ID was set");
253        let header = expect_header(&ctx, "amz-sdk-invocation-id");
254        assert_eq!(expected.0, header, "the invocation ID in the config bag must match the invocation ID in the request header");
255        // UUID should include 32 chars and 4 dashes
256        assert_eq!(header.len(), 36);
257    }
258
259    #[cfg(feature = "test-util")]
260    #[test]
261    fn custom_id_generator() {
262        use aws_smithy_types::config_bag::Layer;
263        let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
264        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
265        ctx.enter_serialization_phase();
266        ctx.set_request(HttpRequest::empty());
267        let _ = ctx.take_input();
268        ctx.enter_before_transmit_phase();
269
270        let mut cfg = ConfigBag::base();
271        let mut layer = Layer::new("test");
272        layer.store_put(SharedInvocationIdGenerator::new(
273            PredefinedInvocationIdGenerator::new(vec![InvocationId::new(
274                "the-best-invocation-id".into(),
275            )]),
276        ));
277        cfg.push_layer(layer);
278
279        let interceptor = InvocationIdInterceptor::new();
280        let mut ctx = Into::into(&mut ctx);
281        interceptor
282            .modify_before_retry_loop(&mut ctx, &rc, &mut cfg)
283            .unwrap();
284        interceptor
285            .modify_before_transmit(&mut ctx, &rc, &mut cfg)
286            .unwrap();
287
288        let header = expect_header(&ctx, "amz-sdk-invocation-id");
289        assert_eq!("the-best-invocation-id", header);
290    }
291}