pulseengine_mcp_auth/middleware/
session_middleware.rs

1//! Session-Aware MCP Authentication Middleware
2//!
3//! This middleware extends the basic MCP authentication to include session management,
4//! JWT token validation, and enhanced security features.
5
6use crate::{
7    AuthContext, AuthenticationManager,
8    jwt::JwtError,
9    middleware::mcp_auth::{AuthExtractionError, McpAuthConfig, McpRequestContext},
10    security::RequestSecurityValidator,
11    session::{Session, SessionError, SessionManager},
12};
13use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use tracing::{debug, error, info, warn};
18
19/// Errors specific to session middleware
20#[derive(Debug, Error)]
21pub enum SessionMiddlewareError {
22    #[error("Session error: {0}")]
23    SessionError(#[from] SessionError),
24
25    #[error("Authentication error: {0}")]
26    AuthError(#[from] AuthExtractionError),
27
28    #[error("JWT validation failed: {0}")]
29    JwtError(#[from] JwtError),
30
31    #[error("Invalid session token format")]
32    InvalidTokenFormat,
33
34    #[error("Session required but not provided")]
35    SessionRequired,
36}
37
38/// Enhanced configuration for session-aware middleware
39#[derive(Debug, Clone)]
40pub struct SessionMiddlewareConfig {
41    /// Base MCP auth configuration
42    pub auth_config: McpAuthConfig,
43
44    /// Enable session management
45    pub enable_sessions: bool,
46
47    /// Require sessions for authenticated requests
48    pub require_sessions: bool,
49
50    /// Enable JWT token authentication
51    pub enable_jwt_auth: bool,
52
53    /// JWT token header name
54    pub jwt_header_name: String,
55
56    /// Session ID header name
57    pub session_header_name: String,
58
59    /// Enable automatic session creation for API keys
60    pub auto_create_sessions: bool,
61
62    /// Session duration for auto-created sessions
63    pub auto_session_duration: Option<chrono::Duration>,
64
65    /// Enable session extension on access
66    pub extend_sessions_on_access: bool,
67
68    /// Methods that bypass session requirements
69    pub session_exempt_methods: Vec<String>,
70}
71
72impl Default for SessionMiddlewareConfig {
73    fn default() -> Self {
74        Self {
75            auth_config: McpAuthConfig::default(),
76            enable_sessions: true,
77            require_sessions: false, // Optional by default
78            enable_jwt_auth: true,
79            jwt_header_name: "Authorization".to_string(),
80            session_header_name: "X-Session-ID".to_string(),
81            auto_create_sessions: true,
82            auto_session_duration: Some(chrono::Duration::hours(24)),
83            extend_sessions_on_access: true,
84            session_exempt_methods: vec!["initialize".to_string(), "ping".to_string()],
85        }
86    }
87}
88
89/// Enhanced request context with session information
90#[derive(Debug, Clone)]
91pub struct SessionRequestContext {
92    /// Base request context
93    pub base_context: McpRequestContext,
94
95    /// Active session (if any)
96    pub session: Option<Session>,
97
98    /// Whether request used JWT authentication
99    pub jwt_authenticated: bool,
100
101    /// Session was created automatically
102    pub auto_created_session: bool,
103}
104
105impl SessionRequestContext {
106    pub fn new(base_context: McpRequestContext) -> Self {
107        Self {
108            base_context,
109            session: None,
110            jwt_authenticated: false,
111            auto_created_session: false,
112        }
113    }
114
115    pub fn with_session(mut self, session: Session, auto_created: bool) -> Self {
116        self.session = Some(session);
117        self.auto_created_session = auto_created;
118        self
119    }
120
121    pub fn with_jwt_auth(mut self) -> Self {
122        self.jwt_authenticated = true;
123        self
124    }
125
126    /// Get the session ID if available
127    pub fn session_id(&self) -> Option<&str> {
128        self.session.as_ref().map(|s| s.session_id.as_str())
129    }
130
131    /// Get the user ID from session or auth context
132    pub fn user_id(&self) -> Option<String> {
133        if let Some(session) = &self.session {
134            Some(session.user_id.clone())
135        } else if let Some(auth_context) = &self.base_context.auth.auth_context {
136            auth_context.api_key_id.clone()
137        } else {
138            None
139        }
140    }
141}
142
143/// Session-aware MCP authentication middleware
144pub struct SessionMiddleware {
145    /// Authentication manager
146    auth_manager: Arc<AuthenticationManager>,
147
148    /// Session manager
149    session_manager: Arc<SessionManager>,
150
151    /// Security validator
152    security_validator: Arc<RequestSecurityValidator>,
153
154    /// Middleware configuration
155    config: SessionMiddlewareConfig,
156}
157
158impl SessionMiddleware {
159    /// Create new session middleware
160    pub fn new(
161        auth_manager: Arc<AuthenticationManager>,
162        session_manager: Arc<SessionManager>,
163        security_validator: Arc<RequestSecurityValidator>,
164        config: SessionMiddlewareConfig,
165    ) -> Self {
166        Self {
167            auth_manager,
168            session_manager,
169            security_validator,
170            config,
171        }
172    }
173
174    /// Create with default configuration
175    pub fn with_default_config(
176        auth_manager: Arc<AuthenticationManager>,
177        session_manager: Arc<SessionManager>,
178    ) -> Self {
179        Self::new(
180            auth_manager,
181            session_manager,
182            Arc::new(RequestSecurityValidator::default()),
183            SessionMiddlewareConfig::default(),
184        )
185    }
186
187    /// Process an incoming MCP request with session awareness
188    pub async fn process_request(
189        &self,
190        request: Request,
191        headers: Option<&HashMap<String, String>>,
192    ) -> Result<(Request, SessionRequestContext), McpError> {
193        // Step 1: Security validation (same as before)
194        if let Err(security_error) = self
195            .security_validator
196            .validate_request(&request, None)
197            .await
198        {
199            error!("Request security validation failed: {}", security_error);
200            return Err(McpError::invalid_request(&format!(
201                "Security validation failed: {}",
202                security_error
203            )));
204        }
205
206        let sanitized_request = self.security_validator.sanitize_request(request).await;
207
208        // Step 2: Extract request ID and create base context
209        let request_id = match &sanitized_request.id {
210            serde_json::Value::String(s) => s.clone(),
211            serde_json::Value::Number(n) => n.to_string(),
212            serde_json::Value::Null => uuid::Uuid::new_v4().to_string(),
213            _ => uuid::Uuid::new_v4().to_string(),
214        };
215
216        let mut base_context = McpRequestContext::new(request_id);
217        let mut session_context = SessionRequestContext::new(base_context.clone());
218
219        // Step 3: Extract client IP
220        if let Some(headers) = headers {
221            if let Some(ip_header) = &self.config.auth_config.client_ip_header {
222                if let Some(client_ip) = headers.get(ip_header) {
223                    base_context = base_context.with_client_ip(client_ip.clone());
224                }
225            }
226        }
227
228        // Step 4: Check if this method requires authentication/sessions
229        if self.should_skip_auth(&sanitized_request.method) {
230            debug!(
231                "Skipping authentication for method: {}",
232                sanitized_request.method
233            );
234            session_context.base_context = base_context;
235            return Ok((sanitized_request, session_context));
236        }
237
238        // Step 5: Try different authentication methods
239        let auth_result = self.authenticate_request(headers).await;
240
241        match auth_result {
242            Ok((auth_context, auth_method, session)) => {
243                // Authentication successful
244                base_context = base_context.with_auth(auth_context.clone(), auth_method.clone());
245
246                if auth_method.starts_with("JWT") {
247                    session_context = session_context.with_jwt_auth();
248                }
249
250                if let Some(session) = session {
251                    session_context = session_context.with_session(session, false);
252                } else if self.config.auto_create_sessions && !session_context.jwt_authenticated {
253                    // Auto-create session for API key authentication
254                    match self.create_auto_session(&auth_context, headers).await {
255                        Ok(session) => {
256                            session_context = session_context.with_session(session, true);
257                            info!(
258                                "Auto-created session for user: {:?}",
259                                auth_context.api_key_id
260                            );
261                        }
262                        Err(e) => {
263                            warn!("Failed to auto-create session: {}", e);
264                        }
265                    }
266                }
267
268                // Check method permissions
269                if let Err(e) = self
270                    .check_method_permissions(&sanitized_request.method, &base_context)
271                    .await
272                {
273                    error!("Method permission check failed: {}", e);
274                    return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
275                }
276
277                session_context.base_context = base_context;
278                debug!("Request authenticated successfully");
279                Ok((sanitized_request, session_context))
280            }
281            Err(e) => {
282                if self.config.auth_config.require_auth {
283                    warn!("Authentication failed: {}", e);
284                    Err(McpError::invalid_request(&format!(
285                        "Authentication required: {}",
286                        e
287                    )))
288                } else {
289                    debug!("Authentication failed but not required: {}", e);
290                    session_context.base_context = base_context;
291                    Ok((sanitized_request, session_context))
292                }
293            }
294        }
295    }
296
297    /// Authenticate request using multiple methods
298    async fn authenticate_request(
299        &self,
300        headers: Option<&HashMap<String, String>>,
301    ) -> Result<(AuthContext, String, Option<Session>), SessionMiddlewareError> {
302        if let Some(headers) = headers {
303            // Try JWT authentication first
304            if self.config.enable_jwt_auth {
305                if let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await {
306                    return Ok((auth_context, method, None));
307                }
308            }
309
310            // Try session ID authentication
311            if self.config.enable_sessions {
312                if let Ok((auth_context, session)) = self.try_session_authentication(headers).await
313                {
314                    return Ok((auth_context, "Session".to_string(), Some(session)));
315                }
316            }
317
318            // Fall back to traditional API key authentication
319            if let Ok((auth_context, method)) = self.try_api_key_authentication(headers).await {
320                return Ok((auth_context, method, None));
321            }
322        }
323
324        Err(SessionMiddlewareError::AuthError(
325            AuthExtractionError::NoAuth,
326        ))
327    }
328
329    /// Try JWT token authentication
330    async fn try_jwt_authentication(
331        &self,
332        headers: &HashMap<String, String>,
333    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
334        if let Some(auth_header) = headers.get(&self.config.jwt_header_name) {
335            if auth_header.starts_with("Bearer ") {
336                let token = &auth_header[7..];
337                let auth_context = self.session_manager.validate_jwt_token(token).await?;
338                return Ok((auth_context, "JWT".to_string()));
339            }
340        }
341
342        Err(SessionMiddlewareError::AuthError(
343            AuthExtractionError::NoAuth,
344        ))
345    }
346
347    /// Try session ID authentication
348    async fn try_session_authentication(
349        &self,
350        headers: &HashMap<String, String>,
351    ) -> Result<(AuthContext, Session), SessionMiddlewareError> {
352        if let Some(session_id) = headers.get(&self.config.session_header_name) {
353            let session = self.session_manager.validate_session(session_id).await?;
354            return Ok((session.auth_context.clone(), session));
355        }
356
357        Err(SessionMiddlewareError::AuthError(
358            AuthExtractionError::NoAuth,
359        ))
360    }
361
362    /// Try API key authentication
363    async fn try_api_key_authentication(
364        &self,
365        headers: &HashMap<String, String>,
366    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
367        // Try Authorization header
368        if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) {
369            if let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await {
370                return Ok((auth_context, method));
371            }
372        }
373
374        // Try X-API-Key header
375        if let Some(api_key) = headers.get("X-API-Key") {
376            if let Ok(auth_context) = self.validate_api_key(api_key).await {
377                return Ok((auth_context, "X-API-Key".to_string()));
378            }
379        }
380
381        Err(SessionMiddlewareError::AuthError(
382            AuthExtractionError::NoAuth,
383        ))
384    }
385
386    /// Parse Authorization header
387    async fn parse_auth_header(
388        &self,
389        auth_header: &str,
390    ) -> Result<(AuthContext, String), SessionMiddlewareError> {
391        let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
392        if parts.len() != 2 {
393            return Err(SessionMiddlewareError::AuthError(
394                AuthExtractionError::InvalidFormat(
395                    "Invalid Authorization header format".to_string(),
396                ),
397            ));
398        }
399
400        match parts[0] {
401            "Bearer" => {
402                let auth_context = self.validate_api_key(parts[1]).await?;
403                Ok((auth_context, "Bearer".to_string()))
404            }
405            "Basic" => {
406                use base64::{Engine as _, engine::general_purpose};
407                let decoded = general_purpose::STANDARD.decode(parts[1]).map_err(|_| {
408                    SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
409                        "Invalid Base64 in Basic auth".to_string(),
410                    ))
411                })?;
412
413                let decoded_str = String::from_utf8(decoded).map_err(|_| {
414                    SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
415                        "Invalid UTF-8 in Basic auth".to_string(),
416                    ))
417                })?;
418
419                let auth_parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
420                if auth_parts.is_empty() {
421                    return Err(SessionMiddlewareError::AuthError(
422                        AuthExtractionError::InvalidFormat(
423                            "Basic auth must contain username".to_string(),
424                        ),
425                    ));
426                }
427
428                let auth_context = self.validate_api_key(auth_parts[0]).await?;
429                Ok((auth_context, "Basic".to_string()))
430            }
431            _ => Err(SessionMiddlewareError::AuthError(
432                AuthExtractionError::UnsupportedMethod(parts[0].to_string()),
433            )),
434        }
435    }
436
437    /// Validate API key and return auth context
438    async fn validate_api_key(&self, api_key: &str) -> Result<AuthContext, SessionMiddlewareError> {
439        let auth_result = self
440            .auth_manager
441            .validate_api_key(api_key, None)
442            .await
443            .map_err(|e| {
444                SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(format!(
445                    "API key validation failed: {}",
446                    e
447                )))
448            })?;
449
450        auth_result.ok_or_else(|| {
451            SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
452                "Invalid API key".to_string(),
453            ))
454        })
455    }
456
457    /// Create automatic session for API key authentication
458    async fn create_auto_session(
459        &self,
460        auth_context: &AuthContext,
461        headers: Option<&HashMap<String, String>>,
462    ) -> Result<Session, SessionError> {
463        let client_ip = headers
464            .and_then(|h| {
465                self.config
466                    .auth_config
467                    .client_ip_header
468                    .as_ref()
469                    .and_then(|ip_header| h.get(ip_header))
470            })
471            .cloned();
472
473        let user_agent = headers.and_then(|h| h.get("User-Agent")).cloned();
474
475        let user_id = auth_context.api_key_id.clone().unwrap_or_else(|| {
476            auth_context
477                .user_id
478                .clone()
479                .unwrap_or_else(|| "unknown".to_string())
480        });
481
482        let (session, _) = self
483            .session_manager
484            .create_session(
485                user_id,
486                auth_context.clone(),
487                self.config.auto_session_duration,
488                client_ip,
489                user_agent,
490            )
491            .await?;
492
493        Ok(session)
494    }
495
496    /// Check if authentication should be skipped for this method
497    fn should_skip_auth(&self, method: &str) -> bool {
498        self.config
499            .auth_config
500            .anonymous_methods
501            .contains(&method.to_string())
502            || self
503                .config
504                .session_exempt_methods
505                .contains(&method.to_string())
506    }
507
508    /// Check method-specific permissions (placeholder - would integrate with permission system)
509    async fn check_method_permissions(
510        &self,
511        _method: &str,
512        _context: &McpRequestContext,
513    ) -> Result<(), String> {
514        // This would integrate with the permission system
515        // For now, just return Ok
516        Ok(())
517    }
518
519    /// Process response (add session headers if needed)
520    pub async fn process_response(
521        &self,
522        response: Response,
523        context: &SessionRequestContext,
524    ) -> Result<(Response, HashMap<String, String>), McpError> {
525        let mut response_headers = HashMap::new();
526
527        // Add session ID to response headers if session exists
528        if let Some(session) = &context.session {
529            response_headers.insert(
530                self.config.session_header_name.clone(),
531                session.session_id.clone(),
532            );
533
534            if context.auto_created_session {
535                response_headers.insert("X-Session-Created".to_string(), "true".to_string());
536            }
537        }
538
539        Ok((response, response_headers))
540    }
541
542    /// Get session manager for external access
543    pub fn session_manager(&self) -> &SessionManager {
544        &self.session_manager
545    }
546
547    /// Get authentication manager
548    pub fn auth_manager(&self) -> &AuthenticationManager {
549        &self.auth_manager
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556    use crate::{
557        AuthConfig,
558        session::{MemorySessionStorage, SessionConfig},
559    };
560
561    async fn create_test_middleware() -> SessionMiddleware {
562        let auth_manager = Arc::new(
563            crate::AuthenticationManager::new(AuthConfig::memory())
564                .await
565                .unwrap(),
566        );
567        let session_manager = Arc::new(SessionManager::new(
568            SessionConfig::default(),
569            Arc::new(MemorySessionStorage::new()),
570        ));
571
572        SessionMiddleware::with_default_config(auth_manager, session_manager)
573    }
574
575    #[tokio::test]
576    async fn test_session_middleware_creation() {
577        let middleware = create_test_middleware().await;
578
579        // Just test that it was created successfully
580        assert!(middleware.config.enable_sessions);
581    }
582
583    #[tokio::test]
584    async fn test_anonymous_request_processing() {
585        let middleware = create_test_middleware().await;
586
587        let request = Request {
588            jsonrpc: "2.0".to_string(),
589            method: "initialize".to_string(), // Anonymous method
590            params: serde_json::json!({}),
591            id: serde_json::Value::Number(1.into()),
592        };
593
594        let result = middleware.process_request(request, None).await;
595        assert!(result.is_ok());
596
597        let (_, context) = result.unwrap();
598        assert!(context.session.is_none());
599        assert!(context.base_context.auth.is_anonymous);
600    }
601}