pulseengine_mcp_auth/middleware/
session_middleware.rs

1//! Session-Aware MCP Authentication Middleware
2//!
3//! This middleware extends the basic MCP authentication to include session management,
4//! JWT token validation, and enhanced security features.
5
6use crate::{
7    AuthContext, AuthenticationManager,
8    jwt::JwtError,
9    middleware::mcp_auth::{AuthExtractionError, McpAuthConfig, McpRequestContext},
10    security::RequestSecurityValidator,
11    session::{Session, SessionError, SessionManager},
12};
13use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use tracing::{debug, error, info, warn};
18
19/// Errors specific to session middleware
20#[derive(Debug, Error)]
21pub enum SessionMiddlewareError {
22    #[error("Session error: {0}")]
23    SessionError(#[from] SessionError),
24
25    #[error("Authentication error: {0}")]
26    AuthError(#[from] AuthExtractionError),
27
28    #[error("JWT validation failed: {0}")]
29    JwtError(#[from] JwtError),
30
31    #[error("Invalid session token format")]
32    InvalidTokenFormat,
33
34    #[error("Session required but not provided")]
35    SessionRequired,
36}
37
38/// Enhanced configuration for session-aware middleware
39#[derive(Debug, Clone)]
40pub struct SessionMiddlewareConfig {
41    /// Base MCP auth configuration
42    pub auth_config: McpAuthConfig,
43
44    /// Enable session management
45    pub enable_sessions: bool,
46
47    /// Require sessions for authenticated requests
48    pub require_sessions: bool,
49
50    /// Enable JWT token authentication
51    pub enable_jwt_auth: bool,
52
53    /// JWT token header name
54    pub jwt_header_name: String,
55
56    /// Session ID header name
57    pub session_header_name: String,
58
59    /// Enable automatic session creation for API keys
60    pub auto_create_sessions: bool,
61
62    /// Session duration for auto-created sessions
63    pub auto_session_duration: Option<chrono::Duration>,
64
65    /// Enable session extension on access
66    pub extend_sessions_on_access: bool,
67
68    /// Methods that bypass session requirements
69    pub session_exempt_methods: Vec<String>,
70}
71
72impl Default for SessionMiddlewareConfig {
73    fn default() -> Self {
74        Self {
75            auth_config: McpAuthConfig::default(),
76            enable_sessions: true,
77            require_sessions: false, // Optional by default
78            enable_jwt_auth: true,
79            jwt_header_name: "Authorization".to_string(),
80            session_header_name: "X-Session-ID".to_string(),
81            auto_create_sessions: true,
82            auto_session_duration: Some(chrono::Duration::hours(24)),
83            extend_sessions_on_access: true,
84            session_exempt_methods: vec!["initialize".to_string(), "ping".to_string()],
85        }
86    }
87}
88
89/// Enhanced request context with session information
90#[derive(Debug, Clone)]
91pub struct SessionRequestContext {
92    /// Base request context
93    pub base_context: McpRequestContext,
94
95    /// Active session (if any)
96    pub session: Option<Session>,
97
98    /// Whether request used JWT authentication
99    pub jwt_authenticated: bool,
100
101    /// Session was created automatically
102    pub auto_created_session: bool,
103}
104
105impl SessionRequestContext {
106    pub fn new(base_context: McpRequestContext) -> Self {
107        Self {
108            base_context,
109            session: None,
110            jwt_authenticated: false,
111            auto_created_session: false,
112        }
113    }
114
115    pub fn with_session(mut self, session: Session, auto_created: bool) -> Self {
116        self.session = Some(session);
117        self.auto_created_session = auto_created;
118        self
119    }
120
121    pub fn with_jwt_auth(mut self) -> Self {
122        self.jwt_authenticated = true;
123        self
124    }
125
126    /// Get the session ID if available
127    pub fn session_id(&self) -> Option<&str> {
128        self.session.as_ref().map(|s| s.session_id.as_str())
129    }
130
131    /// Get the user ID from session or auth context
132    pub fn user_id(&self) -> Option<String> {
133        if let Some(session) = &self.session {
134            Some(session.user_id.clone())
135        } else if let Some(auth_context) = &self.base_context.auth.auth_context {
136            auth_context.api_key_id.clone()
137        } else {
138            None
139        }
140    }
141}
142
143/// Session-aware MCP authentication middleware
144pub struct SessionMiddleware {
145    /// Authentication manager
146    auth_manager: Arc<AuthenticationManager>,
147
148    /// Session manager
149    session_manager: Arc<SessionManager>,
150
151    /// Security validator
152    security_validator: Arc<RequestSecurityValidator>,
153
154    /// Middleware configuration
155    config: SessionMiddlewareConfig,
156}
157
158impl SessionMiddleware {
159    /// Create new session middleware
160    pub fn new(
161        auth_manager: Arc<AuthenticationManager>,
162        session_manager: Arc<SessionManager>,
163        security_validator: Arc<RequestSecurityValidator>,
164        config: SessionMiddlewareConfig,
165    ) -> Self {
166        Self {
167            auth_manager,
168            session_manager,
169            security_validator,
170            config,
171        }
172    }
173
174    /// Create with default configuration
175    pub fn with_default_config(
176        auth_manager: Arc<AuthenticationManager>,
177        session_manager: Arc<SessionManager>,
178    ) -> Self {
179        Self::new(
180            auth_manager,
181            session_manager,
182            Arc::new(RequestSecurityValidator::default()),
183            SessionMiddlewareConfig::default(),
184        )
185    }
186
187    /// Process an incoming MCP request with session awareness
188    pub async fn process_request(
189        &self,
190        request: Request,
191        headers: Option<&HashMap<String, String>>,
192    ) -> Result<(Request, SessionRequestContext), McpError> {
193        // Step 1: Security validation (same as before)
194        if let Err(security_error) = self
195            .security_validator
196            .validate_request(&request, None)
197            .await
198        {
199            error!("Request security validation failed: {}", security_error);
200            return Err(McpError::invalid_request(&format!(
201                "Security validation failed: {}",
202                security_error
203            )));
204        }
205
206        let sanitized_request = self.security_validator.sanitize_request(request).await;
207
208        // Step 2: Extract request ID and create base context
209        let request_id = match &sanitized_request.id {
210            Some(id) => id.to_string(),
211            None => uuid::Uuid::new_v4().to_string(),
212        };
213
214        let mut base_context = McpRequestContext::new(request_id);
215        let mut session_context = SessionRequestContext::new(base_context.clone());
216
217        // Step 3: Extract client IP
218        if let Some(headers) = headers {
219            if let Some(ip_header) = &self.config.auth_config.client_ip_header {
220                if let Some(client_ip) = headers.get(ip_header) {
221                    base_context = base_context.with_client_ip(client_ip.clone());
222                }
223            }
224        }
225
226        // Step 4: Check if this method requires authentication/sessions
227        if self.should_skip_auth(&sanitized_request.method) {
228            debug!(
229                "Skipping authentication for method: {}",
230                sanitized_request.method
231            );
232            session_context.base_context = base_context;
233            return Ok((sanitized_request, session_context));
234        }
235
236        // Step 5: Try different authentication methods
237        let auth_result = self.authenticate_request(headers).await;
238
239        match auth_result {
240            Ok((auth_context, auth_method, session)) => {
241                // Authentication successful
242                base_context = base_context.with_auth(auth_context.clone(), auth_method.clone());
243
244                if auth_method.starts_with("JWT") {
245                    session_context = session_context.with_jwt_auth();
246                }
247
248                if let Some(session) = session {
249                    session_context = session_context.with_session(session, false);
250                } else if self.config.auto_create_sessions && !session_context.jwt_authenticated {
251                    // Auto-create session for API key authentication
252                    match self.create_auto_session(&auth_context, headers).await {
253                        Ok(session) => {
254                            session_context = session_context.with_session(session, true);
255                            info!(
256                                "Auto-created session for user: {:?}",
257                                auth_context.api_key_id
258                            );
259                        }
260                        Err(e) => {
261                            warn!("Failed to auto-create session: {}", e);
262                        }
263                    }
264                }
265
266                // Check method permissions
267                if let Err(e) = self
268                    .check_method_permissions(&sanitized_request.method, &base_context)
269                    .await
270                {
271                    error!("Method permission check failed: {}", e);
272                    return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
273                }
274
275                session_context.base_context = base_context;
276                debug!("Request authenticated successfully");
277                Ok((sanitized_request, session_context))
278            }
279            Err(e) => {
280                if self.config.auth_config.require_auth {
281                    warn!("Authentication failed: {}", e);
282                    Err(McpError::invalid_request(&format!(
283                        "Authentication required: {}",
284                        e
285                    )))
286                } else {
287                    debug!("Authentication failed but not required: {}", e);
288                    session_context.base_context = base_context;
289                    Ok((sanitized_request, session_context))
290                }
291            }
292        }
293    }
294
295    /// Authenticate request using multiple methods
296    async fn authenticate_request(
297        &self,
298        headers: Option<&HashMap<String, String>>,
299    ) -> Result<(AuthContext, String, Option<Session>), SessionMiddlewareError> {
300        if let Some(headers) = headers {
301            // Try JWT authentication first
302            if self.config.enable_jwt_auth {
303                if let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await {
304                    return Ok((auth_context, method, None));
305                }
306            }
307
308            // Try session ID authentication
309            if self.config.enable_sessions {
310                if let Ok((auth_context, session)) = self.try_session_authentication(headers).await
311                {
312                    return Ok((auth_context, "Session".to_string(), Some(session)));
313                }
314            }
315
316            // Fall back to traditional API key authentication
317            if let Ok((auth_context, method)) = self.try_api_key_authentication(headers).await {
318                return Ok((auth_context, method, None));
319            }
320        }
321
322        Err(SessionMiddlewareError::AuthError(
323            AuthExtractionError::NoAuth,
324        ))
325    }
326
327    /// Try JWT token authentication
328    async fn try_jwt_authentication(
329        &self,
330        headers: &HashMap<String, String>,
331    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
332        if let Some(auth_header) = headers.get(&self.config.jwt_header_name) {
333            if auth_header.starts_with("Bearer ") {
334                let token = &auth_header[7..];
335                let auth_context = self.session_manager.validate_jwt_token(token).await?;
336                return Ok((auth_context, "JWT".to_string()));
337            }
338        }
339
340        Err(SessionMiddlewareError::AuthError(
341            AuthExtractionError::NoAuth,
342        ))
343    }
344
345    /// Try session ID authentication
346    async fn try_session_authentication(
347        &self,
348        headers: &HashMap<String, String>,
349    ) -> Result<(AuthContext, Session), SessionMiddlewareError> {
350        if let Some(session_id) = headers.get(&self.config.session_header_name) {
351            let session = self.session_manager.validate_session(session_id).await?;
352            return Ok((session.auth_context.clone(), session));
353        }
354
355        Err(SessionMiddlewareError::AuthError(
356            AuthExtractionError::NoAuth,
357        ))
358    }
359
360    /// Try API key authentication
361    async fn try_api_key_authentication(
362        &self,
363        headers: &HashMap<String, String>,
364    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
365        // Try Authorization header
366        if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) {
367            if let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await {
368                return Ok((auth_context, method));
369            }
370        }
371
372        // Try X-API-Key header
373        if let Some(api_key) = headers.get("X-API-Key") {
374            if let Ok(auth_context) = self.validate_api_key(api_key).await {
375                return Ok((auth_context, "X-API-Key".to_string()));
376            }
377        }
378
379        Err(SessionMiddlewareError::AuthError(
380            AuthExtractionError::NoAuth,
381        ))
382    }
383
384    /// Parse Authorization header
385    async fn parse_auth_header(
386        &self,
387        auth_header: &str,
388    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
389        let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
390        if parts.len() != 2 {
391            return Err(SessionMiddlewareError::AuthError(
392                AuthExtractionError::InvalidFormat(
393                    "Invalid Authorization header format".to_string(),
394                ),
395            ));
396        }
397
398        match parts[0] {
399            "Bearer" => {
400                let auth_context = self.validate_api_key(parts[1]).await?;
401                Ok((auth_context, "Bearer".to_string()))
402            }
403            "Basic" => {
404                use base64::{Engine as _, engine::general_purpose};
405                let decoded = general_purpose::STANDARD.decode(parts[1]).map_err(|_| {
406                    SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
407                        "Invalid Base64 in Basic auth".to_string(),
408                    ))
409                })?;
410
411                let decoded_str = String::from_utf8(decoded).map_err(|_| {
412                    SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
413                        "Invalid UTF-8 in Basic auth".to_string(),
414                    ))
415                })?;
416
417                let auth_parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
418                if auth_parts.is_empty() {
419                    return Err(SessionMiddlewareError::AuthError(
420                        AuthExtractionError::InvalidFormat(
421                            "Basic auth must contain username".to_string(),
422                        ),
423                    ));
424                }
425
426                let auth_context = self.validate_api_key(auth_parts[0]).await?;
427                Ok((auth_context, "Basic".to_string()))
428            }
429            _ => Err(SessionMiddlewareError::AuthError(
430                AuthExtractionError::UnsupportedMethod(parts[0].to_string()),
431            )),
432        }
433    }
434
435    /// Validate API key and return auth context
436    async fn validate_api_key(&self, api_key: &str) -> Result<AuthContext, SessionMiddlewareError> {
437        let auth_result = self
438            .auth_manager
439            .validate_api_key(api_key, None)
440            .await
441            .map_err(|e| {
442                SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(format!(
443                    "API key validation failed: {}",
444                    e
445                )))
446            })?;
447
448        auth_result.ok_or_else(|| {
449            SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
450                "Invalid API key".to_string(),
451            ))
452        })
453    }
454
455    /// Create automatic session for API key authentication
456    async fn create_auto_session(
457        &self,
458        auth_context: &AuthContext,
459        headers: Option<&HashMap<String, String>>,
460    ) -> Result<Session, SessionError> {
461        let client_ip = headers
462            .and_then(|h| {
463                self.config
464                    .auth_config
465                    .client_ip_header
466                    .as_ref()
467                    .and_then(|ip_header| h.get(ip_header))
468            })
469            .cloned();
470
471        let user_agent = headers.and_then(|h| h.get("User-Agent")).cloned();
472
473        let user_id = auth_context.api_key_id.clone().unwrap_or_else(|| {
474            auth_context
475                .user_id
476                .clone()
477                .unwrap_or_else(|| "unknown".to_string())
478        });
479
480        let (session, _) = self
481            .session_manager
482            .create_session(
483                user_id,
484                auth_context.clone(),
485                self.config.auto_session_duration,
486                client_ip,
487                user_agent,
488            )
489            .await?;
490
491        Ok(session)
492    }
493
494    /// Check if authentication should be skipped for this method
495    fn should_skip_auth(&self, method: &str) -> bool {
496        self.config
497            .auth_config
498            .anonymous_methods
499            .contains(&method.to_string())
500            || self
501                .config
502                .session_exempt_methods
503                .contains(&method.to_string())
504    }
505
506    /// Check method-specific permissions (placeholder - would integrate with permission system)
507    async fn check_method_permissions(
508        &self,
509        _method: &str,
510        _context: &McpRequestContext,
511    ) -> Result<(), String> {
512        // This would integrate with the permission system
513        // For now, just return Ok
514        Ok(())
515    }
516
517    /// Process response (add session headers if needed)
518    pub async fn process_response(
519        &self,
520        response: Response,
521        context: &SessionRequestContext,
522    ) -> Result<(Response, HashMap<String, String>), McpError> {
523        let mut response_headers = HashMap::new();
524
525        // Add session ID to response headers if session exists
526        if let Some(session) = &context.session {
527            response_headers.insert(
528                self.config.session_header_name.clone(),
529                session.session_id.clone(),
530            );
531
532            if context.auto_created_session {
533                response_headers.insert("X-Session-Created".to_string(), "true".to_string());
534            }
535        }
536
537        Ok((response, response_headers))
538    }
539
540    /// Get session manager for external access
541    pub fn session_manager(&self) -> &SessionManager {
542        &self.session_manager
543    }
544
545    /// Get authentication manager
546    pub fn auth_manager(&self) -> &AuthenticationManager {
547        &self.auth_manager
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use crate::{
555        AuthConfig,
556        session::{MemorySessionStorage, SessionConfig},
557    };
558
559    async fn create_test_middleware() -> SessionMiddleware {
560        let auth_manager = Arc::new(
561            crate::AuthenticationManager::new(AuthConfig::memory())
562                .await
563                .unwrap(),
564        );
565        let session_manager = Arc::new(SessionManager::new(
566            SessionConfig::default(),
567            Arc::new(MemorySessionStorage::new()),
568        ));
569
570        SessionMiddleware::with_default_config(auth_manager, session_manager)
571    }
572
573    #[tokio::test]
574    async fn test_session_middleware_creation() {
575        let middleware = create_test_middleware().await;
576
577        // Just test that it was created successfully
578        assert!(middleware.config.enable_sessions);
579    }
580
581    #[tokio::test]
582    async fn test_anonymous_request_processing() {
583        let middleware = create_test_middleware().await;
584
585        let request = Request {
586            jsonrpc: "2.0".to_string(),
587            method: "initialize".to_string(), // Anonymous method
588            params: serde_json::json!({}),
589            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
590        };
591
592        let result = middleware.process_request(request, None).await;
593        assert!(result.is_ok());
594
595        let (_, context) = result.unwrap();
596        assert!(context.session.is_none());
597        assert!(context.base_context.auth.is_anonymous);
598    }
599}