postrust_auth/
jwt.rs

1//! JWT token validation.
2
3use super::{AuthResult, JwtConfig, JwtError};
4use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Validate a JWT token and extract claims.
9pub fn validate_token(token: &str, config: &JwtConfig) -> Result<AuthResult, JwtError> {
10    let secret = config.secret.as_ref().ok_or_else(|| {
11        JwtError::InvalidToken("No JWT secret configured".into())
12    })?;
13
14    // Decode secret
15    let key_bytes = if config.secret_is_base64 {
16        base64_decode(secret)?
17    } else {
18        secret.as_bytes().to_vec()
19    };
20
21    let key = DecodingKey::from_secret(&key_bytes);
22
23    // Set up validation
24    let mut validation = Validation::new(Algorithm::HS256);
25    validation.validate_exp = true;
26    validation.validate_nbf = true;
27
28    if let Some(aud) = &config.audience {
29        validation.set_audience(&[aud]);
30    } else {
31        validation.validate_aud = false;
32    }
33
34    // Decode and validate
35    let token_data = decode::<Claims>(token, &key, &validation)
36        .map_err(|e| map_jwt_error(e))?;
37
38    let claims = token_data.claims;
39
40    // Extract role
41    let role = claims
42        .extra
43        .get(&config.role_claim_key)
44        .and_then(|v| v.as_str())
45        .map(|s| s.to_string())
46        .or_else(|| config.anon_role.clone())
47        .ok_or(JwtError::MissingRole)?;
48
49    // Build claims map
50    let mut claims_map = claims.extra;
51    if let Some(sub) = claims.sub {
52        claims_map.insert("sub".into(), serde_json::Value::String(sub));
53    }
54    if let Some(iss) = claims.iss {
55        claims_map.insert("iss".into(), serde_json::Value::String(iss));
56    }
57    if let Some(exp) = claims.exp {
58        claims_map.insert("exp".into(), serde_json::Value::Number(exp.into()));
59    }
60
61    Ok(AuthResult {
62        role,
63        claims: claims_map,
64    })
65}
66
67/// Standard and custom JWT claims.
68#[derive(Debug, Serialize, Deserialize)]
69pub struct Claims {
70    /// Subject
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub sub: Option<String>,
73    /// Issuer
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub iss: Option<String>,
76    /// Expiration time
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub exp: Option<i64>,
79    /// Not before
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub nbf: Option<i64>,
82    /// Issued at
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub iat: Option<i64>,
85    /// Audience
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub aud: Option<String>,
88    /// Custom claims
89    #[serde(flatten)]
90    pub extra: HashMap<String, serde_json::Value>,
91}
92
93/// Decode base64 secret.
94fn base64_decode(s: &str) -> Result<Vec<u8>, JwtError> {
95    use base64::{engine::general_purpose::STANDARD, Engine};
96    STANDARD
97        .decode(s)
98        .map_err(|e| JwtError::InvalidToken(format!("Invalid base64 secret: {}", e)))
99}
100
101/// Map jsonwebtoken error to JwtError.
102fn map_jwt_error(e: jsonwebtoken::errors::Error) -> JwtError {
103    use jsonwebtoken::errors::ErrorKind;
104
105    match e.kind() {
106        ErrorKind::ExpiredSignature => JwtError::Expired,
107        ErrorKind::ImmatureSignature => JwtError::NotYetValid,
108        ErrorKind::InvalidSignature => JwtError::InvalidSignature,
109        ErrorKind::InvalidAudience => JwtError::InvalidAudience,
110        _ => JwtError::InvalidToken(e.to_string()),
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use jsonwebtoken::{encode, EncodingKey, Header};
118
119    fn make_token(claims: &Claims, secret: &str) -> String {
120        let key = EncodingKey::from_secret(secret.as_bytes());
121        encode(&Header::default(), claims, &key).unwrap()
122    }
123
124    #[test]
125    fn test_validate_valid_token() {
126        let secret = "test_secret_key_at_least_32_bytes!";
127
128        let claims = Claims {
129            sub: Some("user123".into()),
130            iss: None,
131            exp: Some(chrono::Utc::now().timestamp() + 3600),
132            nbf: None,
133            iat: None,
134            aud: None,
135            extra: {
136                let mut m = HashMap::new();
137                m.insert("role".into(), serde_json::Value::String("web_user".into()));
138                m
139            },
140        };
141
142        let token = make_token(&claims, secret);
143
144        let config = JwtConfig {
145            secret: Some(secret.into()),
146            ..Default::default()
147        };
148
149        let result = validate_token(&token, &config).unwrap();
150        assert_eq!(result.role, "web_user");
151        assert_eq!(result.get_claim("sub").unwrap(), "user123");
152    }
153
154    #[test]
155    fn test_validate_expired_token() {
156        let secret = "test_secret_key_at_least_32_bytes!";
157
158        let claims = Claims {
159            sub: None,
160            iss: None,
161            exp: Some(chrono::Utc::now().timestamp() - 3600), // Expired
162            nbf: None,
163            iat: None,
164            aud: None,
165            extra: HashMap::new(),
166        };
167
168        let token = make_token(&claims, secret);
169
170        let config = JwtConfig {
171            secret: Some(secret.into()),
172            ..Default::default()
173        };
174
175        let result = validate_token(&token, &config);
176        assert!(matches!(result, Err(JwtError::Expired)));
177    }
178}