1use futures_core::ready;
2use futures_util::future::{self, BoxFuture};
3use http::Request;
4use jsonwebtoken::TokenData;
5use pin_project::pin_project;
6use serde::de::DeserializeOwned;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tower_layer::Layer;
12use tower_service::Service;
13
14use crate::authorizer::Authorizer;
15use crate::AuthError;
16
17pub trait Authorize<B> {
19 type Future: Future<Output = Result<Request<B>, AuthError>>;
20
21 fn authorize(&self, request: Request<B>) -> Self::Future;
25}
26
27impl<S, B, C> Authorize<B> for AuthorizationService<S, C>
28where
29 B: Send + 'static,
30 C: Clone + DeserializeOwned + Send + Sync + 'static,
31{
32 type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
33
34 fn authorize(&self, mut request: Request<B>) -> Self::Future {
38 let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
39 .auths
40 .iter()
41 .filter_map(|a| a.extract_token(request.headers()).map(|t| (t, a.clone())))
42 .collect();
43
44 if tkns_auths.is_empty() {
45 return Box::pin(future::ready(Err(AuthError::MissingToken())));
46 }
47
48 Box::pin(async move {
49 let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer());
50 for (token, auth) in tkns_auths {
51 token_data = auth.check_auth(token.as_str()).await;
52 if token_data.is_ok() {
53 break;
54 }
55 }
56 match token_data {
57 Ok(tdata) => {
58 request.extensions_mut().insert(tdata);
62
63 Ok(request)
64 }
65 Err(err) => Err(err), }
67 })
68 }
69}
70
71#[derive(Clone)]
74pub struct AuthorizationLayer<C>
75where
76 C: Clone + DeserializeOwned + Send,
77{
78 auths: Vec<Arc<Authorizer<C>>>,
79}
80
81impl<C> AuthorizationLayer<C>
82where
83 C: Clone + DeserializeOwned + Send,
84{
85 pub fn new(auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationLayer<C> {
86 Self { auths }
87 }
88}
89
90impl<S, C> Layer<S> for AuthorizationLayer<C>
91where
92 C: Clone + DeserializeOwned + Send + Sync,
93{
94 type Service = AuthorizationService<S, C>;
95
96 fn layer(&self, inner: S) -> Self::Service {
97 AuthorizationService::new(inner, self.auths.clone())
98 }
99}
100
101#[derive(Clone)]
105pub enum JwtSource {
106 AuthorizationHeader,
110 Cookie(String),
114 }
117
118#[derive(Clone)]
119pub struct AuthorizationService<S, C>
120where
121 C: Clone + DeserializeOwned + Send,
122{
123 pub inner: S,
124 pub auths: Vec<Arc<Authorizer<C>>>,
125}
126
127impl<S, C> AuthorizationService<S, C>
128where
129 C: Clone + DeserializeOwned + Send,
130{
131 pub fn get_ref(&self) -> &S {
132 &self.inner
133 }
134
135 pub fn get_mut(&mut self) -> &mut S {
137 &mut self.inner
138 }
139
140 pub fn into_inner(self) -> S {
142 self.inner
143 }
144}
145
146impl<S, C> AuthorizationService<S, C>
147where
148 C: Clone + DeserializeOwned + Send + Sync,
149{
150 pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationService<S, C> {
154 Self { inner, auths }
155 }
156}
157
158impl<S, C, B> Service<Request<B>> for AuthorizationService<S, C>
159where
160 B: Send + 'static,
161 S: Service<Request<B>> + Clone,
162 S::Response: From<AuthError>,
163 C: Clone + DeserializeOwned + Send + Sync + 'static,
164{
165 type Response = S::Response;
166 type Error = S::Error;
167 type Future = ResponseFuture<S, C, B>;
168
169 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170 self.inner.poll_ready(cx)
171 }
172
173 fn call(&mut self, req: Request<B>) -> Self::Future {
174 let inner = self.inner.clone();
175 let inner = std::mem::replace(&mut self.inner, inner);
177
178 let auth_fut = self.authorize(req);
179
180 ResponseFuture {
181 state: State::Authorize { auth_fut },
182 service: inner,
183 }
184 }
185}
186
187#[pin_project]
188pub struct ResponseFuture<S, C, B>
190where
191 B: Send + 'static,
192 S: Service<Request<B>>,
193 C: Clone + DeserializeOwned + Send + Sync + 'static,
194{
195 #[pin]
196 state: State<<AuthorizationService<S, C> as Authorize<B>>::Future, S::Future>,
197 service: S,
198}
199
200#[pin_project(project = StateProj)]
201enum State<A, SFut> {
202 Authorize {
203 #[pin]
204 auth_fut: A,
205 },
206 Authorized {
207 #[pin]
208 svc_fut: SFut,
209 },
210}
211
212impl<S, C, B> Future for ResponseFuture<S, C, B>
213where
214 B: Send,
215 S: Service<Request<B>>,
216 S::Response: From<AuthError>,
217 C: Clone + DeserializeOwned + Send + Sync,
218{
219 type Output = Result<S::Response, S::Error>;
220
221 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222 let mut this = self.project();
223
224 loop {
225 match this.state.as_mut().project() {
226 StateProj::Authorize { auth_fut } => {
227 let auth = ready!(auth_fut.poll(cx));
228 match auth {
229 Ok(req) => {
230 let svc_fut = this.service.call(req);
231 this.state.set(State::Authorized { svc_fut })
232 }
233 Err(res) => {
234 tracing::info!("err: {:?}", res);
235 return Poll::Ready(Ok(res.into()));
236 }
237 };
238 }
239 StateProj::Authorized { svc_fut } => {
240 return svc_fut.poll(cx);
241 }
242 }
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use crate::{authorizer::Authorizer, IntoLayer, JwtAuthorizer, RegisteredClaims};
250
251 use super::AuthorizationLayer;
252
253 #[tokio::test]
254 async fn auth_into_layer() {
255 let auth1: Authorizer = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
256 let layer = auth1.into_layer();
257 assert_eq!(1, layer.auths.len());
258 }
259
260 #[tokio::test]
261 async fn auths_into_layer() {
262 let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
263 let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
264
265 let layer: AuthorizationLayer<RegisteredClaims> = [auth1, auth2].into_layer();
266 assert_eq!(2, layer.auths.len());
267 }
268
269 #[tokio::test]
270 async fn vec_auths_into_layer() {
271 let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
272 let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
273
274 let layer: AuthorizationLayer<RegisteredClaims> = vec![auth1, auth2].into_layer();
275 assert_eq!(2, layer.auths.len());
276 }
277
278 #[tokio::test]
279 async fn jwt_auth_to_layer() {
280 let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
281 #[allow(deprecated)]
282 let layer = auth1.layer().await.unwrap();
283 assert_eq!(1, layer.auths.len());
284 }
285}