Skip to main content

axum_security/jwt/
session.rs

1use std::convert::Infallible;
2
3use axum::{
4    extract::{FromRequestParts, OptionalFromRequestParts},
5    http::{Extensions, StatusCode, request::Parts},
6};
7
8#[derive(Clone, Debug)]
9pub struct Jwt<T>(pub T);
10
11impl<T: Send + Sync + 'static> Jwt<T> {
12    pub fn from_extensions(extensions: &mut Extensions) -> Option<Self> {
13        extensions.remove()
14    }
15}
16
17impl<S, T> FromRequestParts<S> for Jwt<T>
18where
19    S: Send + Sync,
20    T: Send + Sync + 'static,
21{
22    type Rejection = StatusCode;
23
24    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
25        if let Some(session) = <Jwt<T>>::from_extensions(&mut parts.extensions) {
26            Ok(session)
27        } else {
28            Err(StatusCode::UNAUTHORIZED)
29        }
30    }
31}
32
33impl<S, T> OptionalFromRequestParts<S> for Jwt<T>
34where
35    S: Send + Sync,
36    T: Send + Sync + 'static,
37{
38    type Rejection = Infallible;
39
40    async fn from_request_parts(
41        parts: &mut Parts,
42        _state: &S,
43    ) -> Result<Option<Self>, Self::Rejection> {
44        Ok(<Jwt<T>>::from_extensions(&mut parts.extensions))
45    }
46}
47
48#[cfg(test)]
49mod extract_jwt {
50    use axum::{
51        extract::FromRequestParts,
52        http::{Request, StatusCode},
53    };
54
55    use crate::jwt::Jwt;
56
57    #[tokio::test]
58    async fn extract() {
59        let jwt = Jwt(1i32);
60
61        let (mut parts, _) = Request::builder()
62            .extension(jwt.clone())
63            .body(())
64            .unwrap()
65            .into_parts();
66
67        let extracted_jwt = Jwt::<i32>::from_request_parts(&mut parts, &())
68            .await
69            .unwrap();
70
71        assert!(jwt.0 == extracted_jwt.0);
72    }
73
74    #[tokio::test]
75    async fn extract_rejection() {
76        let (mut parts, _) = Request::builder().body(()).unwrap().into_parts();
77
78        let rejection = Jwt::<i32>::from_request_parts(&mut parts, &())
79            .await
80            .unwrap_err();
81
82        assert!(rejection == StatusCode::UNAUTHORIZED);
83    }
84}