pulseengine_mcp_security_middleware/
auth.rs1use crate::error::{SecurityError, SecurityResult};
4use crate::utils::{current_timestamp, secure_compare, validate_api_key_format};
5use chrono::{DateTime, Utc};
6use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct AuthContext {
14    pub user_id: String,
16
17    pub roles: Vec<String>,
19
20    pub api_key: Option<String>,
22
23    pub jwt_claims: Option<JwtClaims>,
25
26    pub authenticated_at: DateTime<Utc>,
28
29    pub request_id: String,
31
32    pub metadata: HashMap<String, String>,
34}
35
36impl AuthContext {
37    pub fn new(user_id: String) -> Self {
39        Self {
40            user_id,
41            roles: Vec::new(),
42            api_key: None,
43            jwt_claims: None,
44            authenticated_at: Utc::now(),
45            request_id: crate::utils::generate_request_id(),
46            metadata: HashMap::new(),
47        }
48    }
49
50    pub fn with_role<S: Into<String>>(mut self, role: S) -> Self {
52        self.roles.push(role.into());
53        self
54    }
55
56    pub fn with_roles<I, S>(mut self, roles: I) -> Self
58    where
59        I: IntoIterator<Item = S>,
60        S: Into<String>,
61    {
62        self.roles.extend(roles.into_iter().map(|r| r.into()));
63        self
64    }
65
66    pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
68        self.api_key = Some(api_key.into());
69        self
70    }
71
72    pub fn with_jwt_claims(mut self, claims: JwtClaims) -> Self {
74        self.jwt_claims = Some(claims);
75        self
76    }
77
78    pub fn with_metadata<K, V>(mut self, key: K, value: V) -> Self
80    where
81        K: Into<String>,
82        V: Into<String>,
83    {
84        self.metadata.insert(key.into(), value.into());
85        self
86    }
87
88    pub fn has_role(&self, role: &str) -> bool {
90        self.roles.contains(&role.to_string())
91    }
92
93    pub fn has_any_role<I>(&self, roles: I) -> bool
95    where
96        I: IntoIterator,
97        I::Item: AsRef<str>,
98    {
99        for role in roles {
100            if self.has_role(role.as_ref()) {
101                return true;
102            }
103        }
104        false
105    }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct JwtClaims {
111    pub sub: String,
113
114    pub exp: u64,
116
117    pub iat: u64,
119
120    pub nbf: Option<u64>,
122
123    pub jti: String,
125
126    pub iss: String,
128
129    pub aud: String,
131
132    pub roles: Option<Vec<String>>,
134
135    pub metadata: Option<HashMap<String, serde_json::Value>>,
137}
138
139impl JwtClaims {
140    pub fn new(user_id: String, issuer: String, audience: String, expires_in_seconds: u64) -> Self {
142        let now = current_timestamp();
143
144        Self {
145            sub: user_id,
146            exp: now + expires_in_seconds,
147            iat: now,
148            nbf: Some(now),
149            jti: Uuid::new_v4().to_string(),
150            iss: issuer,
151            aud: audience,
152            roles: None,
153            metadata: None,
154        }
155    }
156
157    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
159        self.roles = Some(roles);
160        self
161    }
162
163    pub fn is_expired(&self) -> bool {
165        current_timestamp() > self.exp
166    }
167}
168
169pub struct TokenValidator {
171    decoding_key: DecodingKey,
173
174    validation: Validation,
176
177    expected_issuer: String,
179
180    expected_audience: String,
182
183    secret: String,
185}
186
187impl std::fmt::Debug for TokenValidator {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_struct("TokenValidator")
190            .field("expected_issuer", &self.expected_issuer)
191            .field("expected_audience", &self.expected_audience)
192            .field("secret", &"[REDACTED]")
193            .finish()
194    }
195}
196
197impl TokenValidator {
198    pub fn new(secret: &str, issuer: String, audience: String) -> Self {
200        let decoding_key = DecodingKey::from_secret(secret.as_bytes());
201
202        let mut validation = Validation::new(Algorithm::HS256);
203        validation.set_issuer(&[&issuer]);
204        validation.set_audience(&[&audience]);
205        validation.validate_exp = true;
206        validation.validate_nbf = true;
207
208        Self {
209            decoding_key,
210            validation,
211            expected_issuer: issuer,
212            expected_audience: audience,
213            secret: secret.to_string(),
214        }
215    }
216
217    pub fn validate_token(&self, token: &str) -> SecurityResult<JwtClaims> {
219        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
220
221        let claims = token_data.claims;
222
223        if claims.is_expired() {
225            return Err(SecurityError::TokenExpired);
226        }
227
228        if claims.iss != self.expected_issuer {
229            return Err(SecurityError::invalid_token("Invalid issuer"));
230        }
231
232        if claims.aud != self.expected_audience {
233            return Err(SecurityError::invalid_token("Invalid audience"));
234        }
235
236        Ok(claims)
237    }
238
239    pub fn create_token(&self, claims: &JwtClaims) -> SecurityResult<String> {
241        let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
242
243        let header = Header::new(Algorithm::HS256);
244
245        encode(&header, claims, &encoding_key).map_err(SecurityError::from)
246    }
247}
248
249#[derive(Debug, Clone)]
251pub struct ApiKeyValidator {
252    api_keys: HashMap<String, String>, }
255
256impl ApiKeyValidator {
257    pub fn new() -> Self {
259        Self {
260            api_keys: HashMap::new(),
261        }
262    }
263
264    pub fn add_api_key(&mut self, api_key: &str, user_id: String) -> SecurityResult<()> {
266        validate_api_key_format(api_key)?;
267
268        let hash = crate::utils::hash_api_key(api_key);
269        self.api_keys.insert(hash, user_id);
270
271        Ok(())
272    }
273
274    pub fn validate_api_key(&self, api_key: &str) -> SecurityResult<String> {
276        validate_api_key_format(api_key)?;
277
278        let hash = crate::utils::hash_api_key(api_key);
279
280        for (stored_hash, user_id) in &self.api_keys {
282            if secure_compare(&hash, stored_hash) {
283                return Ok(user_id.clone());
284            }
285        }
286
287        Err(SecurityError::InvalidApiKey)
288    }
289
290    pub fn remove_api_key(&mut self, api_key: &str) -> SecurityResult<bool> {
292        validate_api_key_format(api_key)?;
293
294        let hash = crate::utils::hash_api_key(api_key);
295        Ok(self.api_keys.remove(&hash).is_some())
296    }
297
298    pub fn len(&self) -> usize {
300        self.api_keys.len()
301    }
302
303    pub fn is_empty(&self) -> bool {
305        self.api_keys.is_empty()
306    }
307}
308
309impl Default for ApiKeyValidator {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_auth_context_creation() {
321        let ctx = AuthContext::new("user123".to_string())
322            .with_role("admin")
323            .with_roles(vec!["user", "moderator"])
324            .with_metadata("key", "value");
325
326        assert_eq!(ctx.user_id, "user123");
327        assert!(ctx.has_role("admin"));
328        assert!(ctx.has_role("user"));
329        assert!(ctx.has_role("moderator"));
330        assert!(!ctx.has_role("guest"));
331        assert!(ctx.has_any_role(&["admin", "guest"]));
332        assert!(!ctx.has_any_role(&["guest", "visitor"]));
333        assert_eq!(ctx.metadata.get("key"), Some(&"value".to_string()));
334    }
335
336    #[test]
337    fn test_jwt_claims() {
338        let claims = JwtClaims::new(
339            "user123".to_string(),
340            "test-issuer".to_string(),
341            "test-audience".to_string(),
342            3600,
343        )
344        .with_roles(vec!["admin".to_string()]);
345
346        assert_eq!(claims.sub, "user123");
347        assert_eq!(claims.iss, "test-issuer");
348        assert_eq!(claims.aud, "test-audience");
349        assert!(!claims.is_expired());
350        assert_eq!(claims.roles, Some(vec!["admin".to_string()]));
351    }
352
353    #[test]
354    fn test_api_key_validator() {
355        let mut validator = ApiKeyValidator::new();
356        let api_key = crate::utils::generate_api_key();
357
358        validator
360            .add_api_key(&api_key, "user123".to_string())
361            .unwrap();
362        assert_eq!(validator.len(), 1);
363
364        let user_id = validator.validate_api_key(&api_key).unwrap();
366        assert_eq!(user_id, "user123");
367
368        let invalid_key = crate::utils::generate_api_key();
370        assert!(validator.validate_api_key(&invalid_key).is_err());
371
372        assert!(validator.remove_api_key(&api_key).unwrap());
374        assert_eq!(validator.len(), 0);
375        assert!(validator.is_empty());
376    }
377
378    #[test]
379    fn test_token_validator() {
380        let validator = TokenValidator::new(
381            "test-secret",
382            "test-issuer".to_string(),
383            "test-audience".to_string(),
384        );
385
386        let claims = JwtClaims::new(
387            "user123".to_string(),
388            "test-issuer".to_string(),
389            "test-audience".to_string(),
390            3600,
391        );
392
393        let token = validator.create_token(&claims).unwrap();
395        let validated_claims = validator.validate_token(&token).unwrap();
396
397        assert_eq!(validated_claims.sub, "user123");
398        assert_eq!(validated_claims.iss, "test-issuer");
399        assert_eq!(validated_claims.aud, "test-audience");
400    }
401
402    #[test]
403    fn test_auth_context_additional_methods() {
404        let context = AuthContext::new("user123".to_string())
406            .with_roles(vec!["admin".to_string(), "user".to_string()]);
407
408        assert!(context.has_role("admin"));
409        assert!(context.has_role("user"));
410        assert!(!context.has_role("guest"));
411
412        assert!(context.has_any_role(["admin", "guest"]));
414        assert!(context.has_any_role(["user", "guest"]));
415        assert!(!context.has_any_role(["guest", "moderator"]));
416
417        let context = AuthContext::new("user123".to_string())
419            .with_metadata("department", "engineering")
420            .with_metadata("level", "senior");
421
422        assert_eq!(context.metadata.get("department").unwrap(), "engineering");
423        assert_eq!(context.metadata.get("level").unwrap(), "senior");
424    }
425
426    #[test]
427    fn test_jwt_claims_expiration() {
428        use std::time::{SystemTime, UNIX_EPOCH};
429
430        let claims = JwtClaims::new(
432            "user123".to_string(),
433            "test_issuer".to_string(),
434            "test_audience".to_string(),
435            3600, );
437        assert!(!claims.is_expired());
438
439        let mut expired_claims = claims;
441        expired_claims.exp = SystemTime::now()
442            .duration_since(UNIX_EPOCH)
443            .unwrap()
444            .as_secs()
445            - 3600; assert!(expired_claims.is_expired());
447    }
448
449    #[test]
450    fn test_token_validator_edge_cases() {
451        use crate::utils::generate_jwt_secret;
452
453        let secret = generate_jwt_secret();
454        let validator = TokenValidator::new(
455            &secret,
456            "test_issuer".to_string(),
457            "test_audience".to_string(),
458        );
459
460        assert!(validator.validate_token("invalid.token").is_err());
462        assert!(validator.validate_token("").is_err());
463        assert!(validator.validate_token("not_a_jwt").is_err());
464
465        let valid_claims = JwtClaims::new(
467            "user123".to_string(),
468            "test_issuer".to_string(),
469            "test_audience".to_string(),
470            3600, );
472
473        let token = validator.create_token(&valid_claims).unwrap();
474        assert!(validator.validate_token(&token).is_ok());
475
476        let wrong_issuer = TokenValidator::new(
478            &secret,
479            "wrong_issuer".to_string(),
480            "test_audience".to_string(),
481        );
482        assert!(wrong_issuer.validate_token(&token).is_err());
483    }
484}