1use crate::{
7 jwt::JwtError,
8 middleware::mcp_auth::{AuthExtractionError, McpAuthConfig, McpRequestContext},
9 security::RequestSecurityValidator,
10 session::{Session, SessionError, SessionManager},
11 AuthContext, AuthenticationManager,
12};
13use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use tracing::{debug, error, info, warn};
18
19#[derive(Debug, Error)]
21pub enum SessionMiddlewareError {
22 #[error("Session error: {0}")]
23 SessionError(#[from] SessionError),
24
25 #[error("Authentication error: {0}")]
26 AuthError(#[from] AuthExtractionError),
27
28 #[error("JWT validation failed: {0}")]
29 JwtError(#[from] JwtError),
30
31 #[error("Invalid session token format")]
32 InvalidTokenFormat,
33
34 #[error("Session required but not provided")]
35 SessionRequired,
36}
37
38#[derive(Debug, Clone)]
40pub struct SessionMiddlewareConfig {
41 pub auth_config: McpAuthConfig,
43
44 pub enable_sessions: bool,
46
47 pub require_sessions: bool,
49
50 pub enable_jwt_auth: bool,
52
53 pub jwt_header_name: String,
55
56 pub session_header_name: String,
58
59 pub auto_create_sessions: bool,
61
62 pub auto_session_duration: Option<chrono::Duration>,
64
65 pub extend_sessions_on_access: bool,
67
68 pub session_exempt_methods: Vec<String>,
70}
71
72impl Default for SessionMiddlewareConfig {
73 fn default() -> Self {
74 Self {
75 auth_config: McpAuthConfig::default(),
76 enable_sessions: true,
77 require_sessions: false, enable_jwt_auth: true,
79 jwt_header_name: "Authorization".to_string(),
80 session_header_name: "X-Session-ID".to_string(),
81 auto_create_sessions: true,
82 auto_session_duration: Some(chrono::Duration::hours(24)),
83 extend_sessions_on_access: true,
84 session_exempt_methods: vec!["initialize".to_string(), "ping".to_string()],
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct SessionRequestContext {
92 pub base_context: McpRequestContext,
94
95 pub session: Option<Session>,
97
98 pub jwt_authenticated: bool,
100
101 pub auto_created_session: bool,
103}
104
105impl SessionRequestContext {
106 pub fn new(base_context: McpRequestContext) -> Self {
107 Self {
108 base_context,
109 session: None,
110 jwt_authenticated: false,
111 auto_created_session: false,
112 }
113 }
114
115 pub fn with_session(mut self, session: Session, auto_created: bool) -> Self {
116 self.session = Some(session);
117 self.auto_created_session = auto_created;
118 self
119 }
120
121 pub fn with_jwt_auth(mut self) -> Self {
122 self.jwt_authenticated = true;
123 self
124 }
125
126 pub fn session_id(&self) -> Option<&str> {
128 self.session.as_ref().map(|s| s.session_id.as_str())
129 }
130
131 pub fn user_id(&self) -> Option<String> {
133 if let Some(session) = &self.session {
134 Some(session.user_id.clone())
135 } else if let Some(auth_context) = &self.base_context.auth.auth_context {
136 auth_context.api_key_id.clone()
137 } else {
138 None
139 }
140 }
141}
142
143pub struct SessionMiddleware {
145 auth_manager: Arc<AuthenticationManager>,
147
148 session_manager: Arc<SessionManager>,
150
151 security_validator: Arc<RequestSecurityValidator>,
153
154 config: SessionMiddlewareConfig,
156}
157
158impl SessionMiddleware {
159 pub fn new(
161 auth_manager: Arc<AuthenticationManager>,
162 session_manager: Arc<SessionManager>,
163 security_validator: Arc<RequestSecurityValidator>,
164 config: SessionMiddlewareConfig,
165 ) -> Self {
166 Self {
167 auth_manager,
168 session_manager,
169 security_validator,
170 config,
171 }
172 }
173
174 pub fn with_default_config(
176 auth_manager: Arc<AuthenticationManager>,
177 session_manager: Arc<SessionManager>,
178 ) -> Self {
179 Self::new(
180 auth_manager,
181 session_manager,
182 Arc::new(RequestSecurityValidator::default()),
183 SessionMiddlewareConfig::default(),
184 )
185 }
186
187 pub async fn process_request(
189 &self,
190 request: Request,
191 headers: Option<&HashMap<String, String>>,
192 ) -> Result<(Request, SessionRequestContext), McpError> {
193 if let Err(security_error) = self
195 .security_validator
196 .validate_request(&request, None)
197 .await
198 {
199 error!("Request security validation failed: {}", security_error);
200 return Err(McpError::invalid_request(&format!(
201 "Security validation failed: {}",
202 security_error
203 )));
204 }
205
206 let sanitized_request = self.security_validator.sanitize_request(request).await;
207
208 let request_id = match &sanitized_request.id {
210 serde_json::Value::String(s) => s.clone(),
211 serde_json::Value::Number(n) => n.to_string(),
212 serde_json::Value::Null => uuid::Uuid::new_v4().to_string(),
213 _ => uuid::Uuid::new_v4().to_string(),
214 };
215
216 let mut base_context = McpRequestContext::new(request_id);
217 let mut session_context = SessionRequestContext::new(base_context.clone());
218
219 if let Some(headers) = headers {
221 if let Some(ip_header) = &self.config.auth_config.client_ip_header {
222 if let Some(client_ip) = headers.get(ip_header) {
223 base_context = base_context.with_client_ip(client_ip.clone());
224 }
225 }
226 }
227
228 if self.should_skip_auth(&sanitized_request.method) {
230 debug!(
231 "Skipping authentication for method: {}",
232 sanitized_request.method
233 );
234 session_context.base_context = base_context;
235 return Ok((sanitized_request, session_context));
236 }
237
238 let auth_result = self.authenticate_request(headers).await;
240
241 match auth_result {
242 Ok((auth_context, auth_method, session)) => {
243 base_context = base_context.with_auth(auth_context.clone(), auth_method.clone());
245
246 if auth_method.starts_with("JWT") {
247 session_context = session_context.with_jwt_auth();
248 }
249
250 if let Some(session) = session {
251 session_context = session_context.with_session(session, false);
252 } else if self.config.auto_create_sessions && !session_context.jwt_authenticated {
253 match self.create_auto_session(&auth_context, headers).await {
255 Ok(session) => {
256 session_context = session_context.with_session(session, true);
257 info!(
258 "Auto-created session for user: {:?}",
259 auth_context.api_key_id
260 );
261 }
262 Err(e) => {
263 warn!("Failed to auto-create session: {}", e);
264 }
265 }
266 }
267
268 if let Err(e) = self
270 .check_method_permissions(&sanitized_request.method, &base_context)
271 .await
272 {
273 error!("Method permission check failed: {}", e);
274 return Err(McpError::invalid_request(&format!("Access denied: {}", e)));
275 }
276
277 session_context.base_context = base_context;
278 debug!("Request authenticated successfully");
279 Ok((sanitized_request, session_context))
280 }
281 Err(e) => {
282 if self.config.auth_config.require_auth {
283 warn!("Authentication failed: {}", e);
284 Err(McpError::invalid_request(&format!(
285 "Authentication required: {}",
286 e
287 )))
288 } else {
289 debug!("Authentication failed but not required: {}", e);
290 session_context.base_context = base_context;
291 Ok((sanitized_request, session_context))
292 }
293 }
294 }
295 }
296
297 async fn authenticate_request(
299 &self,
300 headers: Option<&HashMap<String, String>>,
301 ) -> Result<(AuthContext, String, Option<Session>), SessionMiddlewareError> {
302 if let Some(headers) = headers {
303 if self.config.enable_jwt_auth {
305 if let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await {
306 return Ok((auth_context, method, None));
307 }
308 }
309
310 if self.config.enable_sessions {
312 if let Ok((auth_context, session)) = self.try_session_authentication(headers).await
313 {
314 return Ok((auth_context, "Session".to_string(), Some(session)));
315 }
316 }
317
318 if let Ok((auth_context, method)) = self.try_api_key_authentication(headers).await {
320 return Ok((auth_context, method, None));
321 }
322 }
323
324 Err(SessionMiddlewareError::AuthError(
325 AuthExtractionError::NoAuth,
326 ))
327 }
328
329 async fn try_jwt_authentication(
331 &self,
332 headers: &HashMap<String, String>,
333 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
334 if let Some(auth_header) = headers.get(&self.config.jwt_header_name) {
335 if auth_header.starts_with("Bearer ") {
336 let token = &auth_header[7..];
337 let auth_context = self.session_manager.validate_jwt_token(token).await?;
338 return Ok((auth_context, "JWT".to_string()));
339 }
340 }
341
342 Err(SessionMiddlewareError::AuthError(
343 AuthExtractionError::NoAuth,
344 ))
345 }
346
347 async fn try_session_authentication(
349 &self,
350 headers: &HashMap<String, String>,
351 ) -> Result<(AuthContext, Session), SessionMiddlewareError> {
352 if let Some(session_id) = headers.get(&self.config.session_header_name) {
353 let session = self.session_manager.validate_session(session_id).await?;
354 return Ok((session.auth_context.clone(), session));
355 }
356
357 Err(SessionMiddlewareError::AuthError(
358 AuthExtractionError::NoAuth,
359 ))
360 }
361
362 async fn try_api_key_authentication(
364 &self,
365 headers: &HashMap<String, String>,
366 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
367 if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) {
369 if let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await {
370 return Ok((auth_context, method));
371 }
372 }
373
374 if let Some(api_key) = headers.get("X-API-Key") {
376 if let Ok(auth_context) = self.validate_api_key(api_key).await {
377 return Ok((auth_context, "X-API-Key".to_string()));
378 }
379 }
380
381 Err(SessionMiddlewareError::AuthError(
382 AuthExtractionError::NoAuth,
383 ))
384 }
385
386 async fn parse_auth_header(
388 &self,
389 auth_header: &str,
390 ) -> Result<(AuthContext, String), SessionMiddlewareError> {
391 let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
392 if parts.len() != 2 {
393 return Err(SessionMiddlewareError::AuthError(
394 AuthExtractionError::InvalidFormat(
395 "Invalid Authorization header format".to_string(),
396 ),
397 ));
398 }
399
400 match parts[0] {
401 "Bearer" => {
402 let auth_context = self.validate_api_key(parts[1]).await?;
403 Ok((auth_context, "Bearer".to_string()))
404 }
405 "Basic" => {
406 use base64::{engine::general_purpose, Engine as _};
407 let decoded = general_purpose::STANDARD.decode(parts[1]).map_err(|_| {
408 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
409 "Invalid Base64 in Basic auth".to_string(),
410 ))
411 })?;
412
413 let decoded_str = String::from_utf8(decoded).map_err(|_| {
414 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
415 "Invalid UTF-8 in Basic auth".to_string(),
416 ))
417 })?;
418
419 let auth_parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
420 if auth_parts.is_empty() {
421 return Err(SessionMiddlewareError::AuthError(
422 AuthExtractionError::InvalidFormat(
423 "Basic auth must contain username".to_string(),
424 ),
425 ));
426 }
427
428 let auth_context = self.validate_api_key(auth_parts[0]).await?;
429 Ok((auth_context, "Basic".to_string()))
430 }
431 _ => Err(SessionMiddlewareError::AuthError(
432 AuthExtractionError::UnsupportedMethod(parts[0].to_string()),
433 )),
434 }
435 }
436
437 async fn validate_api_key(&self, api_key: &str) -> Result<AuthContext, SessionMiddlewareError> {
439 let auth_result = self
440 .auth_manager
441 .validate_api_key(api_key, None)
442 .await
443 .map_err(|e| {
444 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(format!(
445 "API key validation failed: {}",
446 e
447 )))
448 })?;
449
450 auth_result.ok_or_else(|| {
451 SessionMiddlewareError::AuthError(AuthExtractionError::InvalidFormat(
452 "Invalid API key".to_string(),
453 ))
454 })
455 }
456
457 async fn create_auto_session(
459 &self,
460 auth_context: &AuthContext,
461 headers: Option<&HashMap<String, String>>,
462 ) -> Result<Session, SessionError> {
463 let client_ip = headers
464 .and_then(|h| {
465 self.config
466 .auth_config
467 .client_ip_header
468 .as_ref()
469 .and_then(|ip_header| h.get(ip_header))
470 })
471 .cloned();
472
473 let user_agent = headers.and_then(|h| h.get("User-Agent")).cloned();
474
475 let user_id = auth_context.api_key_id.clone().unwrap_or_else(|| {
476 auth_context
477 .user_id
478 .clone()
479 .unwrap_or_else(|| "unknown".to_string())
480 });
481
482 let (session, _) = self
483 .session_manager
484 .create_session(
485 user_id,
486 auth_context.clone(),
487 self.config.auto_session_duration,
488 client_ip,
489 user_agent,
490 )
491 .await?;
492
493 Ok(session)
494 }
495
496 fn should_skip_auth(&self, method: &str) -> bool {
498 self.config
499 .auth_config
500 .anonymous_methods
501 .contains(&method.to_string())
502 || self
503 .config
504 .session_exempt_methods
505 .contains(&method.to_string())
506 }
507
508 async fn check_method_permissions(
510 &self,
511 _method: &str,
512 _context: &McpRequestContext,
513 ) -> Result<(), String> {
514 Ok(())
517 }
518
519 pub async fn process_response(
521 &self,
522 response: Response,
523 context: &SessionRequestContext,
524 ) -> Result<(Response, HashMap<String, String>), McpError> {
525 let mut response_headers = HashMap::new();
526
527 if let Some(session) = &context.session {
529 response_headers.insert(
530 self.config.session_header_name.clone(),
531 session.session_id.clone(),
532 );
533
534 if context.auto_created_session {
535 response_headers.insert("X-Session-Created".to_string(), "true".to_string());
536 }
537 }
538
539 Ok((response, response_headers))
540 }
541
542 pub fn session_manager(&self) -> &SessionManager {
544 &self.session_manager
545 }
546
547 pub fn auth_manager(&self) -> &AuthenticationManager {
549 &self.auth_manager
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556 use crate::{
557 session::{MemorySessionStorage, SessionConfig},
558 AuthConfig,
559 };
560
561 async fn create_test_middleware() -> SessionMiddleware {
562 let auth_manager = Arc::new(
563 crate::AuthenticationManager::new(AuthConfig::memory())
564 .await
565 .unwrap(),
566 );
567 let session_manager = Arc::new(SessionManager::new(
568 SessionConfig::default(),
569 Arc::new(MemorySessionStorage::new()),
570 ));
571
572 SessionMiddleware::with_default_config(auth_manager, session_manager)
573 }
574
575 #[tokio::test]
576 async fn test_session_middleware_creation() {
577 let middleware = create_test_middleware().await;
578
579 assert!(middleware.config.enable_sessions);
581 }
582
583 #[tokio::test]
584 async fn test_anonymous_request_processing() {
585 let middleware = create_test_middleware().await;
586
587 let request = Request {
588 jsonrpc: "2.0".to_string(),
589 method: "initialize".to_string(), params: serde_json::json!({}),
591 id: serde_json::Value::Number(1.into()),
592 };
593
594 let result = middleware.process_request(request, None).await;
595 assert!(result.is_ok());
596
597 let (_, context) = result.unwrap();
598 assert!(context.session.is_none());
599 assert!(context.base_context.auth.is_anonymous);
600 }
601}