pulseengine_mcp_auth/
jwt.rs

1//! JWT token-based authentication
2//!
3//! This module provides secure JWT token generation and validation
4//! for stateless authentication, complementing the API key system.
5
6use chrono::{Duration, Utc};
7use jsonwebtoken::{
8    Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12use thiserror::Error;
13
14use crate::models::{AuthContext, Role};
15
16/// JWT token errors
17#[derive(Debug, Error)]
18pub enum JwtError {
19    #[error("Token generation failed: {0}")]
20    Generation(String),
21
22    #[error("Token validation failed: {0}")]
23    Validation(String),
24
25    #[error("Token expired")]
26    Expired,
27
28    #[error("Invalid token format")]
29    InvalidFormat,
30
31    #[error("Missing claims: {0}")]
32    MissingClaims(String),
33
34    #[error("Insufficient permissions")]
35    InsufficientPermissions,
36}
37
38/// JWT token claims following RFC 7519
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TokenClaims {
41    /// Issuer (iss) - who issued the token
42    pub iss: String,
43
44    /// Subject (sub) - the user/key this token represents
45    pub sub: String,
46
47    /// Audience (aud) - intended recipients
48    pub aud: Vec<String>,
49
50    /// Expiration time (exp) - when token expires (Unix timestamp)
51    pub exp: i64,
52
53    /// Not before (nbf) - token not valid before this time
54    pub nbf: i64,
55
56    /// Issued at (iat) - when token was issued
57    pub iat: i64,
58
59    /// JWT ID (jti) - unique identifier for this token
60    pub jti: String,
61
62    // Custom claims for MCP authentication
63    /// User roles
64    pub roles: Vec<Role>,
65
66    /// API key ID this token was derived from
67    pub key_id: Option<String>,
68
69    /// Client IP address
70    pub client_ip: Option<String>,
71
72    /// Session ID for correlation
73    pub session_id: Option<String>,
74
75    /// Scope - what this token can access
76    pub scope: Vec<String>,
77
78    /// Token type (access, refresh, etc.)
79    pub token_type: TokenType,
80}
81
82/// Token types for different use cases
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
84#[serde(rename_all = "snake_case")]
85pub enum TokenType {
86    /// Short-lived access token
87    Access,
88    /// Long-lived refresh token
89    Refresh,
90    /// One-time use authorization token
91    Authorization,
92}
93
94/// JWT configuration
95#[derive(Debug, Clone)]
96pub struct JwtConfig {
97    /// Issuer name
98    pub issuer: String,
99
100    /// Default audience
101    pub audience: Vec<String>,
102
103    /// Signing algorithm
104    pub algorithm: Algorithm,
105
106    /// Signing secret (HMAC) or private key (RSA/ECDSA)
107    pub signing_secret: Vec<u8>,
108
109    /// Access token lifetime
110    pub access_token_lifetime: Duration,
111
112    /// Refresh token lifetime
113    pub refresh_token_lifetime: Duration,
114
115    /// Enable token blacklisting
116    pub enable_blacklist: bool,
117}
118
119impl Default for JwtConfig {
120    fn default() -> Self {
121        Self {
122            issuer: "pulseengine-mcp-auth".to_string(),
123            audience: vec!["mcp-server".to_string()],
124            algorithm: Algorithm::HS256,
125            signing_secret: b"default-secret-change-in-production".to_vec(),
126            access_token_lifetime: Duration::hours(1),
127            refresh_token_lifetime: Duration::days(7),
128            enable_blacklist: true,
129        }
130    }
131}
132
133/// JWT token manager
134pub struct JwtManager {
135    config: JwtConfig,
136    encoding_key: EncodingKey,
137    decoding_key: DecodingKey,
138    validation: Validation,
139    /// Blacklisted token JTIs
140    blacklist: tokio::sync::RwLock<HashSet<String>>,
141}
142
143impl JwtManager {
144    /// Create a new JWT manager
145    pub fn new(config: JwtConfig) -> Result<Self, JwtError> {
146        let encoding_key = match config.algorithm {
147            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
148                EncodingKey::from_secret(&config.signing_secret)
149            }
150            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
151                EncodingKey::from_rsa_pem(&config.signing_secret)
152                    .map_err(|e| JwtError::Generation(format!("Invalid RSA private key: {}", e)))?
153            }
154            Algorithm::ES256 | Algorithm::ES384 => EncodingKey::from_ec_pem(&config.signing_secret)
155                .map_err(|e| JwtError::Generation(format!("Invalid EC private key: {}", e)))?,
156            _ => return Err(JwtError::Generation("Unsupported algorithm".to_string())),
157        };
158
159        let decoding_key = match config.algorithm {
160            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
161                DecodingKey::from_secret(&config.signing_secret)
162            }
163            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
164                DecodingKey::from_rsa_pem(&config.signing_secret)
165                    .map_err(|e| JwtError::Validation(format!("Invalid RSA public key: {}", e)))?
166            }
167            Algorithm::ES256 | Algorithm::ES384 => DecodingKey::from_ec_pem(&config.signing_secret)
168                .map_err(|e| JwtError::Validation(format!("Invalid EC public key: {}", e)))?,
169            _ => return Err(JwtError::Validation("Unsupported algorithm".to_string())),
170        };
171
172        let mut validation = Validation::new(config.algorithm);
173        validation.set_audience(&config.audience);
174        validation.set_issuer(&[&config.issuer]);
175        validation.validate_exp = true;
176        validation.validate_nbf = true;
177
178        Ok(Self {
179            config,
180            encoding_key,
181            decoding_key,
182            validation,
183            blacklist: tokio::sync::RwLock::new(HashSet::new()),
184        })
185    }
186
187    /// Generate an access token
188    pub async fn generate_access_token(
189        &self,
190        subject: String,
191        roles: Vec<Role>,
192        key_id: Option<String>,
193        client_ip: Option<String>,
194        session_id: Option<String>,
195        scope: Vec<String>,
196    ) -> Result<String, JwtError> {
197        let now = Utc::now();
198        let exp = now + self.config.access_token_lifetime;
199
200        let claims = TokenClaims {
201            iss: self.config.issuer.clone(),
202            sub: subject,
203            aud: self.config.audience.clone(),
204            exp: exp.timestamp(),
205            nbf: now.timestamp(),
206            iat: now.timestamp(),
207            jti: uuid::Uuid::new_v4().to_string(),
208            roles,
209            key_id,
210            client_ip,
211            session_id,
212            scope,
213            token_type: TokenType::Access,
214        };
215
216        let header = Header::new(self.config.algorithm);
217        encode(&header, &claims, &self.encoding_key)
218            .map_err(|e| JwtError::Generation(e.to_string()))
219    }
220
221    /// Generate a refresh token
222    pub async fn generate_refresh_token(
223        &self,
224        subject: String,
225        key_id: Option<String>,
226        session_id: Option<String>,
227    ) -> Result<String, JwtError> {
228        let now = Utc::now();
229        let exp = now + self.config.refresh_token_lifetime;
230
231        let claims = TokenClaims {
232            iss: self.config.issuer.clone(),
233            sub: subject,
234            aud: self.config.audience.clone(),
235            exp: exp.timestamp(),
236            nbf: now.timestamp(),
237            iat: now.timestamp(),
238            jti: uuid::Uuid::new_v4().to_string(),
239            roles: vec![], // Refresh tokens don't carry roles
240            key_id,
241            client_ip: None,
242            session_id,
243            scope: vec!["refresh".to_string()],
244            token_type: TokenType::Refresh,
245        };
246
247        let header = Header::new(self.config.algorithm);
248        encode(&header, &claims, &self.encoding_key)
249            .map_err(|e| JwtError::Generation(e.to_string()))
250    }
251
252    /// Validate and decode a token
253    pub async fn validate_token(&self, token: &str) -> Result<TokenData<TokenClaims>, JwtError> {
254        let token_data = decode::<TokenClaims>(token, &self.decoding_key, &self.validation)
255            .map_err(|e| match e.kind() {
256                jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtError::Expired,
257                jsonwebtoken::errors::ErrorKind::InvalidToken => JwtError::InvalidFormat,
258                _ => JwtError::Validation(e.to_string()),
259            })?;
260
261        // Check if token is blacklisted
262        if self.config.enable_blacklist {
263            let blacklist = self.blacklist.read().await;
264            if blacklist.contains(&token_data.claims.jti) {
265                return Err(JwtError::Validation("Token has been revoked".to_string()));
266            }
267        }
268
269        Ok(token_data)
270    }
271
272    /// Extract auth context from a valid token
273    pub async fn token_to_auth_context(&self, token: &str) -> Result<AuthContext, JwtError> {
274        let token_data = self.validate_token(token).await?;
275        let claims = token_data.claims;
276
277        // Only access tokens can be used for authentication
278        if claims.token_type != TokenType::Access {
279            return Err(JwtError::Validation(
280                "Only access tokens can be used for authentication".to_string(),
281            ));
282        }
283
284        // Extract permissions from roles
285        let permissions: Vec<String> = claims
286            .roles
287            .iter()
288            .flat_map(|role| self.get_permissions_for_role(role))
289            .collect();
290
291        Ok(AuthContext {
292            user_id: Some(claims.sub),
293            roles: claims.roles,
294            api_key_id: claims.key_id,
295            permissions,
296        })
297    }
298
299    /// Refresh an access token using a refresh token
300    pub async fn refresh_access_token(
301        &self,
302        refresh_token: &str,
303        new_roles: Vec<Role>,
304        client_ip: Option<String>,
305        scope: Vec<String>,
306    ) -> Result<String, JwtError> {
307        let token_data = self.validate_token(refresh_token).await?;
308        let claims = token_data.claims;
309
310        // Verify this is a refresh token
311        if claims.token_type != TokenType::Refresh {
312            return Err(JwtError::Validation(
313                "Invalid token type for refresh".to_string(),
314            ));
315        }
316
317        // Generate new access token
318        self.generate_access_token(
319            claims.sub,
320            new_roles,
321            claims.key_id,
322            client_ip,
323            claims.session_id,
324            scope,
325        )
326        .await
327    }
328
329    /// Revoke a token by adding it to blacklist
330    pub async fn revoke_token(&self, token: &str) -> Result<(), JwtError> {
331        if !self.config.enable_blacklist {
332            return Err(JwtError::Validation(
333                "Token blacklisting is disabled".to_string(),
334            ));
335        }
336
337        let token_data = self.validate_token(token).await?;
338        let mut blacklist = self.blacklist.write().await;
339        blacklist.insert(token_data.claims.jti);
340
341        Ok(())
342    }
343
344    /// Clean up expired tokens from blacklist
345    pub async fn cleanup_blacklist(&self) -> usize {
346        if !self.config.enable_blacklist {
347            return 0;
348        }
349
350        let mut blacklist = self.blacklist.write().await;
351        let initial_size = blacklist.len();
352
353        // For now, just clear all (in production, you'd track expiration times)
354        // This is a simplified implementation
355        blacklist.clear();
356
357        initial_size
358    }
359
360    /// Get permissions for a role (helper method)
361    fn get_permissions_for_role(&self, role: &Role) -> Vec<String> {
362        match role {
363            Role::Admin => vec![
364                "admin.*".to_string(),
365                "key.*".to_string(),
366                "user.*".to_string(),
367                "system.*".to_string(),
368            ],
369            Role::Operator => vec![
370                "device.*".to_string(),
371                "monitor.*".to_string(),
372                "key.create".to_string(),
373                "key.list".to_string(),
374            ],
375            Role::Monitor => vec![
376                "monitor.*".to_string(),
377                "health.check".to_string(),
378                "status.read".to_string(),
379            ],
380            Role::Device { allowed_devices } => allowed_devices
381                .iter()
382                .map(|device| format!("device.{}", device))
383                .collect(),
384            Role::Custom { permissions } => permissions.clone(),
385        }
386    }
387
388    /// Get token info without validating signature (for debugging)
389    pub fn decode_token_info(&self, token: &str) -> Result<TokenClaims, JwtError> {
390        let mut validation = Validation::new(self.config.algorithm);
391        validation.validate_exp = false;
392        validation.validate_nbf = false;
393        validation.validate_aud = false;
394        validation.insecure_disable_signature_validation();
395
396        let token_data = decode::<TokenClaims>(token, &self.decoding_key, &validation)
397            .map_err(|_| JwtError::InvalidFormat)?;
398
399        Ok(token_data.claims)
400    }
401}
402
403/// JWT token pair (access + refresh)
404#[derive(Debug, Clone, Serialize, Deserialize)]
405pub struct TokenPair {
406    /// Short-lived access token
407    pub access_token: String,
408    /// Long-lived refresh token
409    pub refresh_token: String,
410    /// Access token type (always "Bearer")
411    pub token_type: String,
412    /// Access token expires in (seconds)
413    pub expires_in: i64,
414    /// Scope of the access token
415    pub scope: Vec<String>,
416}
417
418impl JwtManager {
419    /// Generate a complete token pair
420    pub async fn generate_token_pair(
421        &self,
422        subject: String,
423        roles: Vec<Role>,
424        key_id: Option<String>,
425        client_ip: Option<String>,
426        session_id: Option<String>,
427        scope: Vec<String>,
428    ) -> Result<TokenPair, JwtError> {
429        let access_token = self
430            .generate_access_token(
431                subject.clone(),
432                roles,
433                key_id.clone(),
434                client_ip,
435                session_id.clone(),
436                scope.clone(),
437            )
438            .await?;
439
440        let refresh_token = self
441            .generate_refresh_token(subject, key_id, session_id)
442            .await?;
443
444        Ok(TokenPair {
445            access_token,
446            refresh_token,
447            token_type: "Bearer".to_string(),
448            expires_in: self.config.access_token_lifetime.num_seconds(),
449            scope,
450        })
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[tokio::test]
459    async fn test_jwt_token_generation_and_validation() {
460        let config = JwtConfig::default();
461        let jwt_manager = JwtManager::new(config).unwrap();
462
463        let roles = vec![Role::Admin];
464        let subject = "test-user".to_string();
465        let scope = vec!["read".to_string(), "write".to_string()];
466
467        // Generate access token
468        let token = jwt_manager
469            .generate_access_token(
470                subject.clone(),
471                roles.clone(),
472                Some("key123".to_string()),
473                Some("192.168.1.1".to_string()),
474                Some("session123".to_string()),
475                scope.clone(),
476            )
477            .await
478            .unwrap();
479
480        // Validate token
481        let token_data = jwt_manager.validate_token(&token).await.unwrap();
482        assert_eq!(token_data.claims.sub, subject);
483        assert_eq!(token_data.claims.roles, roles);
484        assert_eq!(token_data.claims.token_type, TokenType::Access);
485    }
486
487    #[tokio::test]
488    async fn test_jwt_token_pair() {
489        let config = JwtConfig::default();
490        let jwt_manager = JwtManager::new(config).unwrap();
491
492        let roles = vec![Role::Monitor];
493        let subject = "test-user".to_string();
494        let scope = vec!["monitor".to_string()];
495
496        // Generate token pair
497        let token_pair = jwt_manager
498            .generate_token_pair(subject.clone(), roles, None, None, None, scope.clone())
499            .await
500            .unwrap();
501
502        // Validate access token
503        let access_data = jwt_manager
504            .validate_token(&token_pair.access_token)
505            .await
506            .unwrap();
507        assert_eq!(access_data.claims.token_type, TokenType::Access);
508
509        // Validate refresh token
510        let refresh_data = jwt_manager
511            .validate_token(&token_pair.refresh_token)
512            .await
513            .unwrap();
514        assert_eq!(refresh_data.claims.token_type, TokenType::Refresh);
515
516        assert_eq!(token_pair.token_type, "Bearer");
517        assert_eq!(token_pair.scope, scope);
518    }
519
520    #[tokio::test]
521    async fn test_jwt_token_revocation() {
522        let config = JwtConfig::default();
523        let jwt_manager = JwtManager::new(config).unwrap();
524
525        let token = jwt_manager
526            .generate_access_token(
527                "test-user".to_string(),
528                vec![Role::Admin],
529                None,
530                None,
531                None,
532                vec!["test".to_string()],
533            )
534            .await
535            .unwrap();
536
537        // Token should be valid initially
538        assert!(jwt_manager.validate_token(&token).await.is_ok());
539
540        // Revoke token
541        jwt_manager.revoke_token(&token).await.unwrap();
542
543        // Token should now be invalid
544        assert!(jwt_manager.validate_token(&token).await.is_err());
545    }
546
547    #[tokio::test]
548    async fn test_auth_context_extraction() {
549        let config = JwtConfig::default();
550        let jwt_manager = JwtManager::new(config).unwrap();
551
552        let roles = vec![Role::Admin, Role::Monitor];
553        let token = jwt_manager
554            .generate_access_token(
555                "test-user".to_string(),
556                roles.clone(),
557                Some("key123".to_string()),
558                None,
559                None,
560                vec!["admin".to_string()],
561            )
562            .await
563            .unwrap();
564
565        let auth_context = jwt_manager.token_to_auth_context(&token).await.unwrap();
566
567        assert_eq!(auth_context.user_id, Some("test-user".to_string()));
568        assert_eq!(auth_context.roles, roles);
569        assert_eq!(auth_context.api_key_id, Some("key123".to_string()));
570        assert!(!auth_context.permissions.is_empty());
571    }
572}