1use crate::error::{AuthError, Result};
4use crate::models::Claims;
5use chrono::{Duration, Utc};
6use jsonwebtoken::{
7 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct JwtManager {
15 config: JwtConfig,
16 keys: Arc<RwLock<KeyStore>>,
17}
18
19#[derive(Clone)]
20pub struct JwtConfig {
21 pub algorithm: Algorithm,
22 pub issuer: String,
23 pub audience: String,
24 pub access_token_ttl: Duration,
25 pub refresh_token_ttl: Duration,
26}
27
28struct KeyStore {
29 active_key: KeyPair,
30 retired_keys: Vec<KeyPair>,
31 key_id_counter: u64,
32}
33
34struct KeyPair {
35 id: String,
36 encoding: EncodingKey,
37 decoding: DecodingKey,
38 algorithm: Algorithm,
39 created_at: chrono::DateTime<Utc>,
40}
41
42impl JwtManager {
43 pub fn new(config: JwtConfig, private_key: &str, public_key: &str) -> Result<Self> {
44 let encoding = Self::create_encoding_key(&config.algorithm, private_key)?;
45 let decoding = Self::create_decoding_key(&config.algorithm, public_key)?;
46
47 let key_pair = KeyPair {
48 id: "key_1".to_string(),
49 encoding,
50 decoding,
51 algorithm: config.algorithm,
52 created_at: Utc::now(),
53 };
54
55 let keys = Arc::new(RwLock::new(KeyStore {
56 active_key: key_pair,
57 retired_keys: Vec::new(),
58 key_id_counter: 1,
59 }));
60
61 Ok(Self { config, keys })
62 }
63
64 fn create_encoding_key(algorithm: &Algorithm, private_key: &str) -> Result<EncodingKey> {
65 match algorithm {
66 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
67 EncodingKey::from_rsa_pem(private_key.as_bytes())
68 .map_err(|e| AuthError::CryptoError(e.to_string()))
69 }
70 Algorithm::ES256 | Algorithm::ES384 => {
71 EncodingKey::from_ec_pem(private_key.as_bytes())
72 .map_err(|e| AuthError::CryptoError(e.to_string()))
73 }
74 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
75 Ok(EncodingKey::from_secret(private_key.as_bytes()))
76 }
77 _ => Err(AuthError::ConfigError(format!("Unsupported algorithm: {:?}", algorithm))),
78 }
79 }
80
81 fn create_decoding_key(algorithm: &Algorithm, public_key: &str) -> Result<DecodingKey> {
82 match algorithm {
83 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
84 DecodingKey::from_rsa_pem(public_key.as_bytes())
85 .map_err(|e| AuthError::CryptoError(e.to_string()))
86 }
87 Algorithm::ES256 | Algorithm::ES384 => {
88 DecodingKey::from_ec_pem(public_key.as_bytes())
89 .map_err(|e| AuthError::CryptoError(e.to_string()))
90 }
91 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
92 Ok(DecodingKey::from_secret(public_key.as_bytes()))
93 }
94 _ => Err(AuthError::ConfigError(format!("Unsupported algorithm: {:?}", algorithm))),
95 }
96 }
97
98 pub async fn create_token(&self, claims: &Claims) -> Result<String> {
99 let keys = self.keys.read().await;
100
101 let mut header = Header::new(self.config.algorithm);
102 header.kid = Some(keys.active_key.id.clone());
103
104 encode(&header, claims, &keys.active_key.encoding)
105 .map_err(|e| AuthError::CryptoError(e.to_string()))
106 }
107
108 pub async fn verify_token(&self, token: &str) -> Result<Claims> {
109 let keys = self.keys.read().await;
110
111 let mut validation = Validation::new(self.config.algorithm);
112 validation.set_issuer(&[&self.config.issuer]);
113 validation.set_audience(&[&self.config.audience]);
114
115 if let Ok(token_data) = decode::<Claims>(token, &keys.active_key.decoding, &validation) {
117 return Ok(token_data.claims);
118 }
119
120 for key in &keys.retired_keys {
122 if let Ok(token_data) = decode::<Claims>(token, &key.decoding, &validation) {
123 return Ok(token_data.claims);
124 }
125 }
126
127 Err(AuthError::InvalidToken("Unable to verify token signature".to_string()))
128 }
129
130 pub async fn rotate_keys(&self, new_private_key: &str, new_public_key: &str) -> Result<()> {
131 let encoding = Self::create_encoding_key(&self.config.algorithm, new_private_key)?;
132 let decoding = Self::create_decoding_key(&self.config.algorithm, new_public_key)?;
133
134 let mut keys = self.keys.write().await;
135 keys.key_id_counter += 1;
136
137 let new_key = KeyPair {
138 id: format!("key_{}", keys.key_id_counter),
139 encoding,
140 decoding,
141 algorithm: self.config.algorithm,
142 created_at: Utc::now(),
143 };
144
145 let old_key = std::mem::replace(&mut keys.active_key, new_key);
147 keys.retired_keys.push(old_key);
148
149 if keys.retired_keys.len() > 3 {
151 keys.retired_keys.remove(0);
152 }
153
154 tracing::info!("JWT keys rotated successfully");
155 Ok(())
156 }
157
158 pub async fn get_jwks(&self) -> Result<JwkSet> {
159 let keys = self.keys.read().await;
160
161 let mut jwks = Vec::new();
162
163 jwks.push(Jwk {
165 kid: keys.active_key.id.clone(),
166 kty: "RSA".to_string(), alg: format!("{:?}", keys.active_key.algorithm),
168 use_: "sig".to_string(),
169 });
170
171 for key in &keys.retired_keys {
173 jwks.push(Jwk {
174 kid: key.id.clone(),
175 kty: "RSA".to_string(),
176 alg: format!("{:?}", key.algorithm),
177 use_: "sig".to_string(),
178 });
179 }
180
181 Ok(JwkSet { keys: jwks })
182 }
183
184 pub fn create_claims(
185 &self,
186 user_id: Uuid,
187 email: String,
188 roles: Vec<String>,
189 permissions: Vec<String>,
190 session_id: Uuid,
191 scopes: Vec<String>,
192 device_id: Option<String>,
193 ) -> Claims {
194 let now = Utc::now();
195 let exp = now + self.config.access_token_ttl;
196
197 Claims {
198 sub: user_id,
199 email,
200 roles,
201 permissions,
202 session_id,
203 iat: now.timestamp(),
204 exp: exp.timestamp(),
205 nbf: now.timestamp(),
206 iss: self.config.issuer.clone(),
207 aud: self.config.audience.clone(),
208 jti: Uuid::new_v4().to_string(),
209 scopes,
210 device_id,
211 }
212 }
213}
214
215#[derive(Debug, Serialize, Deserialize)]
216pub struct JwkSet {
217 pub keys: Vec<Jwk>,
218}
219
220#[derive(Debug, Serialize, Deserialize)]
221pub struct Jwk {
222 pub kid: String,
223 pub kty: String,
224 pub alg: String,
225 #[serde(rename = "use")]
226 pub use_: String,
227}