axum_webtools/security/
jwt.rs

1use axum::{
2    extract::FromRequestParts,
3    http::{request::Parts, StatusCode},
4    response::{IntoResponse, Response},
5    Json, RequestPartsExt,
6};
7
8use chrono::TimeDelta;
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10
11use axum_extra::{
12    headers::{authorization::Bearer, Authorization},
13    TypedHeader,
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::fmt::Display;
18
19pub trait JwtToken: Send + Sync {
20    fn subject(&self) -> String;
21}
22
23fn get_jwt_secret() -> String {
24    std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
25}
26
27fn get_jwt_issuer() -> String {
28    std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
29}
30
31fn get_jwt_audience() -> String {
32    std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
33}
34
35pub fn parse_jwt_token(token: impl Into<String>) -> Result<Claims, jsonwebtoken::errors::Error> {
36    let token = token.into();
37    let jwt_issuer = get_jwt_issuer();
38    let jwt_audience = get_jwt_audience();
39    let jwt_secret = get_jwt_secret();
40    let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
41
42    let mut validation = Validation::new(Algorithm::HS256);
43    validation.set_audience(&[jwt_audience]);
44    validation.set_issuer(&[jwt_issuer]);
45    let token_data = decode::<Claims>(token.as_str(), &decode_key, &validation)?;
46    Ok(token_data.claims)
47}
48
49pub fn create_jwt_token(subject: impl Into<String>) -> String {
50    let now = chrono::Utc::now();
51    let expires_at = TimeDelta::try_days(7)
52        .map(|d| now + d)
53        .expect("Failed to calculate expiration date");
54    let issued_at = now.timestamp() as u64;
55    let exp = expires_at.timestamp() as u64;
56    let iss = get_jwt_issuer();
57    let aud = get_jwt_audience();
58    let sub = subject.into();
59
60    let claims = Claims {
61        iss,
62        sub,
63        issued_at,
64        exp,
65        aud,
66    };
67
68    let jwt_secret = get_jwt_secret();
69    let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
70    encode(&Header::default(), &claims, &encode_key).unwrap()
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74pub struct Claims {
75    pub sub: String,
76    pub aud: String,
77    pub iss: String,
78    pub issued_at: u64,
79    pub exp: u64,
80}
81
82impl Display for Claims {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "Email: {}", self.sub)
85    }
86}
87
88#[cfg(not(test))]
89impl<S> FromRequestParts<S> for Claims
90where
91    S: Send + Sync,
92{
93    type Rejection = AuthError;
94
95    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
96        let TypedHeader(Authorization(bearer)) = parts
97            .extract::<TypedHeader<Authorization<Bearer>>>()
98            .await
99            .map_err(|_| AuthError::InvalidToken)?;
100
101        let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
102
103        Ok(claims)
104    }
105}
106
107#[cfg(test)]
108impl<S> FromRequestParts<S> for Claims
109where
110    S: Send + Sync,
111{
112    type Rejection = AuthError;
113
114    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
115        let sub = parts
116            .headers
117            .get("X-Claims-Subject")
118            .unwrap()
119            .to_str()
120            .unwrap();
121        let sub = sub.to_string();
122        Ok(Claims {
123            sub,
124            aud: "audience".to_string(),
125            iss: "issuer".to_string(),
126            issued_at: chrono::Utc::now().timestamp() as u64,
127            exp: (chrono::Utc::now() + chrono::Duration::days(7)).timestamp() as u64,
128        })
129    }
130}
131
132impl IntoResponse for AuthError {
133    fn into_response(self) -> Response {
134        let (status, error_message) = match self {
135            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
136        };
137        let body = Json(json!({
138            "error": error_message,
139        }));
140        (status, body).into_response()
141    }
142}
143
144#[derive(Debug)]
145pub enum AuthError {
146    InvalidToken,
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use fake::faker::internet::en::FreeEmail;
153    use fake::Fake;
154
155    fn setup() {
156        std::env::set_var("JWT_SECRET", "secret");
157        std::env::set_var("JWT_ISSUER", "issuer");
158        std::env::set_var("JWT_AUDIENCE", "audience");
159    }
160
161    #[test]
162    fn test_create_token() {
163        setup();
164        let email: String = FreeEmail().fake();
165        let token = create_jwt_token(email.clone());
166        let claims = parse_jwt_token(token).unwrap();
167        assert_eq!(email, claims.sub);
168    }
169
170    #[test]
171    fn test_invalid_token() {
172        setup();
173        let email: String = FreeEmail().fake();
174        let mut token = create_jwt_token(email);
175        token.push_str("a");
176        let claims = parse_jwt_token(token);
177        assert!(claims.is_err());
178    }
179}