pulseengine_mcp_auth/middleware/
mcp_auth.rs

1//! MCP Authentication Middleware
2//!
3//! This middleware provides comprehensive authentication and authorization
4//! for MCP requests, integrating with the AuthenticationManager and
5//! permission system.
6
7use crate::{AuthContext, AuthenticationManager, models::Role, security::RequestSecurityValidator};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
10use std::collections::HashMap;
11use std::sync::Arc;
12use thiserror::Error;
13use tracing::{debug, error, warn};
14
15/// Errors that can occur during authentication extraction
16#[derive(Debug, Error)]
17pub enum AuthExtractionError {
18    #[error("No authentication provided")]
19    NoAuth,
20
21    #[error("Invalid authentication format: {0}")]
22    InvalidFormat(String),
23
24    #[error("Authentication method not supported: {0}")]
25    UnsupportedMethod(String),
26
27    #[error("Missing required header: {0}")]
28    MissingHeader(String),
29}
30
31/// Configuration for MCP authentication middleware
32#[derive(Debug, Clone)]
33pub struct McpAuthConfig {
34    /// Require authentication for all requests
35    pub require_auth: bool,
36
37    /// Allow anonymous access to specific methods
38    pub anonymous_methods: Vec<String>,
39
40    /// Methods that require specific roles
41    pub method_role_requirements: HashMap<String, Vec<Role>>,
42
43    /// Enable permission checking for tools and resources
44    pub enable_permission_checking: bool,
45
46    /// Custom authentication header name (default: "Authorization")
47    pub auth_header_name: String,
48
49    /// Enable audit logging for authentication events
50    pub enable_audit_logging: bool,
51
52    /// Client IP header name for proxy environments
53    pub client_ip_header: Option<String>,
54}
55
56impl Default for McpAuthConfig {
57    fn default() -> Self {
58        Self {
59            require_auth: true,
60            anonymous_methods: vec!["initialize".to_string(), "ping".to_string()],
61            method_role_requirements: HashMap::new(),
62            enable_permission_checking: true,
63            auth_header_name: "Authorization".to_string(),
64            enable_audit_logging: true,
65            client_ip_header: Some("X-Forwarded-For".to_string()),
66        }
67    }
68}
69
70/// Authentication context extracted from request
71#[derive(Debug, Clone)]
72pub struct McpAuthContext {
73    /// Authenticated API key context
74    pub auth_context: Option<AuthContext>,
75
76    /// Client IP address
77    pub client_ip: Option<String>,
78
79    /// Authentication method used
80    pub auth_method: Option<String>,
81
82    /// Whether the request is anonymous
83    pub is_anonymous: bool,
84}
85
86/// Request context that includes authentication and metadata
87#[derive(Debug, Clone)]
88pub struct McpRequestContext {
89    /// Unique request identifier
90    pub request_id: String,
91
92    /// Authentication context
93    pub auth: McpAuthContext,
94
95    /// Request timestamp
96    pub timestamp: chrono::DateTime<chrono::Utc>,
97
98    /// Additional metadata
99    pub metadata: HashMap<String, String>,
100}
101
102impl McpRequestContext {
103    pub fn new(request_id: String) -> Self {
104        Self {
105            request_id,
106            auth: McpAuthContext {
107                auth_context: None,
108                client_ip: None,
109                auth_method: None,
110                is_anonymous: true,
111            },
112            timestamp: chrono::Utc::now(),
113            metadata: HashMap::new(),
114        }
115    }
116
117    pub fn with_auth(mut self, auth_context: AuthContext, auth_method: String) -> Self {
118        self.auth.auth_context = Some(auth_context);
119        self.auth.auth_method = Some(auth_method);
120        self.auth.is_anonymous = false;
121        self
122    }
123
124    pub fn with_client_ip(mut self, client_ip: String) -> Self {
125        self.auth.client_ip = Some(client_ip);
126        self
127    }
128}
129
130/// MCP Authentication Middleware
131pub struct McpAuthMiddleware {
132    /// Authentication manager for key validation
133    auth_manager: Arc<AuthenticationManager>,
134
135    /// Middleware configuration
136    config: McpAuthConfig,
137
138    /// Request security validator
139    security_validator: Arc<RequestSecurityValidator>,
140}
141
142impl McpAuthMiddleware {
143    /// Create a new MCP authentication middleware
144    pub fn new(auth_manager: Arc<AuthenticationManager>, config: McpAuthConfig) -> Self {
145        Self {
146            auth_manager,
147            config,
148            security_validator: Arc::new(RequestSecurityValidator::default()),
149        }
150    }
151
152    /// Create with custom security validator
153    pub fn with_security_validator(
154        auth_manager: Arc<AuthenticationManager>,
155        config: McpAuthConfig,
156        security_validator: Arc<RequestSecurityValidator>,
157    ) -> Self {
158        Self {
159            auth_manager,
160            config,
161            security_validator,
162        }
163    }
164
165    /// Create middleware with default configuration
166    pub fn with_default_config(auth_manager: Arc<AuthenticationManager>) -> Self {
167        Self::new(auth_manager, McpAuthConfig::default())
168    }
169
170    /// Get access to the security validator for monitoring violations
171    pub fn security_validator(&self) -> &RequestSecurityValidator {
172        &self.security_validator
173    }
174
175    /// Process an incoming MCP request
176    pub async fn process_request(
177        &self,
178        request: Request,
179        headers: Option<&HashMap<String, String>>,
180    ) -> Result<(Request, McpRequestContext), McpError> {
181        // Step 1: Validate request security first
182        if let Err(security_error) = self
183            .security_validator
184            .validate_request(&request, None)
185            .await
186        {
187            error!("Request security validation failed: {}", security_error);
188            return Err(McpError::invalid_request(&format!(
189                "Security validation failed: {}",
190                security_error
191            )));
192        }
193
194        // Step 2: Sanitize request if needed
195        let sanitized_request = self.security_validator.sanitize_request(request).await;
196
197        let request_id = match &sanitized_request.id {
198            Some(id) => id.to_string(),
199            None => uuid::Uuid::new_v4().to_string(),
200        };
201        let mut context = McpRequestContext::new(request_id);
202
203        // Extract client IP if available
204        if let Some(headers) = headers {
205            if let Some(ip_header) = &self.config.client_ip_header {
206                if let Some(client_ip) = headers.get(ip_header) {
207                    context = context.with_client_ip(client_ip.clone());
208                }
209            }
210        }
211
212        // Check if authentication is required for this method
213        if self.should_skip_auth(&sanitized_request.method) {
214            debug!(
215                "Skipping authentication for method: {}",
216                sanitized_request.method
217            );
218            return Ok((sanitized_request, context));
219        }
220
221        // Extract authentication from headers
222        let auth_result = if let Some(headers) = headers {
223            self.extract_authentication(headers).await
224        } else {
225            Err(AuthExtractionError::NoAuth)
226        };
227
228        match auth_result {
229            Ok((auth_context, auth_method)) => {
230                // Authentication successful
231                context = context.with_auth(auth_context, auth_method);
232
233                // Check method-specific role requirements
234                if let Err(e) = self
235                    .check_method_permissions(&sanitized_request.method, &context)
236                    .await
237                {
238                    error!("Method permission check failed: {}", e);
239                    return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
240                }
241
242                debug!("Request authenticated successfully");
243                Ok((sanitized_request, context))
244            }
245            Err(e) => {
246                if self.config.require_auth {
247                    warn!("Authentication failed: {}", e);
248                    Err(McpError::invalid_request(&format!(
249                        "Authentication required: {}",
250                        e
251                    )))
252                } else {
253                    debug!("Authentication failed but not required: {}", e);
254                    Ok((sanitized_request, context))
255                }
256            }
257        }
258    }
259
260    /// Process an outgoing MCP response
261    pub async fn process_response(
262        &self,
263        response: Response,
264        _context: &McpRequestContext,
265    ) -> Result<Response, McpError> {
266        // Add security headers or process response as needed
267        // For now, just pass through
268        Ok(response)
269    }
270
271    /// Extract authentication from request headers
272    async fn extract_authentication(
273        &self,
274        headers: &HashMap<String, String>,
275    ) -> Result<(AuthContext, String), AuthExtractionError> {
276        // Try to extract from Authorization header
277        if let Some(auth_header) = headers.get(&self.config.auth_header_name) {
278            return self.parse_auth_header(auth_header).await;
279        }
280
281        // Try to extract from X-API-Key header
282        if let Some(api_key) = headers.get("X-API-Key") {
283            return self.validate_api_key(api_key, "X-API-Key").await;
284        }
285
286        Err(AuthExtractionError::NoAuth)
287    }
288
289    /// Parse the Authorization header
290    async fn parse_auth_header(
291        &self,
292        auth_header: &str,
293    ) -> Result<(AuthContext, String), AuthExtractionError> {
294        let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
295        if parts.len() != 2 {
296            return Err(AuthExtractionError::InvalidFormat(
297                "Authorization header must be in format 'Type Token'".to_string(),
298            ));
299        }
300
301        let auth_type = parts[0].to_lowercase();
302        let token = parts[1];
303
304        match auth_type.as_str() {
305            "bearer" => self.validate_api_key(token, "Bearer").await,
306            "apikey" => self.validate_api_key(token, "ApiKey").await,
307            _ => Err(AuthExtractionError::UnsupportedMethod(auth_type)),
308        }
309    }
310
311    /// Validate an API key
312    async fn validate_api_key(
313        &self,
314        api_key: &str,
315        method: &str,
316    ) -> Result<(AuthContext, String), AuthExtractionError> {
317        match self.auth_manager.validate_api_key(api_key, None).await {
318            Ok(Some(auth_context)) => Ok((auth_context, method.to_string())),
319            Ok(None) => Err(AuthExtractionError::InvalidFormat(
320                "Invalid API key".to_string(),
321            )),
322            Err(e) => {
323                error!("API key validation failed: {}", e);
324                Err(AuthExtractionError::InvalidFormat(
325                    "Authentication failed".to_string(),
326                ))
327            }
328        }
329    }
330
331    /// Check if authentication should be skipped for a method
332    fn should_skip_auth(&self, method: &str) -> bool {
333        if !self.config.require_auth {
334            return true;
335        }
336
337        self.config.anonymous_methods.contains(&method.to_string())
338    }
339
340    /// Check method-specific role requirements
341    async fn check_method_permissions(
342        &self,
343        method: &str,
344        context: &McpRequestContext,
345    ) -> Result<(), String> {
346        // If no specific requirements, allow
347        if let Some(required_roles) = self.config.method_role_requirements.get(method) {
348            if let Some(auth_context) = &context.auth.auth_context {
349                // Check if user has one of the required roles
350                let has_required_role = auth_context
351                    .roles
352                    .iter()
353                    .any(|role| required_roles.contains(role));
354                if !has_required_role {
355                    return Err(format!(
356                        "Method '{}' requires one of these roles: {:?}, but user has roles: {:?}",
357                        method, required_roles, auth_context.roles
358                    ));
359                }
360            } else {
361                return Err(format!("Method '{}' requires authentication", method));
362            }
363        }
364
365        Ok(())
366    }
367}
368
369/// Trait for middleware that can process MCP requests and responses
370#[async_trait]
371pub trait McpMiddleware: Send + Sync {
372    /// Process an incoming request
373    async fn process_request(
374        &self,
375        request: Request,
376        context: &McpRequestContext,
377    ) -> Result<Request, McpError>;
378
379    /// Process an outgoing response
380    async fn process_response(
381        &self,
382        response: Response,
383        context: &McpRequestContext,
384    ) -> Result<Response, McpError>;
385}
386
387#[async_trait]
388impl McpMiddleware for McpAuthMiddleware {
389    async fn process_request(
390        &self,
391        request: Request,
392        _context: &McpRequestContext,
393    ) -> Result<Request, McpError> {
394        // This implementation assumes context has already been created
395        // by the initial process_request call
396        Ok(request)
397    }
398
399    async fn process_response(
400        &self,
401        response: Response,
402        context: &McpRequestContext,
403    ) -> Result<Response, McpError> {
404        self.process_response(response, context).await
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::AuthConfig;
412
413    #[tokio::test]
414    async fn test_auth_middleware_creation() {
415        let config = AuthConfig::memory();
416        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
417        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
418
419        assert!(!middleware.config.anonymous_methods.is_empty());
420        assert!(middleware.config.require_auth);
421    }
422
423    #[tokio::test]
424    async fn test_anonymous_method_detection() {
425        let config = AuthConfig::memory();
426        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
427        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
428
429        assert!(middleware.should_skip_auth("initialize"));
430        assert!(middleware.should_skip_auth("ping"));
431        assert!(!middleware.should_skip_auth("tools/call"));
432    }
433
434    #[tokio::test]
435    async fn test_auth_header_parsing() {
436        let config = AuthConfig::memory();
437        let auth_manager = Arc::new(AuthenticationManager::new(config).await.unwrap());
438        let middleware = McpAuthMiddleware::with_default_config(auth_manager);
439
440        // Test invalid format
441        let result = middleware.parse_auth_header("invalid").await;
442        assert!(result.is_err());
443
444        // Test unsupported method
445        let result = middleware.parse_auth_header("Basic token123").await;
446        assert!(matches!(
447            result,
448            Err(AuthExtractionError::UnsupportedMethod(_))
449        ));
450    }
451}