1use aws_smithy_runtime_api::box_error::BoxError;
7use aws_smithy_runtime_api::client::interceptors::context::{
8 BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
9 FinalizerInterceptorContextMut, FinalizerInterceptorContextRef,
10};
11use aws_smithy_runtime_api::client::interceptors::context::{
12 Error, Input, InterceptorContext, Output,
13};
14use aws_smithy_runtime_api::client::interceptors::{
15 dyn_dispatch_hint, Intercept, InterceptorError, OverriddenHooks, SharedInterceptor,
16};
17use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
18use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use aws_smithy_types::error::display::DisplayErrorContext;
22use std::error::Error as StdError;
23use std::fmt;
24use std::marker::PhantomData;
25
26macro_rules! interceptor_impl_fn {
27 (mut $interceptor:ident, $hint:ident) => {
28 pub(crate) fn $interceptor(
29 self,
30 ctx: &mut InterceptorContext,
31 runtime_components: &RuntimeComponents,
32 cfg: &mut ConfigBag,
33 ) -> Result<(), InterceptorError> {
34 tracing::trace!(concat!(
35 "running `",
36 stringify!($interceptor),
37 "` interceptors"
38 ));
39 let mut result: Result<(), (&str, BoxError)> = Ok(());
40 let mut ctx = ctx.into();
41 for interceptor in self.into_iter() {
42 if interceptor.hook_overridden(OverriddenHooks::$hint) {
43 if let Some(interceptor) = interceptor.if_enabled(cfg) {
44 if let Err(new_error) =
45 interceptor.$interceptor(&mut ctx, runtime_components, cfg)
46 {
47 if let Err(last_error) = result {
48 tracing::debug!(
49 "{}::{}: {}",
50 last_error.0,
51 stringify!($interceptor),
52 DisplayErrorContext(&*last_error.1)
53 );
54 }
55 result = Err((interceptor.name(), new_error));
56 }
57 }
58 }
59 }
60 result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
61 }
62 };
63 (ref $interceptor:ident, $hint:ident) => {
64 pub(crate) fn $interceptor(
65 self,
66 ctx: &InterceptorContext,
67 runtime_components: &RuntimeComponents,
68 cfg: &mut ConfigBag,
69 ) -> Result<(), InterceptorError> {
70 tracing::trace!(concat!(
71 "running `",
72 stringify!($interceptor),
73 "` interceptors"
74 ));
75 let mut result: Result<(), (&str, BoxError)> = Ok(());
76 let ctx = ctx.into();
77 for interceptor in self.into_iter() {
78 if interceptor.hook_overridden(OverriddenHooks::$hint) {
79 if let Some(interceptor) = interceptor.if_enabled(cfg) {
80 if let Err(new_error) =
81 interceptor.$interceptor(&ctx, runtime_components, cfg)
82 {
83 if let Err(last_error) = result {
84 tracing::debug!(
85 "{}::{}: {}",
86 last_error.0,
87 stringify!($interceptor),
88 DisplayErrorContext(&*last_error.1)
89 );
90 }
91 result = Err((interceptor.name(), new_error));
92 }
93 }
94 }
95 }
96 result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
97 }
98 };
99}
100
101#[derive(Debug)]
102pub(crate) struct Interceptors<I> {
103 interceptors: I,
104}
105
106impl<I> Interceptors<I>
107where
108 I: Iterator<Item = SharedInterceptor>,
109{
110 pub(crate) fn new(interceptors: I) -> Self {
111 Self { interceptors }
112 }
113
114 fn into_iter(self) -> impl Iterator<Item = ConditionallyEnabledInterceptor> {
115 self.interceptors.map(ConditionallyEnabledInterceptor)
116 }
117
118 pub(crate) fn read_before_execution(
119 self,
120 operation: bool,
121 ctx: &InterceptorContext<Input, Output, Error>,
122 cfg: &mut ConfigBag,
123 ) -> Result<(), InterceptorError> {
124 tracing::trace!(
125 "running {} `read_before_execution` interceptors",
126 if operation { "operation" } else { "client" }
127 );
128 let mut result: Result<(), (&str, BoxError)> = Ok(());
129 let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into();
130 for interceptor in self.into_iter() {
131 if interceptor.hook_overridden(OverriddenHooks::READ_BEFORE_EXECUTION) {
132 if let Some(interceptor) = interceptor.if_enabled(cfg) {
133 if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) {
134 if let Err(last_error) = result {
135 tracing::debug!(
136 "{}::{}: {}",
137 last_error.0,
138 "read_before_execution",
139 DisplayErrorContext(&*last_error.1)
140 );
141 }
142 result = Err((interceptor.name(), new_error));
143 }
144 }
145 }
146 }
147 result.map_err(|(name, err)| InterceptorError::read_before_execution(name, err))
148 }
149
150 interceptor_impl_fn!(mut modify_before_serialization, MODIFY_BEFORE_SERIALIZATION);
151 interceptor_impl_fn!(ref read_before_serialization, READ_BEFORE_SERIALIZATION);
152 interceptor_impl_fn!(ref read_after_serialization, READ_AFTER_SERIALIZATION);
153 interceptor_impl_fn!(mut modify_before_retry_loop, MODIFY_BEFORE_RETRY_LOOP);
154 interceptor_impl_fn!(ref read_before_attempt, READ_BEFORE_ATTEMPT);
155 interceptor_impl_fn!(mut modify_before_signing, MODIFY_BEFORE_SIGNING);
156 interceptor_impl_fn!(ref read_before_signing, READ_BEFORE_SIGNING);
157 interceptor_impl_fn!(ref read_after_signing, READ_AFTER_SIGNING);
158 interceptor_impl_fn!(mut modify_before_transmit, MODIFY_BEFORE_TRANSMIT);
159 interceptor_impl_fn!(ref read_before_transmit, READ_BEFORE_TRANSMIT);
160 interceptor_impl_fn!(ref read_after_transmit, READ_AFTER_TRANSMIT);
161 interceptor_impl_fn!(
162 mut modify_before_deserialization,
163 MODIFY_BEFORE_DESERIALIZATION
164 );
165 interceptor_impl_fn!(ref read_before_deserialization, READ_BEFORE_DESERIALIZATION);
166 interceptor_impl_fn!(ref read_after_deserialization, READ_AFTER_DESERIALIZATION);
167
168 pub(crate) fn modify_before_attempt_completion(
169 self,
170 ctx: &mut InterceptorContext<Input, Output, Error>,
171 runtime_components: &RuntimeComponents,
172 cfg: &mut ConfigBag,
173 ) -> Result<(), InterceptorError> {
174 tracing::trace!("running `modify_before_attempt_completion` interceptors");
175 let mut result: Result<(), (&str, BoxError)> = Ok(());
176 let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
177 for interceptor in self.into_iter() {
178 if interceptor.hook_overridden(OverriddenHooks::MODIFY_BEFORE_ATTEMPT_COMPLETION) {
179 if let Some(interceptor) = interceptor.if_enabled(cfg) {
180 if let Err(new_error) = interceptor.modify_before_attempt_completion(
181 &mut ctx,
182 runtime_components,
183 cfg,
184 ) {
185 if let Err(last_error) = result {
186 tracing::debug!(
187 "{}::{}: {}",
188 last_error.0,
189 "modify_before_attempt_completion",
190 DisplayErrorContext(&*last_error.1)
191 );
192 }
193 result = Err((interceptor.name(), new_error));
194 }
195 }
196 }
197 }
198 result.map_err(|(name, err)| InterceptorError::modify_before_attempt_completion(name, err))
199 }
200
201 pub(crate) fn read_after_attempt(
202 self,
203 ctx: &InterceptorContext<Input, Output, Error>,
204 runtime_components: &RuntimeComponents,
205 cfg: &mut ConfigBag,
206 ) -> Result<(), InterceptorError> {
207 tracing::trace!("running `read_after_attempt` interceptors");
208 let mut result: Result<(), (&str, BoxError)> = Ok(());
209 let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
210 for interceptor in self.into_iter() {
211 if interceptor.hook_overridden(OverriddenHooks::READ_AFTER_ATTEMPT) {
212 if let Some(interceptor) = interceptor.if_enabled(cfg) {
213 if let Err(new_error) =
214 interceptor.read_after_attempt(&ctx, runtime_components, cfg)
215 {
216 if let Err(last_error) = result {
217 tracing::debug!(
218 "{}::{}: {}",
219 last_error.0,
220 "read_after_attempt",
221 DisplayErrorContext(&*last_error.1)
222 );
223 }
224 result = Err((interceptor.name(), new_error));
225 }
226 }
227 }
228 }
229 result.map_err(|(name, err)| InterceptorError::read_after_attempt(name, err))
230 }
231
232 pub(crate) fn modify_before_completion(
233 self,
234 ctx: &mut InterceptorContext<Input, Output, Error>,
235 runtime_components: &RuntimeComponents,
236 cfg: &mut ConfigBag,
237 ) -> Result<(), InterceptorError> {
238 tracing::trace!("running `modify_before_completion` interceptors");
239 let mut result: Result<(), (&str, BoxError)> = Ok(());
240 let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
241 for interceptor in self.into_iter() {
242 if interceptor.hook_overridden(OverriddenHooks::MODIFY_BEFORE_COMPLETION) {
243 if let Some(interceptor) = interceptor.if_enabled(cfg) {
244 if let Err(new_error) =
245 interceptor.modify_before_completion(&mut ctx, runtime_components, cfg)
246 {
247 if let Err(last_error) = result {
248 tracing::debug!(
249 "{}::{}: {}",
250 last_error.0,
251 "modify_before_completion",
252 DisplayErrorContext(&*last_error.1)
253 );
254 }
255 result = Err((interceptor.name(), new_error));
256 }
257 }
258 }
259 }
260 result.map_err(|(name, err)| InterceptorError::modify_before_completion(name, err))
261 }
262
263 pub(crate) fn read_after_execution(
264 self,
265 ctx: &InterceptorContext<Input, Output, Error>,
266 runtime_components: &RuntimeComponents,
267 cfg: &mut ConfigBag,
268 ) -> Result<(), InterceptorError> {
269 tracing::trace!("running `read_after_execution` interceptors");
270 let mut result: Result<(), (&str, BoxError)> = Ok(());
271 let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
272 for interceptor in self.into_iter() {
273 if interceptor.hook_overridden(OverriddenHooks::READ_AFTER_EXECUTION) {
274 if let Some(interceptor) = interceptor.if_enabled(cfg) {
275 if let Err(new_error) =
276 interceptor.read_after_execution(&ctx, runtime_components, cfg)
277 {
278 if let Err(last_error) = result {
279 tracing::debug!(
280 "{}::{}: {}",
281 last_error.0,
282 "read_after_execution",
283 DisplayErrorContext(&*last_error.1)
284 );
285 }
286 result = Err((interceptor.name(), new_error));
287 }
288 }
289 }
290 }
291 result.map_err(|(name, err)| InterceptorError::read_after_execution(name, err))
292 }
293}
294
295struct ConditionallyEnabledInterceptor(SharedInterceptor);
298impl ConditionallyEnabledInterceptor {
299 fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Intercept> {
300 if self.0.enabled(cfg) {
301 Some(&self.0)
302 } else {
303 None
304 }
305 }
306
307 fn hook_overridden(&self, hint: OverriddenHooks) -> bool {
308 self.0.overridden_hooks().contains(hint)
309 }
310}
311
312pub struct MapRequestInterceptor<F, E> {
314 f: F,
315 _phantom: PhantomData<E>,
316}
317
318impl<F, E> fmt::Debug for MapRequestInterceptor<F, E> {
319 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320 write!(f, "MapRequestInterceptor")
321 }
322}
323
324impl<F, E> MapRequestInterceptor<F, E> {
325 pub fn new(f: F) -> Self {
327 Self {
328 f,
329 _phantom: PhantomData,
330 }
331 }
332}
333
334#[dyn_dispatch_hint]
335impl<F, E> Intercept for MapRequestInterceptor<F, E>
336where
337 F: Fn(HttpRequest) -> Result<HttpRequest, E> + Send + Sync + 'static,
338 E: StdError + Send + Sync + 'static,
339{
340 fn name(&self) -> &'static str {
341 "MapRequestInterceptor"
342 }
343
344 fn modify_before_signing(
345 &self,
346 context: &mut BeforeTransmitInterceptorContextMut<'_>,
347 _runtime_components: &RuntimeComponents,
348 _cfg: &mut ConfigBag,
349 ) -> Result<(), BoxError> {
350 let mut request = HttpRequest::new(SdkBody::taken());
351 std::mem::swap(&mut request, context.request_mut());
352 let mut mapped = (self.f)(request)?;
353 std::mem::swap(&mut mapped, context.request_mut());
354
355 Ok(())
356 }
357}
358
359pub struct MutateRequestInterceptor<F> {
361 f: F,
362}
363
364impl<F> fmt::Debug for MutateRequestInterceptor<F> {
365 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 write!(f, "MutateRequestInterceptor")
367 }
368}
369
370impl<F> MutateRequestInterceptor<F> {
371 pub fn new(f: F) -> Self {
373 Self { f }
374 }
375}
376
377#[dyn_dispatch_hint]
378impl<F> Intercept for MutateRequestInterceptor<F>
379where
380 F: Fn(&mut HttpRequest) + Send + Sync + 'static,
381{
382 fn name(&self) -> &'static str {
383 "MutateRequestInterceptor"
384 }
385
386 fn modify_before_signing(
387 &self,
388 context: &mut BeforeTransmitInterceptorContextMut<'_>,
389 _runtime_components: &RuntimeComponents,
390 _cfg: &mut ConfigBag,
391 ) -> Result<(), BoxError> {
392 let request = context.request_mut();
393 (self.f)(request);
394
395 Ok(())
396 }
397}
398
399#[cfg(all(test, feature = "test-util"))]
400mod tests {
401 use super::*;
402 use aws_smithy_runtime_api::box_error::BoxError;
403 use aws_smithy_runtime_api::client::interceptors::context::{
404 BeforeTransmitInterceptorContextRef, Input, InterceptorContext,
405 };
406 use aws_smithy_runtime_api::client::interceptors::{
407 disable_interceptor, Intercept, SharedInterceptor,
408 };
409 use aws_smithy_runtime_api::client::runtime_components::{
410 RuntimeComponents, RuntimeComponentsBuilder,
411 };
412 use aws_smithy_types::config_bag::ConfigBag;
413
414 #[derive(Debug)]
415 struct TestInterceptor;
416 impl Intercept for TestInterceptor {
417 fn name(&self) -> &'static str {
418 "TestInterceptor"
419 }
420 }
421
422 #[test]
423 fn test_disable_interceptors() {
424 #[derive(Debug)]
425 struct PanicInterceptor;
426 impl Intercept for PanicInterceptor {
427 fn name(&self) -> &'static str {
428 "PanicInterceptor"
429 }
430
431 fn read_before_transmit(
432 &self,
433 _context: &BeforeTransmitInterceptorContextRef<'_>,
434 _rc: &RuntimeComponents,
435 _cfg: &mut ConfigBag,
436 ) -> Result<(), BoxError> {
437 Err("boom".into())
438 }
439 }
440 let rc = RuntimeComponentsBuilder::for_tests()
441 .with_interceptor(SharedInterceptor::new(PanicInterceptor))
442 .with_interceptor(SharedInterceptor::new(TestInterceptor))
443 .build()
444 .unwrap();
445
446 let mut cfg = ConfigBag::base();
447 let interceptors = Interceptors::new(rc.interceptors());
448 assert_eq!(
449 interceptors
450 .into_iter()
451 .filter(|i| i.if_enabled(&cfg).is_some())
452 .count(),
453 2
454 );
455
456 Interceptors::new(rc.interceptors())
457 .read_before_transmit(
458 &InterceptorContext::new(Input::doesnt_matter()),
459 &rc,
460 &mut cfg,
461 )
462 .expect_err("interceptor returns error");
463 cfg.interceptor_state()
464 .store_put(disable_interceptor::<PanicInterceptor>("test"));
465 assert_eq!(
466 Interceptors::new(rc.interceptors())
467 .into_iter()
468 .filter(|i| i.if_enabled(&cfg).is_some())
469 .count(),
470 1
471 );
472 Interceptors::new(rc.interceptors())
474 .read_before_transmit(
475 &InterceptorContext::new(Input::doesnt_matter()),
476 &rc,
477 &mut cfg,
478 )
479 .expect("interceptor is now disabled");
480 }
481}