axum_jwt/layer.rs
1//! Middleware types and traits.
2//!
3//! If you just need to access token data in a handler, use
4//! the [`Token`] and [`Claims`](crate::Claims) extractors directly.
5//!
6//! The [`layer`] function creates a configurable axum [middleware] layer.
7//! When a request is made to a handler wrapped in this layer, the
8//! JSON Web Token is validated. If validation succeeds, the handler is called.
9//! If validation fails, a `401 Unauthorized` status code is returned, though
10//! more fine-grained [configuration] is possible.
11//!
12//! [middleware]: https://docs.rs/axum/latest/axum/middleware/index.html
13//! [configuration]: #configuration
14//!
15//! # Examples
16//!
17//! ```
18//! use {
19//! axum::{Router, routing},
20//! axum_jwt::{Decoder, jsonwebtoken::DecodingKey},
21//! };
22//!
23//! // This handler will be called only if the token is successfully validated.
24//! async fn hello() -> String {
25//! "Hello, Anonimus!".to_owned()
26//! }
27//!
28//! # async fn f() -> std::io::Result<()> {
29//! let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
30//!
31//! let app = Router::new()
32//! .route("/", routing::get(hello))
33//! .layer(axum_jwt::layer(decoder));
34//!
35//! # use tokio::net::TcpListener;
36//! let listener = TcpListener::bind("0.0.0.0:3000").await?;
37//! axum::serve(listener, app).await?;
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! # Configuration
43//!
44//! The [`layer`] function accepts a [decoder](Decoder) that defines how to
45//! decode and validate the token.
46//!
47//! Additionally, the layer itself can be
48//! configured: set a [filter](JwtLayer::with_filter) to define the token's
49//! data type and perform extra checks, store the token in
50//! [extensions](JwtLayer::store_to_extension) so it can later be retrieved in
51//! the handler via an extractor, or specify a custom
52//! method of [extracting](JwtLayer::with_extract) the token from the request.
53
54use {
55 crate::{
56 decode::Decoder,
57 error::Error,
58 extract::{Bearer, Extract, Token},
59 },
60 axum_core::{
61 extract::Request,
62 response::{IntoResponse, Response},
63 },
64 http::{Extensions, StatusCode},
65 jsonwebtoken::TokenData,
66 serde::de::{DeserializeOwned, IgnoredAny},
67 std::{
68 any,
69 convert::Infallible,
70 fmt,
71 marker::PhantomData,
72 mem,
73 pin::Pin,
74 task::{self, Context, Poll},
75 },
76 tower_layer::Layer,
77 tower_service::Service,
78};
79
80/// Creates a [layer](JwtLayer) for middleware.
81///
82/// # Examples
83///
84/// ```
85/// use {
86/// axum::{Router, routing},
87/// axum_jwt::{Decoder, jsonwebtoken::DecodingKey},
88/// };
89///
90/// // This handler will be called only if the token is successfully validated.
91/// async fn hello() -> String {
92/// "Hello, Anonimus!".to_owned()
93/// }
94///
95/// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
96///
97/// let app = Router::new()
98/// .route("/", routing::get(hello))
99/// .layer(axum_jwt::layer(decoder));
100/// # let _: Router = app;
101/// ```
102pub fn layer(decoder: Decoder) -> JwtLayer {
103 JwtLayer {
104 decoder,
105 validate: Discard,
106 store: |_, _| {},
107 extract: PhantomData,
108 }
109}
110
111/// Layer type for creating middleware.
112///
113/// To configure the layer and create the middleware service, call
114/// the [`layer`] function.
115pub struct JwtLayer<I = IgnoredAny, H = Discard, X = Bearer> {
116 decoder: Decoder,
117 validate: H,
118 store: fn(Token<I>, &mut Extensions),
119 extract: PhantomData<fn() -> X>,
120}
121
122impl<I, X> JwtLayer<I, Discard, X> {
123 /// Sets a filter for additional validation.
124 ///
125 /// By default, the layer only validates the token header, ignoring all
126 /// its claims. This method allows you to specify an arbitrary data type
127 /// for the claims and perform additional token checks.
128 ///
129 /// The claims type must implement [`Deserialize`](serde::Deserialize).
130 ///
131 /// # Examples
132 ///
133 /// ```
134 /// use {
135 /// axum::{Router, routing},
136 /// axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
137 /// serde::Deserialize,
138 /// };
139 ///
140 /// #[derive(Deserialize)]
141 /// struct User {
142 /// roles: Vec<String>,
143 /// }
144 ///
145 /// // Checks that the user's token contains the admin role.
146 /// fn check_access(t: &Token<User>) -> bool {
147 /// t.claims.roles.iter().any(|role| role == "admin")
148 /// }
149 ///
150 /// // Called only if the role check is successful.
151 /// async fn hello() -> String {
152 /// "Hello, Admin!".to_owned()
153 /// }
154 ///
155 /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
156 ///
157 /// let app = Router::new()
158 /// .route("/", routing::get(hello))
159 /// .layer(axum_jwt::layer(decoder).with_filter(check_access));
160 /// # let _: Router = app;
161 /// ```
162 ///
163 /// If you also need to use the token inside the handler,
164 /// see the [`store_to_extension`](JwtLayer::store_to_extension) method.
165 ///
166 /// # Callback return value
167 ///
168 /// The return value of the provided callback can be:
169 ///
170 /// * `bool`: where `true` means validation succeeded, and `false` means
171 /// it failed returning an HTTP status code `401 Unauthorized`.
172 /// * `Result<(), E>`: where `Ok(())` means validation succeeded
173 /// and `Err(e)` means it failed. The error type must implement
174 /// [`IntoResponse`], which will be called on failure to return the
175 /// corresponding response.
176 pub fn with_filter<H, N, O>(self, validate: H) -> JwtLayer<N, H, X>
177 where
178 H: FnMut(&Token<N>) -> O,
179 N: DeserializeOwned,
180 O: Output,
181 {
182 JwtLayer {
183 decoder: self.decoder,
184 validate,
185 store: |_, _| {},
186 extract: PhantomData,
187 }
188 }
189}
190
191impl<I, H, X> JwtLayer<I, H, X> {
192 /// Configures the layer to store the token in the [extension].
193 ///
194 /// [extension]: https://docs.rs/axum/latest/axum/struct.Extension.html
195 ///
196 /// If you just need to access token data in a handler, use
197 /// the [`Token`] and [`Claims`](crate::Claims) extractors directly.
198 /// This function is useful only if you want to use the middleware
199 /// but still access token data in some handlers.
200 ///
201 /// After calling this method, the middleware will store the parsed token
202 /// in the extension, which can later be retrieved, for example,
203 /// in a handler.
204 ///
205 /// The token is stored only after *successful validation*, including
206 /// the configured [filter](JwtLayer::with_filter). This is usually what
207 /// you want, but if you reuse the handler elsewhere, keep in mind that
208 /// extracting it from [`Extension`] may fail.
209 ///
210 /// [`Extension`]: https://docs.rs/axum/latest/axum/struct.Extension.html
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// use {
216 /// axum::{Extension, Router, routing},
217 /// axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
218 /// serde::Deserialize,
219 /// };
220 ///
221 /// // To store a value in the extension, it must implement `Clone`.
222 /// #[derive(Clone, Deserialize)]
223 /// struct User {
224 /// sub: String,
225 /// roles: Vec<String>,
226 /// }
227 ///
228 /// // Checks that the user's token contains the admin role.
229 /// fn check_access(t: &Token<User>) -> bool {
230 /// t.claims.roles.iter().any(|role| role == "admin")
231 /// }
232 ///
233 /// // Called only if the role check is successful.
234 /// async fn hello(Extension(t): Extension<Token<User>>) -> String {
235 /// // We can also access the parsed token
236 /// format!("Hello, {}!", t.claims.sub)
237 /// }
238 ///
239 /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
240 ///
241 /// let app = Router::new()
242 /// .route("/", routing::get(hello))
243 /// .layer(
244 /// axum_jwt::layer(decoder)
245 /// .with_filter(check_access)
246 /// .store_to_extension(),
247 /// );
248 /// # let _: Router = app;
249 /// ```
250 ///
251 /// <section class="warning">
252 ///
253 /// Note that the `store_to_extension` call comes after setting
254 /// the `with_filter` filter.
255 ///
256 /// This is important because
257 /// `store_to_extension` applies to the current token type.
258 /// The `with_filter` call may change the type and therefore it always
259 /// resets the `store_to_extension` configuration.
260 ///
261 /// ```
262 /// # use {
263 /// # axum::{Router, routing},
264 /// # axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
265 /// # };
266 /// # fn check_access(t: &Token) -> bool { true }
267 /// # async fn hello() {}
268 /// # let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
269 /// let app = Router::new()
270 /// .route("/", routing::get(hello))
271 /// .layer(
272 /// axum_jwt::layer(decoder)
273 /// // Incorrect order!
274 /// .store_to_extension()
275 /// .with_filter(check_access),
276 /// );
277 /// # let _: Router = app;
278 /// ```
279 ///
280 /// </section>
281 ///
282 /// # Read header only
283 ///
284 /// If you just need to retrieve the token without any additional payload,
285 /// omit the `with_filter` call and use the `Token` type without
286 /// a parameter:
287 ///
288 /// ```
289 /// use {
290 /// axum::{Extension, Router, routing},
291 /// axum_jwt::{Decoder, Token, jsonwebtoken::DecodingKey},
292 /// };
293 ///
294 /// async fn hello(Extension(t): Extension<Token>) -> String {
295 /// // Access the parsed token
296 /// format!("Decoded with {:?} algorithm", t.header.alg)
297 /// }
298 ///
299 /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
300 ///
301 /// let app = Router::new()
302 /// .route("/", routing::get(hello))
303 /// .layer(axum_jwt::layer(decoder).store_to_extension());
304 /// # let _: Router = app;
305 /// ```
306 pub fn store_to_extension(mut self) -> Self
307 where
308 I: Clone + Send + Sync + 'static,
309 {
310 self.store = |claims, extensions| {
311 extensions.insert(claims);
312 };
313
314 self
315 }
316}
317
318impl<I, H> JwtLayer<I, H> {
319 /// Applies a token extractor to the layer.
320 ///
321 /// By default, the token is extracted from the `Authorization` header using
322 /// the `Bearer` scheme. If you want to change this behavior, create a new type
323 /// and implement [`Extract`] for it. Then, you can pass this type into the
324 /// layer configuration:
325 ///
326 /// ```
327 /// use {
328 /// axum::{Extension, Router, http::request::Parts, routing},
329 /// axum_jwt::{Decoder, Extract, Token, jsonwebtoken::DecodingKey},
330 /// };
331 ///
332 /// struct Custom;
333 ///
334 /// impl Extract for Custom {
335 /// fn extract(parts: &mut Parts) -> Option<&str> {
336 /// parts.headers.get("X-Auth-Token")?.to_str().ok()
337 /// }
338 /// }
339 ///
340 /// async fn hello(Extension(t): Extension<Token>) -> String {
341 /// format!("Decoded with {:?} algorithm", t.header.alg)
342 /// }
343 ///
344 /// let decoder = Decoder::from_key(DecodingKey::from_secret(b"secret"));
345 ///
346 /// let app = Router::new()
347 /// .route("/", routing::get(hello))
348 /// .layer(axum_jwt::layer(decoder).with_extract(Custom));
349 /// # let _: Router = app;
350 /// ```
351 pub fn with_extract<X>(self, extract: X) -> JwtLayer<I, H, X>
352 where
353 X: Extract,
354 {
355 _ = extract;
356 JwtLayer {
357 decoder: self.decoder,
358 validate: self.validate,
359 store: self.store,
360 extract: PhantomData,
361 }
362 }
363}
364
365impl<I, H, X> Clone for JwtLayer<I, H, X>
366where
367 H: Clone,
368{
369 fn clone(&self) -> Self {
370 Self {
371 decoder: self.decoder.clone(),
372 validate: self.validate.clone(),
373 store: self.store,
374 extract: PhantomData,
375 }
376 }
377}
378
379impl<I, H, X> fmt::Debug for JwtLayer<I, H, X> {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 f.debug_struct("JwtLayer")
382 .field("decoder", &self.decoder)
383 .field("validate", &"..")
384 .field("store", &"..")
385 .field("extract", &any::type_name::<H>())
386 .finish()
387 }
388}
389
390impl<S, I, H, X> Layer<S> for JwtLayer<I, H, X>
391where
392 H: Clone,
393{
394 type Service = Jwt<S, I, H, X>;
395
396 fn layer(&self, svc: S) -> Self::Service {
397 Jwt {
398 svc,
399 decoder: self.decoder.clone(),
400 validate: self.validate.clone(),
401 store: self.store,
402 extract: PhantomData,
403 }
404 }
405}
406
407/// Trait for additional token validation.
408pub trait Validate<I> {
409 type Output: Output;
410 fn validate(&mut self, input: &Token<I>) -> Self::Output;
411}
412
413/// The output value of the [validation](Validate).
414pub trait Output {
415 fn output(self) -> Option<Response>;
416}
417
418impl<E> Output for Result<(), E>
419where
420 E: IntoResponse,
421{
422 fn output(self) -> Option<Response> {
423 self.err().map(E::into_response)
424 }
425}
426
427impl Output for bool {
428 fn output(self) -> Option<Response> {
429 if self {
430 None
431 } else {
432 Some(StatusCode::UNAUTHORIZED.into_response())
433 }
434 }
435}
436
437/// Discards any token data and returns success.
438#[derive(Clone)]
439pub struct Discard;
440
441impl<I> Validate<I> for Discard {
442 type Output = bool;
443
444 fn validate(&mut self, _: &Token<I>) -> Self::Output {
445 true
446 }
447}
448
449impl<F, I, O> Validate<I> for F
450where
451 F: FnMut(&Token<I>) -> O,
452 I: DeserializeOwned,
453 O: Output,
454{
455 type Output = O;
456
457 fn validate(&mut self, input: &Token<I>) -> Self::Output {
458 self(input)
459 }
460}
461
462/// Axum [middleware] for token validation.
463///
464/// [middleware]: https://docs.rs/axum/latest/axum/middleware/index.html
465///
466/// To configure the layer and create the middleware service, call
467/// the [`layer`] function.
468pub struct Jwt<S, I, H = Discard, X = Bearer> {
469 svc: S,
470 decoder: Decoder,
471 validate: H,
472 store: fn(Token<I>, &mut Extensions),
473 extract: PhantomData<fn() -> X>,
474}
475
476impl<S, I, H, X> Clone for Jwt<S, I, H, X>
477where
478 S: Clone,
479 H: Clone,
480{
481 fn clone(&self) -> Self {
482 Self {
483 svc: self.svc.clone(),
484 decoder: self.decoder.clone(),
485 validate: self.validate.clone(),
486 store: self.store,
487 extract: PhantomData,
488 }
489 }
490}
491
492impl<S, I, H, X> fmt::Debug for Jwt<S, I, H, X>
493where
494 S: fmt::Debug,
495{
496 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
497 f.debug_struct("Jwt")
498 .field("svc", &self.svc)
499 .field("decoder", &self.decoder)
500 .field("validate", &"..")
501 .field("store", &"..")
502 .field("extract", &any::type_name::<X>())
503 .finish()
504 }
505}
506
507impl<S, I, H, X> Service<Request> for Jwt<S, I, H, X>
508where
509 S: Service<Request> + Clone,
510 I: DeserializeOwned,
511 H: Validate<I>,
512 X: Extract,
513 Result<S::Response, S::Error>: IntoResponse,
514{
515 type Response = Response;
516 type Error = Infallible;
517 type Future = JwtFuture<S>;
518
519 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
520 Poll::Ready(Ok(()))
521 }
522
523 fn call(&mut self, req: Request) -> Self::Future {
524 let validate = |parts| -> Result<Token<I>, Error> {
525 let token = X::extract(parts).ok_or(Error::Extract)?;
526 let TokenData { header, claims }: TokenData<I> =
527 self.decoder.decode(token).map_err(Error::Jwt)?;
528
529 Ok(Token::new(header, claims))
530 };
531
532 let (mut parts, body) = req.into_parts();
533 match validate(&mut parts) {
534 Ok(token) => {
535 if let Some(res) = self.validate.validate(&token).output() {
536 return JwtFuture::ready(res);
537 }
538
539 (self.store)(token, &mut parts.extensions);
540
541 let req = Request::from_parts(parts, body);
542 let clone = self.svc.clone();
543 let svc = mem::replace(&mut self.svc, clone);
544 JwtFuture::not_ready(svc, req)
545 }
546 Err(e) => JwtFuture::ready(e.into_response()),
547 }
548 }
549}
550
551pin_project_lite::pin_project! {
552 /// Middleware future.
553 pub struct JwtFuture<S>
554 where
555 S: Service<Request>,
556 {
557 #[pin]
558 state: State<S, S::Future>,
559 }
560}
561
562impl<S> JwtFuture<S>
563where
564 S: Service<Request>,
565{
566 fn not_ready(svc: S, req: Request) -> Self {
567 Self {
568 state: State::NotReady { svc, req },
569 }
570 }
571
572 fn ready(res: Response) -> Self {
573 Self {
574 state: State::Ready { res },
575 }
576 }
577}
578
579impl<S> Future for JwtFuture<S>
580where
581 S: Service<Request>,
582 Result<S::Response, S::Error>: IntoResponse,
583{
584 type Output = Result<Response, Infallible>;
585
586 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
587 let mut state = self.project().state;
588 let res = loop {
589 match state.as_mut().project() {
590 StateProj::NotReady { svc, req } => {
591 if let Err(e) = task::ready!(svc.poll_ready(cx)) {
592 state.set(State::Done);
593 break Err(e).into_response();
594 }
595
596 let req = mem::take(req);
597 let fut = svc.call(req);
598 state.set(State::Called { fut });
599 }
600 StateProj::Called { fut } => {
601 let res = task::ready!(fut.poll(cx));
602 state.set(State::Done);
603 break res.into_response();
604 }
605 StateProj::Ready { res } => {
606 let res = mem::take(res);
607 state.set(State::Done);
608 break res;
609 }
610 StateProj::Done => panic!("polled after completion"),
611 }
612 };
613
614 Poll::Ready(Ok(res))
615 }
616}
617
618pin_project_lite::pin_project! {
619 #[project = StateProj]
620 enum State<S, F> {
621 NotReady { svc: S, req: Request },
622 Called {
623 #[pin]
624 fut: F,
625 },
626 Ready { res: Response },
627 Done,
628 }
629}