pulseengine_mcp_security_middleware/
auth.rs

1//! Authentication and token validation logic
2
3use crate::error::{SecurityError, SecurityResult};
4use crate::utils::{current_timestamp, secure_compare, validate_api_key_format};
5use chrono::{DateTime, Utc};
6use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use uuid::Uuid;
10
11/// Authentication context containing validated user information
12#[derive(Debug, Clone)]
13pub struct AuthContext {
14    /// Unique user identifier
15    pub user_id: String,
16
17    /// User's roles or permissions
18    pub roles: Vec<String>,
19
20    /// API key used for authentication (if applicable)
21    pub api_key: Option<String>,
22
23    /// JWT token claims (if JWT was used)
24    pub jwt_claims: Option<JwtClaims>,
25
26    /// Timestamp when authentication occurred
27    pub authenticated_at: DateTime<Utc>,
28
29    /// Request ID for tracing
30    pub request_id: String,
31
32    /// Additional metadata
33    pub metadata: HashMap<String, String>,
34}
35
36impl AuthContext {
37    /// Create a new authentication context
38    pub fn new(user_id: String) -> Self {
39        Self {
40            user_id,
41            roles: Vec::new(),
42            api_key: None,
43            jwt_claims: None,
44            authenticated_at: Utc::now(),
45            request_id: crate::utils::generate_request_id(),
46            metadata: HashMap::new(),
47        }
48    }
49
50    /// Add a role to the authentication context
51    pub fn with_role<S: Into<String>>(mut self, role: S) -> Self {
52        self.roles.push(role.into());
53        self
54    }
55
56    /// Add multiple roles to the authentication context
57    pub fn with_roles<I, S>(mut self, roles: I) -> Self
58    where
59        I: IntoIterator<Item = S>,
60        S: Into<String>,
61    {
62        self.roles.extend(roles.into_iter().map(|r| r.into()));
63        self
64    }
65
66    /// Set the API key used for authentication
67    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
68        self.api_key = Some(api_key.into());
69        self
70    }
71
72    /// Set the JWT claims
73    pub fn with_jwt_claims(mut self, claims: JwtClaims) -> Self {
74        self.jwt_claims = Some(claims);
75        self
76    }
77
78    /// Add metadata
79    pub fn with_metadata<K, V>(mut self, key: K, value: V) -> Self
80    where
81        K: Into<String>,
82        V: Into<String>,
83    {
84        self.metadata.insert(key.into(), value.into());
85        self
86    }
87
88    /// Check if the user has a specific role
89    pub fn has_role(&self, role: &str) -> bool {
90        self.roles.contains(&role.to_string())
91    }
92
93    /// Check if the user has any of the specified roles
94    pub fn has_any_role<I>(&self, roles: I) -> bool
95    where
96        I: IntoIterator,
97        I::Item: AsRef<str>,
98    {
99        for role in roles {
100            if self.has_role(role.as_ref()) {
101                return true;
102            }
103        }
104        false
105    }
106}
107
108/// JWT claims structure
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct JwtClaims {
111    /// Subject (user ID)
112    pub sub: String,
113
114    /// Expiration time
115    pub exp: u64,
116
117    /// Issued at time
118    pub iat: u64,
119
120    /// Not before time
121    pub nbf: Option<u64>,
122
123    /// JWT ID
124    pub jti: String,
125
126    /// Issuer
127    pub iss: String,
128
129    /// Audience
130    pub aud: String,
131
132    /// Custom roles
133    pub roles: Option<Vec<String>>,
134
135    /// Custom metadata
136    pub metadata: Option<HashMap<String, serde_json::Value>>,
137}
138
139impl JwtClaims {
140    /// Create new JWT claims
141    pub fn new(user_id: String, issuer: String, audience: String, expires_in_seconds: u64) -> Self {
142        let now = current_timestamp();
143
144        Self {
145            sub: user_id,
146            exp: now + expires_in_seconds,
147            iat: now,
148            nbf: Some(now),
149            jti: Uuid::new_v4().to_string(),
150            iss: issuer,
151            aud: audience,
152            roles: None,
153            metadata: None,
154        }
155    }
156
157    /// Add roles to the claims
158    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
159        self.roles = Some(roles);
160        self
161    }
162
163    /// Check if the token is expired
164    pub fn is_expired(&self) -> bool {
165        current_timestamp() > self.exp
166    }
167}
168
169/// Token validator for JWT tokens
170pub struct TokenValidator {
171    /// JWT decoding key
172    decoding_key: DecodingKey,
173
174    /// JWT validation parameters
175    validation: Validation,
176
177    /// Expected issuer
178    expected_issuer: String,
179
180    /// Expected audience
181    expected_audience: String,
182
183    /// Secret for encoding (stored for token creation)
184    secret: String,
185}
186
187impl std::fmt::Debug for TokenValidator {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_struct("TokenValidator")
190            .field("expected_issuer", &self.expected_issuer)
191            .field("expected_audience", &self.expected_audience)
192            .field("secret", &"[REDACTED]")
193            .finish()
194    }
195}
196
197impl TokenValidator {
198    /// Create a new token validator
199    pub fn new(secret: &str, issuer: String, audience: String) -> Self {
200        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
201
202        let mut validation = Validation::new(Algorithm::HS256);
203        validation.set_issuer(&[&issuer]);
204        validation.set_audience(&[&audience]);
205        validation.validate_exp = true;
206        validation.validate_nbf = true;
207
208        Self {
209            decoding_key,
210            validation,
211            expected_issuer: issuer,
212            expected_audience: audience,
213            secret: secret.to_string(),
214        }
215    }
216
217    /// Validate a JWT token and return the claims
218    pub fn validate_token(&self, token: &str) -> SecurityResult<JwtClaims> {
219        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
220
221        let claims = token_data.claims;
222
223        // Additional validation
224        if claims.is_expired() {
225            return Err(SecurityError::TokenExpired);
226        }
227
228        if claims.iss != self.expected_issuer {
229            return Err(SecurityError::invalid_token("Invalid issuer"));
230        }
231
232        if claims.aud != self.expected_audience {
233            return Err(SecurityError::invalid_token("Invalid audience"));
234        }
235
236        Ok(claims)
237    }
238
239    /// Create a JWT token from claims
240    pub fn create_token(&self, claims: &JwtClaims) -> SecurityResult<String> {
241        let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
242
243        let header = Header::new(Algorithm::HS256);
244
245        encode(&header, claims, &encoding_key).map_err(SecurityError::from)
246    }
247}
248
249/// API key validator
250#[derive(Debug, Clone)]
251pub struct ApiKeyValidator {
252    /// Stored API key hashes
253    api_keys: HashMap<String, String>, // hash -> user_id
254}
255
256impl ApiKeyValidator {
257    /// Create a new API key validator
258    pub fn new() -> Self {
259        Self {
260            api_keys: HashMap::new(),
261        }
262    }
263
264    /// Add an API key (stores the hash)
265    pub fn add_api_key(&mut self, api_key: &str, user_id: String) -> SecurityResult<()> {
266        validate_api_key_format(api_key)?;
267
268        let hash = crate::utils::hash_api_key(api_key);
269        self.api_keys.insert(hash, user_id);
270
271        Ok(())
272    }
273
274    /// Validate an API key and return the user ID
275    pub fn validate_api_key(&self, api_key: &str) -> SecurityResult<String> {
276        validate_api_key_format(api_key)?;
277
278        let hash = crate::utils::hash_api_key(api_key);
279
280        // Use secure comparison to prevent timing attacks
281        for (stored_hash, user_id) in &self.api_keys {
282            if secure_compare(&hash, stored_hash) {
283                return Ok(user_id.clone());
284            }
285        }
286
287        Err(SecurityError::InvalidApiKey)
288    }
289
290    /// Remove an API key
291    pub fn remove_api_key(&mut self, api_key: &str) -> SecurityResult<bool> {
292        validate_api_key_format(api_key)?;
293
294        let hash = crate::utils::hash_api_key(api_key);
295        Ok(self.api_keys.remove(&hash).is_some())
296    }
297
298    /// Get number of stored API keys
299    pub fn len(&self) -> usize {
300        self.api_keys.len()
301    }
302
303    /// Check if no API keys are stored
304    pub fn is_empty(&self) -> bool {
305        self.api_keys.is_empty()
306    }
307}
308
309impl Default for ApiKeyValidator {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_auth_context_creation() {
321        let ctx = AuthContext::new("user123".to_string())
322            .with_role("admin")
323            .with_roles(vec!["user", "moderator"])
324            .with_metadata("key", "value");
325
326        assert_eq!(ctx.user_id, "user123");
327        assert!(ctx.has_role("admin"));
328        assert!(ctx.has_role("user"));
329        assert!(ctx.has_role("moderator"));
330        assert!(!ctx.has_role("guest"));
331        assert!(ctx.has_any_role(&["admin", "guest"]));
332        assert!(!ctx.has_any_role(&["guest", "visitor"]));
333        assert_eq!(ctx.metadata.get("key"), Some(&"value".to_string()));
334    }
335
336    #[test]
337    fn test_jwt_claims() {
338        let claims = JwtClaims::new(
339            "user123".to_string(),
340            "test-issuer".to_string(),
341            "test-audience".to_string(),
342            3600,
343        )
344        .with_roles(vec!["admin".to_string()]);
345
346        assert_eq!(claims.sub, "user123");
347        assert_eq!(claims.iss, "test-issuer");
348        assert_eq!(claims.aud, "test-audience");
349        assert!(!claims.is_expired());
350        assert_eq!(claims.roles, Some(vec!["admin".to_string()]));
351    }
352
353    #[test]
354    fn test_api_key_validator() {
355        let mut validator = ApiKeyValidator::new();
356        let api_key = crate::utils::generate_api_key();
357
358        // Add API key
359        validator
360            .add_api_key(&api_key, "user123".to_string())
361            .unwrap();
362        assert_eq!(validator.len(), 1);
363
364        // Validate API key
365        let user_id = validator.validate_api_key(&api_key).unwrap();
366        assert_eq!(user_id, "user123");
367
368        // Invalid API key should fail
369        let invalid_key = crate::utils::generate_api_key();
370        assert!(validator.validate_api_key(&invalid_key).is_err());
371
372        // Remove API key
373        assert!(validator.remove_api_key(&api_key).unwrap());
374        assert_eq!(validator.len(), 0);
375        assert!(validator.is_empty());
376    }
377
378    #[test]
379    fn test_token_validator() {
380        let validator = TokenValidator::new(
381            "test-secret",
382            "test-issuer".to_string(),
383            "test-audience".to_string(),
384        );
385
386        let claims = JwtClaims::new(
387            "user123".to_string(),
388            "test-issuer".to_string(),
389            "test-audience".to_string(),
390            3600,
391        );
392
393        // Create and validate token
394        let token = validator.create_token(&claims).unwrap();
395        let validated_claims = validator.validate_token(&token).unwrap();
396
397        assert_eq!(validated_claims.sub, "user123");
398        assert_eq!(validated_claims.iss, "test-issuer");
399        assert_eq!(validated_claims.aud, "test-audience");
400    }
401
402    #[test]
403    fn test_auth_context_additional_methods() {
404        // Test with_roles method
405        let context = AuthContext::new("user123".to_string())
406            .with_roles(vec!["admin".to_string(), "user".to_string()]);
407
408        assert!(context.has_role("admin"));
409        assert!(context.has_role("user"));
410        assert!(!context.has_role("guest"));
411
412        // Test has_any_role
413        assert!(context.has_any_role(["admin", "guest"]));
414        assert!(context.has_any_role(["user", "guest"]));
415        assert!(!context.has_any_role(["guest", "moderator"]));
416
417        // Test with_metadata
418        let context = AuthContext::new("user123".to_string())
419            .with_metadata("department", "engineering")
420            .with_metadata("level", "senior");
421
422        assert_eq!(context.metadata.get("department").unwrap(), "engineering");
423        assert_eq!(context.metadata.get("level").unwrap(), "senior");
424    }
425
426    #[test]
427    fn test_jwt_claims_expiration() {
428        use std::time::{SystemTime, UNIX_EPOCH};
429
430        // Test valid claim (far in future)
431        let claims = JwtClaims::new(
432            "user123".to_string(),
433            "test_issuer".to_string(),
434            "test_audience".to_string(),
435            3600, // 1 hour from now
436        );
437        assert!(!claims.is_expired());
438
439        // Test expired claim by modifying exp to past
440        let mut expired_claims = claims;
441        expired_claims.exp = SystemTime::now()
442            .duration_since(UNIX_EPOCH)
443            .unwrap()
444            .as_secs()
445            - 3600; // 1 hour ago
446        assert!(expired_claims.is_expired());
447    }
448
449    #[test]
450    fn test_token_validator_edge_cases() {
451        use crate::utils::generate_jwt_secret;
452
453        let secret = generate_jwt_secret();
454        let validator = TokenValidator::new(
455            &secret,
456            "test_issuer".to_string(),
457            "test_audience".to_string(),
458        );
459
460        // Test invalid token format
461        assert!(validator.validate_token("invalid.token").is_err());
462        assert!(validator.validate_token("").is_err());
463        assert!(validator.validate_token("not_a_jwt").is_err());
464
465        // Test with valid claims first
466        let valid_claims = JwtClaims::new(
467            "user123".to_string(),
468            "test_issuer".to_string(),
469            "test_audience".to_string(),
470            3600, // 1 hour from now
471        );
472
473        let token = validator.create_token(&valid_claims).unwrap();
474        assert!(validator.validate_token(&token).is_ok());
475
476        // Test token with wrong issuer
477        let wrong_issuer = TokenValidator::new(
478            &secret,
479            "wrong_issuer".to_string(),
480            "test_audience".to_string(),
481        );
482        assert!(wrong_issuer.validate_token(&token).is_err());
483    }
484}