jwt_authorizer/
layer.rs

1use futures_core::ready;
2use futures_util::future::{self, BoxFuture};
3use http::Request;
4use jsonwebtoken::TokenData;
5use pin_project::pin_project;
6use serde::de::DeserializeOwned;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tower_layer::Layer;
12use tower_service::Service;
13
14use crate::authorizer::Authorizer;
15use crate::AuthError;
16
17/// Trait for authorizing requests.
18pub trait Authorize<B> {
19    type Future: Future<Output = Result<Request<B>, AuthError>>;
20
21    /// Authorize the request.
22    ///
23    /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
24    fn authorize(&self, request: Request<B>) -> Self::Future;
25}
26
27impl<S, B, C> Authorize<B> for AuthorizationService<S, C>
28where
29    B: Send + 'static,
30    C: Clone + DeserializeOwned + Send + Sync + 'static,
31{
32    type Future = BoxFuture<'static, Result<Request<B>, AuthError>>;
33
34    /// The authorizers are sequentially applied (check_auth) until one of them validates the token.
35    /// If no authorizer validates the token the request is rejected.
36    ///
37    fn authorize(&self, mut request: Request<B>) -> Self::Future {
38        let tkns_auths: Vec<(String, Arc<Authorizer<C>>)> = self
39            .auths
40            .iter()
41            .filter_map(|a| a.extract_token(request.headers()).map(|t| (t, a.clone())))
42            .collect();
43
44        if tkns_auths.is_empty() {
45            return Box::pin(future::ready(Err(AuthError::MissingToken())));
46        }
47
48        Box::pin(async move {
49            let mut token_data: Result<TokenData<C>, AuthError> = Err(AuthError::NoAuthorizer());
50            for (token, auth) in tkns_auths {
51                token_data = auth.check_auth(token.as_str()).await;
52                if token_data.is_ok() {
53                    break;
54                }
55            }
56            match token_data {
57                Ok(tdata) => {
58                    // Set `token_data` as a request extension so it can be accessed by other
59                    // services down the stack.
60
61                    request.extensions_mut().insert(tdata);
62
63                    Ok(request)
64                }
65                Err(err) => Err(err), // TODO: error containing all errors (not just the last one) or to choose one?
66            }
67        })
68    }
69}
70
71// -------------- Layer -----------------
72
73#[derive(Clone)]
74pub struct AuthorizationLayer<C>
75where
76    C: Clone + DeserializeOwned + Send,
77{
78    auths: Vec<Arc<Authorizer<C>>>,
79}
80
81impl<C> AuthorizationLayer<C>
82where
83    C: Clone + DeserializeOwned + Send,
84{
85    pub fn new(auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationLayer<C> {
86        Self { auths }
87    }
88}
89
90impl<S, C> Layer<S> for AuthorizationLayer<C>
91where
92    C: Clone + DeserializeOwned + Send + Sync,
93{
94    type Service = AuthorizationService<S, C>;
95
96    fn layer(&self, inner: S) -> Self::Service {
97        AuthorizationService::new(inner, self.auths.clone())
98    }
99}
100
101// ----------  AuthorizationService  --------
102
103/// Source of the bearer token
104#[derive(Clone)]
105pub enum JwtSource {
106    /// Storing the bearer token in Authorization header
107    ///
108    /// (default)
109    AuthorizationHeader,
110    /// Cookies
111    ///
112    /// (be careful when using cookies, some precautions must be taken, cf. RFC6750)
113    Cookie(String),
114    // TODO: "Form-Encoded Content Parameter" may be added in the future (OAuth 2.1 / 5.2.1.2)
115    // FormParam,
116}
117
118#[derive(Clone)]
119pub struct AuthorizationService<S, C>
120where
121    C: Clone + DeserializeOwned + Send,
122{
123    pub inner: S,
124    pub auths: Vec<Arc<Authorizer<C>>>,
125}
126
127impl<S, C> AuthorizationService<S, C>
128where
129    C: Clone + DeserializeOwned + Send,
130{
131    pub fn get_ref(&self) -> &S {
132        &self.inner
133    }
134
135    /// Gets a mutable reference to the underlying service.
136    pub fn get_mut(&mut self) -> &mut S {
137        &mut self.inner
138    }
139
140    /// Consumes `self`, returning the underlying service.
141    pub fn into_inner(self) -> S {
142        self.inner
143    }
144}
145
146impl<S, C> AuthorizationService<S, C>
147where
148    C: Clone + DeserializeOwned + Send + Sync,
149{
150    /// Authorize requests using a custom scheme.
151    ///
152    /// The `Authorization` header is required to have the value provided.
153    pub fn new(inner: S, auths: Vec<Arc<Authorizer<C>>>) -> AuthorizationService<S, C> {
154        Self { inner, auths }
155    }
156}
157
158impl<S, C, B> Service<Request<B>> for AuthorizationService<S, C>
159where
160    B: Send + 'static,
161    S: Service<Request<B>> + Clone,
162    S::Response: From<AuthError>,
163    C: Clone + DeserializeOwned + Send + Sync + 'static,
164{
165    type Response = S::Response;
166    type Error = S::Error;
167    type Future = ResponseFuture<S, C, B>;
168
169    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170        self.inner.poll_ready(cx)
171    }
172
173    fn call(&mut self, req: Request<B>) -> Self::Future {
174        let inner = self.inner.clone();
175        // take the service that was ready
176        let inner = std::mem::replace(&mut self.inner, inner);
177
178        let auth_fut = self.authorize(req);
179
180        ResponseFuture {
181            state: State::Authorize { auth_fut },
182            service: inner,
183        }
184    }
185}
186
187#[pin_project]
188/// Response future for [`AuthorizationService`].
189pub struct ResponseFuture<S, C, B>
190where
191    B: Send + 'static,
192    S: Service<Request<B>>,
193    C: Clone + DeserializeOwned + Send + Sync + 'static,
194{
195    #[pin]
196    state: State<<AuthorizationService<S, C> as Authorize<B>>::Future, S::Future>,
197    service: S,
198}
199
200#[pin_project(project = StateProj)]
201enum State<A, SFut> {
202    Authorize {
203        #[pin]
204        auth_fut: A,
205    },
206    Authorized {
207        #[pin]
208        svc_fut: SFut,
209    },
210}
211
212impl<S, C, B> Future for ResponseFuture<S, C, B>
213where
214    B: Send,
215    S: Service<Request<B>>,
216    S::Response: From<AuthError>,
217    C: Clone + DeserializeOwned + Send + Sync,
218{
219    type Output = Result<S::Response, S::Error>;
220
221    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222        let mut this = self.project();
223
224        loop {
225            match this.state.as_mut().project() {
226                StateProj::Authorize { auth_fut } => {
227                    let auth = ready!(auth_fut.poll(cx));
228                    match auth {
229                        Ok(req) => {
230                            let svc_fut = this.service.call(req);
231                            this.state.set(State::Authorized { svc_fut })
232                        }
233                        Err(res) => {
234                            tracing::info!("err: {:?}", res);
235                            return Poll::Ready(Ok(res.into()));
236                        }
237                    };
238                }
239                StateProj::Authorized { svc_fut } => {
240                    return svc_fut.poll(cx);
241                }
242            }
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use crate::{authorizer::Authorizer, IntoLayer, JwtAuthorizer, RegisteredClaims};
250
251    use super::AuthorizationLayer;
252
253    #[tokio::test]
254    async fn auth_into_layer() {
255        let auth1: Authorizer = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
256        let layer = auth1.into_layer();
257        assert_eq!(1, layer.auths.len());
258    }
259
260    #[tokio::test]
261    async fn auths_into_layer() {
262        let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
263        let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
264
265        let layer: AuthorizationLayer<RegisteredClaims> = [auth1, auth2].into_layer();
266        assert_eq!(2, layer.auths.len());
267    }
268
269    #[tokio::test]
270    async fn vec_auths_into_layer() {
271        let auth1 = JwtAuthorizer::from_secret("aaa").build().await.unwrap();
272        let auth2 = JwtAuthorizer::from_secret("bbb").build().await.unwrap();
273
274        let layer: AuthorizationLayer<RegisteredClaims> = vec![auth1, auth2].into_layer();
275        assert_eq!(2, layer.auths.len());
276    }
277
278    #[tokio::test]
279    async fn jwt_auth_to_layer() {
280        let auth1: JwtAuthorizer = JwtAuthorizer::from_secret("aaa");
281        #[allow(deprecated)]
282        let layer = auth1.layer().await.unwrap();
283        assert_eq!(1, layer.auths.len());
284    }
285}