aws_runtime/
invocation_id.rs1use std::fmt::Debug;
7use std::sync::{Arc, Mutex};
8
9use fastrand::Rng;
10use http_1x::{HeaderName, HeaderValue};
11
12use aws_smithy_runtime_api::box_error::BoxError;
13use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut;
14use aws_smithy_runtime_api::client::interceptors::{dyn_dispatch_hint, Intercept};
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
17#[cfg(feature = "test-util")]
18pub use test_util::{NoInvocationIdGenerator, PredefinedInvocationIdGenerator};
19
20#[allow(clippy::declare_interior_mutable_const)] const AMZ_SDK_INVOCATION_ID: HeaderName = HeaderName::from_static("amz-sdk-invocation-id");
22
23pub trait InvocationIdGenerator: Debug + Send + Sync {
25 fn generate(&self) -> Result<Option<InvocationId>, BoxError>;
28}
29
30#[derive(Clone, Debug)]
32pub struct SharedInvocationIdGenerator(Arc<dyn InvocationIdGenerator>);
33
34impl SharedInvocationIdGenerator {
35 pub fn new(gen: impl InvocationIdGenerator + 'static) -> Self {
37 Self(Arc::new(gen))
38 }
39}
40
41impl InvocationIdGenerator for SharedInvocationIdGenerator {
42 fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
43 self.0.generate()
44 }
45}
46
47impl Storable for SharedInvocationIdGenerator {
48 type Storer = StoreReplace<Self>;
49}
50
51#[derive(Debug, Default)]
53pub struct DefaultInvocationIdGenerator {
54 rng: Mutex<Rng>,
55}
56
57impl DefaultInvocationIdGenerator {
58 pub fn new() -> Self {
60 Default::default()
61 }
62
63 pub fn with_seed(seed: u64) -> Self {
65 Self {
66 rng: Mutex::new(Rng::with_seed(seed)),
67 }
68 }
69}
70
71impl InvocationIdGenerator for DefaultInvocationIdGenerator {
72 fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
73 let mut rng = self.rng.lock().unwrap();
74 let mut random_bytes = [0u8; 16];
75 rng.fill(&mut random_bytes);
76
77 let id = uuid::Builder::from_random_bytes(random_bytes).into_uuid();
78 Ok(Some(InvocationId::new(id.to_string())))
79 }
80}
81
82#[non_exhaustive]
84#[derive(Debug, Default)]
85pub struct InvocationIdInterceptor {
86 default: DefaultInvocationIdGenerator,
87}
88
89impl InvocationIdInterceptor {
90 pub fn new() -> Self {
92 Self::default()
93 }
94}
95
96#[dyn_dispatch_hint]
97impl Intercept for InvocationIdInterceptor {
98 fn name(&self) -> &'static str {
99 "InvocationIdInterceptor"
100 }
101
102 fn modify_before_retry_loop(
103 &self,
104 _ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
105 _runtime_components: &RuntimeComponents,
106 cfg: &mut ConfigBag,
107 ) -> Result<(), BoxError> {
108 let gen = cfg
109 .load::<SharedInvocationIdGenerator>()
110 .map(|gen| gen as &dyn InvocationIdGenerator)
111 .unwrap_or(&self.default);
112 if let Some(id) = gen.generate()? {
113 cfg.interceptor_state().store_put::<InvocationId>(id);
114 }
115
116 Ok(())
117 }
118
119 fn modify_before_transmit(
120 &self,
121 ctx: &mut BeforeTransmitInterceptorContextMut<'_>,
122 _runtime_components: &RuntimeComponents,
123 cfg: &mut ConfigBag,
124 ) -> Result<(), BoxError> {
125 let headers = ctx.request_mut().headers_mut();
126 if let Some(id) = cfg.load::<InvocationId>() {
127 headers.append(AMZ_SDK_INVOCATION_ID, id.0.clone());
128 }
129 Ok(())
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct InvocationId(HeaderValue);
136
137impl InvocationId {
138 pub fn new(invocation_id: String) -> Self {
143 Self(
144 HeaderValue::try_from(invocation_id)
145 .expect("invocation ID must be a valid HTTP header value"),
146 )
147 }
148}
149
150impl Storable for InvocationId {
151 type Storer = StoreReplace<Self>;
152}
153
154#[cfg(feature = "test-util")]
155mod test_util {
156 use super::*;
157
158 impl InvocationId {
159 pub fn new_from_str(uuid: &'static str) -> Self {
161 InvocationId(HeaderValue::from_static(uuid))
162 }
163 }
164
165 #[derive(Debug)]
167 pub struct PredefinedInvocationIdGenerator {
168 pre_generated_ids: Arc<Mutex<Vec<InvocationId>>>,
169 }
170
171 impl PredefinedInvocationIdGenerator {
172 pub fn new(mut invocation_ids: Vec<InvocationId>) -> Self {
174 invocation_ids.reverse();
177
178 Self {
179 pre_generated_ids: Arc::new(Mutex::new(invocation_ids)),
180 }
181 }
182 }
183
184 impl InvocationIdGenerator for PredefinedInvocationIdGenerator {
185 fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
186 Ok(Some(
187 self.pre_generated_ids
188 .lock()
189 .expect("this will never be under contention")
190 .pop()
191 .expect("testers will provide enough invocation IDs"),
192 ))
193 }
194 }
195
196 #[derive(Debug, Default)]
198 pub struct NoInvocationIdGenerator;
199
200 impl NoInvocationIdGenerator {
201 pub fn new() -> Self {
203 Self
204 }
205 }
206
207 impl InvocationIdGenerator for NoInvocationIdGenerator {
208 fn generate(&self) -> Result<Option<InvocationId>, BoxError> {
209 Ok(None)
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use aws_smithy_runtime_api::client::interceptors::context::{
217 BeforeTransmitInterceptorContextMut, Input, InterceptorContext,
218 };
219 use aws_smithy_runtime_api::client::interceptors::Intercept;
220 use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
221 use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
222 use aws_smithy_types::config_bag::ConfigBag;
223
224 use super::*;
225
226 fn expect_header<'a>(
227 context: &'a BeforeTransmitInterceptorContextMut<'_>,
228 header_name: &str,
229 ) -> &'a str {
230 context.request().headers().get(header_name).unwrap()
231 }
232
233 #[test]
234 fn default_id_generator() {
235 let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
236 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
237 ctx.enter_serialization_phase();
238 ctx.set_request(HttpRequest::empty());
239 let _ = ctx.take_input();
240 ctx.enter_before_transmit_phase();
241
242 let mut cfg = ConfigBag::base();
243 let interceptor = InvocationIdInterceptor::new();
244 let mut ctx = Into::into(&mut ctx);
245 interceptor
246 .modify_before_retry_loop(&mut ctx, &rc, &mut cfg)
247 .unwrap();
248 interceptor
249 .modify_before_transmit(&mut ctx, &rc, &mut cfg)
250 .unwrap();
251
252 let expected = cfg.load::<InvocationId>().expect("invocation ID was set");
253 let header = expect_header(&ctx, "amz-sdk-invocation-id");
254 assert_eq!(expected.0, header, "the invocation ID in the config bag must match the invocation ID in the request header");
255 assert_eq!(header.len(), 36);
257 }
258
259 #[cfg(feature = "test-util")]
260 #[test]
261 fn custom_id_generator() {
262 use aws_smithy_types::config_bag::Layer;
263 let rc = RuntimeComponentsBuilder::for_tests().build().unwrap();
264 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
265 ctx.enter_serialization_phase();
266 ctx.set_request(HttpRequest::empty());
267 let _ = ctx.take_input();
268 ctx.enter_before_transmit_phase();
269
270 let mut cfg = ConfigBag::base();
271 let mut layer = Layer::new("test");
272 layer.store_put(SharedInvocationIdGenerator::new(
273 PredefinedInvocationIdGenerator::new(vec![InvocationId::new(
274 "the-best-invocation-id".into(),
275 )]),
276 ));
277 cfg.push_layer(layer);
278
279 let interceptor = InvocationIdInterceptor::new();
280 let mut ctx = Into::into(&mut ctx);
281 interceptor
282 .modify_before_retry_loop(&mut ctx, &rc, &mut cfg)
283 .unwrap();
284 interceptor
285 .modify_before_transmit(&mut ctx, &rc, &mut cfg)
286 .unwrap();
287
288 let header = expect_header(&ctx, "amz-sdk-invocation-id");
289 assert_eq!("the-best-invocation-id", header);
290 }
291}