avl_auth/
jwt.rs

1//! Advanced JWT implementation with key rotation and multi-algorithm support
2
3use crate::error::{AuthError, Result};
4use crate::models::Claims;
5use chrono::{Duration, Utc};
6use jsonwebtoken::{
7    decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct JwtManager {
15    config: JwtConfig,
16    keys: Arc<RwLock<KeyStore>>,
17}
18
19#[derive(Clone)]
20pub struct JwtConfig {
21    pub algorithm: Algorithm,
22    pub issuer: String,
23    pub audience: String,
24    pub access_token_ttl: Duration,
25    pub refresh_token_ttl: Duration,
26}
27
28struct KeyStore {
29    active_key: KeyPair,
30    retired_keys: Vec<KeyPair>,
31    key_id_counter: u64,
32}
33
34struct KeyPair {
35    id: String,
36    encoding: EncodingKey,
37    decoding: DecodingKey,
38    algorithm: Algorithm,
39    created_at: chrono::DateTime<Utc>,
40}
41
42impl JwtManager {
43    pub fn new(config: JwtConfig, private_key: &str, public_key: &str) -> Result<Self> {
44        let encoding = Self::create_encoding_key(&config.algorithm, private_key)?;
45        let decoding = Self::create_decoding_key(&config.algorithm, public_key)?;
46
47        let key_pair = KeyPair {
48            id: "key_1".to_string(),
49            encoding,
50            decoding,
51            algorithm: config.algorithm,
52            created_at: Utc::now(),
53        };
54
55        let keys = Arc::new(RwLock::new(KeyStore {
56            active_key: key_pair,
57            retired_keys: Vec::new(),
58            key_id_counter: 1,
59        }));
60
61        Ok(Self { config, keys })
62    }
63
64    fn create_encoding_key(algorithm: &Algorithm, private_key: &str) -> Result<EncodingKey> {
65        match algorithm {
66            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
67                EncodingKey::from_rsa_pem(private_key.as_bytes())
68                    .map_err(|e| AuthError::CryptoError(e.to_string()))
69            }
70            Algorithm::ES256 | Algorithm::ES384 => {
71                EncodingKey::from_ec_pem(private_key.as_bytes())
72                    .map_err(|e| AuthError::CryptoError(e.to_string()))
73            }
74            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
75                Ok(EncodingKey::from_secret(private_key.as_bytes()))
76            }
77            _ => Err(AuthError::ConfigError(format!("Unsupported algorithm: {:?}", algorithm))),
78        }
79    }
80
81    fn create_decoding_key(algorithm: &Algorithm, public_key: &str) -> Result<DecodingKey> {
82        match algorithm {
83            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
84                DecodingKey::from_rsa_pem(public_key.as_bytes())
85                    .map_err(|e| AuthError::CryptoError(e.to_string()))
86            }
87            Algorithm::ES256 | Algorithm::ES384 => {
88                DecodingKey::from_ec_pem(public_key.as_bytes())
89                    .map_err(|e| AuthError::CryptoError(e.to_string()))
90            }
91            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
92                Ok(DecodingKey::from_secret(public_key.as_bytes()))
93            }
94            _ => Err(AuthError::ConfigError(format!("Unsupported algorithm: {:?}", algorithm))),
95        }
96    }
97
98    pub async fn create_token(&self, claims: &Claims) -> Result<String> {
99        let keys = self.keys.read().await;
100
101        let mut header = Header::new(self.config.algorithm);
102        header.kid = Some(keys.active_key.id.clone());
103
104        encode(&header, claims, &keys.active_key.encoding)
105            .map_err(|e| AuthError::CryptoError(e.to_string()))
106    }
107
108    pub async fn verify_token(&self, token: &str) -> Result<Claims> {
109        let keys = self.keys.read().await;
110
111        let mut validation = Validation::new(self.config.algorithm);
112        validation.set_issuer(&[&self.config.issuer]);
113        validation.set_audience(&[&self.config.audience]);
114
115        // Try active key first
116        if let Ok(token_data) = decode::<Claims>(token, &keys.active_key.decoding, &validation) {
117            return Ok(token_data.claims);
118        }
119
120        // Try retired keys
121        for key in &keys.retired_keys {
122            if let Ok(token_data) = decode::<Claims>(token, &key.decoding, &validation) {
123                return Ok(token_data.claims);
124            }
125        }
126
127        Err(AuthError::InvalidToken("Unable to verify token signature".to_string()))
128    }
129
130    pub async fn rotate_keys(&self, new_private_key: &str, new_public_key: &str) -> Result<()> {
131        let encoding = Self::create_encoding_key(&self.config.algorithm, new_private_key)?;
132        let decoding = Self::create_decoding_key(&self.config.algorithm, new_public_key)?;
133
134        let mut keys = self.keys.write().await;
135        keys.key_id_counter += 1;
136
137        let new_key = KeyPair {
138            id: format!("key_{}", keys.key_id_counter),
139            encoding,
140            decoding,
141            algorithm: self.config.algorithm,
142            created_at: Utc::now(),
143        };
144
145        // Move current active key to retired
146        let old_key = std::mem::replace(&mut keys.active_key, new_key);
147        keys.retired_keys.push(old_key);
148
149        // Keep only last 3 retired keys
150        if keys.retired_keys.len() > 3 {
151            keys.retired_keys.remove(0);
152        }
153
154        tracing::info!("JWT keys rotated successfully");
155        Ok(())
156    }
157
158    pub async fn get_jwks(&self) -> Result<JwkSet> {
159        let keys = self.keys.read().await;
160
161        let mut jwks = Vec::new();
162
163        // Add active key
164        jwks.push(Jwk {
165            kid: keys.active_key.id.clone(),
166            kty: "RSA".to_string(), // Simplified, would need to determine from algorithm
167            alg: format!("{:?}", keys.active_key.algorithm),
168            use_: "sig".to_string(),
169        });
170
171        // Add retired keys
172        for key in &keys.retired_keys {
173            jwks.push(Jwk {
174                kid: key.id.clone(),
175                kty: "RSA".to_string(),
176                alg: format!("{:?}", key.algorithm),
177                use_: "sig".to_string(),
178            });
179        }
180
181        Ok(JwkSet { keys: jwks })
182    }
183
184    pub fn create_claims(
185        &self,
186        user_id: Uuid,
187        email: String,
188        roles: Vec<String>,
189        permissions: Vec<String>,
190        session_id: Uuid,
191        scopes: Vec<String>,
192        device_id: Option<String>,
193    ) -> Claims {
194        let now = Utc::now();
195        let exp = now + self.config.access_token_ttl;
196
197        Claims {
198            sub: user_id,
199            email,
200            roles,
201            permissions,
202            session_id,
203            iat: now.timestamp(),
204            exp: exp.timestamp(),
205            nbf: now.timestamp(),
206            iss: self.config.issuer.clone(),
207            aud: self.config.audience.clone(),
208            jti: Uuid::new_v4().to_string(),
209            scopes,
210            device_id,
211        }
212    }
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216pub struct JwkSet {
217    pub keys: Vec<Jwk>,
218}
219
220#[derive(Debug, Serialize, Deserialize)]
221pub struct Jwk {
222    pub kid: String,
223    pub kty: String,
224    pub alg: String,
225    #[serde(rename = "use")]
226    pub use_: String,
227}