1use crate::{
7 AuthContext, AuthenticationManager,
8 jwt::JwtError,
9 middleware::mcp_auth::{AuthExtractionError, McpAuthConfig, McpRequestContext},
10 security::RequestSecurityValidator,
11 session::{Session, SessionError, SessionManager},
12};
13use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use tracing::{debug, error, info, warn};
18
19#[derive(Debug, Error)]
21pub enum SessionMiddlewareError {
22 #[error("Session error: {0}")]
23 SessionError(#[from] SessionError),
24
25 #[error("Authentication error: {0}")]
26 AuthError(#[from] AuthExtractionError),
27
28 #[error("JWT validation failed: {0}")]
29 JwtError(#[from] JwtError),
30
31 #[error("Invalid session token format")]
32 InvalidTokenFormat,
33
34 #[error("Session required but not provided")]
35 SessionRequired,
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionMiddlewareConfig {
41 pub auth_config: McpAuthConfig,
43
44 pub enable_sessions: bool,
46
47 pub require_sessions: bool,
49
50 pub enable_jwt_auth: bool,
52
53 pub jwt_header_name: String,
55
56 pub session_header_name: String,
58
59 pub auto_create_sessions: bool,
61
62 pub auto_session_duration: Option<chrono::Duration>,
64
65 pub extend_sessions_on_access: bool,
67
68 pub session_exempt_methods: Vec<String>,
70}
71
72impl Default for SessionMiddlewareConfig {
73 fn default() -> Self {
74 Self {
75 auth_config: McpAuthConfig::default(),
76 enable_sessions: true,
77 require_sessions: false, enable_jwt_auth: true,
79 jwt_header_name: "Authorization".to_string(),
80 session_header_name: "X-Session-ID".to_string(),
81 auto_create_sessions: true,
82 auto_session_duration: Some(chrono::Duration::hours(24)),
83 extend_sessions_on_access: true,
84 session_exempt_methods: vec!["initialize".to_string(), "ping".to_string()],
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct SessionRequestContext {
92 pub base_context: McpRequestContext,
94
95 pub session: Option<Session>,
97
98 pub jwt_authenticated: bool,
100
101 pub auto_created_session: bool,
103}
104
105impl SessionRequestContext {
106 pub fn new(base_context: McpRequestContext) -> Self {
107 Self {
108 base_context,
109 session: None,
110 jwt_authenticated: false,
111 auto_created_session: false,
112 }
113 }
114
115 pub fn with_session(mut self, session: Session, auto_created: bool) -> Self {
116 self.session = Some(session);
117 self.auto_created_session = auto_created;
118 self
119 }
120
121 pub fn with_jwt_auth(mut self) -> Self {
122 self.jwt_authenticated = true;
123 self
124 }
125
126 pub fn session_id(&self) -> Option<&str> {
128 self.session.as_ref().map(|s| s.session_id.as_str())
129 }
130
131 pub fn user_id(&self) -> Option<String> {
133 if let Some(session) = &self.session {
134 Some(session.user_id.clone())
135 } else if let Some(auth_context) = &self.base_context.auth.auth_context {
136 auth_context.api_key_id.clone()
137 } else {
138 None
139 }
140 }
141}
142
143pub struct SessionMiddleware {
145 auth_manager: Arc<AuthenticationManager>,
147
148 session_manager: Arc<SessionManager>,
150
151 security_validator: Arc<RequestSecurityValidator>,
153
154 config: SessionMiddlewareConfig,
156}
157
158impl SessionMiddleware {
159 pub fn new(
161 auth_manager: Arc<AuthenticationManager>,
162 session_manager: Arc<SessionManager>,
163 security_validator: Arc<RequestSecurityValidator>,
164 config: SessionMiddlewareConfig,
165 ) -> Self {
166 Self {
167 auth_manager,
168 session_manager,
169 security_validator,
170 config,
171 }
172 }
173
174 pub fn with_default_config(
176 auth_manager: Arc<AuthenticationManager>,
177 session_manager: Arc<SessionManager>,
178 ) -> Self {
179 Self::new(
180 auth_manager,
181 session_manager,
182 Arc::new(RequestSecurityValidator::default()),
183 SessionMiddlewareConfig::default(),
184 )
185 }
186
187 pub async fn process_request(
189 &self,
190 request: Request,
191 headers: Option<&HashMap<String, String>>,
192 ) -> Result<(Request, SessionRequestContext), McpError> {
193 if let Err(security_error) = self
195 .security_validator
196 .validate_request(&request, None)
197 .await
198 {
199 error!("Request security validation failed: {}", security_error);
200 return Err(McpError::invalid_request(&format!(
201 "Security validation failed: {}",
202 security_error
203 )));
204 }
205
206 let sanitized_request = self.security_validator.sanitize_request(request).await;
207
208 let request_id = match &sanitized_request.id {
210 Some(id) => id.to_string(),
211 None => uuid::Uuid::new_v4().to_string(),
212 };
213
214 let mut base_context = McpRequestContext::new(request_id);
215 let mut session_context = SessionRequestContext::new(base_context.clone());
216
217 if let Some(headers) = headers {
219 if let Some(ip_header) = &self.config.auth_config.client_ip_header {
220 if let Some(client_ip) = headers.get(ip_header) {
221 base_context = base_context.with_client_ip(client_ip.clone());
222 }
223 }
224 }
225
226 if self.should_skip_auth(&sanitized_request.method) {
228 debug!(
229 "Skipping authentication for method: {}",
230 sanitized_request.method
231 );
232 session_context.base_context = base_context;
233 return Ok((sanitized_request, session_context));
234 }
235
236 let auth_result = self.authenticate_request(headers).await;
238
239 match auth_result {
240 Ok((auth_context, auth_method, session)) => {
241 base_context = base_context.with_auth(auth_context.clone(), auth_method.clone());
243
244 if auth_method.starts_with("JWT") {
245 session_context = session_context.with_jwt_auth();
246 }
247
248 if let Some(session) = session {
249 session_context = session_context.with_session(session, false);
250 } else if self.config.auto_create_sessions && !session_context.jwt_authenticated {
251 match self.create_auto_session(&auth_context, headers).await {
253 Ok(session) => {
254 session_context = session_context.with_session(session, true);
255 info!(
256 "Auto-created session for user: {:?}",
257 auth_context.api_key_id
258 );
259 }
260 Err(e) => {
261 warn!("Failed to auto-create session: {}", e);
262 }
263 }
264 }
265
266 if let Err(e) = self
268 .check_method_permissions(&sanitized_request.method, &base_context)
269 .await
270 {
271 error!("Method permission check failed: {}", e);
272 return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
273 }
274
275 session_context.base_context = base_context;
276 debug!("Request authenticated successfully");
277 Ok((sanitized_request, session_context))
278 }
279 Err(e) => {
280 if self.config.auth_config.require_auth {
281 warn!("Authentication failed: {}", e);
282 Err(McpError::invalid_request(&format!(
283 "Authentication required: {}",
284 e
285 )))
286 } else {
287 debug!("Authentication failed but not required: {}", e);
288 session_context.base_context = base_context;
289 Ok((sanitized_request, session_context))
290 }
291 }
292 }
293 }
294
295 async fn authenticate_request(
297 &self,
298 headers: Option<&HashMap<String, String>>,
299 ) -> Result<(AuthContext, String, Option<Session>), SessionMiddlewareError> {
300 if let Some(headers) = headers {
301 if self.config.enable_jwt_auth {
303 if let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await {
304 return Ok((auth_context, method, None));
305 }
306 }
307
308 if self.config.enable_sessions {
310 if let Ok((auth_context, session)) = self.try_session_authentication(headers).await
311 {
312 return Ok((auth_context, "Session".to_string(), Some(session)));
313 }
314 }
315
316 if let Ok((auth_context, method)) = self.try_api_key_authentication(headers).await {
318 return Ok((auth_context, method, None));
319 }
320 }
321
322 Err(SessionMiddlewareError::AuthError(
323 AuthExtractionError::NoAuth,
324 ))
325 }
326
327 async fn try_jwt_authentication(
329 &self,
330 headers: &HashMap<String, String>,
331 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
332 if let Some(auth_header) = headers.get(&self.config.jwt_header_name) {
333 if auth_header.starts_with("Bearer ") {
334 let token = &auth_header[7..];
335 let auth_context = self.session_manager.validate_jwt_token(token).await?;
336 return Ok((auth_context, "JWT".to_string()));
337 }
338 }
339
340 Err(SessionMiddlewareError::AuthError(
341 AuthExtractionError::NoAuth,
342 ))
343 }
344
345 async fn try_session_authentication(
347 &self,
348 headers: &HashMap<String, String>,
349 ) -> Result<(AuthContext, Session), SessionMiddlewareError> {
350 if let Some(session_id) = headers.get(&self.config.session_header_name) {
351 let session = self.session_manager.validate_session(session_id).await?;
352 return Ok((session.auth_context.clone(), session));
353 }
354
355 Err(SessionMiddlewareError::AuthError(
356 AuthExtractionError::NoAuth,
357 ))
358 }
359
360 async fn try_api_key_authentication(
362 &self,
363 headers: &HashMap<String, String>,
364 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
365 if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) {
367 if let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await {
368 return Ok((auth_context, method));
369 }
370 }
371
372 if let Some(api_key) = headers.get("X-API-Key") {
374 if let Ok(auth_context) = self.validate_api_key(api_key).await {
375 return Ok((auth_context, "X-API-Key".to_string()));
376 }
377 }
378
379 Err(SessionMiddlewareError::AuthError(
380 AuthExtractionError::NoAuth,
381 ))
382 }
383
384 async fn parse_auth_header(
386 &self,
387 auth_header: &str,
388 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
389 let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
390 if parts.len() != 2 {
391 return Err(SessionMiddlewareError::AuthError(
392 AuthExtractionError::InvalidFormat(
393 "Invalid Authorization header format".to_string(),
394 ),
395 ));
396 }
397
398 match parts[0] {
399 "Bearer" => {
400 let auth_context = self.validate_api_key(parts[1]).await?;
401 Ok((auth_context, "Bearer".to_string()))
402 }
403 "Basic" => {
404 use base64::{Engine as _, engine::general_purpose};
405 let decoded = general_purpose::STANDARD.decode(parts[1]).map_err(|_| {
406 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
407 "Invalid Base64 in Basic auth".to_string(),
408 ))
409 })?;
410
411 let decoded_str = String::from_utf8(decoded).map_err(|_| {
412 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
413 "Invalid UTF-8 in Basic auth".to_string(),
414 ))
415 })?;
416
417 let auth_parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
418 if auth_parts.is_empty() {
419 return Err(SessionMiddlewareError::AuthError(
420 AuthExtractionError::InvalidFormat(
421 "Basic auth must contain username".to_string(),
422 ),
423 ));
424 }
425
426 let auth_context = self.validate_api_key(auth_parts[0]).await?;
427 Ok((auth_context, "Basic".to_string()))
428 }
429 _ => Err(SessionMiddlewareError::AuthError(
430 AuthExtractionError::UnsupportedMethod(parts[0].to_string()),
431 )),
432 }
433 }
434
435 async fn validate_api_key(&self, api_key: &str) -> Result<AuthContext, SessionMiddlewareError> {
437 let auth_result = self
438 .auth_manager
439 .validate_api_key(api_key, None)
440 .await
441 .map_err(|e| {
442 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(format!(
443 "API key validation failed: {}",
444 e
445 )))
446 })?;
447
448 auth_result.ok_or_else(|| {
449 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
450 "Invalid API key".to_string(),
451 ))
452 })
453 }
454
455 async fn create_auto_session(
457 &self,
458 auth_context: &AuthContext,
459 headers: Option<&HashMap<String, String>>,
460 ) -> Result<Session, SessionError> {
461 let client_ip = headers
462 .and_then(|h| {
463 self.config
464 .auth_config
465 .client_ip_header
466 .as_ref()
467 .and_then(|ip_header| h.get(ip_header))
468 })
469 .cloned();
470
471 let user_agent = headers.and_then(|h| h.get("User-Agent")).cloned();
472
473 let user_id = auth_context.api_key_id.clone().unwrap_or_else(|| {
474 auth_context
475 .user_id
476 .clone()
477 .unwrap_or_else(|| "unknown".to_string())
478 });
479
480 let (session, _) = self
481 .session_manager
482 .create_session(
483 user_id,
484 auth_context.clone(),
485 self.config.auto_session_duration,
486 client_ip,
487 user_agent,
488 )
489 .await?;
490
491 Ok(session)
492 }
493
494 fn should_skip_auth(&self, method: &str) -> bool {
496 self.config
497 .auth_config
498 .anonymous_methods
499 .contains(&method.to_string())
500 || self
501 .config
502 .session_exempt_methods
503 .contains(&method.to_string())
504 }
505
506 async fn check_method_permissions(
508 &self,
509 _method: &str,
510 _context: &McpRequestContext,
511 ) -> Result<(), String> {
512 Ok(())
515 }
516
517 pub async fn process_response(
519 &self,
520 response: Response,
521 context: &SessionRequestContext,
522 ) -> Result<(Response, HashMap<String, String>), McpError> {
523 let mut response_headers = HashMap::new();
524
525 if let Some(session) = &context.session {
527 response_headers.insert(
528 self.config.session_header_name.clone(),
529 session.session_id.clone(),
530 );
531
532 if context.auto_created_session {
533 response_headers.insert("X-Session-Created".to_string(), "true".to_string());
534 }
535 }
536
537 Ok((response, response_headers))
538 }
539
540 pub fn session_manager(&self) -> &SessionManager {
542 &self.session_manager
543 }
544
545 pub fn auth_manager(&self) -> &AuthenticationManager {
547 &self.auth_manager
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use crate::{
555 AuthConfig,
556 session::{MemorySessionStorage, SessionConfig},
557 };
558
559 async fn create_test_middleware() -> SessionMiddleware {
560 let auth_manager = Arc::new(
561 crate::AuthenticationManager::new(AuthConfig::memory())
562 .await
563 .unwrap(),
564 );
565 let session_manager = Arc::new(SessionManager::new(
566 SessionConfig::default(),
567 Arc::new(MemorySessionStorage::new()),
568 ));
569
570 SessionMiddleware::with_default_config(auth_manager, session_manager)
571 }
572
573 #[tokio::test]
574 async fn test_session_middleware_creation() {
575 let middleware = create_test_middleware().await;
576
577 assert!(middleware.config.enable_sessions);
579 }
580
581 #[tokio::test]
582 async fn test_anonymous_request_processing() {
583 let middleware = create_test_middleware().await;
584
585 let request = Request {
586 jsonrpc: "2.0".to_string(),
587 method: "initialize".to_string(), params: serde_json::json!({}),
589 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
590 };
591
592 let result = middleware.process_request(request, None).await;
593 assert!(result.is_ok());
594
595 let (_, context) = result.unwrap();
596 assert!(context.session.is_none());
597 assert!(context.base_context.auth.is_anonymous);
598 }
599}