1use crate::security::{audit::AuditLogger, SecurityError};
7use anyhow::{anyhow, Result};
8use chrono::{Duration, Utc};
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::{HashMap, HashSet};
13use std::env;
14use std::sync::Arc;
15use tracing::{debug, error, warn};
16use uuid::Uuid;
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct Claims {
21 pub sub: String, pub client_id: String, pub scope: Vec<String>, pub iat: i64, pub exp: i64, pub jti: String, }
28
29#[derive(Debug, Clone, PartialEq)]
31pub enum AuthMethod {
32 ApiKey,
33 JwtToken,
34 Certificate,
35 None,
36}
37
38#[derive(Debug, Clone)]
40pub struct AuthContext {
41 pub client_id: String,
42 pub user_id: String,
43 pub method: AuthMethod,
44 pub scopes: Vec<String>,
45 pub expires_at: Option<chrono::DateTime<Utc>>,
46 pub request_id: String,
47}
48
49#[derive(Debug, Clone)]
51pub struct MCPAuthConfig {
52 pub enabled: bool,
53 pub jwt_secret: String,
54 pub jwt_expiry_seconds: u64,
55 pub api_keys: HashMap<String, ApiKeyInfo>,
56 pub allowed_certificates: HashSet<String>,
57 pub require_scope: Vec<String>,
58 pub performance_target_ms: u64,
59}
60
61#[derive(Debug, Clone)]
63pub struct ApiKeyInfo {
64 pub client_id: String,
65 pub scopes: Vec<String>,
66 pub expires_at: Option<chrono::DateTime<Utc>>,
67 pub last_used: Option<chrono::DateTime<Utc>>,
68 pub usage_count: u64,
69}
70
71impl Default for MCPAuthConfig {
72 fn default() -> Self {
73 Self {
74 enabled: false,
75 jwt_secret: env::var("MCP_JWT_SECRET").unwrap_or_else(|_| {
76 "change-me-in-production-super-secret-key-minimum-32-chars".to_string()
77 }),
78 jwt_expiry_seconds: env::var("MCP_JWT_EXPIRY_SECONDS")
79 .ok()
80 .and_then(|s| s.parse().ok())
81 .unwrap_or(3600), api_keys: Self::load_api_keys_from_env(),
83 allowed_certificates: Self::load_certificates_from_env(),
84 require_scope: vec!["mcp:read".to_string(), "mcp:write".to_string()],
85 performance_target_ms: 5, }
87 }
88}
89
90impl MCPAuthConfig {
91 fn load_api_keys_from_env() -> HashMap<String, ApiKeyInfo> {
93 let mut api_keys = HashMap::new();
94
95 if let Ok(keys_json) = env::var("MCP_API_KEYS") {
97 match serde_json::from_str::<HashMap<String, Value>>(&keys_json) {
98 Ok(keys) => {
99 for (key, info) in keys {
100 if let Ok(client_id) = info
101 .get("client_id")
102 .and_then(|v| v.as_str())
103 .ok_or("Missing client_id")
104 {
105 let scopes = info
106 .get("scopes")
107 .and_then(|v| v.as_array())
108 .map(|arr| {
109 arr.iter()
110 .filter_map(|s| s.as_str().map(String::from))
111 .collect()
112 })
113 .unwrap_or_else(|| {
114 vec!["mcp:read".to_string(), "mcp:write".to_string()]
115 });
116
117 let expires_at = info
118 .get("expires_at")
119 .and_then(|v| v.as_str())
120 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
121 .map(|dt| dt.with_timezone(&Utc));
122
123 api_keys.insert(
124 key,
125 ApiKeyInfo {
126 client_id: client_id.to_string(),
127 scopes,
128 expires_at,
129 last_used: None,
130 usage_count: 0,
131 },
132 );
133 }
134 }
135 }
136 Err(e) => {
137 warn!("Failed to parse MCP_API_KEYS: {}", e);
138 }
139 }
140 }
141
142 if api_keys.is_empty() {
144 if let Ok(api_key) = env::var("MCP_API_KEY") {
145 let client_id =
146 env::var("MCP_CLIENT_ID").unwrap_or_else(|_| "default-client".to_string());
147
148 api_keys.insert(
149 api_key,
150 ApiKeyInfo {
151 client_id,
152 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
153 expires_at: None,
154 last_used: None,
155 usage_count: 0,
156 },
157 );
158 }
159 }
160
161 api_keys
162 }
163
164 fn load_certificates_from_env() -> HashSet<String> {
166 let mut certs = HashSet::new();
167
168 if let Ok(cert_thumbprints) = env::var("MCP_ALLOWED_CERTS") {
169 for thumbprint in cert_thumbprints.split(',') {
170 certs.insert(thumbprint.trim().to_string());
171 }
172 }
173
174 certs
175 }
176
177 pub fn from_env() -> Self {
179 let is_production = env::var("ENVIRONMENT")
181 .unwrap_or_else(|_| "development".to_string())
182 .to_lowercase()
183 == "production";
184
185 let auth_enabled = env::var("MCP_AUTH_ENABLED")
187 .map(|s| s.parse().unwrap_or(true))
188 .unwrap_or(true);
189
190 if is_production && !auth_enabled {
192 eprintln!("WARNING: Authentication is disabled in production environment! This is a security risk.");
193 }
194
195 Self {
196 enabled: auth_enabled,
197 ..Self::default()
198 }
199 }
200}
201
202pub struct MCPAuth {
204 config: MCPAuthConfig,
205 encoding_key: EncodingKey,
206 decoding_key: DecodingKey,
207 audit_logger: Arc<AuditLogger>,
208 revoked_tokens: Arc<tokio::sync::RwLock<HashSet<String>>>,
209}
210
211impl MCPAuth {
212 pub fn new(config: MCPAuthConfig, audit_logger: Arc<AuditLogger>) -> Result<Self> {
214 let encoding_key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
215 let decoding_key = DecodingKey::from_secret(config.jwt_secret.as_bytes());
216
217 Ok(Self {
218 config,
219 encoding_key,
220 decoding_key,
221 audit_logger,
222 revoked_tokens: Arc::new(tokio::sync::RwLock::new(HashSet::new())),
223 })
224 }
225
226 pub async fn authenticate_request(
228 &self,
229 method: &str,
230 params: Option<&Value>,
231 headers: &HashMap<String, String>,
232 ) -> Result<Option<AuthContext>> {
233 let start_time = std::time::Instant::now();
234
235 if !self.config.enabled {
237 return Ok(None);
238 }
239
240 let request_id = Uuid::new_v4().to_string();
241
242 let auth_result = if let Some(auth_header) = headers.get("authorization") {
244 if let Some(token) = auth_header.strip_prefix("Bearer ") {
245 self.validate_jwt_token(token, &request_id).await
246 } else if let Some(api_key) = auth_header.strip_prefix("ApiKey ") {
247 self.validate_api_key(api_key, &request_id).await
248 } else {
249 Err(anyhow!("Invalid authorization header format"))
250 }
251 } else if let Some(cert_thumbprint) = headers.get("x-client-cert-thumbprint") {
252 self.validate_certificate(cert_thumbprint, &request_id)
253 .await
254 } else if let Some(api_key) = headers.get("x-api-key") {
255 self.validate_api_key(api_key, &request_id).await
256 } else {
257 Err(anyhow!("No authentication credentials provided"))
258 };
259
260 let elapsed = start_time.elapsed();
261
262 if elapsed.as_millis() > self.config.performance_target_ms as u128 {
264 warn!(
265 "Authentication took {}ms, exceeding target of {}ms",
266 elapsed.as_millis(),
267 self.config.performance_target_ms
268 );
269 }
270
271 match auth_result {
272 Ok(context) => {
273 debug!(
274 "Authentication successful for client: {}",
275 context.client_id
276 );
277
278 self.audit_logger
280 .log_auth_event(&context.client_id, &context.user_id, method, true, None)
281 .await;
282
283 Ok(Some(context))
284 }
285 Err(e) => {
286 error!("Authentication failed: {}", e);
287
288 let client_id = headers
290 .get("x-client-id")
291 .or_else(|| headers.get("client-id"))
292 .map(|s| s.as_str())
293 .unwrap_or("unknown");
294
295 self.audit_logger
296 .log_auth_event(client_id, "unknown", method, false, Some(&e.to_string()))
297 .await;
298
299 Err(SecurityError::AuthenticationFailed {
300 message: e.to_string(),
301 }
302 .into())
303 }
304 }
305 }
306
307 async fn validate_jwt_token(&self, token: &str, request_id: &str) -> Result<AuthContext> {
309 {
311 let revoked = self.revoked_tokens.read().await;
312 if revoked.contains(token) {
313 return Err(anyhow!("Token has been revoked"));
314 }
315 }
316
317 let mut validation = Validation::new(Algorithm::HS256);
318 validation.set_required_spec_claims(&["sub", "exp", "iat"]);
319
320 let token_data = decode::<Claims>(token, &self.decoding_key, &validation)
321 .map_err(|e| anyhow!("Invalid JWT token: {}", e))?;
322
323 let claims = token_data.claims;
324
325 let now = Utc::now().timestamp();
327 if claims.exp < now {
328 return Err(anyhow!("Token has expired"));
329 }
330
331 if !self.has_required_scopes(&claims.scope) {
333 return Err(anyhow!("Insufficient permissions"));
334 }
335
336 Ok(AuthContext {
337 client_id: claims.client_id,
338 user_id: claims.sub,
339 method: AuthMethod::JwtToken,
340 scopes: claims.scope,
341 expires_at: chrono::DateTime::from_timestamp(claims.exp, 0),
342 request_id: request_id.to_string(),
343 })
344 }
345
346 async fn validate_api_key(&self, api_key: &str, request_id: &str) -> Result<AuthContext> {
348 let api_key_info = self
349 .config
350 .api_keys
351 .get(api_key)
352 .ok_or_else(|| anyhow!("Invalid API key"))?;
353
354 if let Some(expires_at) = api_key_info.expires_at {
356 if Utc::now() > expires_at {
357 return Err(anyhow!("API key has expired"));
358 }
359 }
360
361 if !self.has_required_scopes(&api_key_info.scopes) {
363 return Err(anyhow!("Insufficient permissions"));
364 }
365
366 Ok(AuthContext {
367 client_id: api_key_info.client_id.clone(),
368 user_id: api_key_info.client_id.clone(), method: AuthMethod::ApiKey,
370 scopes: api_key_info.scopes.clone(),
371 expires_at: api_key_info.expires_at,
372 request_id: request_id.to_string(),
373 })
374 }
375
376 async fn validate_certificate(
378 &self,
379 thumbprint: &str,
380 request_id: &str,
381 ) -> Result<AuthContext> {
382 if !self.config.allowed_certificates.contains(thumbprint) {
383 return Err(anyhow!("Certificate not allowed"));
384 }
385
386 Ok(AuthContext {
389 client_id: format!("cert-{thumbprint}"),
390 user_id: format!("cert-{thumbprint}"),
391 method: AuthMethod::Certificate,
392 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
393 expires_at: None, request_id: request_id.to_string(),
395 })
396 }
397
398 fn has_required_scopes(&self, provided_scopes: &[String]) -> bool {
400 if self.config.require_scope.is_empty() {
401 return true;
402 }
403
404 self.config
405 .require_scope
406 .iter()
407 .all(|required| provided_scopes.contains(required))
408 }
409
410 pub async fn generate_token(
412 &self,
413 client_id: &str,
414 user_id: &str,
415 scopes: Vec<String>,
416 ) -> Result<String> {
417 let now = Utc::now();
418 let exp = now + Duration::seconds(self.config.jwt_expiry_seconds as i64);
419
420 let claims = Claims {
421 sub: user_id.to_string(),
422 client_id: client_id.to_string(),
423 scope: scopes,
424 iat: now.timestamp(),
425 exp: exp.timestamp(),
426 jti: Uuid::new_v4().to_string(),
427 };
428
429 encode(&Header::default(), &claims, &self.encoding_key)
430 .map_err(|e| anyhow!("Failed to generate token: {}", e))
431 }
432
433 pub async fn revoke_token(&self, token: &str) -> Result<()> {
435 let mut revoked = self.revoked_tokens.write().await;
436 revoked.insert(token.to_string());
437 debug!("Token revoked");
438 Ok(())
439 }
440
441 pub fn validate_tool_access(&self, context: &AuthContext, tool_name: &str) -> Result<()> {
443 let required_scope = match tool_name {
445 "store_memory" | "harvest_conversation" | "migrate_memory" | "delete_memory" => {
446 "mcp:write"
447 }
448 "search_memory"
449 | "get_statistics"
450 | "what_did_you_remember"
451 | "get_harvester_metrics" => "mcp:read",
452 _ => "mcp:read", };
454
455 if !context.scopes.contains(&required_scope.to_string()) {
456 return Err(SecurityError::AuthorizationFailed {
457 message: format!("Tool '{tool_name}' requires '{required_scope}' scope"),
458 }
459 .into());
460 }
461
462 Ok(())
463 }
464
465 pub async fn get_stats(&self) -> serde_json::Value {
467 let revoked_count = self.revoked_tokens.read().await.len();
468
469 serde_json::json!({
470 "enabled": self.config.enabled,
471 "api_keys_configured": self.config.api_keys.len(),
472 "certificates_allowed": self.config.allowed_certificates.len(),
473 "revoked_tokens": revoked_count,
474 "performance_target_ms": self.config.performance_target_ms,
475 "jwt_expiry_seconds": self.config.jwt_expiry_seconds,
476 })
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::security::AuditConfig;
484 use std::collections::HashMap;
485 use tempfile::tempdir;
486
487 fn create_test_config() -> MCPAuthConfig {
488 let mut api_keys = HashMap::new();
489 api_keys.insert(
490 "test-key-123".to_string(),
491 ApiKeyInfo {
492 client_id: "test-client".to_string(),
493 scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
494 expires_at: None,
495 last_used: None,
496 usage_count: 0,
497 },
498 );
499
500 let mut certs = HashSet::new();
501 certs.insert("abc123def456".to_string());
502
503 MCPAuthConfig {
504 enabled: true,
505 jwt_secret: "test-secret-key-minimum-32-characters-long".to_string(),
506 jwt_expiry_seconds: 3600,
507 api_keys,
508 allowed_certificates: certs,
509 require_scope: vec!["mcp:read".to_string()],
510 performance_target_ms: 5,
511 }
512 }
513
514 async fn create_test_auth() -> MCPAuth {
515 let config = create_test_config();
516 let temp_dir = tempdir().unwrap();
517 let audit_config = AuditConfig {
518 enabled: true,
519 log_all_requests: true,
520 log_data_access: true,
521 log_modifications: true,
522 log_auth_events: true,
523 retention_days: 30,
524 };
525 let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
526 MCPAuth::new(config, audit_logger).unwrap()
527 }
528
529 #[tokio::test]
530 async fn test_api_key_authentication() {
531 let auth = create_test_auth().await;
532
533 let mut headers = HashMap::new();
534 headers.insert(
535 "authorization".to_string(),
536 "ApiKey test-key-123".to_string(),
537 );
538
539 let result = auth
540 .authenticate_request("tools/call", None, &headers)
541 .await;
542 assert!(result.is_ok());
543
544 let context = result.unwrap().unwrap();
545 assert_eq!(context.client_id, "test-client");
546 assert_eq!(context.method, AuthMethod::ApiKey);
547 assert!(context.scopes.contains(&"mcp:read".to_string()));
548 }
549
550 #[tokio::test]
551 async fn test_jwt_authentication() {
552 let auth = create_test_auth().await;
553
554 let token = auth
556 .generate_token(
557 "test-client",
558 "test-user",
559 vec!["mcp:read".to_string(), "mcp:write".to_string()],
560 )
561 .await
562 .unwrap();
563
564 let mut headers = HashMap::new();
565 headers.insert("authorization".to_string(), format!("Bearer {token}"));
566
567 let result = auth
568 .authenticate_request("tools/call", None, &headers)
569 .await;
570 assert!(result.is_ok());
571
572 let context = result.unwrap().unwrap();
573 assert_eq!(context.client_id, "test-client");
574 assert_eq!(context.user_id, "test-user");
575 assert_eq!(context.method, AuthMethod::JwtToken);
576 }
577
578 #[tokio::test]
579 async fn test_certificate_authentication() {
580 let auth = create_test_auth().await;
581
582 let mut headers = HashMap::new();
583 headers.insert(
584 "x-client-cert-thumbprint".to_string(),
585 "abc123def456".to_string(),
586 );
587
588 let result = auth
589 .authenticate_request("tools/call", None, &headers)
590 .await;
591 assert!(result.is_ok());
592
593 let context = result.unwrap().unwrap();
594 assert_eq!(context.client_id, "cert-abc123def456");
595 assert_eq!(context.method, AuthMethod::Certificate);
596 }
597
598 #[tokio::test]
599 async fn test_invalid_api_key() {
600 let auth = create_test_auth().await;
601
602 let mut headers = HashMap::new();
603 headers.insert(
604 "authorization".to_string(),
605 "ApiKey invalid-key".to_string(),
606 );
607
608 let result = auth
609 .authenticate_request("tools/call", None, &headers)
610 .await;
611 assert!(result.is_err());
612 }
613
614 #[tokio::test]
615 async fn test_tool_access_validation() {
616 let auth = create_test_auth().await;
617
618 let context = AuthContext {
619 client_id: "test-client".to_string(),
620 user_id: "test-user".to_string(),
621 method: AuthMethod::ApiKey,
622 scopes: vec!["mcp:read".to_string()],
623 expires_at: None,
624 request_id: "test-request".to_string(),
625 };
626
627 assert!(auth.validate_tool_access(&context, "search_memory").is_ok());
629 assert!(auth
630 .validate_tool_access(&context, "get_statistics")
631 .is_ok());
632
633 assert!(auth.validate_tool_access(&context, "store_memory").is_err());
635 assert!(auth
636 .validate_tool_access(&context, "delete_memory")
637 .is_err());
638 }
639
640 #[tokio::test]
641 async fn test_token_revocation() {
642 let auth = create_test_auth().await;
643
644 let token = auth
645 .generate_token("test-client", "test-user", vec!["mcp:read".to_string()])
646 .await
647 .unwrap();
648
649 let mut headers = HashMap::new();
651 headers.insert("authorization".to_string(), format!("Bearer {token}"));
652
653 let result = auth
654 .authenticate_request("tools/call", None, &headers)
655 .await;
656 assert!(result.is_ok());
657
658 auth.revoke_token(&token).await.unwrap();
660
661 let result = auth
663 .authenticate_request("tools/call", None, &headers)
664 .await;
665 assert!(result.is_err());
666 }
667
668 #[tokio::test]
669 async fn test_disabled_authentication() {
670 let mut config = create_test_config();
671 config.enabled = false;
672
673 let temp_dir = tempdir().unwrap();
674 let audit_config = AuditConfig {
675 enabled: true,
676 log_all_requests: true,
677 log_data_access: true,
678 log_modifications: true,
679 log_auth_events: true,
680 retention_days: 30,
681 };
682 let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
683 let auth = MCPAuth::new(config, audit_logger).unwrap();
684
685 let headers = HashMap::new();
686 let result = auth
687 .authenticate_request("tools/call", None, &headers)
688 .await;
689 assert!(result.is_ok());
690 assert!(result.unwrap().is_none()); }
692}