pulseengine_mcp_auth/middleware/
mcp_auth.rs

1//! MCP Authentication Middleware
2//!
3//! This middleware provides comprehensive authentication and authorization
4//! for MCP requests, integrating with the AuthenticationManager and
5//! permission system.
6
7use crate::{models::Role, security::RequestSecurityValidator, AuthContext, AuthenticationManager};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
10use std::collections::HashMap;
11use std::sync::Arc;
12use thiserror::Error;
13use tracing::{debug, error, warn};
14
15/// Errors that can occur during authentication extraction
16#[derive(Debug, Error)]
17pub enum AuthExtractionError {
18    #[error("No authentication provided")]
19    NoAuth,
20
21    #[error("Invalid authentication format: {0}")]
22    InvalidFormat(String),
23
24    #[error("Authentication method not supported: {0}")]
25    UnsupportedMethod(String),
26
27    #[error("Missing required header: {0}")]
28    MissingHeader(String),
29}
30
31/// Configuration for MCP authentication middleware
32#[derive(Debug, Clone)]
33pub struct McpAuthConfig {
34    /// Require authentication for all requests
35    pub require_auth: bool,
36
37    /// Allow anonymous access to specific methods
38    pub anonymous_methods: Vec<String>,
39
40    /// Methods that require specific roles
41    pub method_role_requirements: HashMap<String, Vec<Role>>,
42
43    /// Enable permission checking for tools and resources
44    pub enable_permission_checking: bool,
45
46    /// Custom authentication header name (default: "Authorization")
47    pub auth_header_name: String,
48
49    /// Enable audit logging for authentication events
50    pub enable_audit_logging: bool,
51
52    /// Client IP header name for proxy environments
53    pub client_ip_header: Option<String>,
54}
55
56impl Default for McpAuthConfig {
57    fn default() -> Self {
58        Self {
59            require_auth: true,
60            anonymous_methods: vec!["initialize".to_string(), "ping".to_string()],
61            method_role_requirements: HashMap::new(),
62            enable_permission_checking: true,
63            auth_header_name: "Authorization".to_string(),
64            enable_audit_logging: true,
65            client_ip_header: Some("X-Forwarded-For".to_string()),
66        }
67    }
68}
69
70/// Authentication context extracted from request
71#[derive(Debug, Clone)]
72pub struct McpAuthContext {
73    /// Authenticated API key context
74    pub auth_context: Option<AuthContext>,
75
76    /// Client IP address
77    pub client_ip: Option<String>,
78
79    /// Authentication method used
80    pub auth_method: Option<String>,
81
82    /// Whether the request is anonymous
83    pub is_anonymous: bool,
84}
85
86/// Request context that includes authentication and metadata
87#[derive(Debug, Clone)]
88pub struct McpRequestContext {
89    /// Unique request identifier
90    pub request_id: String,
91
92    /// Authentication context
93    pub auth: McpAuthContext,
94
95    /// Request timestamp
96    pub timestamp: chrono::DateTime<chrono::Utc>,
97
98    /// Additional metadata
99    pub metadata: HashMap<String, String>,
100}
101
102impl McpRequestContext {
103    pub fn new(request_id: String) -> Self {
104        Self {
105            request_id,
106            auth: McpAuthContext {
107                auth_context: None,
108                client_ip: None,
109                auth_method: None,
110                is_anonymous: true,
111            },
112            timestamp: chrono::Utc::now(),
113            metadata: HashMap::new(),
114        }
115    }
116
117    pub fn with_auth(mut self, auth_context: AuthContext, auth_method: String) -> Self {
118        self.auth.auth_context = Some(auth_context);
119        self.auth.auth_method = Some(auth_method);
120        self.auth.is_anonymous = false;
121        self
122    }
123
124    pub fn with_client_ip(mut self, client_ip: String) -> Self {
125        self.auth.client_ip = Some(client_ip);
126        self
127    }
128}
129
130/// MCP Authentication Middleware
131pub struct McpAuthMiddleware {
132    /// Authentication manager for key validation
133    auth_manager: Arc<AuthenticationManager>,
134
135    /// Middleware configuration
136    config: McpAuthConfig,
137
138    /// Request security validator
139    security_validator: Arc<RequestSecurityValidator>,
140}
141
142impl McpAuthMiddleware {
143    /// Create a new MCP authentication middleware
144    pub fn new(auth_manager: Arc<AuthenticationManager>, config: McpAuthConfig) -> Self {
145        Self {
146            auth_manager,
147            config,
148            security_validator: Arc::new(RequestSecurityValidator::default()),
149        }
150    }
151
152    /// Create with custom security validator
153    pub fn with_security_validator(
154        auth_manager: Arc<AuthenticationManager>,
155        config: McpAuthConfig,
156        security_validator: Arc<RequestSecurityValidator>,
157    ) -> Self {
158        Self {
159            auth_manager,
160            config,
161            security_validator,
162        }
163    }
164
165    /// Create middleware with default configuration
166    pub fn with_default_config(auth_manager: Arc<AuthenticationManager>) -> Self {
167        Self::new(auth_manager, McpAuthConfig::default())
168    }
169
170    /// Get access to the security validator for monitoring violations
171    pub fn security_validator(&self) -> &RequestSecurityValidator {
172        &self.security_validator
173    }
174
175    /// Process an incoming MCP request
176    pub async fn process_request(
177        &self,
178        request: Request,
179        headers: Option<&HashMap<String, String>>,
180    ) -> Result<(Request, McpRequestContext), McpError> {
181        // Step 1: Validate request security first
182        if let Err(security_error) = self
183            .security_validator
184            .validate_request(&request, None)
185            .await
186        {
187            error!("Request security validation failed: {}", security_error);
188            return Err(McpError::invalid_request(&format!(
189                "Security validation failed: {}",
190                security_error
191            )));
192        }
193
194        // Step 2: Sanitize request if needed
195        let sanitized_request = self.security_validator.sanitize_request(request).await;
196
197        let request_id = match &sanitized_request.id {
198            serde_json::Value::String(s) => s.clone(),
199            serde_json::Value::Number(n) => n.to_string(),
200            serde_json::Value::Null => uuid::Uuid::new_v4().to_string(),
201            _ => uuid::Uuid::new_v4().to_string(),
202        };
203        let mut context = McpRequestContext::new(request_id);
204
205        // Extract client IP if available
206        if let Some(headers) = headers {
207            if let Some(ip_header) = &self.config.client_ip_header {
208                if let Some(client_ip) = headers.get(ip_header) {
209                    context = context.with_client_ip(client_ip.clone());
210                }
211            }
212        }
213
214        // Check if authentication is required for this method
215        if self.should_skip_auth(&sanitized_request.method) {
216            debug!(
217                "Skipping authentication for method: {}",
218                sanitized_request.method
219            );
220            return Ok((sanitized_request, context));
221        }
222
223        // Extract authentication from headers
224        let auth_result = if let Some(headers) = headers {
225            self.extract_authentication(headers).await
226        } else {
227            Err(AuthExtractionError::NoAuth)
228        };
229
230        match auth_result {
231            Ok((auth_context, auth_method)) => {
232                // Authentication successful
233                context = context.with_auth(auth_context, auth_method);
234
235                // Check method-specific role requirements
236                if let Err(e) = self
237                    .check_method_permissions(&sanitized_request.method, &context)
238                    .await
239                {
240                    error!("Method permission check failed: {}", e);
241                    return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
242                }
243
244                debug!("Request authenticated successfully");
245                Ok((sanitized_request, context))
246            }
247            Err(e) => {
248                if self.config.require_auth {
249                    warn!("Authentication failed: {}", e);
250                    Err(McpError::invalid_request(&format!(
251                        "Authentication required: {}",
252                        e
253                    )))
254                } else {
255                    debug!("Authentication failed but not required: {}", e);
256                    Ok((sanitized_request, context))
257                }
258            }
259        }
260    }
261
262    /// Process an outgoing MCP response
263    pub async fn process_response(
264        &self,
265        response: Response,
266        _context: &McpRequestContext,
267    ) -> Result<Response, McpError> {
268        // Add security headers or process response as needed
269        // For now, just pass through
270        Ok(response)
271    }
272
273    /// Extract authentication from request headers
274    async fn extract_authentication(
275        &self,
276        headers: &HashMap<String, String>,
277    ) -> Result<(AuthContext, String), AuthExtractionError> {
278        // Try to extract from Authorization header
279        if let Some(auth_header) = headers.get(&self.config.auth_header_name) {
280            return self.parse_auth_header(auth_header).await;
281        }
282
283        // Try to extract from X-API-Key header
284        if let Some(api_key) = headers.get("X-API-Key") {
285            return self.validate_api_key(api_key, "X-API-Key").await;
286        }
287
288        Err(AuthExtractionError::NoAuth)
289    }
290
291    /// Parse the Authorization header
292    async fn parse_auth_header(
293        &self,
294        auth_header: &str,
295    ) -> Result<(AuthContext, String), AuthExtractionError> {
296        let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
297        if parts.len() != 2 {
298            return Err(AuthExtractionError::InvalidFormat(
299                "Authorization header must be in format 'Type Token'".to_string(),
300            ));
301        }
302
303        let auth_type = parts[0].to_lowercase();
304        let token = parts[1];
305
306        match auth_type.as_str() {
307            "bearer" => self.validate_api_key(token, "Bearer").await,
308            "apikey" => self.validate_api_key(token, "ApiKey").await,
309            _ => Err(AuthExtractionError::UnsupportedMethod(auth_type)),
310        }
311    }
312
313    /// Validate an API key
314    async fn validate_api_key(
315        &self,
316        api_key: &str,
317        method: &str,
318    ) -> Result<(AuthContext, String), AuthExtractionError> {
319        match self.auth_manager.validate_api_key(api_key, None).await {
320            Ok(Some(auth_context)) => Ok((auth_context, method.to_string())),
321            Ok(None) => Err(AuthExtractionError::InvalidFormat(
322                "Invalid API key".to_string(),
323            )),
324            Err(e) => {
325                error!("API key validation failed: {}", e);
326                Err(AuthExtractionError::InvalidFormat(
327                    "Authentication failed".to_string(),
328                ))
329            }
330        }
331    }
332
333    /// Check if authentication should be skipped for a method
334    fn should_skip_auth(&self, method: &str) -> bool {
335        if !self.config.require_auth {
336            return true;
337        }
338
339        self.config.anonymous_methods.contains(&method.to_string())
340    }
341
342    /// Check method-specific role requirements
343    async fn check_method_permissions(
344        &self,
345        method: &str,
346        context: &McpRequestContext,
347    ) -> Result<(), String> {
348        // If no specific requirements, allow
349        if let Some(required_roles) = self.config.method_role_requirements.get(method) {
350            if let Some(auth_context) = &context.auth.auth_context {
351                // Check if user has one of the required roles
352                let has_required_role = auth_context
353                    .roles
354                    .iter()
355                    .any(|role| required_roles.contains(role));
356                if !has_required_role {
357                    return Err(format!(
358                        "Method '{}' requires one of these roles: {:?}, but user has roles: {:?}",
359                        method, required_roles, auth_context.roles
360                    ));
361                }
362            } else {
363                return Err(format!("Method '{}' requires authentication", method));
364            }
365        }
366
367        Ok(())
368    }
369}
370
371/// Trait for middleware that can process MCP requests and responses
372#[async_trait]
373pub trait McpMiddleware: Send + Sync {
374    /// Process an incoming request
375    async fn process_request(
376        &self,
377        request: Request,
378        context: &McpRequestContext,
379    ) -> Result<Request, McpError>;
380
381    /// Process an outgoing response
382    async fn process_response(
383        &self,
384        response: Response,
385        context: &McpRequestContext,
386    ) -> Result<Response, McpError>;
387}
388
389#[async_trait]
390impl McpMiddleware for McpAuthMiddleware {
391    async fn process_request(
392        &self,
393        request: Request,
394        _context: &McpRequestContext,
395    ) -> Result<Request, McpError> {
396        // This implementation assumes context has already been created
397        // by the initial process_request call
398        Ok(request)
399    }
400
401    async fn process_response(
402        &self,
403        response: Response,
404        context: &McpRequestContext,
405    ) -> Result<Response, McpError> {
406        self.process_response(response, context).await
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::AuthConfig;
414
415    #[tokio::test]
416    async fn test_auth_middleware_creation() {
417        let config = AuthConfig::memory();
418        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
419        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
420
421        assert!(!middleware.config.anonymous_methods.is_empty());
422        assert!(middleware.config.require_auth);
423    }
424
425    #[tokio::test]
426    async fn test_anonymous_method_detection() {
427        let config = AuthConfig::memory();
428        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
429        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
430
431        assert!(middleware.should_skip_auth("initialize"));
432        assert!(middleware.should_skip_auth("ping"));
433        assert!(!middleware.should_skip_auth("tools/call"));
434    }
435
436    #[tokio::test]
437    async fn test_auth_header_parsing() {
438        let config = AuthConfig::memory();
439        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
440        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
441
442        // Test invalid format
443        let result = middleware.parse_auth_header("invalid").await;
444        assert!(result.is_err());
445
446        // Test unsupported method
447        let result = middleware.parse_auth_header("Basic token123").await;
448        assert!(matches!(
449            result,
450            Err(AuthExtractionError::UnsupportedMethod(_))
451        ));
452    }
453}