axum_util/
interceptor.rs

1use std::{
2    fmt,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use axum::{body::BoxBody, response::IntoResponse};
9use futures::Future;
10use http::{Request, Response};
11use http_body::Body;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::errors::ApiResult;
16
17pub trait Interception: Send + Sync {
18    type Carryover: Send + Sync;
19
20    fn on_request(&self, parts: &mut http::request::Parts) -> ApiResult<Self::Carryover>;
21
22    fn on_response(
23        &self,
24        carryover: Self::Carryover,
25        parts: &mut http::response::Parts,
26    ) -> ApiResult<()>;
27}
28
29#[derive(Default, Clone)]
30pub struct BasicInterception {
31    on_request: Option<Arc<dyn Fn(&mut http::request::Parts) -> ApiResult<()> + Send + Sync>>,
32    on_response: Option<Arc<dyn Fn(&mut http::response::Parts) -> ApiResult<()> + Send + Sync>>,
33}
34
35impl BasicInterception {
36    pub fn on_request(
37        mut self,
38        func: impl Fn(&mut http::request::Parts) -> ApiResult<()> + Send + Sync + 'static,
39    ) -> Self {
40        self.on_request = Some(Arc::new(func));
41        self
42    }
43
44    pub fn on_response(
45        mut self,
46        func: impl Fn(&mut http::response::Parts) -> ApiResult<()> + Send + Sync + 'static,
47    ) -> Self {
48        self.on_response = Some(Arc::new(func));
49        self
50    }
51}
52
53impl Interception for BasicInterception {
54    type Carryover = ();
55
56    fn on_request(&self, parts: &mut http::request::Parts) -> ApiResult<Self::Carryover> {
57        if let Some(on_request) = &self.on_request {
58            on_request(parts)?;
59        }
60        Ok(())
61    }
62
63    fn on_response(&self, _: Self::Carryover, parts: &mut http::response::Parts) -> ApiResult<()> {
64        if let Some(on_response) = &self.on_response {
65            on_response(parts)?;
66        }
67        Ok(())
68    }
69}
70
71#[derive(Clone)]
72pub struct InterceptorLayer<I: Interception>(pub Arc<I>);
73
74impl<S, I: Interception> Layer<S> for InterceptorLayer<I> {
75    type Service = Interceptor<S, I>;
76
77    fn layer(&self, service: S) -> Self::Service {
78        Interceptor(service, self.0.clone())
79    }
80}
81
82#[derive(Clone)]
83pub struct Interceptor<S, I: Interception>(S, Arc<I>);
84
85#[pin_project::pin_project]
86pub struct InterceptorFuture<S, I: Interception, ReqBody, ResBody>
87where
88    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
89    S::Error: fmt::Display + 'static,
90{
91    #[pin]
92    inner: S::Future,
93    interception: Arc<I>,
94    carryover: Option<I::Carryover>,
95}
96
97impl<S, I: Interception, ReqBody> Future for InterceptorFuture<S, I, ReqBody, BoxBody>
98where
99    S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
100    S::Error: fmt::Display + 'static,
101{
102    type Output = <S::Future as Future>::Output;
103
104    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105        let this = self.project();
106        match this.inner.poll(cx) {
107            Poll::Pending => Poll::Pending,
108            Poll::Ready(Ok(mut response)) => {
109                let (mut parts, body) = response.into_parts();
110                if let Err(e) = this
111                    .interception
112                    .on_response(this.carryover.take().unwrap(), &mut parts)
113                {
114                    return Poll::Ready(Ok(e.into_response()));
115                }
116                response = Response::from_parts(parts, body);
117                Poll::Ready(Ok(response))
118            }
119            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
120        }
121    }
122}
123
124impl<S, I: Interception + 'static, ReqBody> Service<Request<ReqBody>> for Interceptor<S, I>
125where
126    S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
127    ReqBody: Body + 'static,
128    S: 'static,
129    S::Error: Send + fmt::Display + 'static,
130    S::Future: Send,
131{
132    type Response = Response<BoxBody>;
133    type Error = S::Error;
134    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
135
136    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137        self.0.poll_ready(cx)
138    }
139
140    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
141        let (mut parts, body) = req.into_parts();
142        let carryover = match self.1.on_request(&mut parts) {
143            Ok(x) => x,
144            Err(e) => return Box::pin(futures::future::ready(Ok(e.into_response()))),
145        };
146        req = Request::from_parts(parts, body);
147        let future = self.0.call(req);
148
149        Box::pin(InterceptorFuture::<S, I, ReqBody, BoxBody> {
150            inner: future,
151            interception: self.1.clone(),
152            carryover: Some(carryover),
153        })
154    }
155}