1use std::fmt::Debug;
5use std::ops::ControlFlow;
6#[cfg(any(feature = "tower-service", test))]
7use std::pin::Pin;
8use std::sync::Arc;
9#[cfg(any(feature = "tower-service", test))]
10use std::task::{Context, Poll};
11
12use crate::Service;
13
14pub struct Intercept<In, Out, S> {
60 inner: Arc<InterceptInner<In, Out>>,
61 service: S,
62}
63
64impl<In, Out, S: Clone> Clone for Intercept<In, Out, S> {
65 fn clone(&self) -> Self {
66 Self {
67 inner: Arc::clone(&self.inner),
68 service: self.service.clone(),
69 }
70 }
71}
72
73impl<In, Out, S: Debug> Debug for Intercept<In, Out, S> {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("Intercept").field("service", &self.service).finish_non_exhaustive()
76 }
77}
78
79#[derive(Clone)]
99pub struct InterceptLayer<In, Out> {
100 on_input: Vec<OnInput<In>>,
101 modify_input: Vec<ModifyInput<In, Out>>,
102 modify_output: Vec<ModifyOutput<Out>>,
103 on_output: Vec<OnOutput<Out>>,
104}
105
106impl<In, Out> Debug for InterceptLayer<In, Out> {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("InterceptLayer")
109 .field("on_input", &self.on_input.len())
110 .field("modify_input", &self.modify_input.len())
111 .field("modify_output", &self.modify_output.len())
112 .field("on_output", &self.on_output.len())
113 .finish_non_exhaustive()
114 }
115}
116
117impl<In, Out> Intercept<In, Out, ()> {
118 #[must_use]
135 pub fn layer() -> InterceptLayer<In, Out> {
136 InterceptLayer {
137 on_input: Vec::default(),
138 modify_input: Vec::default(),
139 modify_output: Vec::default(),
140 on_output: Vec::default(),
141 }
142 }
143}
144
145impl<In: Send, Out, S> Service<In> for Intercept<In, Out, S>
146where
147 S: Service<In, Out = Out>,
148{
149 type Out = Out;
150
151 async fn execute(&self, mut input: In) -> Self::Out {
157 match self.inner.before_execute(input) {
158 ControlFlow::Break(output) => return output,
159 ControlFlow::Continue(new_input) => input = new_input,
160 }
161
162 let output = self.service.execute(input).await;
163
164 self.inner.after_execute(output)
165 }
166}
167
168#[cfg(any(feature = "tower-service", test))]
170pub struct InterceptFuture<Out> {
171 inner: Pin<Box<dyn Future<Output = Out> + Send>>,
172}
173
174#[cfg(any(feature = "tower-service", test))]
175impl<Out> Debug for InterceptFuture<Out> {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("InterceptFuture").finish_non_exhaustive()
178 }
179}
180
181#[cfg(any(feature = "tower-service", test))]
182impl<Out> Future for InterceptFuture<Out> {
183 type Output = Out;
184
185 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
186 self.inner.as_mut().poll(cx)
187 }
188}
189
190#[cfg(any(feature = "tower-service", test))]
191impl<Req, Res, Err, S> tower_service::Service<Req> for Intercept<Req, Result<Res, Err>, S>
192where
193 Err: Send + 'static,
194 Req: Send + 'static,
195 Res: Send + 'static,
196 S: tower_service::Service<Req, Response = Res, Error = Err> + Send + Sync + 'static,
197 S::Future: Send + 'static,
198{
199 type Response = Res;
200 type Error = Err;
201 type Future = InterceptFuture<Result<Res, Err>>;
202
203 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
204 self.service.poll_ready(cx)
205 }
206
207 fn call(&mut self, req: Req) -> Self::Future {
208 let result = self.inner.before_execute(req);
209 let req = match result {
210 ControlFlow::Break(result) => {
211 return InterceptFuture {
212 inner: Box::pin(async move { result }),
213 };
214 }
215 ControlFlow::Continue(new_req) => new_req,
216 };
217
218 let inner = Arc::clone(&self.inner);
219 let future = self.service.call(req);
220
221 InterceptFuture {
222 inner: Box::pin(async move {
223 let r = future.await;
224 inner.after_execute(r)
225 }),
226 }
227 }
228}
229
230impl<In, Out> InterceptLayer<In, Out> {
231 #[must_use]
252 pub fn on_input<F>(mut self, f: F) -> Self
253 where
254 F: Fn(&In) + Send + Sync + 'static,
255 {
256 self.on_input.push(OnInput(Arc::new(f)));
257 self
258 }
259
260 #[must_use]
281 pub fn on_output<F>(mut self, f: F) -> Self
282 where
283 F: Fn(&Out) + Send + Sync + 'static,
284 {
285 self.on_output.push(OnOutput(Arc::new(f)));
286 self
287 }
288
289 #[must_use]
311 pub fn modify_input<F>(self, f: F) -> Self
312 where
313 F: Fn(In) -> In + Send + Sync + 'static,
314 {
315 self.input_control_flow(move |input| ControlFlow::Continue(f(input)))
316 }
317
318 pub(crate) fn input_control_flow<F>(mut self, f: F) -> Self
321 where
322 F: Fn(In) -> ControlFlow<Out, In> + Send + Sync + 'static,
323 {
324 self.modify_input.push(ModifyInput(Arc::new(f)));
325 self
326 }
327
328 #[must_use]
350 pub fn modify_output<F>(mut self, f: F) -> Self
351 where
352 F: Fn(Out) -> Out + Send + Sync + 'static,
353 {
354 self.modify_output.push(ModifyOutput(Arc::new(f)));
355 self
356 }
357}
358
359impl<In, Out, S> crate::Layer<S> for InterceptLayer<In, Out> {
360 type Service = Intercept<In, Out, S>;
361
362 fn layer(&self, inner: S) -> Self::Service {
363 let intercept_inner = InterceptInner {
364 modify_input: self.modify_input.clone().into(),
365 on_input: self.on_input.clone().into(),
366 modify_output: self.modify_output.clone().into(),
367 on_output: self.on_output.clone().into(),
368 };
369
370 Intercept {
371 inner: Arc::new(intercept_inner),
372 service: inner,
373 }
374 }
375}
376
377struct OnInput<In>(Arc<dyn Fn(&In) + Send + Sync>);
378
379impl<In> Clone for OnInput<In> {
380 fn clone(&self) -> Self {
381 Self(Arc::clone(&self.0))
382 }
383}
384
385struct OnOutput<Out>(Arc<dyn Fn(&Out) + Send + Sync>);
386
387impl<Out> Clone for OnOutput<Out> {
388 fn clone(&self) -> Self {
389 Self(Arc::clone(&self.0))
390 }
391}
392
393struct ModifyInput<In, Out>(Arc<dyn Fn(In) -> ControlFlow<Out, In> + Send + Sync>);
394
395impl<In, Out> Clone for ModifyInput<In, Out> {
396 fn clone(&self) -> Self {
397 Self(Arc::clone(&self.0))
398 }
399}
400
401struct ModifyOutput<Out>(Arc<dyn Fn(Out) -> Out + Send + Sync>);
402
403impl<Out> Clone for ModifyOutput<Out> {
404 fn clone(&self) -> Self {
405 Self(Arc::clone(&self.0))
406 }
407}
408
409struct InterceptInner<In, Out> {
410 modify_input: Arc<[ModifyInput<In, Out>]>,
411 on_input: Arc<[OnInput<In>]>,
412 modify_output: Arc<[ModifyOutput<Out>]>,
413 on_output: Arc<[OnOutput<Out>]>,
414}
415
416impl<In, Out> InterceptInner<In, Out> {
417 #[inline]
418 fn before_execute(&self, mut input: In) -> ControlFlow<Out, In> {
419 for on_input in self.on_input.iter() {
420 on_input.0(&input);
421 }
422
423 for modify in self.modify_input.iter() {
424 match modify.0(input) {
425 ControlFlow::Break(output) => return ControlFlow::Break(output),
426 ControlFlow::Continue(new_input) => input = new_input,
427 }
428 }
429
430 ControlFlow::Continue(input)
431 }
432
433 #[inline]
434 fn after_execute(&self, mut output: Out) -> Out {
435 for on_output in self.on_output.iter() {
436 on_output.0(&output);
437 }
438
439 for modify in self.modify_output.iter() {
440 output = modify.0(output);
441 }
442
443 output
444 }
445}
446
447#[cfg_attr(coverage_nightly, coverage(off))]
448#[cfg(test)]
449mod tests {
450 use std::future::poll_fn;
451 use std::sync::atomic::{AtomicU16, Ordering};
452
453 use futures::executor::block_on;
454 use tower_service::Service as TowerService;
455
456 use super::*;
457 use crate::{Execute, Layer, Stack};
458
459 #[test]
460 pub fn ensure_types() {
461 static_assertions::assert_impl_all!(Intercept::<String, String, ()>: Debug, Clone, Send, Sync);
462 static_assertions::assert_impl_all!(InterceptLayer::<String, String>: Debug, Clone, Send, Sync);
463 }
464
465 #[test]
466 #[expect(clippy::similar_names, reason = "Test")]
467 fn input_modification_order() {
468 let called = Arc::new(AtomicU16::default());
469 let called_clone = Arc::clone(&called);
470
471 let called2 = Arc::new(AtomicU16::default());
472 let called2_clone = Arc::clone(&called2);
473
474 let stack = (
475 Intercept::layer()
476 .modify_input(|input: String| format!("{input}1"))
477 .modify_input(|input: String| format!("{input}2"))
478 .on_input(move |_input| {
479 called.fetch_add(1, Ordering::Relaxed);
480 })
481 .on_input(move |_input| {
482 called2.fetch_add(1, Ordering::Relaxed);
483 }),
484 Execute::new(|input: String| async move { input }),
485 );
486
487 let service = stack.into_service();
488 let response = block_on(service.execute("test".to_string()));
489 assert_eq!(called_clone.load(Ordering::Relaxed), 1);
490 assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
491 assert_eq!(response, "test12");
492 }
493
494 #[test]
495 #[expect(clippy::similar_names, reason = "Test")]
496 fn out_modification_order() {
497 let called = Arc::new(AtomicU16::default());
498 let called_clone = Arc::clone(&called);
499
500 let called2 = Arc::new(AtomicU16::default());
501 let called2_clone = Arc::clone(&called2);
502
503 let stack = (
504 Intercept::layer()
505 .modify_output(|output: String| format!("{output}1"))
506 .modify_output(|output: String| format!("{output}2"))
507 .on_output(move |_output| {
508 called.fetch_add(1, Ordering::Relaxed);
509 })
510 .on_output(move |_output| {
511 called2.fetch_add(1, Ordering::Relaxed);
512 }),
513 Execute::new(|input: String| async move { input }),
514 );
515
516 let service = stack.into_service();
517 let response = block_on(service.execute("test".to_string()));
518 assert_eq!(called_clone.load(Ordering::Relaxed), 1);
519 assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
520 assert_eq!(response, "test12");
521 }
522
523 #[test]
524 #[expect(clippy::similar_names, reason = "Test")]
525 fn tower_service() {
526 let called = Arc::new(AtomicU16::default());
527 let called_clone = Arc::clone(&called);
528
529 let called2 = Arc::new(AtomicU16::default());
530 let called2_clone = Arc::clone(&called2);
531
532 let stack = (
533 Intercept::layer()
534 .modify_input(|input: String| format!("{input}1"))
535 .modify_input(|input: String| format!("{input}2"))
536 .on_input(move |_input| {
537 called.fetch_add(1, Ordering::Relaxed);
538 })
539 .on_input(move |_input| {
540 called2.fetch_add(1, Ordering::Relaxed);
541 }),
542 Execute::new(|input: String| async move { Ok::<_, String>(input) }),
543 );
544
545 let mut service = stack.into_service();
546 let future = async move {
547 poll_fn(|cx| service.poll_ready(cx)).await.unwrap();
548 let response = service.call("test".to_string()).await.unwrap();
549 assert_eq!(response, "test12");
550 };
551
552 block_on(future);
553
554 assert_eq!(called_clone.load(Ordering::Relaxed), 1);
555 assert_eq!(called2_clone.load(Ordering::Relaxed), 1);
556 }
557
558 struct MockService {
560 poll_ready_response: Poll<Result<(), String>>,
561 }
562
563 impl MockService {
564 fn new(poll_ready_response: Poll<Result<(), String>>) -> Self {
565 Self { poll_ready_response }
566 }
567 }
568
569 impl TowerService<String> for MockService {
570 type Response = String;
571 type Error = String;
572 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
573
574 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
575 self.poll_ready_response.clone()
576 }
577
578 fn call(&mut self, req: String) -> Self::Future {
579 Box::pin(async move { Ok(req) })
580 }
581 }
582
583 #[test]
584 fn poll_ready_propagates_pending() {
585 let mock_service = MockService::new(Poll::Pending);
586 let intercept_layer = InterceptLayer {
587 on_input: Vec::default(),
588 modify_input: Vec::default(),
589 modify_output: Vec::default(),
590 on_output: Vec::default(),
591 };
592 let mut intercept = intercept_layer.layer(mock_service);
593
594 let waker = futures::task::noop_waker();
595 let mut cx = Context::from_waker(&waker);
596
597 let result = intercept.poll_ready(&mut cx);
598 assert!(result.is_pending());
599 }
600
601 #[test]
602 fn poll_ready_propagates_error() {
603 let mock_service = MockService::new(Poll::Ready(Err("service error".to_string())));
604 let intercept_layer = InterceptLayer {
605 on_input: Vec::default(),
606 modify_input: Vec::default(),
607 modify_output: Vec::default(),
608 on_output: Vec::default(),
609 };
610 let mut intercept = intercept_layer.layer(mock_service);
611
612 let waker = futures::task::noop_waker();
613 let mut cx = Context::from_waker(&waker);
614
615 let result = intercept.poll_ready(&mut cx);
616 match result {
617 Poll::Ready(Err(err)) => assert_eq!(err, "service error"),
618 _ => panic!("Expected Poll::Ready(Err), got {result:?}"),
619 }
620 }
621
622 #[test]
623 fn poll_ready_propagates_success() {
624 let mock_service = MockService::new(Poll::Ready(Ok(())));
625 let intercept_layer = InterceptLayer {
626 on_input: Vec::default(),
627 modify_input: Vec::default(),
628 modify_output: Vec::default(),
629 on_output: Vec::default(),
630 };
631 let mut intercept = intercept_layer.layer(mock_service);
632
633 let waker = futures::task::noop_waker();
634 let mut cx = Context::from_waker(&waker);
635
636 let result = intercept.poll_ready(&mut cx);
637 match result {
638 Poll::Ready(Ok(())) => (),
639 _ => panic!("Expected Poll::Ready(Ok(())), got {result:?}"),
640 }
641 }
642
643 #[test]
644 fn debug_intercept() {
645 let debug_str = format!("{:?}", Intercept::<String, String, ()>::layer().layer("inner"));
646
647 assert_eq!(debug_str, "Intercept { service: \"inner\", .. }");
648 }
649
650 #[test]
651 fn debug_intercept_layer() {
652 let debug_str = format!("{:?}", Intercept::<String, String, ()>::layer());
653
654 assert_eq!(
655 debug_str,
656 "InterceptLayer { on_input: 0, modify_input: 0, modify_output: 0, on_output: 0, .. }"
657 );
658 }
659
660 #[test]
661 fn clone_intercept() {
662 let cloned = Intercept::<String, String, ()>::layer().layer("inner").clone();
663
664 assert_eq!(cloned.service, "inner");
665 }
666
667 #[test]
668 fn debug_intercept_future() {
669 let future: InterceptFuture<String> = InterceptFuture {
670 inner: Box::pin(async { "test".to_string() }),
671 };
672 let debug_str = format!("{future:?}");
673 assert!(debug_str.contains("InterceptFuture"));
674 }
675
676 #[test]
677 fn short_circuit_layered() {
678 let stack = (
679 Intercept::layer().input_control_flow(|_: String| ControlFlow::Break("rejected".into())),
680 Execute::new(|_: String| async { "should not run".to_string() }),
681 );
682 let svc = stack.into_service();
683 assert_eq!(block_on(svc.execute("test".into())), "rejected");
684 }
685
686 #[test]
687 fn short_circuit_tower() {
688 let stack = (
689 Intercept::layer().input_control_flow(|_: String| ControlFlow::Break(Ok("rejected".into()))),
690 Execute::new(|_: String| async { Ok::<_, ()>("should not run".into()) }),
691 );
692 let mut svc = stack.into_service();
693 let res = block_on(async {
694 poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
695 svc.call("test".into()).await
696 });
697 assert_eq!(res, Ok("rejected".to_string()));
698 }
699}