axum_webtools/security/
jwt.rs

1use axum::{
2    extract::FromRequestParts,
3    http::{request::Parts, StatusCode},
4    response::{IntoResponse, Response},
5    Json, RequestPartsExt,
6};
7
8use chrono::TimeDelta;
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10
11use axum_extra::{
12    headers::{authorization::Bearer, Authorization},
13    TypedHeader,
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::fmt::Display;
18
19fn get_jwt_secret() -> String {
20    std::env::var("JWT_SECRET").expect("JWT_SECRET must be set")
21}
22
23fn get_jwt_issuer() -> String {
24    std::env::var("JWT_ISSUER").expect("JWT_ISSUER must be set")
25}
26
27fn get_jwt_audience() -> String {
28    std::env::var("JWT_AUDIENCE").expect("JWT_AUDIENCE must be set")
29}
30
31pub fn parse_jwt_token(token: impl Into<String>) -> Result<Claims, jsonwebtoken::errors::Error> {
32    let token = token.into();
33    let jwt_issuer = get_jwt_issuer();
34    let jwt_audience = get_jwt_audience();
35    let jwt_secret = get_jwt_secret();
36    let decode_key = DecodingKey::from_secret(jwt_secret.as_bytes());
37
38    let mut validation = Validation::new(Algorithm::HS256);
39    validation.set_audience(&[jwt_audience]);
40    validation.set_issuer(&[jwt_issuer]);
41    let token_data = decode::<Claims>(token.as_str(), &decode_key, &validation)?;
42    Ok(token_data.claims)
43}
44
45pub fn create_jwt_token(subject: impl Into<String>) -> String {
46    let now = chrono::Utc::now();
47    let expires_at = TimeDelta::try_days(7)
48        .map(|d| now + d)
49        .expect("Failed to calculate expiration date");
50    let issued_at = now.timestamp() as u64;
51    let exp = expires_at.timestamp() as u64;
52    let iss = get_jwt_issuer();
53    let aud = get_jwt_audience();
54    let sub = subject.into();
55
56    let claims = Claims {
57        iss,
58        sub,
59        issued_at,
60        exp,
61        aud,
62    };
63
64    let jwt_secret = get_jwt_secret();
65    let encode_key = EncodingKey::from_secret(jwt_secret.as_bytes());
66    encode(&Header::default(), &claims, &encode_key).unwrap()
67}
68
69#[derive(Debug, Serialize, Deserialize)]
70pub struct Claims {
71    pub sub: String,
72    pub aud: String,
73    pub iss: String,
74    pub issued_at: u64,
75    pub exp: u64,
76}
77
78impl Display for Claims {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "Email: {}", self.sub)
81    }
82}
83
84impl<S> FromRequestParts<S> for Claims
85where
86    S: Send + Sync,
87{
88    type Rejection = AuthError;
89
90    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
91        let TypedHeader(Authorization(bearer)) = parts
92            .extract::<TypedHeader<Authorization<Bearer>>>()
93            .await
94            .map_err(|_| AuthError::InvalidToken)?;
95
96        let claims = parse_jwt_token(bearer.token()).map_err(|_| AuthError::InvalidToken)?;
97
98        Ok(claims)
99    }
100}
101
102impl IntoResponse for AuthError {
103    fn into_response(self) -> Response {
104        let (status, error_message) = match self {
105            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
106        };
107        let body = Json(json!({
108            "error": error_message,
109        }));
110        (status, body).into_response()
111    }
112}
113
114#[derive(Debug)]
115pub enum AuthError {
116    InvalidToken,
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use fake::faker::internet::en::FreeEmail;
123    use fake::Fake;
124
125    fn setup() {
126        std::env::set_var("JWT_SECRET", "secret");
127        std::env::set_var("JWT_ISSUER", "issuer");
128        std::env::set_var("JWT_AUDIENCE", "audience");
129    }
130
131    #[test]
132    fn test_create_token() {
133        setup();
134        let email: String = FreeEmail().fake();
135        let token = create_jwt_token(email.clone());
136        let claims = parse_jwt_token(token).unwrap();
137        assert_eq!(email, claims.sub);
138    }
139
140    #[test]
141    fn test_invalid_token() {
142        setup();
143        let email: String = FreeEmail().fake();
144        let mut token = create_jwt_token(email);
145        token.push_str("a");
146        let claims = parse_jwt_token(token);
147        assert!(claims.is_err());
148    }
149}