tower_jwt/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{convert::Infallible, future::Future, marker::PhantomData, pin::Pin, task::Poll};
4
5use async_trait::async_trait;
6use headers::{authorization::Bearer, Authorization, HeaderMapExt};
7use http::{Request, Response, StatusCode};
8use jsonwebtoken::decode;
9pub use jsonwebtoken::{DecodingKey, Validation};
10use pin_project::pin_project;
11use serde::Deserialize;
12use tower::{Layer, Service};
13use tracing::{error, trace};
14
15/// Trait to get a decoding key asynchronously
16#[async_trait]
17pub trait DecodingKeyFn: Send + Sync + Clone {
18    type Error: std::error::Error + Send;
19
20    async fn decoding_key(&self) -> Result<DecodingKey, Self::Error>;
21}
22
23#[async_trait]
24impl<F, O> DecodingKeyFn for F
25where
26    F: Fn() -> O + Sync + Send + Clone,
27    O: Future<Output = DecodingKey> + Send,
28{
29    type Error = Infallible;
30
31    async fn decoding_key(&self) -> Result<DecodingKey, Self::Error> {
32        Ok((self)().await)
33    }
34}
35
36#[async_trait]
37impl DecodingKeyFn for DecodingKey {
38    type Error = Infallible;
39
40    async fn decoding_key(&self) -> Result<DecodingKey, Self::Error> {
41        Ok(self.clone())
42    }
43}
44
45/// Layer to validate JWT tokens with a decoding key. Valid claims are added to the request extension
46///
47/// It can also be used with tonic. See:
48/// https://github.com/hyperium/tonic/blob/master/examples/src/tower/server.rs
49#[derive(Clone)]
50pub struct JwtLayer<Claim, F = DecodingKey> {
51    /// User provided function to get the decoding key from
52    decoding_key_fn: F,
53    /// The validation to apply when parsing the token
54    validation: Validation,
55    _phantom: PhantomData<Claim>,
56}
57
58impl<Claim, F: DecodingKeyFn> JwtLayer<Claim, F> {
59    /// Create a new layer to validate JWT tokens with the given decoding key
60    /// Tokens will only be accepted if they pass the validation
61    pub fn new(validation: Validation, decoding_key_fn: F) -> Self {
62        Self {
63            decoding_key_fn,
64            validation,
65            _phantom: PhantomData,
66        }
67    }
68}
69
70impl<S, Claim, F: DecodingKeyFn> Layer<S> for JwtLayer<Claim, F> {
71    type Service = Jwt<S, Claim, F>;
72
73    fn layer(&self, inner: S) -> Self::Service {
74        Jwt {
75            inner,
76            decoding_key_fn: self.decoding_key_fn.clone(),
77            validation: Box::new(self.validation.clone()),
78            _phantom: self._phantom,
79        }
80    }
81}
82
83/// Middleware for validating a valid JWT token is present on "authorization: bearer <token>"
84#[derive(Clone)]
85pub struct Jwt<S, Claim, F> {
86    inner: S,
87    decoding_key_fn: F,
88    // Using a Box here to reduce cloning it the whole time
89    validation: Box<Validation>,
90    _phantom: PhantomData<Claim>,
91}
92
93type AsyncTraitFuture<A> = Pin<Box<dyn Future<Output = A> + Send>>;
94
95#[pin_project(project = JwtFutureProj, project_replace = JwtFutureProjOwn)]
96pub enum JwtFuture<
97    DecKeyFn: DecodingKeyFn,
98    TService: Service<Request<ReqBody>, Response = Response<ResBody>>,
99    ReqBody,
100    ResBody,
101    Claim,
102> {
103    // If there was an error return a BAD_REQUEST.
104    Error,
105
106    // We are ready to call the inner service.
107    WaitForFuture {
108        #[pin]
109        future: TService::Future,
110    },
111
112    // We have a token and need to run our logic.
113    HasTokenWaitingForDecodingKey {
114        bearer: Authorization<Bearer>,
115        request: Request<ReqBody>,
116        #[pin]
117        decoding_key_future: AsyncTraitFuture<Result<DecodingKey, DecKeyFn::Error>>,
118        validation: Box<Validation>,
119        service: TService,
120        _phantom: PhantomData<Claim>,
121    },
122}
123
124impl<DecKeyFn, TService, ReqBody, ResBody, Claim> Future
125    for JwtFuture<DecKeyFn, TService, ReqBody, ResBody, Claim>
126where
127    DecKeyFn: DecodingKeyFn + 'static,
128    TService: Service<Request<ReqBody>, Response = Response<ResBody>>,
129    ResBody: Default,
130    for<'de> Claim: Deserialize<'de> + Send + Sync + Clone + 'static,
131{
132    type Output = Result<TService::Response, TService::Error>;
133
134    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
135        match self.as_mut().project() {
136            JwtFutureProj::Error => {
137                let response = Response::builder()
138                    .status(StatusCode::BAD_REQUEST)
139                    .body(Default::default())
140                    .unwrap();
141                Poll::Ready(Ok(response))
142            }
143            JwtFutureProj::WaitForFuture { future } => future.poll(cx),
144            JwtFutureProj::HasTokenWaitingForDecodingKey {
145                bearer,
146                decoding_key_future,
147                validation,
148                ..
149            } => match decoding_key_future.poll(cx) {
150                Poll::Pending => Poll::Pending,
151                Poll::Ready(Err(error)) => {
152                    error!(
153                        error = &error as &dyn std::error::Error,
154                        "failed to get decoding key"
155                    );
156                    let response = Response::builder()
157                        .status(StatusCode::SERVICE_UNAVAILABLE)
158                        .body(Default::default())
159                        .unwrap();
160
161                    Poll::Ready(Ok(response))
162                }
163                Poll::Ready(Ok(decoding_key)) => {
164                    let claim_result = RequestClaim::<Claim>::from_token(
165                        bearer.token().trim(),
166                        &decoding_key,
167                        validation,
168                    );
169                    match claim_result {
170                        Err(code) => {
171                            error!(code = %code, "failed to decode JWT");
172
173                            let response = Response::builder()
174                                .status(code)
175                                .body(Default::default())
176                                .unwrap();
177
178                            Poll::Ready(Ok(response))
179                        }
180                        Ok(claim) => {
181                            let owned = self.as_mut().project_replace(JwtFuture::Error);
182                            match owned {
183                                    JwtFutureProjOwn::HasTokenWaitingForDecodingKey {
184                                        mut request, mut service, ..
185                                    } => {
186                                        request.extensions_mut().insert(claim);
187                                        let future = service.call(request);
188                                        self.as_mut().set(JwtFuture::WaitForFuture { future });
189                                        self.poll(cx)
190                                    },
191                                    _ => unreachable!("We know that we're in the 'HasTokenWaitingForDecodingKey' state"),
192                                }
193                        }
194                    }
195                }
196            },
197        }
198    }
199}
200
201impl<S, ReqBody, ResBody, Claim, F> Service<Request<ReqBody>> for Jwt<S, Claim, F>
202where
203    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Send + Clone + 'static,
204    S::Future: Send + 'static,
205    ResBody: Default,
206    F: DecodingKeyFn + 'static,
207    <F as DecodingKeyFn>::Error: 'static,
208    for<'de> Claim: Deserialize<'de> + Send + Sync + Clone + 'static,
209{
210    type Response = S::Response;
211    type Error = S::Error;
212    type Future = JwtFuture<F, S, ReqBody, ResBody, Claim>;
213
214    fn poll_ready(
215        &mut self,
216        cx: &mut std::task::Context<'_>,
217    ) -> std::task::Poll<Result<(), Self::Error>> {
218        self.inner.poll_ready(cx)
219    }
220
221    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
222        match req.headers().typed_try_get::<Authorization<Bearer>>() {
223            Ok(Some(bearer)) => {
224                let decoding_key_fn = self.decoding_key_fn.clone();
225                let decoding_key_future =
226                    Box::pin(async move { decoding_key_fn.decoding_key().await });
227                Self::Future::HasTokenWaitingForDecodingKey {
228                    bearer,
229                    request: req,
230                    decoding_key_future,
231                    validation: self.validation.clone(),
232                    service: self.inner.clone(),
233                    _phantom: self._phantom,
234                }
235            }
236            Ok(None) => {
237                let future = self.inner.call(req);
238
239                Self::Future::WaitForFuture { future }
240            }
241            Err(_) => Self::Future::Error,
242        }
243    }
244}
245
246/// Used to hold the validated claim from the JWT token
247#[derive(Clone, Debug)]
248pub struct RequestClaim<T>
249where
250    for<'de> T: Deserialize<'de>,
251{
252    /// The claim from the token
253    pub claim: T,
254
255    /// The original token that was parsed
256    pub token: String,
257}
258
259impl<T> RequestClaim<T>
260where
261    for<'de> T: Deserialize<'de>,
262{
263    pub fn from_token(
264        token: &str,
265        decoding_key: &DecodingKey,
266        validation: &Validation,
267    ) -> Result<Self, StatusCode> {
268        trace!("converting token to claim");
269        let claim: T = decode(token, decoding_key, validation)
270            .map_err(|err| {
271                error!(
272                    error = &err as &dyn std::error::Error,
273                    "failed to convert token to claim"
274                );
275                match err.kind() {
276                    jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
277                        StatusCode::from_u16(499).unwrap() // Expired status code which is safe to unwrap
278                    }
279                    jsonwebtoken::errors::ErrorKind::InvalidSignature
280                    | jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName
281                    | jsonwebtoken::errors::ErrorKind::InvalidIssuer
282                    | jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
283                        StatusCode::UNAUTHORIZED
284                    }
285                    jsonwebtoken::errors::ErrorKind::InvalidToken
286                    | jsonwebtoken::errors::ErrorKind::InvalidAlgorithm
287                    | jsonwebtoken::errors::ErrorKind::Base64(_)
288                    | jsonwebtoken::errors::ErrorKind::Json(_)
289                    | jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST,
290                    jsonwebtoken::errors::ErrorKind::MissingAlgorithm => {
291                        StatusCode::INTERNAL_SERVER_ERROR
292                    }
293                    jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
294                    _ => StatusCode::INTERNAL_SERVER_ERROR,
295                }
296            })?
297            .claims;
298
299        Ok(Self {
300            claim,
301            token: token.to_string(),
302        })
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use axum::{routing::get, Extension, Router};
309    use chrono::{Duration, Utc};
310    use http_body_util::{BodyExt, Empty};
311    use jsonwebtoken::{encode, EncodingKey, Header};
312    use ring::{
313        rand,
314        signature::{self, Ed25519KeyPair, KeyPair},
315    };
316    use serde::Serialize;
317    use std::ops::Add;
318    use tower::ServiceExt;
319
320    use super::*;
321
322    #[derive(Deserialize, Clone, Serialize)]
323    pub struct Claim {
324        /// Expiration time (as UTC timestamp).
325        pub exp: usize,
326        /// Issued at (as UTC timestamp).
327        iat: usize,
328        /// Issuer.
329        iss: String,
330        /// Not Before (as UTC timestamp).
331        nbf: usize,
332        /// Subject (whom token refers to).
333        pub sub: String,
334    }
335
336    impl Claim {
337        /// Create a new claim for a user with the given scopes and limits.
338        pub fn new(sub: String) -> Self {
339            let iat = Utc::now();
340            let exp = iat.add(Duration::try_minutes(5).unwrap());
341
342            Self {
343                exp: exp.timestamp() as usize,
344                iat: iat.timestamp() as usize,
345                iss: "test-issuer".to_string(),
346                nbf: iat.timestamp() as usize,
347                sub,
348            }
349        }
350
351        pub fn into_token(self, encoding_key: &EncodingKey) -> Result<String, StatusCode> {
352            encode(
353                &Header::new(jsonwebtoken::Algorithm::EdDSA),
354                &self,
355                encoding_key,
356            )
357            .map_err(|err| {
358                error!(
359                    error = &err as &dyn std::error::Error,
360                    "failed to convert claim to token"
361                );
362                match err.kind() {
363                    jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR,
364                    jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
365                    _ => StatusCode::INTERNAL_SERVER_ERROR,
366                }
367            })
368        }
369    }
370
371    #[tokio::test]
372    async fn authorization_layer() {
373        let claim = Claim::new("ferries".to_string());
374
375        let doc = signature::Ed25519KeyPair::generate_pkcs8(&rand::SystemRandom::new()).unwrap();
376        let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
377        let pair = Ed25519KeyPair::from_pkcs8(doc.as_ref()).unwrap();
378        let public_key = pair.public_key().as_ref().to_vec();
379        let decoding_key = DecodingKey::from_ed_der(&public_key);
380
381        let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA);
382        validation.set_issuer(&["test-issuer"]);
383
384        let router = Router::new()
385            .route(
386                "/",
387                get(|claim: Option<Extension<RequestClaim<Claim>>>| async move {
388                    if let Some(Extension(claim)) = claim {
389                        (StatusCode::OK, format!("Hello, {}", claim.claim.sub))
390                    } else {
391                        (StatusCode::UNAUTHORIZED, "Not authorized".to_string())
392                    }
393                }),
394            )
395            .layer(JwtLayer::<Claim, _>::new(validation, move || {
396                let decoding_key = decoding_key.clone();
397
398                async { decoding_key }
399            }));
400
401        //////////////////////////////////////////////////////////////////////////
402        // Test token missing
403        //////////////////////////////////////////////////////////////////////////
404        let response = router
405            .clone()
406            .oneshot(
407                http::Request::builder()
408                    .uri("/")
409                    .body(Empty::new())
410                    .unwrap(),
411            )
412            .await
413            .unwrap();
414
415        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
416
417        //////////////////////////////////////////////////////////////////////////
418        // Test bearer missing
419        //////////////////////////////////////////////////////////////////////////
420        let token = claim.clone().into_token(&encoding_key).unwrap();
421        let response = router
422            .clone()
423            .oneshot(
424                http::Request::builder()
425                    .uri("/")
426                    .header("authorization", token.clone())
427                    .body(Empty::new())
428                    .unwrap(),
429            )
430            .await
431            .unwrap();
432
433        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
434
435        //////////////////////////////////////////////////////////////////////////
436        // Test valid
437        //////////////////////////////////////////////////////////////////////////
438        let response = router
439            .clone()
440            .oneshot(
441                http::Request::builder()
442                    .uri("/")
443                    .header("authorization", format!("Bearer {token}"))
444                    .body(Empty::new())
445                    .unwrap(),
446            )
447            .await
448            .unwrap();
449
450        assert_eq!(response.status(), StatusCode::OK);
451
452        //////////////////////////////////////////////////////////////////////////
453        // Test valid extra padding
454        //////////////////////////////////////////////////////////////////////////
455        let response = router
456            .clone()
457            .oneshot(
458                http::Request::builder()
459                    .uri("/")
460                    .header("Authorization", format!("Bearer   {token}   "))
461                    .body(Empty::new())
462                    .unwrap(),
463            )
464            .await
465            .unwrap();
466
467        assert_eq!(response.status(), StatusCode::OK);
468        let body = response.into_body().collect().await.unwrap().to_bytes();
469
470        assert_eq!(&body[..], b"Hello, ferries");
471    }
472}