codex_memory/mcp_server/
auth.rs

1//! MCP Authentication Middleware
2//!
3//! This module provides authentication middleware for MCP requests,
4//! supporting API keys, JWT tokens, and certificate-based authentication.
5
6use crate::security::{audit::AuditLogger, SecurityError};
7use anyhow::{anyhow, Result};
8use chrono::{Duration, Utc};
9use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::{HashMap, HashSet};
13use std::env;
14use std::sync::Arc;
15use tracing::{debug, error, warn};
16use uuid::Uuid;
17
18/// JWT claims structure
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct Claims {
21    pub sub: String,        // Subject (user ID)
22    pub client_id: String,  // Client identifier
23    pub scope: Vec<String>, // Permissions/scopes
24    pub iat: i64,           // Issued at
25    pub exp: i64,           // Expiration time
26    pub jti: String,        // JWT ID (for revocation)
27}
28
29/// Authentication method types
30#[derive(Debug, Clone, PartialEq)]
31pub enum AuthMethod {
32    ApiKey,
33    JwtToken,
34    Certificate,
35    None,
36}
37
38/// Authentication context for validated requests
39#[derive(Debug, Clone)]
40pub struct AuthContext {
41    pub client_id: String,
42    pub user_id: String,
43    pub method: AuthMethod,
44    pub scopes: Vec<String>,
45    pub expires_at: Option<chrono::DateTime<Utc>>,
46    pub request_id: String,
47}
48
49/// MCP Authentication configuration
50#[derive(Debug, Clone)]
51pub struct MCPAuthConfig {
52    pub enabled: bool,
53    pub jwt_secret: String,
54    pub jwt_expiry_seconds: u64,
55    pub api_keys: HashMap<String, ApiKeyInfo>,
56    pub allowed_certificates: HashSet<String>,
57    pub require_scope: Vec<String>,
58    pub performance_target_ms: u64,
59}
60
61/// API Key information
62#[derive(Debug, Clone)]
63pub struct ApiKeyInfo {
64    pub client_id: String,
65    pub scopes: Vec<String>,
66    pub expires_at: Option<chrono::DateTime<Utc>>,
67    pub last_used: Option<chrono::DateTime<Utc>>,
68    pub usage_count: u64,
69}
70
71impl Default for MCPAuthConfig {
72    fn default() -> Self {
73        Self {
74            enabled: false,
75            jwt_secret: env::var("MCP_JWT_SECRET").unwrap_or_else(|_| {
76                "change-me-in-production-super-secret-key-minimum-32-chars".to_string()
77            }),
78            jwt_expiry_seconds: env::var("MCP_JWT_EXPIRY_SECONDS")
79                .ok()
80                .and_then(|s| s.parse().ok())
81                .unwrap_or(3600), // 1 hour
82            api_keys: Self::load_api_keys_from_env(),
83            allowed_certificates: Self::load_certificates_from_env(),
84            require_scope: vec!["mcp:read".to_string(), "mcp:write".to_string()],
85            performance_target_ms: 5, // Must be <5ms per requirement
86        }
87    }
88}
89
90impl MCPAuthConfig {
91    /// Load API keys from environment variables
92    fn load_api_keys_from_env() -> HashMap<String, ApiKeyInfo> {
93        let mut api_keys = HashMap::new();
94
95        // Load from MCP_API_KEYS environment variable (JSON format)
96        if let Ok(keys_json) = env::var("MCP_API_KEYS") {
97            match serde_json::from_str::<HashMap<String, Value>>(&keys_json) {
98                Ok(keys) => {
99                    for (key, info) in keys {
100                        if let Ok(client_id) = info
101                            .get("client_id")
102                            .and_then(|v| v.as_str())
103                            .ok_or("Missing client_id")
104                        {
105                            let scopes = info
106                                .get("scopes")
107                                .and_then(|v| v.as_array())
108                                .map(|arr| {
109                                    arr.iter()
110                                        .filter_map(|s| s.as_str().map(String::from))
111                                        .collect()
112                                })
113                                .unwrap_or_else(|| {
114                                    vec!["mcp:read".to_string(), "mcp:write".to_string()]
115                                });
116
117                            let expires_at = info
118                                .get("expires_at")
119                                .and_then(|v| v.as_str())
120                                .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok())
121                                .map(|dt| dt.with_timezone(&Utc));
122
123                            api_keys.insert(
124                                key,
125                                ApiKeyInfo {
126                                    client_id: client_id.to_string(),
127                                    scopes,
128                                    expires_at,
129                                    last_used: None,
130                                    usage_count: 0,
131                                },
132                            );
133                        }
134                    }
135                }
136                Err(e) => {
137                    warn!("Failed to parse MCP_API_KEYS: {}", e);
138                }
139            }
140        }
141
142        // Fallback: single API key from MCP_API_KEY
143        if api_keys.is_empty() {
144            if let Ok(api_key) = env::var("MCP_API_KEY") {
145                let client_id =
146                    env::var("MCP_CLIENT_ID").unwrap_or_else(|_| "default-client".to_string());
147
148                api_keys.insert(
149                    api_key,
150                    ApiKeyInfo {
151                        client_id,
152                        scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
153                        expires_at: None,
154                        last_used: None,
155                        usage_count: 0,
156                    },
157                );
158            }
159        }
160
161        api_keys
162    }
163
164    /// Load allowed certificates from environment variables
165    fn load_certificates_from_env() -> HashSet<String> {
166        let mut certs = HashSet::new();
167
168        if let Ok(cert_thumbprints) = env::var("MCP_ALLOWED_CERTS") {
169            for thumbprint in cert_thumbprints.split(',') {
170                certs.insert(thumbprint.trim().to_string());
171            }
172        }
173
174        certs
175    }
176
177    /// Create configuration from environment variables
178    pub fn from_env() -> Self {
179        // In production, authentication should be enabled by default for security
180        let is_production = env::var("ENVIRONMENT")
181            .unwrap_or_else(|_| "development".to_string())
182            .to_lowercase()
183            == "production";
184
185        // Default to enabled unless explicitly disabled
186        let auth_enabled = env::var("MCP_AUTH_ENABLED")
187            .map(|s| s.parse().unwrap_or(true))
188            .unwrap_or(true);
189
190        // Warn if authentication is disabled in production
191        if is_production && !auth_enabled {
192            eprintln!("WARNING: Authentication is disabled in production environment! This is a security risk.");
193        }
194
195        Self {
196            enabled: auth_enabled,
197            ..Self::default()
198        }
199    }
200}
201
202/// MCP Authentication middleware
203pub struct MCPAuth {
204    config: MCPAuthConfig,
205    encoding_key: EncodingKey,
206    decoding_key: DecodingKey,
207    audit_logger: Arc<AuditLogger>,
208    revoked_tokens: Arc<tokio::sync::RwLock<HashSet<String>>>,
209}
210
211impl MCPAuth {
212    /// Create a new authentication middleware
213    pub fn new(config: MCPAuthConfig, audit_logger: Arc<AuditLogger>) -> Result<Self> {
214        let encoding_key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
215        let decoding_key = DecodingKey::from_secret(config.jwt_secret.as_bytes());
216
217        Ok(Self {
218            config,
219            encoding_key,
220            decoding_key,
221            audit_logger,
222            revoked_tokens: Arc::new(tokio::sync::RwLock::new(HashSet::new())),
223        })
224    }
225
226    /// Authenticate an MCP request
227    pub async fn authenticate_request(
228        &self,
229        method: &str,
230        params: Option<&Value>,
231        headers: &HashMap<String, String>,
232    ) -> Result<Option<AuthContext>> {
233        let start_time = std::time::Instant::now();
234
235        // Skip authentication if disabled
236        if !self.config.enabled {
237            return Ok(None);
238        }
239
240        let request_id = Uuid::new_v4().to_string();
241
242        // Determine authentication method and validate
243        let auth_result = if let Some(auth_header) = headers.get("authorization") {
244            if let Some(token) = auth_header.strip_prefix("Bearer ") {
245                self.validate_jwt_token(token, &request_id).await
246            } else if let Some(api_key) = auth_header.strip_prefix("ApiKey ") {
247                self.validate_api_key(api_key, &request_id).await
248            } else {
249                Err(anyhow!("Invalid authorization header format"))
250            }
251        } else if let Some(cert_thumbprint) = headers.get("x-client-cert-thumbprint") {
252            self.validate_certificate(cert_thumbprint, &request_id)
253                .await
254        } else if let Some(api_key) = headers.get("x-api-key") {
255            self.validate_api_key(api_key, &request_id).await
256        } else {
257            Err(anyhow!("No authentication credentials provided"))
258        };
259
260        let elapsed = start_time.elapsed();
261
262        // Check performance requirement
263        if elapsed.as_millis() > self.config.performance_target_ms as u128 {
264            warn!(
265                "Authentication took {}ms, exceeding target of {}ms",
266                elapsed.as_millis(),
267                self.config.performance_target_ms
268            );
269        }
270
271        match auth_result {
272            Ok(context) => {
273                debug!(
274                    "Authentication successful for client: {}",
275                    context.client_id
276                );
277
278                // Log successful authentication
279                self.audit_logger
280                    .log_auth_event(&context.client_id, &context.user_id, method, true, None)
281                    .await;
282
283                Ok(Some(context))
284            }
285            Err(e) => {
286                error!("Authentication failed: {}", e);
287
288                // Log failed authentication
289                let client_id = headers
290                    .get("x-client-id")
291                    .or_else(|| headers.get("client-id"))
292                    .map(|s| s.as_str())
293                    .unwrap_or("unknown");
294
295                self.audit_logger
296                    .log_auth_event(client_id, "unknown", method, false, Some(&e.to_string()))
297                    .await;
298
299                Err(SecurityError::AuthenticationFailed {
300                    message: e.to_string(),
301                }
302                .into())
303            }
304        }
305    }
306
307    /// Validate JWT token
308    async fn validate_jwt_token(&self, token: &str, request_id: &str) -> Result<AuthContext> {
309        // Check if token is revoked
310        {
311            let revoked = self.revoked_tokens.read().await;
312            if revoked.contains(token) {
313                return Err(anyhow!("Token has been revoked"));
314            }
315        }
316
317        let mut validation = Validation::new(Algorithm::HS256);
318        validation.set_required_spec_claims(&["sub", "exp", "iat"]);
319
320        let token_data = decode::<Claims>(token, &self.decoding_key, &validation)
321            .map_err(|e| anyhow!("Invalid JWT token: {}", e))?;
322
323        let claims = token_data.claims;
324
325        // Verify token is not expired
326        let now = Utc::now().timestamp();
327        if claims.exp < now {
328            return Err(anyhow!("Token has expired"));
329        }
330
331        // Verify required scopes
332        if !self.has_required_scopes(&claims.scope) {
333            return Err(anyhow!("Insufficient permissions"));
334        }
335
336        Ok(AuthContext {
337            client_id: claims.client_id,
338            user_id: claims.sub,
339            method: AuthMethod::JwtToken,
340            scopes: claims.scope,
341            expires_at: chrono::DateTime::from_timestamp(claims.exp, 0),
342            request_id: request_id.to_string(),
343        })
344    }
345
346    /// Validate API key
347    async fn validate_api_key(&self, api_key: &str, request_id: &str) -> Result<AuthContext> {
348        let api_key_info = self
349            .config
350            .api_keys
351            .get(api_key)
352            .ok_or_else(|| anyhow!("Invalid API key"))?;
353
354        // Check if key is expired
355        if let Some(expires_at) = api_key_info.expires_at {
356            if Utc::now() > expires_at {
357                return Err(anyhow!("API key has expired"));
358            }
359        }
360
361        // Verify required scopes
362        if !self.has_required_scopes(&api_key_info.scopes) {
363            return Err(anyhow!("Insufficient permissions"));
364        }
365
366        Ok(AuthContext {
367            client_id: api_key_info.client_id.clone(),
368            user_id: api_key_info.client_id.clone(), // Use client_id as user_id for API keys
369            method: AuthMethod::ApiKey,
370            scopes: api_key_info.scopes.clone(),
371            expires_at: api_key_info.expires_at,
372            request_id: request_id.to_string(),
373        })
374    }
375
376    /// Validate client certificate
377    async fn validate_certificate(
378        &self,
379        thumbprint: &str,
380        request_id: &str,
381    ) -> Result<AuthContext> {
382        if !self.config.allowed_certificates.contains(thumbprint) {
383            return Err(anyhow!("Certificate not allowed"));
384        }
385
386        // For certificate-based auth, we grant full access
387        // In production, you'd extract more info from the certificate
388        Ok(AuthContext {
389            client_id: format!("cert-{thumbprint}"),
390            user_id: format!("cert-{thumbprint}"),
391            method: AuthMethod::Certificate,
392            scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
393            expires_at: None, // Certificates don't expire in this context
394            request_id: request_id.to_string(),
395        })
396    }
397
398    /// Check if provided scopes meet requirements
399    fn has_required_scopes(&self, provided_scopes: &[String]) -> bool {
400        if self.config.require_scope.is_empty() {
401            return true;
402        }
403
404        self.config
405            .require_scope
406            .iter()
407            .all(|required| provided_scopes.contains(required))
408    }
409
410    /// Generate JWT token for a client
411    pub async fn generate_token(
412        &self,
413        client_id: &str,
414        user_id: &str,
415        scopes: Vec<String>,
416    ) -> Result<String> {
417        let now = Utc::now();
418        let exp = now + Duration::seconds(self.config.jwt_expiry_seconds as i64);
419
420        let claims = Claims {
421            sub: user_id.to_string(),
422            client_id: client_id.to_string(),
423            scope: scopes,
424            iat: now.timestamp(),
425            exp: exp.timestamp(),
426            jti: Uuid::new_v4().to_string(),
427        };
428
429        encode(&Header::default(), &claims, &self.encoding_key)
430            .map_err(|e| anyhow!("Failed to generate token: {}", e))
431    }
432
433    /// Revoke a JWT token
434    pub async fn revoke_token(&self, token: &str) -> Result<()> {
435        let mut revoked = self.revoked_tokens.write().await;
436        revoked.insert(token.to_string());
437        debug!("Token revoked");
438        Ok(())
439    }
440
441    /// Validate tool access permissions
442    pub fn validate_tool_access(&self, context: &AuthContext, tool_name: &str) -> Result<()> {
443        // Map tools to required scopes
444        let required_scope = match tool_name {
445            "store_memory" | "harvest_conversation" | "migrate_memory" | "delete_memory" => {
446                "mcp:write"
447            }
448            "search_memory"
449            | "get_statistics"
450            | "what_did_you_remember"
451            | "get_harvester_metrics" => "mcp:read",
452            _ => "mcp:read", // Default to read access
453        };
454
455        if !context.scopes.contains(&required_scope.to_string()) {
456            return Err(SecurityError::AuthorizationFailed {
457                message: format!("Tool '{tool_name}' requires '{required_scope}' scope"),
458            }
459            .into());
460        }
461
462        Ok(())
463    }
464
465    /// Get authentication statistics
466    pub async fn get_stats(&self) -> serde_json::Value {
467        let revoked_count = self.revoked_tokens.read().await.len();
468
469        serde_json::json!({
470            "enabled": self.config.enabled,
471            "api_keys_configured": self.config.api_keys.len(),
472            "certificates_allowed": self.config.allowed_certificates.len(),
473            "revoked_tokens": revoked_count,
474            "performance_target_ms": self.config.performance_target_ms,
475            "jwt_expiry_seconds": self.config.jwt_expiry_seconds,
476        })
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::security::AuditConfig;
484    use std::collections::HashMap;
485    use tempfile::tempdir;
486
487    fn create_test_config() -> MCPAuthConfig {
488        let mut api_keys = HashMap::new();
489        api_keys.insert(
490            "test-key-123".to_string(),
491            ApiKeyInfo {
492                client_id: "test-client".to_string(),
493                scopes: vec!["mcp:read".to_string(), "mcp:write".to_string()],
494                expires_at: None,
495                last_used: None,
496                usage_count: 0,
497            },
498        );
499
500        let mut certs = HashSet::new();
501        certs.insert("abc123def456".to_string());
502
503        MCPAuthConfig {
504            enabled: true,
505            jwt_secret: "test-secret-key-minimum-32-characters-long".to_string(),
506            jwt_expiry_seconds: 3600,
507            api_keys,
508            allowed_certificates: certs,
509            require_scope: vec!["mcp:read".to_string()],
510            performance_target_ms: 5,
511        }
512    }
513
514    async fn create_test_auth() -> MCPAuth {
515        let config = create_test_config();
516        let temp_dir = tempdir().unwrap();
517        let audit_config = AuditConfig {
518            enabled: true,
519            log_all_requests: true,
520            log_data_access: true,
521            log_modifications: true,
522            log_auth_events: true,
523            retention_days: 30,
524        };
525        let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
526        MCPAuth::new(config, audit_logger).unwrap()
527    }
528
529    #[tokio::test]
530    async fn test_api_key_authentication() {
531        let auth = create_test_auth().await;
532
533        let mut headers = HashMap::new();
534        headers.insert(
535            "authorization".to_string(),
536            "ApiKey test-key-123".to_string(),
537        );
538
539        let result = auth
540            .authenticate_request("tools/call", None, &headers)
541            .await;
542        assert!(result.is_ok());
543
544        let context = result.unwrap().unwrap();
545        assert_eq!(context.client_id, "test-client");
546        assert_eq!(context.method, AuthMethod::ApiKey);
547        assert!(context.scopes.contains(&"mcp:read".to_string()));
548    }
549
550    #[tokio::test]
551    async fn test_jwt_authentication() {
552        let auth = create_test_auth().await;
553
554        // Generate a test token
555        let token = auth
556            .generate_token(
557                "test-client",
558                "test-user",
559                vec!["mcp:read".to_string(), "mcp:write".to_string()],
560            )
561            .await
562            .unwrap();
563
564        let mut headers = HashMap::new();
565        headers.insert("authorization".to_string(), format!("Bearer {token}"));
566
567        let result = auth
568            .authenticate_request("tools/call", None, &headers)
569            .await;
570        assert!(result.is_ok());
571
572        let context = result.unwrap().unwrap();
573        assert_eq!(context.client_id, "test-client");
574        assert_eq!(context.user_id, "test-user");
575        assert_eq!(context.method, AuthMethod::JwtToken);
576    }
577
578    #[tokio::test]
579    async fn test_certificate_authentication() {
580        let auth = create_test_auth().await;
581
582        let mut headers = HashMap::new();
583        headers.insert(
584            "x-client-cert-thumbprint".to_string(),
585            "abc123def456".to_string(),
586        );
587
588        let result = auth
589            .authenticate_request("tools/call", None, &headers)
590            .await;
591        assert!(result.is_ok());
592
593        let context = result.unwrap().unwrap();
594        assert_eq!(context.client_id, "cert-abc123def456");
595        assert_eq!(context.method, AuthMethod::Certificate);
596    }
597
598    #[tokio::test]
599    async fn test_invalid_api_key() {
600        let auth = create_test_auth().await;
601
602        let mut headers = HashMap::new();
603        headers.insert(
604            "authorization".to_string(),
605            "ApiKey invalid-key".to_string(),
606        );
607
608        let result = auth
609            .authenticate_request("tools/call", None, &headers)
610            .await;
611        assert!(result.is_err());
612    }
613
614    #[tokio::test]
615    async fn test_tool_access_validation() {
616        let auth = create_test_auth().await;
617
618        let context = AuthContext {
619            client_id: "test-client".to_string(),
620            user_id: "test-user".to_string(),
621            method: AuthMethod::ApiKey,
622            scopes: vec!["mcp:read".to_string()],
623            expires_at: None,
624            request_id: "test-request".to_string(),
625        };
626
627        // Should allow read operations
628        assert!(auth.validate_tool_access(&context, "search_memory").is_ok());
629        assert!(auth
630            .validate_tool_access(&context, "get_statistics")
631            .is_ok());
632
633        // Should deny write operations
634        assert!(auth.validate_tool_access(&context, "store_memory").is_err());
635        assert!(auth
636            .validate_tool_access(&context, "delete_memory")
637            .is_err());
638    }
639
640    #[tokio::test]
641    async fn test_token_revocation() {
642        let auth = create_test_auth().await;
643
644        let token = auth
645            .generate_token("test-client", "test-user", vec!["mcp:read".to_string()])
646            .await
647            .unwrap();
648
649        // Token should work initially
650        let mut headers = HashMap::new();
651        headers.insert("authorization".to_string(), format!("Bearer {token}"));
652
653        let result = auth
654            .authenticate_request("tools/call", None, &headers)
655            .await;
656        assert!(result.is_ok());
657
658        // Revoke the token
659        auth.revoke_token(&token).await.unwrap();
660
661        // Token should no longer work
662        let result = auth
663            .authenticate_request("tools/call", None, &headers)
664            .await;
665        assert!(result.is_err());
666    }
667
668    #[tokio::test]
669    async fn test_disabled_authentication() {
670        let mut config = create_test_config();
671        config.enabled = false;
672
673        let temp_dir = tempdir().unwrap();
674        let audit_config = AuditConfig {
675            enabled: true,
676            log_all_requests: true,
677            log_data_access: true,
678            log_modifications: true,
679            log_auth_events: true,
680            retention_days: 30,
681        };
682        let audit_logger = Arc::new(AuditLogger::new(audit_config).unwrap());
683        let auth = MCPAuth::new(config, audit_logger).unwrap();
684
685        let headers = HashMap::new();
686        let result = auth
687            .authenticate_request("tools/call", None, &headers)
688            .await;
689        assert!(result.is_ok());
690        assert!(result.unwrap().is_none()); // Should return None when disabled
691    }
692}