Skip to main content

aws_smithy_runtime/client/
interceptors.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::box_error::BoxError;
7use aws_smithy_runtime_api::client::interceptors::context::{
8    BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
9    FinalizerInterceptorContextMut, FinalizerInterceptorContextRef,
10};
11use aws_smithy_runtime_api::client::interceptors::context::{
12    Error, Input, InterceptorContext, Output,
13};
14use aws_smithy_runtime_api::client::interceptors::{
15    dyn_dispatch_hint, Intercept, InterceptorError, OverriddenHooks, SharedInterceptor,
16};
17use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
18use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use aws_smithy_types::error::display::DisplayErrorContext;
22use std::error::Error as StdError;
23use std::fmt;
24use std::marker::PhantomData;
25
26macro_rules! interceptor_impl_fn {
27    (mut $interceptor:ident, $hint:ident) => {
28        pub(crate) fn $interceptor(
29            self,
30            ctx: &mut InterceptorContext,
31            runtime_components: &RuntimeComponents,
32            cfg: &mut ConfigBag,
33        ) -> Result<(), InterceptorError> {
34            tracing::trace!(concat!(
35                "running `",
36                stringify!($interceptor),
37                "` interceptors"
38            ));
39            let mut result: Result<(), (&str, BoxError)> = Ok(());
40            let mut ctx = ctx.into();
41            for interceptor in self.into_iter() {
42                if interceptor.hook_overridden(OverriddenHooks::$hint) {
43                    if let Some(interceptor) = interceptor.if_enabled(cfg) {
44                        if let Err(new_error) =
45                            interceptor.$interceptor(&mut ctx, runtime_components, cfg)
46                        {
47                            if let Err(last_error) = result {
48                                tracing::debug!(
49                                    "{}::{}: {}",
50                                    last_error.0,
51                                    stringify!($interceptor),
52                                    DisplayErrorContext(&*last_error.1)
53                                );
54                            }
55                            result = Err((interceptor.name(), new_error));
56                        }
57                    }
58                }
59            }
60            result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
61        }
62    };
63    (ref $interceptor:ident, $hint:ident) => {
64        pub(crate) fn $interceptor(
65            self,
66            ctx: &InterceptorContext,
67            runtime_components: &RuntimeComponents,
68            cfg: &mut ConfigBag,
69        ) -> Result<(), InterceptorError> {
70            tracing::trace!(concat!(
71                "running `",
72                stringify!($interceptor),
73                "` interceptors"
74            ));
75            let mut result: Result<(), (&str, BoxError)> = Ok(());
76            let ctx = ctx.into();
77            for interceptor in self.into_iter() {
78                if interceptor.hook_overridden(OverriddenHooks::$hint) {
79                    if let Some(interceptor) = interceptor.if_enabled(cfg) {
80                        if let Err(new_error) =
81                            interceptor.$interceptor(&ctx, runtime_components, cfg)
82                        {
83                            if let Err(last_error) = result {
84                                tracing::debug!(
85                                    "{}::{}: {}",
86                                    last_error.0,
87                                    stringify!($interceptor),
88                                    DisplayErrorContext(&*last_error.1)
89                                );
90                            }
91                            result = Err((interceptor.name(), new_error));
92                        }
93                    }
94                }
95            }
96            result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
97        }
98    };
99}
100
101#[derive(Debug)]
102pub(crate) struct Interceptors<I> {
103    interceptors: I,
104}
105
106impl<I> Interceptors<I>
107where
108    I: Iterator<Item = SharedInterceptor>,
109{
110    pub(crate) fn new(interceptors: I) -> Self {
111        Self { interceptors }
112    }
113
114    fn into_iter(self) -> impl Iterator<Item = ConditionallyEnabledInterceptor> {
115        self.interceptors.map(ConditionallyEnabledInterceptor)
116    }
117
118    pub(crate) fn read_before_execution(
119        self,
120        operation: bool,
121        ctx: &InterceptorContext<Input, Output, Error>,
122        cfg: &mut ConfigBag,
123    ) -> Result<(), InterceptorError> {
124        tracing::trace!(
125            "running {} `read_before_execution` interceptors",
126            if operation { "operation" } else { "client" }
127        );
128        let mut result: Result<(), (&str, BoxError)> = Ok(());
129        let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into();
130        for interceptor in self.into_iter() {
131            if interceptor.hook_overridden(OverriddenHooks::READ_BEFORE_EXECUTION) {
132                if let Some(interceptor) = interceptor.if_enabled(cfg) {
133                    if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) {
134                        if let Err(last_error) = result {
135                            tracing::debug!(
136                                "{}::{}: {}",
137                                last_error.0,
138                                "read_before_execution",
139                                DisplayErrorContext(&*last_error.1)
140                            );
141                        }
142                        result = Err((interceptor.name(), new_error));
143                    }
144                }
145            }
146        }
147        result.map_err(|(name, err)| InterceptorError::read_before_execution(name, err))
148    }
149
150    interceptor_impl_fn!(mut modify_before_serialization, MODIFY_BEFORE_SERIALIZATION);
151    interceptor_impl_fn!(ref read_before_serialization, READ_BEFORE_SERIALIZATION);
152    interceptor_impl_fn!(ref read_after_serialization, READ_AFTER_SERIALIZATION);
153    interceptor_impl_fn!(mut modify_before_retry_loop, MODIFY_BEFORE_RETRY_LOOP);
154    interceptor_impl_fn!(ref read_before_attempt, READ_BEFORE_ATTEMPT);
155    interceptor_impl_fn!(mut modify_before_signing, MODIFY_BEFORE_SIGNING);
156    interceptor_impl_fn!(ref read_before_signing, READ_BEFORE_SIGNING);
157    interceptor_impl_fn!(ref read_after_signing, READ_AFTER_SIGNING);
158    interceptor_impl_fn!(mut modify_before_transmit, MODIFY_BEFORE_TRANSMIT);
159    interceptor_impl_fn!(ref read_before_transmit, READ_BEFORE_TRANSMIT);
160    interceptor_impl_fn!(ref read_after_transmit, READ_AFTER_TRANSMIT);
161    interceptor_impl_fn!(
162        mut modify_before_deserialization,
163        MODIFY_BEFORE_DESERIALIZATION
164    );
165    interceptor_impl_fn!(ref read_before_deserialization, READ_BEFORE_DESERIALIZATION);
166    interceptor_impl_fn!(ref read_after_deserialization, READ_AFTER_DESERIALIZATION);
167
168    pub(crate) fn modify_before_attempt_completion(
169        self,
170        ctx: &mut InterceptorContext<Input, Output, Error>,
171        runtime_components: &RuntimeComponents,
172        cfg: &mut ConfigBag,
173    ) -> Result<(), InterceptorError> {
174        tracing::trace!("running `modify_before_attempt_completion` interceptors");
175        let mut result: Result<(), (&str, BoxError)> = Ok(());
176        let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
177        for interceptor in self.into_iter() {
178            if interceptor.hook_overridden(OverriddenHooks::MODIFY_BEFORE_ATTEMPT_COMPLETION) {
179                if let Some(interceptor) = interceptor.if_enabled(cfg) {
180                    if let Err(new_error) = interceptor.modify_before_attempt_completion(
181                        &mut ctx,
182                        runtime_components,
183                        cfg,
184                    ) {
185                        if let Err(last_error) = result {
186                            tracing::debug!(
187                                "{}::{}: {}",
188                                last_error.0,
189                                "modify_before_attempt_completion",
190                                DisplayErrorContext(&*last_error.1)
191                            );
192                        }
193                        result = Err((interceptor.name(), new_error));
194                    }
195                }
196            }
197        }
198        result.map_err(|(name, err)| InterceptorError::modify_before_attempt_completion(name, err))
199    }
200
201    pub(crate) fn read_after_attempt(
202        self,
203        ctx: &InterceptorContext<Input, Output, Error>,
204        runtime_components: &RuntimeComponents,
205        cfg: &mut ConfigBag,
206    ) -> Result<(), InterceptorError> {
207        tracing::trace!("running `read_after_attempt` interceptors");
208        let mut result: Result<(), (&str, BoxError)> = Ok(());
209        let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
210        for interceptor in self.into_iter() {
211            if interceptor.hook_overridden(OverriddenHooks::READ_AFTER_ATTEMPT) {
212                if let Some(interceptor) = interceptor.if_enabled(cfg) {
213                    if let Err(new_error) =
214                        interceptor.read_after_attempt(&ctx, runtime_components, cfg)
215                    {
216                        if let Err(last_error) = result {
217                            tracing::debug!(
218                                "{}::{}: {}",
219                                last_error.0,
220                                "read_after_attempt",
221                                DisplayErrorContext(&*last_error.1)
222                            );
223                        }
224                        result = Err((interceptor.name(), new_error));
225                    }
226                }
227            }
228        }
229        result.map_err(|(name, err)| InterceptorError::read_after_attempt(name, err))
230    }
231
232    pub(crate) fn modify_before_completion(
233        self,
234        ctx: &mut InterceptorContext<Input, Output, Error>,
235        runtime_components: &RuntimeComponents,
236        cfg: &mut ConfigBag,
237    ) -> Result<(), InterceptorError> {
238        tracing::trace!("running `modify_before_completion` interceptors");
239        let mut result: Result<(), (&str, BoxError)> = Ok(());
240        let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
241        for interceptor in self.into_iter() {
242            if interceptor.hook_overridden(OverriddenHooks::MODIFY_BEFORE_COMPLETION) {
243                if let Some(interceptor) = interceptor.if_enabled(cfg) {
244                    if let Err(new_error) =
245                        interceptor.modify_before_completion(&mut ctx, runtime_components, cfg)
246                    {
247                        if let Err(last_error) = result {
248                            tracing::debug!(
249                                "{}::{}: {}",
250                                last_error.0,
251                                "modify_before_completion",
252                                DisplayErrorContext(&*last_error.1)
253                            );
254                        }
255                        result = Err((interceptor.name(), new_error));
256                    }
257                }
258            }
259        }
260        result.map_err(|(name, err)| InterceptorError::modify_before_completion(name, err))
261    }
262
263    pub(crate) fn read_after_execution(
264        self,
265        ctx: &InterceptorContext<Input, Output, Error>,
266        runtime_components: &RuntimeComponents,
267        cfg: &mut ConfigBag,
268    ) -> Result<(), InterceptorError> {
269        tracing::trace!("running `read_after_execution` interceptors");
270        let mut result: Result<(), (&str, BoxError)> = Ok(());
271        let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
272        for interceptor in self.into_iter() {
273            if interceptor.hook_overridden(OverriddenHooks::READ_AFTER_EXECUTION) {
274                if let Some(interceptor) = interceptor.if_enabled(cfg) {
275                    if let Err(new_error) =
276                        interceptor.read_after_execution(&ctx, runtime_components, cfg)
277                    {
278                        if let Err(last_error) = result {
279                            tracing::debug!(
280                                "{}::{}: {}",
281                                last_error.0,
282                                "read_after_execution",
283                                DisplayErrorContext(&*last_error.1)
284                            );
285                        }
286                        result = Err((interceptor.name(), new_error));
287                    }
288                }
289            }
290        }
291        result.map_err(|(name, err)| InterceptorError::read_after_execution(name, err))
292    }
293}
294
295/// A interceptor wrapper to conditionally enable the interceptor based on
296/// [`DisableInterceptor`](aws_smithy_runtime_api::client::interceptors::DisableInterceptor)
297struct ConditionallyEnabledInterceptor(SharedInterceptor);
298impl ConditionallyEnabledInterceptor {
299    fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Intercept> {
300        if self.0.enabled(cfg) {
301            Some(&self.0)
302        } else {
303            None
304        }
305    }
306
307    fn hook_overridden(&self, hint: OverriddenHooks) -> bool {
308        self.0.overridden_hooks().contains(hint)
309    }
310}
311
312/// Interceptor that maps the request with a given function.
313pub struct MapRequestInterceptor<F, E> {
314    f: F,
315    _phantom: PhantomData<E>,
316}
317
318impl<F, E> fmt::Debug for MapRequestInterceptor<F, E> {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        write!(f, "MapRequestInterceptor")
321    }
322}
323
324impl<F, E> MapRequestInterceptor<F, E> {
325    /// Creates a new `MapRequestInterceptor`.
326    pub fn new(f: F) -> Self {
327        Self {
328            f,
329            _phantom: PhantomData,
330        }
331    }
332}
333
334#[dyn_dispatch_hint]
335impl<F, E> Intercept for MapRequestInterceptor<F, E>
336where
337    F: Fn(HttpRequest) -> Result<HttpRequest, E> + Send + Sync + 'static,
338    E: StdError + Send + Sync + 'static,
339{
340    fn name(&self) -> &'static str {
341        "MapRequestInterceptor"
342    }
343
344    fn modify_before_signing(
345        &self,
346        context: &mut BeforeTransmitInterceptorContextMut<'_>,
347        _runtime_components: &RuntimeComponents,
348        _cfg: &mut ConfigBag,
349    ) -> Result<(), BoxError> {
350        let mut request = HttpRequest::new(SdkBody::taken());
351        std::mem::swap(&mut request, context.request_mut());
352        let mut mapped = (self.f)(request)?;
353        std::mem::swap(&mut mapped, context.request_mut());
354
355        Ok(())
356    }
357}
358
359/// Interceptor that mutates the request with a given function.
360pub struct MutateRequestInterceptor<F> {
361    f: F,
362}
363
364impl<F> fmt::Debug for MutateRequestInterceptor<F> {
365    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366        write!(f, "MutateRequestInterceptor")
367    }
368}
369
370impl<F> MutateRequestInterceptor<F> {
371    /// Creates a new `MutateRequestInterceptor`.
372    pub fn new(f: F) -> Self {
373        Self { f }
374    }
375}
376
377#[dyn_dispatch_hint]
378impl<F> Intercept for MutateRequestInterceptor<F>
379where
380    F: Fn(&mut HttpRequest) + Send + Sync + 'static,
381{
382    fn name(&self) -> &'static str {
383        "MutateRequestInterceptor"
384    }
385
386    fn modify_before_signing(
387        &self,
388        context: &mut BeforeTransmitInterceptorContextMut<'_>,
389        _runtime_components: &RuntimeComponents,
390        _cfg: &mut ConfigBag,
391    ) -> Result<(), BoxError> {
392        let request = context.request_mut();
393        (self.f)(request);
394
395        Ok(())
396    }
397}
398
399#[cfg(all(test, feature = "test-util"))]
400mod tests {
401    use super::*;
402    use aws_smithy_runtime_api::box_error::BoxError;
403    use aws_smithy_runtime_api::client::interceptors::context::{
404        BeforeTransmitInterceptorContextRef, Input, InterceptorContext,
405    };
406    use aws_smithy_runtime_api::client::interceptors::{
407        disable_interceptor, Intercept, SharedInterceptor,
408    };
409    use aws_smithy_runtime_api::client::runtime_components::{
410        RuntimeComponents, RuntimeComponentsBuilder,
411    };
412    use aws_smithy_types::config_bag::ConfigBag;
413
414    #[derive(Debug)]
415    struct TestInterceptor;
416    impl Intercept for TestInterceptor {
417        fn name(&self) -> &'static str {
418            "TestInterceptor"
419        }
420    }
421
422    #[test]
423    fn test_disable_interceptors() {
424        #[derive(Debug)]
425        struct PanicInterceptor;
426        impl Intercept for PanicInterceptor {
427            fn name(&self) -> &'static str {
428                "PanicInterceptor"
429            }
430
431            fn read_before_transmit(
432                &self,
433                _context: &BeforeTransmitInterceptorContextRef<'_>,
434                _rc: &RuntimeComponents,
435                _cfg: &mut ConfigBag,
436            ) -> Result<(), BoxError> {
437                Err("boom".into())
438            }
439        }
440        let rc = RuntimeComponentsBuilder::for_tests()
441            .with_interceptor(SharedInterceptor::new(PanicInterceptor))
442            .with_interceptor(SharedInterceptor::new(TestInterceptor))
443            .build()
444            .unwrap();
445
446        let mut cfg = ConfigBag::base();
447        let interceptors = Interceptors::new(rc.interceptors());
448        assert_eq!(
449            interceptors
450                .into_iter()
451                .filter(|i| i.if_enabled(&cfg).is_some())
452                .count(),
453            2
454        );
455
456        Interceptors::new(rc.interceptors())
457            .read_before_transmit(
458                &InterceptorContext::new(Input::doesnt_matter()),
459                &rc,
460                &mut cfg,
461            )
462            .expect_err("interceptor returns error");
463        cfg.interceptor_state()
464            .store_put(disable_interceptor::<PanicInterceptor>("test"));
465        assert_eq!(
466            Interceptors::new(rc.interceptors())
467                .into_iter()
468                .filter(|i| i.if_enabled(&cfg).is_some())
469                .count(),
470            1
471        );
472        // shouldn't error because interceptors won't run
473        Interceptors::new(rc.interceptors())
474            .read_before_transmit(
475                &InterceptorContext::new(Input::doesnt_matter()),
476                &rc,
477                &mut cfg,
478            )
479            .expect("interceptor is now disabled");
480    }
481}