pulseengine_mcp_security_middleware/
auth.rs1use crate::error::{SecurityError, SecurityResult};
4use crate::utils::{current_timestamp, secure_compare, validate_api_key_format};
5use chrono::{DateTime, Utc};
6use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct AuthContext {
14 pub user_id: String,
16
17 pub roles: Vec<String>,
19
20 pub api_key: Option<String>,
22
23 pub jwt_claims: Option<JwtClaims>,
25
26 pub authenticated_at: DateTime<Utc>,
28
29 pub request_id: String,
31
32 pub metadata: HashMap<String, String>,
34}
35
36impl AuthContext {
37 pub fn new(user_id: String) -> Self {
39 Self {
40 user_id,
41 roles: Vec::new(),
42 api_key: None,
43 jwt_claims: None,
44 authenticated_at: Utc::now(),
45 request_id: crate::utils::generate_request_id(),
46 metadata: HashMap::new(),
47 }
48 }
49
50 pub fn with_role<S: Into<String>>(mut self, role: S) -> Self {
52 self.roles.push(role.into());
53 self
54 }
55
56 pub fn with_roles<I, S>(mut self, roles: I) -> Self
58 where
59 I: IntoIterator<Item = S>,
60 S: Into<String>,
61 {
62 self.roles.extend(roles.into_iter().map(|r| r.into()));
63 self
64 }
65
66 pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
68 self.api_key = Some(api_key.into());
69 self
70 }
71
72 pub fn with_jwt_claims(mut self, claims: JwtClaims) -> Self {
74 self.jwt_claims = Some(claims);
75 self
76 }
77
78 pub fn with_metadata<K, V>(mut self, key: K, value: V) -> Self
80 where
81 K: Into<String>,
82 V: Into<String>,
83 {
84 self.metadata.insert(key.into(), value.into());
85 self
86 }
87
88 pub fn has_role(&self, role: &str) -> bool {
90 self.roles.contains(&role.to_string())
91 }
92
93 pub fn has_any_role<I>(&self, roles: I) -> bool
95 where
96 I: IntoIterator,
97 I::Item: AsRef<str>,
98 {
99 for role in roles {
100 if self.has_role(role.as_ref()) {
101 return true;
102 }
103 }
104 false
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct JwtClaims {
111 pub sub: String,
113
114 pub exp: u64,
116
117 pub iat: u64,
119
120 pub nbf: Option<u64>,
122
123 pub jti: String,
125
126 pub iss: String,
128
129 pub aud: String,
131
132 pub roles: Option<Vec<String>>,
134
135 pub metadata: Option<HashMap<String, serde_json::Value>>,
137}
138
139impl JwtClaims {
140 pub fn new(user_id: String, issuer: String, audience: String, expires_in_seconds: u64) -> Self {
142 let now = current_timestamp();
143
144 Self {
145 sub: user_id,
146 exp: now + expires_in_seconds,
147 iat: now,
148 nbf: Some(now),
149 jti: Uuid::new_v4().to_string(),
150 iss: issuer,
151 aud: audience,
152 roles: None,
153 metadata: None,
154 }
155 }
156
157 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
159 self.roles = Some(roles);
160 self
161 }
162
163 pub fn is_expired(&self) -> bool {
165 current_timestamp() > self.exp
166 }
167}
168
169pub struct TokenValidator {
171 decoding_key: DecodingKey,
173
174 validation: Validation,
176
177 expected_issuer: String,
179
180 expected_audience: String,
182
183 secret: String,
185}
186
187impl std::fmt::Debug for TokenValidator {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 f.debug_struct("TokenValidator")
190 .field("expected_issuer", &self.expected_issuer)
191 .field("expected_audience", &self.expected_audience)
192 .field("secret", &"[REDACTED]")
193 .finish()
194 }
195}
196
197impl TokenValidator {
198 pub fn new(secret: &str, issuer: String, audience: String) -> Self {
200 let decoding_key = DecodingKey::from_secret(secret.as_bytes());
201
202 let mut validation = Validation::new(Algorithm::HS256);
203 validation.set_issuer(&[&issuer]);
204 validation.set_audience(&[&audience]);
205 validation.validate_exp = true;
206 validation.validate_nbf = true;
207
208 Self {
209 decoding_key,
210 validation,
211 expected_issuer: issuer,
212 expected_audience: audience,
213 secret: secret.to_string(),
214 }
215 }
216
217 pub fn validate_token(&self, token: &str) -> SecurityResult<JwtClaims> {
219 let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
220
221 let claims = token_data.claims;
222
223 if claims.is_expired() {
225 return Err(SecurityError::TokenExpired);
226 }
227
228 if claims.iss != self.expected_issuer {
229 return Err(SecurityError::invalid_token("Invalid issuer"));
230 }
231
232 if claims.aud != self.expected_audience {
233 return Err(SecurityError::invalid_token("Invalid audience"));
234 }
235
236 Ok(claims)
237 }
238
239 pub fn create_token(&self, claims: &JwtClaims) -> SecurityResult<String> {
241 let encoding_key = EncodingKey::from_secret(self.secret.as_bytes());
242
243 let header = Header::new(Algorithm::HS256);
244
245 encode(&header, claims, &encoding_key).map_err(SecurityError::from)
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct ApiKeyValidator {
252 api_keys: HashMap<String, String>, }
255
256impl ApiKeyValidator {
257 pub fn new() -> Self {
259 Self {
260 api_keys: HashMap::new(),
261 }
262 }
263
264 pub fn add_api_key(&mut self, api_key: &str, user_id: String) -> SecurityResult<()> {
266 validate_api_key_format(api_key)?;
267
268 let hash = crate::utils::hash_api_key(api_key);
269 self.api_keys.insert(hash, user_id);
270
271 Ok(())
272 }
273
274 pub fn validate_api_key(&self, api_key: &str) -> SecurityResult<String> {
276 validate_api_key_format(api_key)?;
277
278 let hash = crate::utils::hash_api_key(api_key);
279
280 for (stored_hash, user_id) in &self.api_keys {
282 if secure_compare(&hash, stored_hash) {
283 return Ok(user_id.clone());
284 }
285 }
286
287 Err(SecurityError::InvalidApiKey)
288 }
289
290 pub fn remove_api_key(&mut self, api_key: &str) -> SecurityResult<bool> {
292 validate_api_key_format(api_key)?;
293
294 let hash = crate::utils::hash_api_key(api_key);
295 Ok(self.api_keys.remove(&hash).is_some())
296 }
297
298 pub fn len(&self) -> usize {
300 self.api_keys.len()
301 }
302
303 pub fn is_empty(&self) -> bool {
305 self.api_keys.is_empty()
306 }
307}
308
309impl Default for ApiKeyValidator {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_auth_context_creation() {
321 let ctx = AuthContext::new("user123".to_string())
322 .with_role("admin")
323 .with_roles(vec!["user", "moderator"])
324 .with_metadata("key", "value");
325
326 assert_eq!(ctx.user_id, "user123");
327 assert!(ctx.has_role("admin"));
328 assert!(ctx.has_role("user"));
329 assert!(ctx.has_role("moderator"));
330 assert!(!ctx.has_role("guest"));
331 assert!(ctx.has_any_role(&["admin", "guest"]));
332 assert!(!ctx.has_any_role(&["guest", "visitor"]));
333 assert_eq!(ctx.metadata.get("key"), Some(&"value".to_string()));
334 }
335
336 #[test]
337 fn test_jwt_claims() {
338 let claims = JwtClaims::new(
339 "user123".to_string(),
340 "test-issuer".to_string(),
341 "test-audience".to_string(),
342 3600,
343 )
344 .with_roles(vec!["admin".to_string()]);
345
346 assert_eq!(claims.sub, "user123");
347 assert_eq!(claims.iss, "test-issuer");
348 assert_eq!(claims.aud, "test-audience");
349 assert!(!claims.is_expired());
350 assert_eq!(claims.roles, Some(vec!["admin".to_string()]));
351 }
352
353 #[test]
354 fn test_api_key_validator() {
355 let mut validator = ApiKeyValidator::new();
356 let api_key = crate::utils::generate_api_key();
357
358 validator
360 .add_api_key(&api_key, "user123".to_string())
361 .unwrap();
362 assert_eq!(validator.len(), 1);
363
364 let user_id = validator.validate_api_key(&api_key).unwrap();
366 assert_eq!(user_id, "user123");
367
368 let invalid_key = crate::utils::generate_api_key();
370 assert!(validator.validate_api_key(&invalid_key).is_err());
371
372 assert!(validator.remove_api_key(&api_key).unwrap());
374 assert_eq!(validator.len(), 0);
375 assert!(validator.is_empty());
376 }
377
378 #[test]
379 fn test_token_validator() {
380 let validator = TokenValidator::new(
381 "test-secret",
382 "test-issuer".to_string(),
383 "test-audience".to_string(),
384 );
385
386 let claims = JwtClaims::new(
387 "user123".to_string(),
388 "test-issuer".to_string(),
389 "test-audience".to_string(),
390 3600,
391 );
392
393 let token = validator.create_token(&claims).unwrap();
395 let validated_claims = validator.validate_token(&token).unwrap();
396
397 assert_eq!(validated_claims.sub, "user123");
398 assert_eq!(validated_claims.iss, "test-issuer");
399 assert_eq!(validated_claims.aud, "test-audience");
400 }
401
402 #[test]
403 fn test_auth_context_additional_methods() {
404 let context = AuthContext::new("user123".to_string())
406 .with_roles(vec!["admin".to_string(), "user".to_string()]);
407
408 assert!(context.has_role("admin"));
409 assert!(context.has_role("user"));
410 assert!(!context.has_role("guest"));
411
412 assert!(context.has_any_role(["admin", "guest"]));
414 assert!(context.has_any_role(["user", "guest"]));
415 assert!(!context.has_any_role(["guest", "moderator"]));
416
417 let context = AuthContext::new("user123".to_string())
419 .with_metadata("department", "engineering")
420 .with_metadata("level", "senior");
421
422 assert_eq!(context.metadata.get("department").unwrap(), "engineering");
423 assert_eq!(context.metadata.get("level").unwrap(), "senior");
424 }
425
426 #[test]
427 fn test_jwt_claims_expiration() {
428 use std::time::{SystemTime, UNIX_EPOCH};
429
430 let claims = JwtClaims::new(
432 "user123".to_string(),
433 "test_issuer".to_string(),
434 "test_audience".to_string(),
435 3600, );
437 assert!(!claims.is_expired());
438
439 let mut expired_claims = claims;
441 expired_claims.exp = SystemTime::now()
442 .duration_since(UNIX_EPOCH)
443 .unwrap()
444 .as_secs()
445 - 3600; assert!(expired_claims.is_expired());
447 }
448
449 #[test]
450 fn test_token_validator_edge_cases() {
451 use crate::utils::generate_jwt_secret;
452
453 let secret = generate_jwt_secret();
454 let validator = TokenValidator::new(
455 &secret,
456 "test_issuer".to_string(),
457 "test_audience".to_string(),
458 );
459
460 assert!(validator.validate_token("invalid.token").is_err());
462 assert!(validator.validate_token("").is_err());
463 assert!(validator.validate_token("not_a_jwt").is_err());
464
465 let valid_claims = JwtClaims::new(
467 "user123".to_string(),
468 "test_issuer".to_string(),
469 "test_audience".to_string(),
470 3600, );
472
473 let token = validator.create_token(&valid_claims).unwrap();
474 assert!(validator.validate_token(&token).is_ok());
475
476 let wrong_issuer = TokenValidator::new(
478 &secret,
479 "wrong_issuer".to_string(),
480 "test_audience".to_string(),
481 );
482 assert!(wrong_issuer.validate_token(&token).is_err());
483 }
484}