1use chrono::{Duration, Utc};
7use jsonwebtoken::{
8 decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use thiserror::Error;
13
14use crate::models::{AuthContext, Role};
15
16#[derive(Debug, Error)]
18pub enum JwtError {
19 #[error("Token generation failed: {0}")]
20 Generation(String),
21
22 #[error("Token validation failed: {0}")]
23 Validation(String),
24
25 #[error("Token expired")]
26 Expired,
27
28 #[error("Invalid token format")]
29 InvalidFormat,
30
31 #[error("Missing claims: {0}")]
32 MissingClaims(String),
33
34 #[error("Insufficient permissions")]
35 InsufficientPermissions,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TokenClaims {
41 pub iss: String,
43
44 pub sub: String,
46
47 pub aud: Vec<String>,
49
50 pub exp: i64,
52
53 pub nbf: i64,
55
56 pub iat: i64,
58
59 pub jti: String,
61
62 pub roles: Vec<Role>,
65
66 pub key_id: Option<String>,
68
69 pub client_ip: Option<String>,
71
72 pub session_id: Option<String>,
74
75 pub scope: Vec<String>,
77
78 pub token_type: TokenType,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
84#[serde(rename_all = "snake_case")]
85pub enum TokenType {
86 Access,
88 Refresh,
90 Authorization,
92}
93
94#[derive(Debug, Clone)]
96pub struct JwtConfig {
97 pub issuer: String,
99
100 pub audience: Vec<String>,
102
103 pub algorithm: Algorithm,
105
106 pub signing_secret: Vec<u8>,
108
109 pub access_token_lifetime: Duration,
111
112 pub refresh_token_lifetime: Duration,
114
115 pub enable_blacklist: bool,
117}
118
119impl Default for JwtConfig {
120 fn default() -> Self {
121 Self {
122 issuer: "pulseengine-mcp-auth".to_string(),
123 audience: vec!["mcp-server".to_string()],
124 algorithm: Algorithm::HS256,
125 signing_secret: b"default-secret-change-in-production".to_vec(),
126 access_token_lifetime: Duration::hours(1),
127 refresh_token_lifetime: Duration::days(7),
128 enable_blacklist: true,
129 }
130 }
131}
132
133pub struct JwtManager {
135 config: JwtConfig,
136 encoding_key: EncodingKey,
137 decoding_key: DecodingKey,
138 validation: Validation,
139 blacklist: tokio::sync::RwLock<HashSet<String>>,
141}
142
143impl JwtManager {
144 pub fn new(config: JwtConfig) -> Result<Self, JwtError> {
146 let encoding_key = match config.algorithm {
147 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
148 EncodingKey::from_secret(&config.signing_secret)
149 }
150 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
151 EncodingKey::from_rsa_pem(&config.signing_secret)
152 .map_err(|e| JwtError::Generation(format!("Invalid RSA private key: {}", e)))?
153 }
154 Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(&config.signing_secret)
155 .map_err(|e| JwtError::Generation(format!("Invalid EC private key: {}", e)))?,
156 _ => return Err(JwtError::Generation("Unsupported algorithm".to_string())),
157 };
158
159 let decoding_key = match config.algorithm {
160 Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
161 DecodingKey::from_secret(&config.signing_secret)
162 }
163 Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
164 DecodingKey::from_rsa_pem(&config.signing_secret)
165 .map_err(|e| JwtError::Validation(format!("Invalid RSA public key: {}", e)))?
166 }
167 Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(&config.signing_secret)
168 .map_err(|e| JwtError::Validation(format!("Invalid EC public key: {}", e)))?,
169 _ => return Err(JwtError::Validation("Unsupported algorithm".to_string())),
170 };
171
172 let mut validation = Validation::new(config.algorithm);
173 validation.set_audience(&config.audience);
174 validation.set_issuer(&[&config.issuer]);
175 validation.validate_exp = true;
176 validation.validate_nbf = true;
177
178 Ok(Self {
179 config,
180 encoding_key,
181 decoding_key,
182 validation,
183 blacklist: tokio::sync::RwLock::new(HashSet::new()),
184 })
185 }
186
187 pub async fn generate_access_token(
189 &self,
190 subject: String,
191 roles: Vec<Role>,
192 key_id: Option<String>,
193 client_ip: Option<String>,
194 session_id: Option<String>,
195 scope: Vec<String>,
196 ) -> Result<String, JwtError> {
197 let now = Utc::now();
198 let exp = now + self.config.access_token_lifetime;
199
200 let claims = TokenClaims {
201 iss: self.config.issuer.clone(),
202 sub: subject,
203 aud: self.config.audience.clone(),
204 exp: exp.timestamp(),
205 nbf: now.timestamp(),
206 iat: now.timestamp(),
207 jti: uuid::Uuid::new_v4().to_string(),
208 roles,
209 key_id,
210 client_ip,
211 session_id,
212 scope,
213 token_type: TokenType::Access,
214 };
215
216 let header = Header::new(self.config.algorithm);
217 encode(&header, &claims, &self.encoding_key)
218 .map_err(|e| JwtError::Generation(e.to_string()))
219 }
220
221 pub async fn generate_refresh_token(
223 &self,
224 subject: String,
225 key_id: Option<String>,
226 session_id: Option<String>,
227 ) -> Result<String, JwtError> {
228 let now = Utc::now();
229 let exp = now + self.config.refresh_token_lifetime;
230
231 let claims = TokenClaims {
232 iss: self.config.issuer.clone(),
233 sub: subject,
234 aud: self.config.audience.clone(),
235 exp: exp.timestamp(),
236 nbf: now.timestamp(),
237 iat: now.timestamp(),
238 jti: uuid::Uuid::new_v4().to_string(),
239 roles: vec![], key_id,
241 client_ip: None,
242 session_id,
243 scope: vec!["refresh".to_string()],
244 token_type: TokenType::Refresh,
245 };
246
247 let header = Header::new(self.config.algorithm);
248 encode(&header, &claims, &self.encoding_key)
249 .map_err(|e| JwtError::Generation(e.to_string()))
250 }
251
252 pub async fn validate_token(&self, token: &str) -> Result<TokenData<TokenClaims>, JwtError> {
254 let token_data = decode::<TokenClaims>(token, &self.decoding_key, &self.validation)
255 .map_err(|e| match e.kind() {
256 jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired,
257 jsonwebtoken::errors::ErrorKind::InvalidToken => JwtError::InvalidFormat,
258 _ => JwtError::Validation(e.to_string()),
259 })?;
260
261 if self.config.enable_blacklist {
263 let blacklist = self.blacklist.read().await;
264 if blacklist.contains(&token_data.claims.jti) {
265 return Err(JwtError::Validation("Token has been revoked".to_string()));
266 }
267 }
268
269 Ok(token_data)
270 }
271
272 pub async fn token_to_auth_context(&self, token: &str) -> Result<AuthContext, JwtError> {
274 let token_data = self.validate_token(token).await?;
275 let claims = token_data.claims;
276
277 if claims.token_type != TokenType::Access {
279 return Err(JwtError::Validation(
280 "Only access tokens can be used for authentication".to_string(),
281 ));
282 }
283
284 let permissions: Vec<String> = claims
286 .roles
287 .iter()
288 .flat_map(|role| self.get_permissions_for_role(role))
289 .collect();
290
291 Ok(AuthContext {
292 user_id: Some(claims.sub),
293 roles: claims.roles,
294 api_key_id: claims.key_id,
295 permissions,
296 })
297 }
298
299 pub async fn refresh_access_token(
301 &self,
302 refresh_token: &str,
303 new_roles: Vec<Role>,
304 client_ip: Option<String>,
305 scope: Vec<String>,
306 ) -> Result<String, JwtError> {
307 let token_data = self.validate_token(refresh_token).await?;
308 let claims = token_data.claims;
309
310 if claims.token_type != TokenType::Refresh {
312 return Err(JwtError::Validation(
313 "Invalid token type for refresh".to_string(),
314 ));
315 }
316
317 self.generate_access_token(
319 claims.sub,
320 new_roles,
321 claims.key_id,
322 client_ip,
323 claims.session_id,
324 scope,
325 )
326 .await
327 }
328
329 pub async fn revoke_token(&self, token: &str) -> Result<(), JwtError> {
331 if !self.config.enable_blacklist {
332 return Err(JwtError::Validation(
333 "Token blacklisting is disabled".to_string(),
334 ));
335 }
336
337 let token_data = self.validate_token(token).await?;
338 let mut blacklist = self.blacklist.write().await;
339 blacklist.insert(token_data.claims.jti);
340
341 Ok(())
342 }
343
344 pub async fn cleanup_blacklist(&self) -> usize {
346 if !self.config.enable_blacklist {
347 return 0;
348 }
349
350 let mut blacklist = self.blacklist.write().await;
351 let initial_size = blacklist.len();
352
353 blacklist.clear();
356
357 initial_size
358 }
359
360 fn get_permissions_for_role(&self, role: &Role) -> Vec<String> {
362 match role {
363 Role::Admin => vec![
364 "admin.*".to_string(),
365 "key.*".to_string(),
366 "user.*".to_string(),
367 "system.*".to_string(),
368 ],
369 Role::Operator => vec![
370 "device.*".to_string(),
371 "monitor.*".to_string(),
372 "key.create".to_string(),
373 "key.list".to_string(),
374 ],
375 Role::Monitor => vec![
376 "monitor.*".to_string(),
377 "health.check".to_string(),
378 "status.read".to_string(),
379 ],
380 Role::Device { allowed_devices } => allowed_devices
381 .iter()
382 .map(|device| format!("device.{}", device))
383 .collect(),
384 Role::Custom { permissions } => permissions.clone(),
385 }
386 }
387
388 pub fn decode_token_info(&self, token: &str) -> Result<TokenClaims, JwtError> {
390 let mut validation = Validation::new(self.config.algorithm);
391 validation.validate_exp = false;
392 validation.validate_nbf = false;
393 validation.validate_aud = false;
394 validation.insecure_disable_signature_validation();
395
396 let token_data = decode::<TokenClaims>(token, &self.decoding_key, &validation)
397 .map_err(|_| JwtError::InvalidFormat)?;
398
399 Ok(token_data.claims)
400 }
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct TokenPair {
406 pub access_token: String,
408 pub refresh_token: String,
410 pub token_type: String,
412 pub expires_in: i64,
414 pub scope: Vec<String>,
416}
417
418impl JwtManager {
419 pub async fn generate_token_pair(
421 &self,
422 subject: String,
423 roles: Vec<Role>,
424 key_id: Option<String>,
425 client_ip: Option<String>,
426 session_id: Option<String>,
427 scope: Vec<String>,
428 ) -> Result<TokenPair, JwtError> {
429 let access_token = self
430 .generate_access_token(
431 subject.clone(),
432 roles,
433 key_id.clone(),
434 client_ip,
435 session_id.clone(),
436 scope.clone(),
437 )
438 .await?;
439
440 let refresh_token = self
441 .generate_refresh_token(subject, key_id, session_id)
442 .await?;
443
444 Ok(TokenPair {
445 access_token,
446 refresh_token,
447 token_type: "Bearer".to_string(),
448 expires_in: self.config.access_token_lifetime.num_seconds(),
449 scope,
450 })
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[tokio::test]
459 async fn test_jwt_token_generation_and_validation() {
460 let config = JwtConfig::default();
461 let jwt_manager = JwtManager::new(config).unwrap();
462
463 let roles = vec![Role::Admin];
464 let subject = "test-user".to_string();
465 let scope = vec!["read".to_string(), "write".to_string()];
466
467 let token = jwt_manager
469 .generate_access_token(
470 subject.clone(),
471 roles.clone(),
472 Some("key123".to_string()),
473 Some("192.168.1.1".to_string()),
474 Some("session123".to_string()),
475 scope.clone(),
476 )
477 .await
478 .unwrap();
479
480 let token_data = jwt_manager.validate_token(&token).await.unwrap();
482 assert_eq!(token_data.claims.sub, subject);
483 assert_eq!(token_data.claims.roles, roles);
484 assert_eq!(token_data.claims.token_type, TokenType::Access);
485 }
486
487 #[tokio::test]
488 async fn test_jwt_token_pair() {
489 let config = JwtConfig::default();
490 let jwt_manager = JwtManager::new(config).unwrap();
491
492 let roles = vec![Role::Monitor];
493 let subject = "test-user".to_string();
494 let scope = vec!["monitor".to_string()];
495
496 let token_pair = jwt_manager
498 .generate_token_pair(subject.clone(), roles, None, None, None, scope.clone())
499 .await
500 .unwrap();
501
502 let access_data = jwt_manager
504 .validate_token(&token_pair.access_token)
505 .await
506 .unwrap();
507 assert_eq!(access_data.claims.token_type, TokenType::Access);
508
509 let refresh_data = jwt_manager
511 .validate_token(&token_pair.refresh_token)
512 .await
513 .unwrap();
514 assert_eq!(refresh_data.claims.token_type, TokenType::Refresh);
515
516 assert_eq!(token_pair.token_type, "Bearer");
517 assert_eq!(token_pair.scope, scope);
518 }
519
520 #[tokio::test]
521 async fn test_jwt_token_revocation() {
522 let config = JwtConfig::default();
523 let jwt_manager = JwtManager::new(config).unwrap();
524
525 let token = jwt_manager
526 .generate_access_token(
527 "test-user".to_string(),
528 vec![Role::Admin],
529 None,
530 None,
531 None,
532 vec!["test".to_string()],
533 )
534 .await
535 .unwrap();
536
537 assert!(jwt_manager.validate_token(&token).await.is_ok());
539
540 jwt_manager.revoke_token(&token).await.unwrap();
542
543 assert!(jwt_manager.validate_token(&token).await.is_err());
545 }
546
547 #[tokio::test]
548 async fn test_auth_context_extraction() {
549 let config = JwtConfig::default();
550 let jwt_manager = JwtManager::new(config).unwrap();
551
552 let roles = vec![Role::Admin, Role::Monitor];
553 let token = jwt_manager
554 .generate_access_token(
555 "test-user".to_string(),
556 roles.clone(),
557 Some("key123".to_string()),
558 None,
559 None,
560 vec!["admin".to_string()],
561 )
562 .await
563 .unwrap();
564
565 let auth_context = jwt_manager.token_to_auth_context(&token).await.unwrap();
566
567 assert_eq!(auth_context.user_id, Some("test-user".to_string()));
568 assert_eq!(auth_context.roles, roles);
569 assert_eq!(auth_context.api_key_id, Some("key123".to_string()));
570 assert!(!auth_context.permissions.is_empty());
571 }
572}