1use std::{
2 fmt,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use axum::{body::BoxBody, response::IntoResponse};
9use futures::Future;
10use http::{Request, Response};
11use http_body::Body;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::errors::ApiResult;
16
17pub trait Interception: Send + Sync {
18 type Carryover: Send + Sync;
19
20 fn on_request(&self, parts: &mut http::request::Parts) -> ApiResult<Self::Carryover>;
21
22 fn on_response(
23 &self,
24 carryover: Self::Carryover,
25 parts: &mut http::response::Parts,
26 ) -> ApiResult<()>;
27}
28
29#[derive(Default, Clone)]
30pub struct BasicInterception {
31 on_request: Option<Arc<dyn Fn(&mut http::request::Parts) -> ApiResult<()> + Send + Sync>>,
32 on_response: Option<Arc<dyn Fn(&mut http::response::Parts) -> ApiResult<()> + Send + Sync>>,
33}
34
35impl BasicInterception {
36 pub fn on_request(
37 mut self,
38 func: impl Fn(&mut http::request::Parts) -> ApiResult<()> + Send + Sync + 'static,
39 ) -> Self {
40 self.on_request = Some(Arc::new(func));
41 self
42 }
43
44 pub fn on_response(
45 mut self,
46 func: impl Fn(&mut http::response::Parts) -> ApiResult<()> + Send + Sync + 'static,
47 ) -> Self {
48 self.on_response = Some(Arc::new(func));
49 self
50 }
51}
52
53impl Interception for BasicInterception {
54 type Carryover = ();
55
56 fn on_request(&self, parts: &mut http::request::Parts) -> ApiResult<Self::Carryover> {
57 if let Some(on_request) = &self.on_request {
58 on_request(parts)?;
59 }
60 Ok(())
61 }
62
63 fn on_response(&self, _: Self::Carryover, parts: &mut http::response::Parts) -> ApiResult<()> {
64 if let Some(on_response) = &self.on_response {
65 on_response(parts)?;
66 }
67 Ok(())
68 }
69}
70
71#[derive(Clone)]
72pub struct InterceptorLayer<I: Interception>(pub Arc<I>);
73
74impl<S, I: Interception> Layer<S> for InterceptorLayer<I> {
75 type Service = Interceptor<S, I>;
76
77 fn layer(&self, service: S) -> Self::Service {
78 Interceptor(service, self.0.clone())
79 }
80}
81
82#[derive(Clone)]
83pub struct Interceptor<S, I: Interception>(S, Arc<I>);
84
85#[pin_project::pin_project]
86pub struct InterceptorFuture<S, I: Interception, ReqBody, ResBody>
87where
88 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
89 S::Error: fmt::Display + 'static,
90{
91 #[pin]
92 inner: S::Future,
93 interception: Arc<I>,
94 carryover: Option<I::Carryover>,
95}
96
97impl<S, I: Interception, ReqBody> Future for InterceptorFuture<S, I, ReqBody, BoxBody>
98where
99 S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
100 S::Error: fmt::Display + 'static,
101{
102 type Output = <S::Future as Future>::Output;
103
104 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105 let this = self.project();
106 match this.inner.poll(cx) {
107 Poll::Pending => Poll::Pending,
108 Poll::Ready(Ok(mut response)) => {
109 let (mut parts, body) = response.into_parts();
110 if let Err(e) = this
111 .interception
112 .on_response(this.carryover.take().unwrap(), &mut parts)
113 {
114 return Poll::Ready(Ok(e.into_response()));
115 }
116 response = Response::from_parts(parts, body);
117 Poll::Ready(Ok(response))
118 }
119 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
120 }
121 }
122}
123
124impl<S, I: Interception + 'static, ReqBody> Service<Request<ReqBody>> for Interceptor<S, I>
125where
126 S: Service<Request<ReqBody>, Response = Response<BoxBody>>,
127 ReqBody: Body + 'static,
128 S: 'static,
129 S::Error: Send + fmt::Display + 'static,
130 S::Future: Send,
131{
132 type Response = Response<BoxBody>;
133 type Error = S::Error;
134 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
135
136 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137 self.0.poll_ready(cx)
138 }
139
140 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
141 let (mut parts, body) = req.into_parts();
142 let carryover = match self.1.on_request(&mut parts) {
143 Ok(x) => x,
144 Err(e) => return Box::pin(futures::future::ready(Ok(e.into_response()))),
145 };
146 req = Request::from_parts(parts, body);
147 let future = self.0.call(req);
148
149 Box::pin(InterceptorFuture::<S, I, ReqBody, BoxBody> {
150 inner: future,
151 interception: self.1.clone(),
152 carryover: Some(carryover),
153 })
154 }
155}