layered/
intercept.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use std::fmt::Debug;
5use std::ops::ControlFlow;
6#[cfg(any(feature = "tower-service", test))]
7use std::pin::Pin;
8use std::sync::Arc;
9#[cfg(any(feature = "tower-service", test))]
10use std::task::{Context, Poll};
11
12use crate::Service;
13
14/// Middleware for observing and modifying service inputs and outputs.
15///
16/// Useful for logging, debugging, metrics, validation, and other cross-cutting concerns.
17///
18/// # Examples
19///
20/// Simple usage that observes inputs and outputs without modification:
21///
22/// ```
23/// # use layered::{Execute, Stack, Intercept, Service};
24/// # async fn example() {
25/// let stack = (
26///     Intercept::layer()
27///         .on_input(|input| println!("request: {input}"))
28///         .on_output(|output| println!("response: {output}")),
29///     Execute::new(|input: String| async move { input }),
30/// );
31///
32/// let service = stack.into_service();
33/// let response = service.execute("input".to_string()).await;
34/// # }
35/// ```
36///
37/// Advanced usage of `Intercept` allows you to modify and observe inputs and outputs:
38///
39/// ```
40/// # use layered::{Execute, Stack, Intercept, Service};
41/// # async fn example() {
42/// let stack = (
43///     Intercept::<String, String, _>::layer()
44///         .on_input(|input| println!("request: {input}")) // input observers are called first
45///         .on_input(|input| println!("another: {input}")) // multiple observers supported
46///         .modify_input(|input| input.to_uppercase()) // then inputs are modified
47///         .modify_input(|input| input.to_lowercase()) // multiple modifications supported
48///         .on_output(|output| println!("response: {output}")) // output observers called first
49///         .on_output(|output| println!("another response: {output}")) // multiple observers supported
50///         .modify_output(|output| output.trim().to_string()) // then outputs are modified
51///         .modify_output(|output| format!("result: {output}")), // multiple modifications supported
52///     Execute::new(|input: String| async move { input }),
53/// );
54///
55/// let service = stack.into_service();
56/// let response = service.execute("input".to_string()).await;
57/// # }
58/// ```
59pub struct Intercept<In, Out, S> {
60    inner: Arc<InterceptInner<In, Out>>,
61    service: S,
62}
63
64impl<In, Out, S: Clone> Clone for Intercept<In, Out, S> {
65    fn clone(&self) -> Self {
66        Self {
67            inner: Arc::clone(&self.inner),
68            service: self.service.clone(),
69        }
70    }
71}
72
73impl<In, Out, S: Debug> Debug for Intercept<In, Out, S> {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("Intercept").field("service", &self.service).finish_non_exhaustive()
76    }
77}
78
79/// Builder for creating `Intercept` middleware.
80///
81/// Provides a fluent API for configuring input and output observers and modifiers.
82/// Create with `Intercept::layer()`.
83///
84/// # Examples
85///
86/// ```
87/// # use layered::{Execute, Stack, Intercept, Service};
88/// # async fn example() {
89/// let stack = (
90///     Intercept::layer(), // Create a new interception layer
91///     Execute::new(|input: String| async move { input }),
92/// );
93///
94/// let service = stack.into_service();
95/// let response = service.execute("input".to_string()).await;
96/// # }
97/// ```
98#[derive(Clone)]
99pub struct InterceptLayer<In, Out> {
100    on_input: Vec<OnInput<In>>,
101    modify_input: Vec<ModifyInput<In, Out>>,
102    modify_output: Vec<ModifyOutput<Out>>,
103    on_output: Vec<OnOutput<Out>>,
104}
105
106impl<In, Out> Debug for InterceptLayer<In, Out> {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("InterceptLayer")
109            .field("on_input", &self.on_input.len())
110            .field("modify_input", &self.modify_input.len())
111            .field("modify_output", &self.modify_output.len())
112            .field("on_output", &self.on_output.len())
113            .finish_non_exhaustive()
114    }
115}
116
117impl<In, Out> Intercept<In, Out, ()> {
118    /// Creates a new `InterceptLayer` for building interception middleware.
119    ///
120    /// # Examples
121    ///
122    /// ```
123    /// # use layered::{Execute, Stack, Intercept, Service};
124    /// # async fn example() {
125    /// let stack = (
126    ///     Intercept::layer(), // Create a new interception layer, no observers yet
127    ///     Execute::new(|input: String| async move { input }),
128    /// );
129    ///
130    /// let service = stack.into_service();
131    /// let response = service.execute("input".to_string()).await;
132    /// # }
133    /// ```
134    #[must_use]
135    pub fn layer() -> InterceptLayer<In, Out> {
136        InterceptLayer {
137            on_input: Vec::default(),
138            modify_input: Vec::default(),
139            modify_output: Vec::default(),
140            on_output: Vec::default(),
141        }
142    }
143}
144
145impl<In: Send, Out, S> Service<In> for Intercept<In, Out, S>
146where
147    S: Service<In, Out = Out>,
148{
149    type Out = Out;
150
151    /// Executes the wrapped service with interception and modification.
152    ///
153    /// Execution order: input observers → input modifications → service execution
154    /// → output observers → output modifications. Input modifications can short-circuit
155    /// execution by returning `ControlFlow::Break`.
156    async fn execute(&self, mut input: In) -> Self::Out {
157        match self.inner.before_execute(input) {
158            ControlFlow::Break(output) => return output,
159            ControlFlow::Continue(new_input) => input = new_input,
160        }
161
162        let output = self.service.execute(input).await;
163
164        self.inner.after_execute(output)
165    }
166}
167
168/// Future returned by [`Intercept`] when used as a tower [`Service`](tower_service::Service).
169#[cfg(any(feature = "tower-service", test))]
170pub struct InterceptFuture<Out> {
171    inner: Pin<Box<dyn Future<Output = Out> + Send>>,
172}
173
174#[cfg(any(feature = "tower-service", test))]
175impl<Out> Debug for InterceptFuture<Out> {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("InterceptFuture").finish_non_exhaustive()
178    }
179}
180
181#[cfg(any(feature = "tower-service", test))]
182impl<Out> Future for InterceptFuture<Out> {
183    type Output = Out;
184
185    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186        self.inner.as_mut().poll(cx)
187    }
188}
189
190#[cfg(any(feature = "tower-service", test))]
191impl<Req, Res, Err, S> tower_service::Service<Req> for Intercept<Req, Result<Res, Err>, S>
192where
193    Err: Send + 'static,
194    Req: Send + 'static,
195    Res: Send + 'static,
196    S: tower_service::Service<Req, Response = Res, Error = Err> + Send + Sync + 'static,
197    S::Future: Send + 'static,
198{
199    type Response = Res;
200    type Error = Err;
201    type Future = InterceptFuture<Result<Res, Err>>;
202
203    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
204        self.service.poll_ready(cx)
205    }
206
207    fn call(&mut self, req: Req) -> Self::Future {
208        let result = self.inner.before_execute(req);
209        let req = match result {
210            ControlFlow::Break(result) => {
211                return InterceptFuture {
212                    inner: Box::pin(async move { result }),
213                };
214            }
215            ControlFlow::Continue(new_req) => new_req,
216        };
217
218        let inner = Arc::clone(&self.inner);
219        let future = self.service.call(req);
220
221        InterceptFuture {
222            inner: Box::pin(async move {
223                let r = future.await;
224                inner.after_execute(r)
225            }),
226        }
227    }
228}
229
230impl<In, Out> InterceptLayer<In, Out> {
231    /// Adds an observer for incoming inputs.
232    ///
233    /// Called before input modifications. Multiple observers execute in registration order.
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// # use layered::{Execute, Stack, Intercept, Service};
239    /// # async fn example() {
240    /// let stack = (
241    ///     Intercept::layer()
242    ///         .on_input(|input| println!("processing: {input}"))
243    ///         .on_input(|input| println!("another: {input}")),
244    ///     Execute::new(|input: String| async move { input }),
245    /// );
246    ///
247    /// let service = stack.into_service();
248    /// let response = service.execute("input".to_string()).await;
249    /// # }
250    /// ```
251    #[must_use]
252    pub fn on_input<F>(mut self, f: F) -> Self
253    where
254        F: Fn(&In) + Send + Sync + 'static,
255    {
256        self.on_input.push(OnInput(Arc::new(f)));
257        self
258    }
259
260    /// Adds an observer for outgoing outputs.
261    ///
262    /// Called before output modifications. Multiple observers execute in registration order.
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// # use layered::{Execute, Stack, Intercept, Service};
268    /// # async fn example() {
269    /// let stack = (
270    ///     Intercept::layer()
271    ///         .on_output(|output| println!("response: {output}"))
272    ///         .on_output(|output| println!("another response: {output}")),
273    ///     Execute::new(|input: String| async move { input }),
274    /// );
275    ///
276    /// let service = stack.into_service();
277    /// let response = service.execute("input".to_string()).await;
278    /// # }
279    /// ```
280    #[must_use]
281    pub fn on_output<F>(mut self, f: F) -> Self
282    where
283        F: Fn(&Out) + Send + Sync + 'static,
284    {
285        self.on_output.push(OnOutput(Arc::new(f)));
286        self
287    }
288
289    /// Adds a transformation for incoming inputs.
290    ///
291    /// Transforms inputs before service execution. Multiple modifications apply
292    /// in registration order, each receiving the previous output.
293    ///
294    /// # Examples
295    ///
296    /// ```
297    /// # use layered::{Execute, Stack, Intercept, Service};
298    /// # async fn example() {
299    /// let stack = (
300    ///     Intercept::layer()
301    ///         .modify_input(|input: String| input.trim().to_string())
302    ///         .modify_input(|input| input.to_lowercase()),
303    ///     Execute::new(|input: String| async move { input }),
304    /// );
305    ///
306    /// let service = stack.into_service();
307    /// let response = service.execute("input".to_string()).await;
308    /// # }
309    /// ```
310    #[must_use]
311    pub fn modify_input<F>(self, f: F) -> Self
312    where
313        F: Fn(In) -> In + Send + Sync + 'static,
314    {
315        self.input_control_flow(move |input| ControlFlow::Continue(f(input)))
316    }
317
318    /// Adds a modification function with control flow for incoming requests.
319    /// Returns `ControlFlow::Break` to short-circuit execution and return early.
320    pub(crate) fn input_control_flow<F>(mut self, f: F) -> Self
321    where
322        F: Fn(In) -> ControlFlow<Out, In> + Send + Sync + 'static,
323    {
324        self.modify_input.push(ModifyInput(Arc::new(f)));
325        self
326    }
327
328    /// Adds a transformation for outgoing outputs.
329    ///
330    /// Transforms outputs after service execution. Multiple modifications apply
331    /// in registration order, each receiving the previous output.
332    ///
333    /// # Examples
334    ///
335    /// ```
336    /// # use layered::{Execute, Stack, Intercept, Service};
337    /// # async fn example() {
338    /// let stack = (
339    ///     Intercept::layer()
340    ///         .modify_output(|output: String| output.trim().to_string())
341    ///         .modify_output(|output| format!("Result: {}", output)),
342    ///     Execute::new(|input: String| async move { input }),
343    /// );
344    ///
345    /// let service = stack.into_service();
346    /// let response = service.execute("input".to_string()).await;
347    /// # }
348    /// ```
349    #[must_use]
350    pub fn modify_output<F>(mut self, f: F) -> Self
351    where
352        F: Fn(Out) -> Out + Send + Sync + 'static,
353    {
354        self.modify_output.push(ModifyOutput(Arc::new(f)));
355        self
356    }
357}
358
359impl<In, Out, S> crate::Layer<S> for InterceptLayer<In, Out> {
360    type Service = Intercept<In, Out, S>;
361
362    fn layer(&self, inner: S) -> Self::Service {
363        let intercept_inner = InterceptInner {
364            modify_input: self.modify_input.clone().into(),
365            on_input: self.on_input.clone().into(),
366            modify_output: self.modify_output.clone().into(),
367            on_output: self.on_output.clone().into(),
368        };
369
370        Intercept {
371            inner: Arc::new(intercept_inner),
372            service: inner,
373        }
374    }
375}
376
377struct OnInput<In>(Arc<dyn Fn(&In) + Send + Sync>);
378
379impl<In> Clone for OnInput<In> {
380    fn clone(&self) -> Self {
381        Self(Arc::clone(&self.0))
382    }
383}
384
385struct OnOutput<Out>(Arc<dyn Fn(&Out) + Send + Sync>);
386
387impl<Out> Clone for OnOutput<Out> {
388    fn clone(&self) -> Self {
389        Self(Arc::clone(&self.0))
390    }
391}
392
393struct ModifyInput<In, Out>(Arc<dyn Fn(In) -> ControlFlow<Out, In> + Send + Sync>);
394
395impl<In, Out> Clone for ModifyInput<In, Out> {
396    fn clone(&self) -> Self {
397        Self(Arc::clone(&self.0))
398    }
399}
400
401struct ModifyOutput<Out>(Arc<dyn Fn(Out) -> Out + Send + Sync>);
402
403impl<Out> Clone for ModifyOutput<Out> {
404    fn clone(&self) -> Self {
405        Self(Arc::clone(&self.0))
406    }
407}
408
409struct InterceptInner<In, Out> {
410    modify_input: Arc<[ModifyInput<In, Out>]>,
411    on_input: Arc<[OnInput<In>]>,
412    modify_output: Arc<[ModifyOutput<Out>]>,
413    on_output: Arc<[OnOutput<Out>]>,
414}
415
416impl<In, Out> InterceptInner<In, Out> {
417    #[inline]
418    fn before_execute(&self, mut input: In) -> ControlFlow<Out, In> {
419        for on_input in self.on_input.iter() {
420            on_input.0(&input);
421        }
422
423        for modify in self.modify_input.iter() {
424            match modify.0(input) {
425                ControlFlow::Break(output) => return ControlFlow::Break(output),
426                ControlFlow::Continue(new_input) => input = new_input,
427            }
428        }
429
430        ControlFlow::Continue(input)
431    }
432
433    #[inline]
434    fn after_execute(&self, mut output: Out) -> Out {
435        for on_output in self.on_output.iter() {
436            on_output.0(&output);
437        }
438
439        for modify in self.modify_output.iter() {
440            output = modify.0(output);
441        }
442
443        output
444    }
445}
446
447#[cfg_attr(coverage_nightly, coverage(off))]
448#[cfg(test)]
449mod tests {
450    use std::future::poll_fn;
451    use std::sync::atomic::{AtomicU16, Ordering};
452
453    use futures::executor::block_on;
454    use tower_service::Service as TowerService;
455
456    use super::*;
457    use crate::{Execute, Layer, Stack};
458
459    #[test]
460    pub fn ensure_types() {
461        static_assertions::assert_impl_all!(Intercept::<String, String, ()>: Debug, Clone, Send, Sync);
462        static_assertions::assert_impl_all!(InterceptLayer::<String, String>: Debug, Clone, Send, Sync);
463    }
464
465    #[test]
466    #[expect(clippy::similar_names, reason = "Test")]
467    fn input_modification_order() {
468        let called = Arc::new(AtomicU16::default());
469        let called_clone = Arc::clone(&called);
470
471        let called2 = Arc::new(AtomicU16::default());
472        let called2_clone = Arc::clone(&called2);
473
474        let stack = (
475            Intercept::layer()
476                .modify_input(|input: String| format!("{input}1"))
477                .modify_input(|input: String| format!("{input}2"))
478                .on_input(move |_input| {
479                    called.fetch_add(1, Ordering::Relaxed);
480                })
481                .on_input(move |_input| {
482                    called2.fetch_add(1, Ordering::Relaxed);
483                }),
484            Execute::new(|input: String| async move { input }),
485        );
486
487        let service = stack.into_service();
488        let response = block_on(service.execute("test".to_string()));
489        assert_eq!(called_clone.load(Ordering::Relaxed), 1);
490        assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
491        assert_eq!(response, "test12");
492    }
493
494    #[test]
495    #[expect(clippy::similar_names, reason = "Test")]
496    fn out_modification_order() {
497        let called = Arc::new(AtomicU16::default());
498        let called_clone = Arc::clone(&called);
499
500        let called2 = Arc::new(AtomicU16::default());
501        let called2_clone = Arc::clone(&called2);
502
503        let stack = (
504            Intercept::layer()
505                .modify_output(|output: String| format!("{output}1"))
506                .modify_output(|output: String| format!("{output}2"))
507                .on_output(move |_output| {
508                    called.fetch_add(1, Ordering::Relaxed);
509                })
510                .on_output(move |_output| {
511                    called2.fetch_add(1, Ordering::Relaxed);
512                }),
513            Execute::new(|input: String| async move { input }),
514        );
515
516        let service = stack.into_service();
517        let response = block_on(service.execute("test".to_string()));
518        assert_eq!(called_clone.load(Ordering::Relaxed), 1);
519        assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
520        assert_eq!(response, "test12");
521    }
522
523    #[test]
524    #[expect(clippy::similar_names, reason = "Test")]
525    fn tower_service() {
526        let called = Arc::new(AtomicU16::default());
527        let called_clone = Arc::clone(&called);
528
529        let called2 = Arc::new(AtomicU16::default());
530        let called2_clone = Arc::clone(&called2);
531
532        let stack = (
533            Intercept::layer()
534                .modify_input(|input: String| format!("{input}1"))
535                .modify_input(|input: String| format!("{input}2"))
536                .on_input(move |_input| {
537                    called.fetch_add(1, Ordering::Relaxed);
538                })
539                .on_input(move |_input| {
540                    called2.fetch_add(1, Ordering::Relaxed);
541                }),
542            Execute::new(|input: String| async move { Ok::<_, String>(input) }),
543        );
544
545        let mut service = stack.into_service();
546        let future = async move {
547            poll_fn(|cx| service.poll_ready(cx)).await.unwrap();
548            let response = service.call("test".to_string()).await.unwrap();
549            assert_eq!(response, "test12");
550        };
551
552        block_on(future);
553
554        assert_eq!(called_clone.load(Ordering::Relaxed), 1);
555        assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
556    }
557
558    // Mock service for testing poll_ready behavior
559    struct MockService {
560        poll_ready_response: Poll<Result<(), String>>,
561    }
562
563    impl MockService {
564        fn new(poll_ready_response: Poll<Result<(), String>>) -> Self {
565            Self { poll_ready_response }
566        }
567    }
568
569    impl TowerService<String> for MockService {
570        type Response = String;
571        type Error = String;
572        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
573
574        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
575            self.poll_ready_response.clone()
576        }
577
578        fn call(&mut self, req: String) -> Self::Future {
579            Box::pin(async move { Ok(req) })
580        }
581    }
582
583    #[test]
584    fn poll_ready_propagates_pending() {
585        let mock_service = MockService::new(Poll::Pending);
586        let intercept_layer = InterceptLayer {
587            on_input: Vec::default(),
588            modify_input: Vec::default(),
589            modify_output: Vec::default(),
590            on_output: Vec::default(),
591        };
592        let mut intercept = intercept_layer.layer(mock_service);
593
594        let waker = futures::task::noop_waker();
595        let mut cx = Context::from_waker(&waker);
596
597        let result = intercept.poll_ready(&mut cx);
598        assert!(result.is_pending());
599    }
600
601    #[test]
602    fn poll_ready_propagates_error() {
603        let mock_service = MockService::new(Poll::Ready(Err("service error".to_string())));
604        let intercept_layer = InterceptLayer {
605            on_input: Vec::default(),
606            modify_input: Vec::default(),
607            modify_output: Vec::default(),
608            on_output: Vec::default(),
609        };
610        let mut intercept = intercept_layer.layer(mock_service);
611
612        let waker = futures::task::noop_waker();
613        let mut cx = Context::from_waker(&waker);
614
615        let result = intercept.poll_ready(&mut cx);
616        match result {
617            Poll::Ready(Err(err)) => assert_eq!(err, "service error"),
618            _ => panic!("Expected Poll::Ready(Err), got {result:?}"),
619        }
620    }
621
622    #[test]
623    fn poll_ready_propagates_success() {
624        let mock_service = MockService::new(Poll::Ready(Ok(())));
625        let intercept_layer = InterceptLayer {
626            on_input: Vec::default(),
627            modify_input: Vec::default(),
628            modify_output: Vec::default(),
629            on_output: Vec::default(),
630        };
631        let mut intercept = intercept_layer.layer(mock_service);
632
633        let waker = futures::task::noop_waker();
634        let mut cx = Context::from_waker(&waker);
635
636        let result = intercept.poll_ready(&mut cx);
637        match result {
638            Poll::Ready(Ok(())) => (),
639            _ => panic!("Expected Poll::Ready(Ok(())), got {result:?}"),
640        }
641    }
642
643    #[test]
644    fn debug_intercept() {
645        let debug_str = format!("{:?}", Intercept::<String, String, ()>::layer().layer("inner"));
646
647        assert_eq!(debug_str, "Intercept { service: \"inner\", .. }");
648    }
649
650    #[test]
651    fn debug_intercept_layer() {
652        let debug_str = format!("{:?}", Intercept::<String, String, ()>::layer());
653
654        assert_eq!(
655            debug_str,
656            "InterceptLayer { on_input: 0, modify_input: 0, modify_output: 0, on_output: 0, .. }"
657        );
658    }
659
660    #[test]
661    fn clone_intercept() {
662        let cloned = Intercept::<String, String, ()>::layer().layer("inner").clone();
663
664        assert_eq!(cloned.service, "inner");
665    }
666
667    #[test]
668    fn debug_intercept_future() {
669        let future: InterceptFuture<String> = InterceptFuture {
670            inner: Box::pin(async { "test".to_string() }),
671        };
672        let debug_str = format!("{future:?}");
673        assert!(debug_str.contains("InterceptFuture"));
674    }
675
676    #[test]
677    fn short_circuit_layered() {
678        let stack = (
679            Intercept::layer().input_control_flow(|_: String| ControlFlow::Break("rejected".into())),
680            Execute::new(|_: String| async { "should not run".to_string() }),
681        );
682        let svc = stack.into_service();
683        assert_eq!(block_on(svc.execute("test".into())), "rejected");
684    }
685
686    #[test]
687    fn short_circuit_tower() {
688        let stack = (
689            Intercept::layer().input_control_flow(|_: String| ControlFlow::Break(Ok("rejected".into()))),
690            Execute::new(|_: String| async { Ok::<_, ()>("should not run".into()) }),
691        );
692        let mut svc = stack.into_service();
693        let res = block_on(async {
694            poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
695            svc.call("test".into()).await
696        });
697        assert_eq!(res, Ok("rejected".to_string()));
698    }
699}