Skip to main content

axum_jwt/
layer.rs

1//! Middleware types and traits.
2//!
3//! If you just need to access token data in a handler, use
4//! the [`Token`] and [`Claims`](crate::Claims) extractors directly.
5//!
6//! The [`layer`] function creates a configurable axum [middleware] layer.
7//! When a request is made to a handler wrapped in this layer, the
8//! JSON Web Token is validated. If validation succeeds, the handler is called.
9//! If validation fails, a `401 Unauthorized` status code is returned, though
10//! more fine-grained [configuration] is possible.
11//!
12//! [middleware]: https://docs.rs/axum/latest/axum/middleware/index.html
13//! [configuration]: #configuration
14//!
15//! # Examples
16//!
17//! ```
18//! use {
19//!     axum::{Router, routing},
20//!     axum_jwt::{Decoder, jsonwebtoken::DecodingKey},
21//! };
22//!
23//! // This handler will be called only if the token is successfully validated.
24//! async fn hello() -> String {
25//!     "Hello, Anonimus!".to_owned()
26//! }
27//!
28//! # async fn f() -> std::io::Result<()> {
29//! let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
30//!
31//! let app = Router::new()
32//!     .route("/", routing::get(hello))
33//!     .layer(axum_jwt::layer(decoder));
34//!
35//! # use tokio::net::TcpListener;
36//! let listener = TcpListener::bind("0.0.0.0:3000").await?;
37//! axum::serve(listener, app).await?;
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! # Configuration
43//!
44//! The [`layer`] function accepts a [decoder](Decoder) that defines how to
45//! decode and validate the token.
46//!
47//! Additionally, the layer itself can be
48//! configured: set a [filter](JwtLayer::with_filter) to define the token's
49//! data type and perform extra checks, store the token in
50//! [extensions](JwtLayer::store_to_extension) so it can later be retrieved in
51//! the handler via an extractor, or specify a custom
52//! method of [extracting](JwtLayer::with_extract) the token from the request.
53
54use {
55    crate::{
56        decode::Decoder,
57        error::Error,
58        extract::{Bearer, Extract, Token},
59    },
60    axum_core::{
61        extract::Request,
62        response::{IntoResponse, Response},
63    },
64    http::{Extensions, StatusCode},
65    jsonwebtoken::TokenData,
66    serde::de::{DeserializeOwned, IgnoredAny},
67    std::{
68        any,
69        convert::Infallible,
70        fmt,
71        marker::PhantomData,
72        mem,
73        pin::Pin,
74        task::{self, Context, Poll},
75    },
76    tower_layer::Layer,
77    tower_service::Service,
78};
79
80/// Creates a [layer](JwtLayer) for middleware.
81///
82/// # Examples
83///
84/// ```
85/// use {
86///     axum::{Router, routing},
87///     axum_jwt::{Decoder, jsonwebtoken::DecodingKey},
88/// };
89///
90/// // This handler will be called only if the token is successfully validated.
91/// async fn hello() -> String {
92///     "Hello, Anonimus!".to_owned()
93/// }
94///
95/// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
96///
97/// let app = Router::new()
98///     .route("/", routing::get(hello))
99///     .layer(axum_jwt::layer(decoder));
100/// # let _: Router = app;
101/// ```
102pub fn layer(decoder: Decoder) -> JwtLayer {
103    JwtLayer {
104        decoder,
105        validate: Discard,
106        store: |_, _| {},
107        extract: PhantomData,
108    }
109}
110
111/// Layer type for creating middleware.
112///
113/// To configure the layer and create the middleware service, call
114/// the [`layer`] function.
115pub struct JwtLayer<I = IgnoredAny, H = Discard, X = Bearer> {
116    decoder: Decoder,
117    validate: H,
118    store: fn(Token<I>, &mut Extensions),
119    extract: PhantomData<fn() -> X>,
120}
121
122impl<I, X> JwtLayer<I, Discard, X> {
123    /// Sets a filter for additional validation.
124    ///
125    /// By default, the layer only validates the token header, ignoring all
126    /// its claims. This method allows you to specify an arbitrary data type
127    /// for the claims and perform additional token checks.
128    ///
129    /// The claims type must implement [`Deserialize`](serde::Deserialize).
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// use {
135    ///     axum::{Router, routing},
136    ///     axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
137    ///     serde::Deserialize,
138    /// };
139    ///
140    /// #[derive(Deserialize)]
141    /// struct User {
142    ///     roles: Vec<String>,
143    /// }
144    ///
145    /// // Checks that the user's token contains the admin role.
146    /// fn check_access(t: &Token<User>) -> bool {
147    ///     t.claims.roles.iter().any(|role| role == "admin")
148    /// }
149    ///
150    /// // Called only if the role check is successful.
151    /// async fn hello() -> String {
152    ///     "Hello, Admin!".to_owned()
153    /// }
154    ///
155    /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
156    ///
157    /// let app = Router::new()
158    ///     .route("/", routing::get(hello))
159    ///     .layer(axum_jwt::layer(decoder).with_filter(check_access));
160    /// # let _: Router = app;
161    /// ```
162    ///
163    /// If you also need to use the token inside the handler,
164    /// see the [`store_to_extension`](JwtLayer::store_to_extension) method.
165    ///
166    /// # Callback return value
167    ///
168    /// The return value of the provided callback can be:
169    ///
170    /// * `bool`: where `true` means validation succeeded, and `false` means
171    ///   it failed returning an HTTP status code `401 Unauthorized`.
172    /// * `Result<(), E>`: where `Ok(())` means validation succeeded
173    ///   and `Err(e)` means it failed. The error type must implement
174    ///   [`IntoResponse`], which will be called on failure to return the
175    ///   corresponding response.
176    pub fn with_filter<H, N, O>(self, validate: H) -> JwtLayer<N, H, X>
177    where
178        H: FnMut(&Token<N>) -> O,
179        N: DeserializeOwned,
180        O: Output,
181    {
182        JwtLayer {
183            decoder: self.decoder,
184            validate,
185            store: |_, _| {},
186            extract: PhantomData,
187        }
188    }
189}
190
191impl<I, H, X> JwtLayer<I, H, X> {
192    /// Configures the layer to store the token in the [extension].
193    ///
194    /// [extension]: https://docs.rs/axum/latest/axum/struct.Extension.html
195    ///
196    /// If you just need to access token data in a handler, use
197    /// the [`Token`] and [`Claims`](crate::Claims) extractors directly.
198    /// This function is useful only if you want to use the middleware
199    /// but still access token data in some handlers.
200    ///
201    /// After calling this method, the middleware will store the parsed token
202    /// in the extension, which can later be retrieved, for example,
203    /// in a handler.
204    ///
205    /// The token is stored only after *successful validation*, including
206    /// the configured [filter](JwtLayer::with_filter). This is usually what
207    /// you want, but if you reuse the handler elsewhere, keep in mind that
208    /// extracting it from [`Extension`] may fail.
209    ///
210    /// [`Extension`]: https://docs.rs/axum/latest/axum/struct.Extension.html
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use {
216    ///     axum::{Extension, Router, routing},
217    ///     axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
218    ///     serde::Deserialize,
219    /// };
220    ///
221    /// // To store a value in the extension, it must implement `Clone`.
222    /// #[derive(Clone, Deserialize)]
223    /// struct User {
224    ///     sub: String,
225    ///     roles: Vec<String>,
226    /// }
227    ///
228    /// // Checks that the user's token contains the admin role.
229    /// fn check_access(t: &Token<User>) -> bool {
230    ///     t.claims.roles.iter().any(|role| role == "admin")
231    /// }
232    ///
233    /// // Called only if the role check is successful.
234    /// async fn hello(Extension(t): Extension<Token<User>>) -> String {
235    ///     // We can also access the parsed token
236    ///     format!("Hello, {}!", t.claims.sub)
237    /// }
238    ///
239    /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
240    ///
241    /// let app = Router::new()
242    ///     .route("/", routing::get(hello))
243    ///     .layer(
244    ///         axum_jwt::layer(decoder)
245    ///             .with_filter(check_access)
246    ///             .store_to_extension(),
247    ///     );
248    /// # let _: Router = app;
249    /// ```
250    ///
251    /// <section class="warning">
252    ///
253    /// Note that the `store_to_extension` call comes after setting
254    /// the `with_filter` filter.
255    ///
256    /// This is important because
257    /// `store_to_extension` applies to the current token type.
258    /// The `with_filter` call may change the type and therefore it always
259    /// resets the `store_to_extension` configuration.
260    ///
261    /// ```
262    /// # use {
263    /// #     axum::{Router, routing},
264    /// #     axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
265    /// # };
266    /// # fn check_access(t: &Token) -> bool { true }
267    /// # async fn hello() {}
268    /// # let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
269    /// let app = Router::new()
270    ///     .route("/", routing::get(hello))
271    ///     .layer(
272    ///         axum_jwt::layer(decoder)
273    ///             // Incorrect order!
274    ///             .store_to_extension()
275    ///             .with_filter(check_access),
276    ///     );
277    /// # let _: Router = app;
278    /// ```
279    ///
280    /// </section>
281    ///
282    /// # Read header only
283    ///
284    /// If you just need to retrieve the token without any additional payload,
285    /// omit the `with_filter` call and use the `Token` type without
286    /// a parameter:
287    ///
288    /// ```
289    /// use {
290    ///     axum::{Extension, Router, routing},
291    ///     axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
292    /// };
293    ///
294    /// async fn hello(Extension(t): Extension<Token>) -> String {
295    ///     // Access the parsed token
296    ///     format!("Decoded with {:?} algorithm", t.header.alg)
297    /// }
298    ///
299    /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
300    ///
301    /// let app = Router::new()
302    ///     .route("/", routing::get(hello))
303    ///     .layer(axum_jwt::layer(decoder).store_to_extension());
304    /// # let _: Router = app;
305    /// ```
306    pub fn store_to_extension(mut self) -> Self
307    where
308        I: Clone + Send + Sync + 'static,
309    {
310        self.store = |claims, extensions| {
311            extensions.insert(claims);
312        };
313
314        self
315    }
316}
317
318impl<I, H> JwtLayer<I, H> {
319    /// Applies a token extractor to the layer.
320    ///
321    /// By default, the token is extracted from the `Authorization` header using
322    /// the `Bearer` scheme. If you want to change this behavior, create a new type
323    /// and implement [`Extract`] for it. Then, you can pass this type into the
324    /// layer configuration:
325    ///
326    /// ```
327    /// use {
328    ///     axum::{Extension, Router, http::request::Parts, routing},
329    ///     axum_jwt::{Decoder, Extract, Token, jsonwebtoken::DecodingKey},
330    /// };
331    ///
332    /// struct Custom;
333    ///
334    /// impl Extract for Custom {
335    ///     fn extract(parts: &mut Parts) -> Option<&str> {
336    ///         parts.headers.get("X-Auth-Token")?.to_str().ok()
337    ///     }
338    /// }
339    ///
340    /// async fn hello(Extension(t): Extension<Token>) -> String {
341    ///     format!("Decoded with {:?} algorithm", t.header.alg)
342    /// }
343    ///
344    /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
345    ///
346    /// let app = Router::new()
347    ///     .route("/", routing::get(hello))
348    ///     .layer(axum_jwt::layer(decoder).with_extract(Custom));
349    /// # let _: Router = app;
350    /// ```
351    pub fn with_extract<X>(self, extract: X) -> JwtLayer<I, H, X>
352    where
353        X: Extract,
354    {
355        _ = extract;
356        JwtLayer {
357            decoder: self.decoder,
358            validate: self.validate,
359            store: self.store,
360            extract: PhantomData,
361        }
362    }
363}
364
365impl<I, H, X> Clone for JwtLayer<I, H, X>
366where
367    H: Clone,
368{
369    fn clone(&self) -> Self {
370        Self {
371            decoder: self.decoder.clone(),
372            validate: self.validate.clone(),
373            store: self.store,
374            extract: PhantomData,
375        }
376    }
377}
378
379impl<I, H, X> fmt::Debug for JwtLayer<I, H, X> {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        f.debug_struct("JwtLayer")
382            .field("decoder", &self.decoder)
383            .field("validate", &"..")
384            .field("store", &"..")
385            .field("extract", &any::type_name::<H>())
386            .finish()
387    }
388}
389
390impl<S, I, H, X> Layer<S> for JwtLayer<I, H, X>
391where
392    H: Clone,
393{
394    type Service = Jwt<S, I, H, X>;
395
396    fn layer(&self, svc: S) -> Self::Service {
397        Jwt {
398            svc,
399            decoder: self.decoder.clone(),
400            validate: self.validate.clone(),
401            store: self.store,
402            extract: PhantomData,
403        }
404    }
405}
406
407/// Trait for additional token validation.
408pub trait Validate<I> {
409    type Output: Output;
410    fn validate(&mut self, input: &Token<I>) -> Self::Output;
411}
412
413/// The output value of the [validation](Validate).
414pub trait Output {
415    fn output(self) -> Option<Response>;
416}
417
418impl<E> Output for Result<(), E>
419where
420    E: IntoResponse,
421{
422    fn output(self) -> Option<Response> {
423        self.err().map(E::into_response)
424    }
425}
426
427impl Output for bool {
428    fn output(self) -> Option<Response> {
429        if self {
430            None
431        } else {
432            Some(StatusCode::UNAUTHORIZED.into_response())
433        }
434    }
435}
436
437/// Discards any token data and returns success.
438#[derive(Clone)]
439pub struct Discard;
440
441impl<I> Validate<I> for Discard {
442    type Output = bool;
443
444    fn validate(&mut self, _: &Token<I>) -> Self::Output {
445        true
446    }
447}
448
449impl<F, I, O> Validate<I> for F
450where
451    F: FnMut(&Token<I>) -> O,
452    I: DeserializeOwned,
453    O: Output,
454{
455    type Output = O;
456
457    fn validate(&mut self, input: &Token<I>) -> Self::Output {
458        self(input)
459    }
460}
461
462/// Axum [middleware] for token validation.
463///
464/// [middleware]: https://docs.rs/axum/latest/axum/middleware/index.html
465///
466/// To configure the layer and create the middleware service, call
467/// the [`layer`] function.
468pub struct Jwt<S, I, H = Discard, X = Bearer> {
469    svc: S,
470    decoder: Decoder,
471    validate: H,
472    store: fn(Token<I>, &mut Extensions),
473    extract: PhantomData<fn() -> X>,
474}
475
476impl<S, I, H, X> Clone for Jwt<S, I, H, X>
477where
478    S: Clone,
479    H: Clone,
480{
481    fn clone(&self) -> Self {
482        Self {
483            svc: self.svc.clone(),
484            decoder: self.decoder.clone(),
485            validate: self.validate.clone(),
486            store: self.store,
487            extract: PhantomData,
488        }
489    }
490}
491
492impl<S, I, H, X> fmt::Debug for Jwt<S, I, H, X>
493where
494    S: fmt::Debug,
495{
496    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
497        f.debug_struct("Jwt")
498            .field("svc", &self.svc)
499            .field("decoder", &self.decoder)
500            .field("validate", &"..")
501            .field("store", &"..")
502            .field("extract", &any::type_name::<X>())
503            .finish()
504    }
505}
506
507impl<S, I, H, X> Service<Request> for Jwt<S, I, H, X>
508where
509    S: Service<Request> + Clone,
510    I: DeserializeOwned,
511    H: Validate<I>,
512    X: Extract,
513    Result<S::Response, S::Error>: IntoResponse,
514{
515    type Response = Response;
516    type Error = Infallible;
517    type Future = JwtFuture<S>;
518
519    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
520        Poll::Ready(Ok(()))
521    }
522
523    fn call(&mut self, req: Request) -> Self::Future {
524        let validate = |parts| -> Result<Token<I>, Error> {
525            let token = X::extract(parts).ok_or(Error::Extract)?;
526            let TokenData { header, claims }: TokenData<I> =
527                self.decoder.decode(token).map_err(Error::Jwt)?;
528
529            Ok(Token::new(header, claims))
530        };
531
532        let (mut parts, body) = req.into_parts();
533        match validate(&mut parts) {
534            Ok(token) => {
535                if let Some(res) = self.validate.validate(&token).output() {
536                    return JwtFuture::ready(res);
537                }
538
539                (self.store)(token, &mut parts.extensions);
540
541                let req = Request::from_parts(parts, body);
542                let clone = self.svc.clone();
543                let svc = mem::replace(&mut self.svc, clone);
544                JwtFuture::not_ready(svc, req)
545            }
546            Err(e) => JwtFuture::ready(e.into_response()),
547        }
548    }
549}
550
551pin_project_lite::pin_project! {
552    /// Middleware future.
553    pub struct JwtFuture<S>
554    where
555        S: Service<Request>,
556    {
557        #[pin]
558        state: State<S, S::Future>,
559    }
560}
561
562impl<S> JwtFuture<S>
563where
564    S: Service<Request>,
565{
566    fn not_ready(svc: S, req: Request) -> Self {
567        Self {
568            state: State::NotReady { svc, req },
569        }
570    }
571
572    fn ready(res: Response) -> Self {
573        Self {
574            state: State::Ready { res },
575        }
576    }
577}
578
579impl<S> Future for JwtFuture<S>
580where
581    S: Service<Request>,
582    Result<S::Response, S::Error>: IntoResponse,
583{
584    type Output = Result<Response, Infallible>;
585
586    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
587        let mut state = self.project().state;
588        let res = loop {
589            match state.as_mut().project() {
590                StateProj::NotReady { svc, req } => {
591                    if let Err(e) = task::ready!(svc.poll_ready(cx)) {
592                        state.set(State::Done);
593                        break Err(e).into_response();
594                    }
595
596                    let req = mem::take(req);
597                    let fut = svc.call(req);
598                    state.set(State::Called { fut });
599                }
600                StateProj::Called { fut } => {
601                    let res = task::ready!(fut.poll(cx));
602                    state.set(State::Done);
603                    break res.into_response();
604                }
605                StateProj::Ready { res } => {
606                    let res = mem::take(res);
607                    state.set(State::Done);
608                    break res;
609                }
610                StateProj::Done => panic!("polled after completion"),
611            }
612        };
613
614        Poll::Ready(Ok(res))
615    }
616}
617
618pin_project_lite::pin_project! {
619    #[project = StateProj]
620    enum State<S, F> {
621        NotReady { svc: S, req: Request },
622        Called {
623            #[pin]
624            fut: F,
625        },
626        Ready { res: Response },
627        Done,
628    }
629}