pulseengine_mcp_transport/
http.rs

1//! HTTP transport with Server-Sent Events (SSE) support
2
3use crate::{
4    RequestHandler, Transport, TransportError,
5    batch::{JsonRpcMessage, process_batch},
6    validation::validate_message_string,
7};
8use async_trait::async_trait;
9use axum::response::sse::{Event, KeepAlive};
10use axum::{
11    Router,
12    extract::{Query, State},
13    http::{
14        HeaderMap, StatusCode,
15        header::{AUTHORIZATION, ORIGIN},
16    },
17    response::{IntoResponse, Response as AxumResponse, Sse},
18    routing::{get, post},
19};
20// futures_util used for async_stream
21// mcp_protocol types are imported via batch module
22use serde::Deserialize;
23use serde_json;
24use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
25use tokio::sync::{Mutex, RwLock, broadcast};
26use tower::ServiceBuilder;
27use tower_http::cors::CorsLayer;
28use tracing::{debug, error, info, warn};
29use uuid::Uuid;
30
31/// Configuration for HTTP transport
32#[derive(Debug, Clone)]
33pub struct HttpConfig {
34    /// Port to bind to
35    pub port: u16,
36    /// Host to bind to (default: localhost)
37    pub host: String,
38    /// Maximum message size in bytes
39    pub max_message_size: usize,
40    /// Enable CORS
41    pub enable_cors: bool,
42    /// Allowed origins (None = allow any)
43    pub allowed_origins: Option<Vec<String>>,
44    /// Enable message validation
45    pub validate_messages: bool,
46    /// Session timeout in seconds
47    pub session_timeout_secs: u64,
48    /// Enable authentication
49    pub require_auth: bool,
50    /// Valid bearer tokens (for testing)
51    pub valid_tokens: Vec<String>,
52}
53
54impl Default for HttpConfig {
55    fn default() -> Self {
56        Self {
57            port: 3000,
58            host: "127.0.0.1".to_string(),
59            max_message_size: 10 * 1024 * 1024, // 10MB
60            enable_cors: true,
61            allowed_origins: None,
62            validate_messages: true,
63            session_timeout_secs: 300, // 5 minutes
64            require_auth: false,
65            valid_tokens: vec![],
66        }
67    }
68}
69
70/// Session information for HTTP clients
71#[derive(Clone)]
72struct SessionInfo {
73    #[allow(dead_code)]
74    id: String,
75    #[allow(dead_code)]
76    created_at: std::time::Instant,
77    last_activity: std::time::Instant,
78    event_sender: broadcast::Sender<String>,
79    // Keep at least one receiver alive to prevent channel closure
80    #[allow(dead_code)]
81    _keepalive_receiver: Arc<Mutex<broadcast::Receiver<String>>>,
82}
83
84/// Shared state for HTTP transport
85#[derive(Clone)]
86struct HttpState {
87    handler: Arc<RequestHandler>,
88    config: HttpConfig,
89    sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
90}
91
92/// Query parameters for SSE endpoint
93#[derive(Debug, Deserialize)]
94struct SseQuery {
95    #[serde(rename = "sessionId")]
96    session_id: Option<String>,
97    #[serde(rename = "lastEventId")]
98    #[allow(dead_code)]
99    last_event_id: Option<String>,
100    #[serde(rename = "transportType")]
101    #[allow(dead_code)]
102    transport_type: Option<String>,
103    #[allow(dead_code)]
104    url: Option<String>,
105}
106
107/// HTTP transport for MCP protocol
108///
109/// Implements the MCP HTTP transport specification:
110/// - HTTP POST for client-to-server messages
111/// - Server-Sent Events (SSE) for server-to-client streaming
112/// - Session management with Mcp-Session-Id header
113/// - Origin validation and CORS support
114/// - Authentication support
115pub struct HttpTransport {
116    config: HttpConfig,
117    state: Option<HttpState>,
118    server_handle: Option<tokio::task::JoinHandle<()>>,
119}
120
121impl HttpTransport {
122    /// Create a new HTTP transport with default configuration
123    pub fn new(port: u16) -> Self {
124        let config = HttpConfig {
125            port,
126            ..Default::default()
127        };
128
129        Self {
130            config,
131            state: None,
132            server_handle: None,
133        }
134    }
135
136    /// Get the configuration
137    pub fn config(&self) -> &HttpConfig {
138        &self.config
139    }
140
141    /// Check if the transport is initialized
142    pub fn is_initialized(&self) -> bool {
143        self.state.is_some()
144    }
145
146    /// Check if the server is running
147    pub fn is_running(&self) -> bool {
148        self.server_handle.is_some()
149    }
150
151    /// Send a message to all connected SSE clients
152    pub async fn broadcast_message(&self, message: &str) -> Result<(), TransportError> {
153        if let Some(ref state) = self.state {
154            let sessions = state.sessions.read().await;
155            for (session_id, session) in sessions.iter() {
156                if let Err(e) = session.event_sender.send(message.to_string()) {
157                    debug!("Failed to send to session {}: {}", session_id, e);
158                }
159            }
160            Ok(())
161        } else {
162            Err(TransportError::Connection(
163                "Transport not started".to_string(),
164            ))
165        }
166    }
167
168    /// Create a new HTTP transport with custom configuration
169    pub fn with_config(config: HttpConfig) -> Self {
170        Self {
171            config,
172            state: None,
173            server_handle: None,
174        }
175    }
176
177    /// Create or get session
178    async fn ensure_session(state: Arc<HttpState>, session_id: Option<String>) -> String {
179        if let Some(id) = session_id {
180            // Check if session exists
181            let sessions = state.sessions.read().await;
182            if sessions.contains_key(&id) {
183                return id;
184            }
185            // If session doesn't exist, create it with the provided ID
186            drop(sessions);
187            let (tx, keepalive_rx) = broadcast::channel(1024);
188            let session_info = SessionInfo {
189                id: id.clone(),
190                created_at: std::time::Instant::now(),
191                last_activity: std::time::Instant::now(),
192                event_sender: tx,
193                _keepalive_receiver: Arc::new(Mutex::new(keepalive_rx)),
194            };
195            let mut sessions = state.sessions.write().await;
196            sessions.insert(id.clone(), session_info);
197            info!("Created session with provided ID: {}", id);
198            return id;
199        }
200
201        // Create new session with generated ID
202        let session_id = Uuid::new_v4().to_string();
203        let (tx, keepalive_rx) = broadcast::channel(1024);
204        let session_info = SessionInfo {
205            id: session_id.clone(),
206            created_at: std::time::Instant::now(),
207            last_activity: std::time::Instant::now(),
208            event_sender: tx,
209            _keepalive_receiver: Arc::new(Mutex::new(keepalive_rx)),
210        };
211
212        {
213            let mut sessions = state.sessions.write().await;
214            sessions.insert(session_id.clone(), session_info);
215        }
216
217        debug!("Created new session: {}", session_id);
218        session_id
219    }
220
221    /// Update session activity
222    async fn update_session_activity(state: Arc<HttpState>, session_id: &str) {
223        let mut sessions = state.sessions.write().await;
224        if let Some(session) = sessions.get_mut(session_id) {
225            session.last_activity = std::time::Instant::now();
226        }
227    }
228
229    /// Clean up expired sessions
230    async fn cleanup_sessions(state: Arc<HttpState>) {
231        let timeout = Duration::from_secs(state.config.session_timeout_secs);
232        let now = std::time::Instant::now();
233
234        let mut sessions = state.sessions.write().await;
235        sessions.retain(|id, session| {
236            let expired = now.duration_since(session.last_activity) > timeout;
237            if expired {
238                debug!("Removing expired session: {}", id);
239            }
240            !expired
241        });
242    }
243
244    /// Validate origin header
245    pub fn validate_origin(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
246        if let Some(allowed_origins) = &config.allowed_origins {
247            if let Some(origin) = headers.get(ORIGIN) {
248                let origin_str = origin
249                    .to_str()
250                    .map_err(|_| TransportError::Protocol("Invalid Origin header".to_string()))?;
251
252                if !allowed_origins.contains(&origin_str.to_string()) {
253                    return Err(TransportError::Protocol(format!(
254                        "Origin not allowed: {origin_str}"
255                    )));
256                }
257            } else {
258                return Err(TransportError::Protocol(
259                    "Missing Origin header".to_string(),
260                ));
261            }
262        }
263
264        Ok(())
265    }
266
267    /// Validate authentication
268    pub fn validate_auth(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
269        if !config.require_auth {
270            return Ok(());
271        }
272
273        let auth_header = headers
274            .get(AUTHORIZATION)
275            .ok_or_else(|| TransportError::Protocol("Missing Authorization header".to_string()))?;
276
277        let auth_str = auth_header
278            .to_str()
279            .map_err(|_| TransportError::Protocol("Invalid Authorization header".to_string()))?;
280
281        if let Some(token) = auth_str.strip_prefix("Bearer ") {
282            if config.valid_tokens.contains(&token.to_string()) {
283                Ok(())
284            } else {
285                Err(TransportError::Protocol("Invalid bearer token".to_string()))
286            }
287        } else {
288            Err(TransportError::Protocol(
289                "Invalid Authorization format, expected Bearer token".to_string(),
290            ))
291        }
292    }
293}
294
295/// Query parameters for POST messages endpoint
296#[derive(Debug, Clone, Deserialize)]
297struct PostQuery {
298    #[serde(alias = "sessionId")]
299    session_id: Option<String>,
300}
301
302/// Handle POST requests (client-to-server messages)
303async fn handle_post(
304    State(state): State<Arc<HttpState>>,
305    Query(query): Query<PostQuery>,
306    headers: HeaderMap,
307    body: String,
308) -> Result<AxumResponse<String>, StatusCode> {
309    info!("Received POST request with session query: {:?}", query);
310    debug!("Raw request body: {}", body);
311
312    // Parse JSON directly to handle both wrapped and direct JSON-RPC formats
313    let request_value: serde_json::Value = match serde_json::from_str(&body) {
314        Ok(v) => v,
315        Err(e) => {
316            warn!("Failed to parse JSON: {}", e);
317            return Err(StatusCode::BAD_REQUEST);
318        }
319    };
320
321    // Extract the actual message (handle both wrapped {"message": {...}} and direct {...} formats)
322    let message = if let Some(wrapped_message) = request_value.get("message") {
323        // Wrapped format: {"message": {"jsonrpc": "2.0", ...}}
324        wrapped_message.clone()
325    } else if request_value.get("jsonrpc").is_some() {
326        // Direct JSON-RPC format: {"jsonrpc": "2.0", ...}
327        request_value
328    } else {
329        warn!("Invalid request format - no 'message' field and no 'jsonrpc' field");
330        return Err(StatusCode::BAD_REQUEST);
331    };
332
333    info!("Request message: {:?}", message);
334
335    // Validate origin
336    if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
337        warn!("Origin validation failed: {}", e);
338        return Err(StatusCode::FORBIDDEN);
339    }
340
341    // Validate authentication
342    if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
343        warn!("Authentication failed: {}", e);
344        return Err(StatusCode::UNAUTHORIZED);
345    }
346
347    // Get session ID from query parameter (MCP standard) or header (fallback)
348    let session_id_from_request = query.session_id.or_else(|| {
349        headers
350            .get("Mcp-Session-Id")
351            .and_then(|v| v.to_str().ok())
352            .map(|s| s.to_string())
353    });
354
355    // Ensure session exists (create if needed)
356    let session_id = HttpTransport::ensure_session(state.clone(), session_id_from_request).await;
357
358    // Validate message
359    let message_json = serde_json::to_string(&message).map_err(|_| StatusCode::BAD_REQUEST)?;
360
361    if state.config.validate_messages {
362        if let Err(e) = validate_message_string(&message_json, Some(state.config.max_message_size))
363        {
364            warn!("Message validation failed: {}", e);
365            return Err(StatusCode::BAD_REQUEST);
366        }
367    }
368
369    // Parse and process message
370    let message = JsonRpcMessage::parse(&message_json).map_err(|_| StatusCode::BAD_REQUEST)?;
371
372    // Validate JSON-RPC structure
373    if let Err(e) = message.validate() {
374        warn!("JSON-RPC validation failed: {}", e);
375        return Err(StatusCode::BAD_REQUEST);
376    }
377
378    // Update session activity
379    {
380        HttpTransport::update_session_activity(state.clone(), &session_id).await;
381    }
382
383    // Process the message
384    match process_batch(message, &state.handler).await {
385        Ok(Some(response_message)) => {
386            let response_json = response_message
387                .to_string()
388                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
389
390            // Implement proper MCP backwards compatibility protocol
391            let accept_header = headers
392                .get("accept")
393                .and_then(|v| v.to_str().ok())
394                .unwrap_or("");
395
396            // Determine transport mode based on Accept header
397            // MCP Inspector sends:
398            // - SSE mode: "text/event-stream"
399            // - Streamable HTTP mode: "text/event-stream, application/json"
400            debug!("Received Accept header: '{}'", accept_header);
401
402            let wants_json_response = if accept_header.contains("application/json") {
403                // If both are present, this is streamable HTTP mode from MCP Inspector
404                // For "text/event-stream, application/json" - this is streamable HTTP
405                true
406            } else {
407                // Only "text/event-stream" or no JSON at all - use SSE mode
408                false
409            };
410
411            debug!(
412                "Transport mode selected: {}",
413                if wants_json_response {
414                    "streamable-http"
415                } else {
416                    "sse"
417                }
418            );
419
420            if wants_json_response {
421                // New Streamable HTTP transport - return response directly
422                info!(
423                    "Using Streamable HTTP transport, returning response directly for session: {}, Accept: {}",
424                    session_id, accept_header
425                );
426                debug!("Direct response: {}", response_json);
427                Ok(AxumResponse::builder()
428                    .status(StatusCode::OK)
429                    .header("Content-Type", "application/json")
430                    .header("Mcp-Session-Id", session_id)
431                    .body(response_json)
432                    .unwrap())
433            } else {
434                // Legacy HTTP+SSE transport - send through SSE
435                info!(
436                    "Using legacy HTTP+SSE transport for session: {}, Accept: {}",
437                    session_id, accept_header
438                );
439                debug!("Response to send through SSE: {}", response_json);
440
441                let sessions = state.sessions.read().await;
442                info!("Active sessions: {}", sessions.len());
443
444                if let Some(session) = sessions.get(&session_id) {
445                    info!("Found session {}, sending response", session_id);
446                    match session.event_sender.send(response_json.clone()) {
447                        Ok(num_receivers) => {
448                            info!(
449                                "Response sent successfully to {} receivers on session: {}",
450                                num_receivers, session_id
451                            );
452                        }
453                        Err(e) => {
454                            warn!("Failed to send response through SSE: {}", e);
455                        }
456                    }
457                } else {
458                    warn!(
459                        "Session {} not found for response, trying any active session",
460                        session_id
461                    );
462                    // Fallback: try to send to any active session (for MCP Inspector compatibility)
463                    let mut sent = false;
464                    for (sid, session) in sessions.iter() {
465                        match session.event_sender.send(response_json.clone()) {
466                            Ok(num_receivers) => {
467                                info!(
468                                    "Response sent successfully to {} receivers on fallback session: {}",
469                                    num_receivers, sid
470                                );
471                                sent = true;
472                                break;
473                            }
474                            Err(e) => {
475                                debug!("Failed to send to session {}: {}", sid, e);
476                            }
477                        }
478                    }
479                    if !sent {
480                        warn!("No active sessions available to send response");
481                    }
482                }
483
484                // Return 204 No Content (response sent through SSE)
485                Ok(AxumResponse::builder()
486                    .status(StatusCode::NO_CONTENT)
487                    .header("Mcp-Session-Id", session_id)
488                    .body("".to_string())
489                    .unwrap())
490            }
491        }
492        Ok(None) => {
493            // No response needed (notifications only)
494            Ok(AxumResponse::builder()
495                .status(StatusCode::NO_CONTENT)
496                .body("".to_string())
497                .unwrap())
498        }
499        Err(e) => {
500            error!("Failed to process message: {}", e);
501
502            // Create error response
503            let error_response = pulseengine_mcp_protocol::Response {
504                jsonrpc: "2.0".to_string(),
505                id: None,
506                result: None,
507                error: Some(pulseengine_mcp_protocol::Error::internal_error(
508                    e.to_string(),
509                )),
510            };
511
512            if let Ok(error_json) = serde_json::to_string(&error_response) {
513                // Use same transport detection logic as for success responses
514                let accept_header = headers
515                    .get("accept")
516                    .and_then(|v| v.to_str().ok())
517                    .unwrap_or("");
518                let wants_json_response = accept_header.contains("application/json");
519
520                if wants_json_response {
521                    // New Streamable HTTP transport - return error directly
522                    debug!(
523                        "Using Streamable HTTP transport, returning error directly: {}",
524                        error_json
525                    );
526                    Ok(AxumResponse::builder()
527                        .status(StatusCode::OK)
528                        .header("Content-Type", "application/json")
529                        .header("Mcp-Session-Id", session_id)
530                        .body(error_json)
531                        .unwrap())
532                } else {
533                    // Legacy HTTP+SSE transport - send through SSE
534                    debug!(
535                        "Using legacy HTTP+SSE transport, sending error through SSE: {}",
536                        error_json
537                    );
538                    let sessions = state.sessions.read().await;
539                    if let Some(session) = sessions.get(&session_id) {
540                        if let Err(e) = session.event_sender.send(error_json.clone()) {
541                            warn!("Failed to send error through SSE: {}", e);
542                        } else {
543                            debug!(
544                                "Error response sent successfully to session: {}",
545                                session_id
546                            );
547                        }
548                    } else {
549                        warn!(
550                            "Session {} not found for error response, trying any active session",
551                            session_id
552                        );
553                        // Fallback: try to send to any active session (for MCP Inspector compatibility)
554                        let mut sent = false;
555                        for (sid, session) in sessions.iter() {
556                            if session.event_sender.send(error_json.clone()).is_ok() {
557                                debug!(
558                                    "Error response sent successfully to fallback session: {}",
559                                    sid
560                                );
561                                sent = true;
562                                break;
563                            }
564                        }
565                        if !sent {
566                            warn!("No active sessions available to send error response");
567                        }
568                    }
569
570                    // Return 204 No Content (error sent through SSE)
571                    Ok(AxumResponse::builder()
572                        .status(StatusCode::NO_CONTENT)
573                        .body("".to_string())
574                        .unwrap())
575                }
576            } else {
577                // Fallback to HTTP error
578                Err(StatusCode::INTERNAL_SERVER_ERROR)
579            }
580        }
581    }
582}
583
584/// Handle SSE requests (server-to-client streaming)
585async fn handle_sse(
586    uri: axum::http::Uri,
587    State(state): State<Arc<HttpState>>,
588    headers: HeaderMap,
589    Query(query): Query<SseQuery>,
590) -> Result<axum::response::Response, StatusCode> {
591    info!(
592        "Received SSE request - URI: {}, query string: {:?}, parsed query: {:?}",
593        uri,
594        uri.query(),
595        query
596    );
597    info!("Headers: {:?}", headers);
598
599    // Validate origin
600    if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
601        warn!("Origin validation failed: {}", e);
602        return Err(StatusCode::FORBIDDEN);
603    }
604
605    // Validate authentication
606    if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
607        warn!("Authentication failed: {}", e);
608        return Err(StatusCode::UNAUTHORIZED);
609    }
610
611    // Get or create session
612    let session_id = HttpTransport::ensure_session(state.clone(), query.session_id).await;
613
614    // All MCP clients expect SSE with "endpoint" event first (based on official Python SDK)
615    info!("Creating MCP-compliant SSE stream with endpoint event");
616
617    // Get the event receiver for this session
618    let receiver = {
619        let sessions = state.sessions.read().await;
620        sessions
621            .get(&session_id)
622            .map(|session| session.event_sender.subscribe())
623            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
624    };
625
626    info!("Starting SSE stream for session: {}", session_id);
627
628    // Clone session_id for headers since it will be moved into the stream
629    let session_id_for_header = session_id.clone();
630
631    // Create SSE stream following official MCP Python SDK pattern
632    let stream = async_stream::stream! {
633        let mut event_counter = 0u64;
634
635        // Send "endpoint" event first (as per official MCP SDK)
636        // Use camelCase sessionId to match MCP Inspector expectations
637        let endpoint_url = format!("/messages?sessionId={session_id}");
638        info!("Sending 'endpoint' event for session: {} with URL: {}", session_id, endpoint_url);
639        event_counter += 1;
640        yield Ok::<_, axum::Error>(Event::default()
641            .id(event_counter.to_string())
642            .event("endpoint")
643            .data(endpoint_url));
644
645        // Stream events from the receiver
646        let mut receiver = receiver;
647        loop {
648            tokio::select! {
649                Ok(data) = receiver.recv() => {
650                    event_counter += 1;
651                    yield Ok::<_, axum::Error>(Event::default()
652                        .id(event_counter.to_string())
653                        .event("message")
654                        .data(data));
655                }
656                _ = tokio::time::sleep(Duration::from_secs(30)) => {
657                    // Send periodic ping to keep connection alive
658                    event_counter += 1;
659                    yield Ok::<_, axum::Error>(Event::default()
660                        .id(event_counter.to_string())
661                        .event("ping")
662                        .data(serde_json::json!({
663                            "type": "ping",
664                            "timestamp": chrono::Utc::now().to_rfc3339()
665                        }).to_string()));
666                }
667            }
668        }
669    };
670
671    // Build SSE response with proper headers and keep-alive
672    let sse = Sse::new(stream).keep_alive(
673        KeepAlive::new()
674            .interval(Duration::from_secs(15))
675            .text("keep-alive"),
676    );
677
678    // Convert to Response and add headers
679    let mut response = sse.into_response();
680    response.headers_mut().insert(
681        axum::http::header::CACHE_CONTROL,
682        "no-cache".parse().unwrap(),
683    );
684    response.headers_mut().insert(
685        axum::http::header::CONNECTION,
686        "keep-alive".parse().unwrap(),
687    );
688    response
689        .headers_mut()
690        .insert("X-Accel-Buffering", "no".parse().unwrap());
691
692    // Add session ID header as per MCP spec
693    response
694        .headers_mut()
695        .insert("Mcp-Session-Id", session_id_for_header.parse().unwrap());
696
697    Ok(response)
698}
699
700/// Handle health check requests
701async fn handle_health() -> &'static str {
702    "OK"
703}
704
705#[async_trait]
706impl Transport for HttpTransport {
707    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
708        info!(
709            "Starting HTTP transport on {}:{}",
710            self.config.host, self.config.port
711        );
712
713        let state = Arc::new(HttpState {
714            handler: Arc::new(handler),
715            config: self.config.clone(),
716            sessions: Arc::new(RwLock::new(HashMap::new())),
717        });
718
719        // Build CORS layer - be very permissive for MCP Inspector
720        let cors = CorsLayer::very_permissive().expose_headers(vec![
721            axum::http::header::HeaderName::from_static("mcp-session-id"),
722            axum::http::header::HeaderName::from_static("content-type"),
723        ]);
724
725        // Build router
726        let app = Router::new()
727            .route("/messages", post(handle_post))
728            .route("/sse", get(handle_sse))
729            .route("/health", get(handle_health))
730            .layer(ServiceBuilder::new().layer(cors))
731            .with_state(state.clone());
732
733        // Start session cleanup task
734        let cleanup_state = state.clone();
735        tokio::spawn(async move {
736            let mut interval = tokio::time::interval(Duration::from_secs(60));
737            loop {
738                interval.tick().await;
739                HttpTransport::cleanup_sessions(cleanup_state.clone()).await;
740            }
741        });
742
743        // Start server
744        let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
745            .parse()
746            .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
747
748        let listener = tokio::net::TcpListener::bind(addr)
749            .await
750            .map_err(|e| TransportError::Connection(format!("Failed to bind to {addr}: {e}")))?;
751
752        info!("HTTP transport listening on {}", addr);
753        info!("Endpoints:");
754        info!("  POST   http://{}/messages   - MCP messages", addr);
755        info!("  GET    http://{}/sse        - Server-Sent Events", addr);
756        info!("  GET    http://{}/health     - Health check", addr);
757
758        let server_handle = tokio::spawn(async move {
759            if let Err(e) = axum::serve(listener, app).await {
760                error!("HTTP server error: {}", e);
761            }
762        });
763
764        self.state = Some(HttpState {
765            handler: state.handler.clone(),
766            config: state.config.clone(),
767            sessions: state.sessions.clone(),
768        });
769        self.server_handle = Some(server_handle);
770
771        Ok(())
772    }
773
774    async fn stop(&mut self) -> Result<(), TransportError> {
775        info!("Stopping HTTP transport");
776
777        if let Some(handle) = self.server_handle.take() {
778            handle.abort();
779        }
780
781        self.state = None;
782        Ok(())
783    }
784
785    async fn health_check(&self) -> Result<(), TransportError> {
786        if self.state.is_some() {
787            Ok(())
788        } else {
789            Err(TransportError::Connection(
790                "HTTP transport not running".to_string(),
791            ))
792        }
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use axum::extract::{Query, State};
800    use axum::http::{HeaderMap, HeaderValue, StatusCode};
801    use pulseengine_mcp_protocol::{Error as McpError, Response};
802    use serde_json::json;
803    use std::sync::Arc;
804    use tokio::sync::RwLock;
805
806    // Mock handler for testing
807    fn mock_handler(
808        request: pulseengine_mcp_protocol::Request,
809    ) -> std::pin::Pin<
810        Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
811    > {
812        Box::pin(async move {
813            Response {
814                jsonrpc: "2.0".to_string(),
815                id: request.id,
816                result: Some(json!({"echo": request.method})),
817                error: None,
818            }
819        })
820    }
821
822    // Mock handler that returns an error
823    fn mock_error_handler(
824        request: pulseengine_mcp_protocol::Request,
825    ) -> std::pin::Pin<
826        Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
827    > {
828        Box::pin(async move {
829            Response {
830                jsonrpc: "2.0".to_string(),
831                id: request.id,
832                result: None,
833                error: Some(McpError::method_not_found(format!(
834                    "Method '{}' not supported",
835                    request.method
836                ))),
837            }
838        })
839    }
840
841    // Mock handler that returns None (for notifications)
842    fn mock_notification_handler(
843        _request: pulseengine_mcp_protocol::Request,
844    ) -> std::pin::Pin<
845        Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
846    > {
847        Box::pin(async move {
848            Response {
849                jsonrpc: "2.0".to_string(),
850                id: None,
851                result: None,
852                error: None,
853            }
854        })
855    }
856
857    fn create_test_state() -> Arc<HttpState> {
858        let config = HttpConfig::default();
859        Arc::new(HttpState {
860            handler: Arc::new(Box::new(mock_handler)),
861            config,
862            sessions: Arc::new(RwLock::new(HashMap::new())),
863        })
864    }
865
866    fn create_test_headers() -> HeaderMap {
867        let mut headers = HeaderMap::new();
868        headers.insert("Content-Type", "application/json".parse().unwrap());
869        headers
870    }
871
872    // === HttpConfig Tests ===
873
874    #[test]
875    fn test_http_config_default() {
876        let config = HttpConfig::default();
877        assert_eq!(config.port, 3000);
878        assert_eq!(config.host, "127.0.0.1");
879        assert_eq!(config.max_message_size, 10 * 1024 * 1024);
880        assert!(config.enable_cors);
881        assert!(config.allowed_origins.is_none());
882        assert!(config.validate_messages);
883        assert_eq!(config.session_timeout_secs, 300);
884        assert!(!config.require_auth);
885        assert!(config.valid_tokens.is_empty());
886    }
887
888    #[test]
889    fn test_http_config_custom() {
890        let config = HttpConfig {
891            port: 8080,
892            host: "0.0.0.0".to_string(),
893            max_message_size: 1024,
894            enable_cors: false,
895            allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
896            validate_messages: true,
897            session_timeout_secs: 600,
898            require_auth: true,
899            valid_tokens: vec!["test-token".to_string()],
900        };
901
902        let transport = HttpTransport::with_config(config.clone());
903        assert_eq!(transport.config.port, 8080);
904        assert_eq!(transport.config.host, "0.0.0.0");
905        assert_eq!(transport.config.max_message_size, 1024);
906        assert!(!transport.config.enable_cors);
907        assert_eq!(
908            transport.config.allowed_origins,
909            Some(vec!["http://localhost:3000".to_string()])
910        );
911        assert!(transport.config.validate_messages);
912        assert_eq!(transport.config.session_timeout_secs, 600);
913        assert!(transport.config.require_auth);
914        assert_eq!(transport.config.valid_tokens, vec!["test-token"]);
915    }
916
917    // === HttpTransport Construction Tests ===
918
919    #[test]
920    fn test_http_transport_new() {
921        let transport = HttpTransport::new(8080);
922        assert_eq!(transport.config.port, 8080);
923        assert_eq!(transport.config.host, "127.0.0.1");
924        assert!(!transport.is_initialized());
925        assert!(!transport.is_running());
926    }
927
928    #[test]
929    fn test_http_transport_with_config() {
930        let config = HttpConfig {
931            port: 9000,
932            host: "192.168.1.1".to_string(),
933            ..Default::default()
934        };
935        let transport = HttpTransport::with_config(config);
936        assert_eq!(transport.config.port, 9000);
937        assert_eq!(transport.config.host, "192.168.1.1");
938        assert!(!transport.is_initialized());
939        assert!(!transport.is_running());
940    }
941
942    #[test]
943    fn test_http_transport_config_access() {
944        let transport = HttpTransport::new(4000);
945        let config = transport.config();
946        assert_eq!(config.port, 4000);
947    }
948
949    // === SessionInfo Tests ===
950
951    #[test]
952    fn test_session_info_creation() {
953        let (tx, rx) = broadcast::channel(1024);
954        let session = SessionInfo {
955            id: "test-session".to_string(),
956            created_at: std::time::Instant::now(),
957            last_activity: std::time::Instant::now(),
958            event_sender: tx,
959            _keepalive_receiver: Arc::new(Mutex::new(rx)),
960        };
961
962        assert_eq!(session.id, "test-session");
963    }
964
965    // === Query Parameter Tests ===
966
967    #[test]
968    fn test_query_deserialization() {
969        // Basic query parsing tests - using axum's built-in functionality
970        let query = PostQuery {
971            session_id: Some("test123".to_string()),
972        };
973        assert_eq!(query.session_id, Some("test123".to_string()));
974    }
975
976    // === Origin Validation Tests ===
977
978    #[test]
979    fn test_validate_origin_no_restrictions() {
980        let config = HttpConfig {
981            allowed_origins: None,
982            ..Default::default()
983        };
984
985        let headers = HeaderMap::new();
986        assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
987
988        let mut headers = HeaderMap::new();
989        headers.insert(ORIGIN, "http://any-origin.com".parse().unwrap());
990        assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
991    }
992
993    #[test]
994    fn test_validate_origin_with_allowed_origins() {
995        let config = HttpConfig {
996            allowed_origins: Some(vec![
997                "http://localhost:3000".to_string(),
998                "https://example.com".to_string(),
999            ]),
1000            ..Default::default()
1001        };
1002
1003        let mut headers = HeaderMap::new();
1004        headers.insert(ORIGIN, "http://localhost:3000".parse().unwrap());
1005        assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
1006
1007        headers.insert(ORIGIN, "https://example.com".parse().unwrap());
1008        assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
1009
1010        headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1011        assert!(HttpTransport::validate_origin(&config, &headers).is_err());
1012    }
1013
1014    #[test]
1015    fn test_validate_origin_missing_header() {
1016        let config = HttpConfig {
1017            allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
1018            ..Default::default()
1019        };
1020
1021        let headers = HeaderMap::new();
1022        let result = HttpTransport::validate_origin(&config, &headers);
1023        assert!(result.is_err());
1024        assert!(
1025            result
1026                .unwrap_err()
1027                .to_string()
1028                .contains("Missing Origin header")
1029        );
1030    }
1031
1032    #[test]
1033    fn test_validate_origin_invalid_header() {
1034        let config = HttpConfig {
1035            allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
1036            ..Default::default()
1037        };
1038
1039        let mut headers = HeaderMap::new();
1040        // Create an invalid UTF-8 header value
1041        headers.insert(ORIGIN, HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap());
1042        let result = HttpTransport::validate_origin(&config, &headers);
1043        assert!(result.is_err());
1044        assert!(
1045            result
1046                .unwrap_err()
1047                .to_string()
1048                .contains("Invalid Origin header")
1049        );
1050    }
1051
1052    // === Authentication Tests ===
1053
1054    #[test]
1055    fn test_validate_auth_disabled() {
1056        let config = HttpConfig {
1057            require_auth: false,
1058            ..Default::default()
1059        };
1060
1061        let headers = HeaderMap::new();
1062        assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1063    }
1064
1065    #[test]
1066    fn test_validate_auth_missing_header() {
1067        let config = HttpConfig {
1068            require_auth: true,
1069            valid_tokens: vec!["valid-token".to_string()],
1070            ..Default::default()
1071        };
1072
1073        let headers = HeaderMap::new();
1074        let result = HttpTransport::validate_auth(&config, &headers);
1075        assert!(result.is_err());
1076        assert!(
1077            result
1078                .unwrap_err()
1079                .to_string()
1080                .contains("Missing Authorization header")
1081        );
1082    }
1083
1084    #[test]
1085    fn test_validate_auth_invalid_header() {
1086        let config = HttpConfig {
1087            require_auth: true,
1088            valid_tokens: vec!["valid-token".to_string()],
1089            ..Default::default()
1090        };
1091
1092        let mut headers = HeaderMap::new();
1093        headers.insert(
1094            AUTHORIZATION,
1095            HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
1096        );
1097        let result = HttpTransport::validate_auth(&config, &headers);
1098        assert!(result.is_err());
1099        assert!(
1100            result
1101                .unwrap_err()
1102                .to_string()
1103                .contains("Invalid Authorization header")
1104        );
1105    }
1106
1107    #[test]
1108    fn test_validate_auth_valid_bearer_token() {
1109        let config = HttpConfig {
1110            require_auth: true,
1111            valid_tokens: vec!["valid-token".to_string(), "another-token".to_string()],
1112            ..Default::default()
1113        };
1114
1115        let mut headers = HeaderMap::new();
1116        headers.insert(AUTHORIZATION, "Bearer valid-token".parse().unwrap());
1117        assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1118
1119        headers.insert(AUTHORIZATION, "Bearer another-token".parse().unwrap());
1120        assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
1121    }
1122
1123    #[test]
1124    fn test_validate_auth_invalid_bearer_token() {
1125        let config = HttpConfig {
1126            require_auth: true,
1127            valid_tokens: vec!["valid-token".to_string()],
1128            ..Default::default()
1129        };
1130
1131        let mut headers = HeaderMap::new();
1132        headers.insert(AUTHORIZATION, "Bearer invalid-token".parse().unwrap());
1133        let result = HttpTransport::validate_auth(&config, &headers);
1134        assert!(result.is_err());
1135        assert!(
1136            result
1137                .unwrap_err()
1138                .to_string()
1139                .contains("Invalid bearer token")
1140        );
1141    }
1142
1143    #[test]
1144    fn test_validate_auth_invalid_format() {
1145        let config = HttpConfig {
1146            require_auth: true,
1147            valid_tokens: vec!["valid-token".to_string()],
1148            ..Default::default()
1149        };
1150
1151        let mut headers = HeaderMap::new();
1152        headers.insert(AUTHORIZATION, "Basic dXNlcjpwYXNz".parse().unwrap());
1153        let result = HttpTransport::validate_auth(&config, &headers);
1154        assert!(result.is_err());
1155        assert!(
1156            result
1157                .unwrap_err()
1158                .to_string()
1159                .contains("Invalid Authorization format")
1160        );
1161
1162        headers.insert(AUTHORIZATION, "just-a-token".parse().unwrap());
1163        let result = HttpTransport::validate_auth(&config, &headers);
1164        assert!(result.is_err());
1165        assert!(
1166            result
1167                .unwrap_err()
1168                .to_string()
1169                .contains("Invalid Authorization format")
1170        );
1171    }
1172
1173    // === Session Management Tests ===
1174
1175    #[tokio::test]
1176    async fn test_ensure_session_new() {
1177        let state = create_test_state();
1178
1179        // Create session without providing session ID
1180        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1181        assert!(!session_id.is_empty());
1182
1183        // Verify session exists
1184        let sessions = state.sessions.read().await;
1185        assert!(sessions.contains_key(&session_id));
1186        assert_eq!(sessions.len(), 1);
1187    }
1188
1189    #[tokio::test]
1190    async fn test_ensure_session_with_provided_id() {
1191        let state = create_test_state();
1192
1193        // Create session with provided ID
1194        let provided_id = "my-session-123".to_string();
1195        let session_id =
1196            HttpTransport::ensure_session(state.clone(), Some(provided_id.clone())).await;
1197        assert_eq!(session_id, provided_id);
1198
1199        // Verify session exists
1200        let sessions = state.sessions.read().await;
1201        assert!(sessions.contains_key(&session_id));
1202        assert_eq!(sessions.len(), 1);
1203    }
1204
1205    #[tokio::test]
1206    async fn test_ensure_session_existing() {
1207        let state = create_test_state();
1208
1209        // Create session first time
1210        let session_id = "existing-session".to_string();
1211        let result1 = HttpTransport::ensure_session(state.clone(), Some(session_id.clone())).await;
1212        assert_eq!(result1, session_id);
1213
1214        // Try to create session with same ID
1215        let result2 = HttpTransport::ensure_session(state.clone(), Some(session_id.clone())).await;
1216        assert_eq!(result2, session_id);
1217
1218        // Verify only one session exists
1219        let sessions = state.sessions.read().await;
1220        assert_eq!(sessions.len(), 1);
1221    }
1222
1223    #[tokio::test]
1224    async fn test_update_session_activity() {
1225        let state = create_test_state();
1226
1227        // Create session
1228        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1229
1230        // Get initial activity time
1231        let initial_activity = {
1232            let sessions = state.sessions.read().await;
1233            sessions.get(&session_id).unwrap().last_activity
1234        };
1235
1236        // Wait a bit and update activity
1237        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
1238        HttpTransport::update_session_activity(state.clone(), &session_id).await;
1239
1240        // Verify activity was updated
1241        let updated_activity = {
1242            let sessions = state.sessions.read().await;
1243            sessions.get(&session_id).unwrap().last_activity
1244        };
1245
1246        assert!(updated_activity > initial_activity);
1247    }
1248
1249    #[tokio::test]
1250    async fn test_update_session_activity_nonexistent() {
1251        let state = create_test_state();
1252
1253        // Try to update activity for non-existent session
1254        HttpTransport::update_session_activity(state.clone(), "nonexistent").await;
1255
1256        // Should not crash, and no sessions should exist
1257        let sessions = state.sessions.read().await;
1258        assert_eq!(sessions.len(), 0);
1259    }
1260
1261    #[tokio::test]
1262    async fn test_cleanup_sessions() {
1263        let config = HttpConfig {
1264            session_timeout_secs: 1, // 1 second timeout for testing
1265            ..Default::default()
1266        };
1267
1268        let state = Arc::new(HttpState {
1269            handler: Arc::new(Box::new(mock_handler)),
1270            config,
1271            sessions: Arc::new(RwLock::new(HashMap::new())),
1272        });
1273
1274        // Create a session
1275        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1276
1277        // Manually set the session as old
1278        {
1279            let mut sessions = state.sessions.write().await;
1280            if let Some(session) = sessions.get_mut(&session_id) {
1281                session.last_activity =
1282                    std::time::Instant::now() - std::time::Duration::from_secs(2);
1283            }
1284        }
1285
1286        // Run cleanup
1287        HttpTransport::cleanup_sessions(state.clone()).await;
1288
1289        // Session should be removed
1290        let sessions = state.sessions.read().await;
1291        assert_eq!(sessions.len(), 0);
1292    }
1293
1294    #[tokio::test]
1295    async fn test_cleanup_sessions_keeps_active() {
1296        let state = create_test_state();
1297
1298        // Create sessions
1299        let session_id1 = HttpTransport::ensure_session(state.clone(), None).await;
1300        let session_id2 = HttpTransport::ensure_session(state.clone(), None).await;
1301
1302        // Run cleanup (sessions should remain as they're recent)
1303        HttpTransport::cleanup_sessions(state.clone()).await;
1304
1305        // Both sessions should still exist
1306        let sessions = state.sessions.read().await;
1307        assert_eq!(sessions.len(), 2);
1308        assert!(sessions.contains_key(&session_id1));
1309        assert!(sessions.contains_key(&session_id2));
1310    }
1311
1312    // === Broadcast Message Tests ===
1313
1314    #[tokio::test]
1315    async fn test_broadcast_message_not_initialized() {
1316        let transport = HttpTransport::new(3000);
1317        let result = transport.broadcast_message("test message").await;
1318        assert!(result.is_err());
1319        assert!(
1320            result
1321                .unwrap_err()
1322                .to_string()
1323                .contains("Transport not started")
1324        );
1325    }
1326
1327    #[tokio::test]
1328    async fn test_broadcast_message_with_sessions() {
1329        let state = create_test_state();
1330
1331        // Simulate initialized transport
1332        let transport = HttpTransport {
1333            config: HttpConfig::default(),
1334            state: Some((*state).clone()),
1335            server_handle: None,
1336        };
1337
1338        // Create a session
1339        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1340
1341        // Get a receiver to test the broadcast
1342        let mut receiver = {
1343            let sessions = state.sessions.read().await;
1344            sessions.get(&session_id).unwrap().event_sender.subscribe()
1345        };
1346
1347        // Broadcast a message
1348        let result = transport.broadcast_message("test broadcast").await;
1349        assert!(result.is_ok());
1350
1351        // Verify message was received
1352        let received = receiver.recv().await.unwrap();
1353        assert_eq!(received, "test broadcast");
1354    }
1355
1356    #[tokio::test]
1357    async fn test_broadcast_message_no_sessions() {
1358        let state = create_test_state();
1359
1360        // Simulate initialized transport with no sessions
1361        let transport = HttpTransport {
1362            config: HttpConfig::default(),
1363            state: Some((*state).clone()),
1364            server_handle: None,
1365        };
1366
1367        // Broadcast a message (should succeed even with no sessions)
1368        let result = transport.broadcast_message("test broadcast").await;
1369        assert!(result.is_ok());
1370    }
1371
1372    // === Handle POST Tests ===
1373
1374    #[tokio::test]
1375    async fn test_handle_post_valid_wrapped_message() {
1376        let state = create_test_state();
1377        let query = PostQuery {
1378            session_id: Some("test-session".to_string()),
1379        };
1380        let headers = create_test_headers();
1381        let body = json!({
1382            "message": {
1383                "jsonrpc": "2.0",
1384                "method": "ping",
1385                "params": {},
1386                "id": 1
1387            }
1388        })
1389        .to_string();
1390
1391        let result = handle_post(State(state), Query(query), headers, body).await;
1392        assert!(result.is_ok());
1393    }
1394
1395    #[tokio::test]
1396    async fn test_handle_post_valid_direct_message() {
1397        let state = create_test_state();
1398        let query = PostQuery { session_id: None };
1399        let headers = create_test_headers();
1400        let body = json!({
1401            "jsonrpc": "2.0",
1402            "method": "ping",
1403            "params": {},
1404            "id": 1
1405        })
1406        .to_string();
1407
1408        let result = handle_post(State(state), Query(query), headers, body).await;
1409        assert!(result.is_ok());
1410    }
1411
1412    #[tokio::test]
1413    async fn test_handle_post_invalid_json() {
1414        let state = create_test_state();
1415        let query = PostQuery { session_id: None };
1416        let headers = create_test_headers();
1417        let body = "invalid json".to_string();
1418
1419        let result = handle_post(State(state), Query(query), headers, body).await;
1420        assert!(result.is_err());
1421        assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1422    }
1423
1424    #[tokio::test]
1425    async fn test_handle_post_invalid_format() {
1426        let state = create_test_state();
1427        let query = PostQuery { session_id: None };
1428        let headers = create_test_headers();
1429        let body = json!({
1430            "not_jsonrpc": "data"
1431        })
1432        .to_string();
1433
1434        let result = handle_post(State(state), Query(query), headers, body).await;
1435        assert!(result.is_err());
1436        assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1437    }
1438
1439    #[tokio::test]
1440    async fn test_handle_post_origin_validation_failure() {
1441        let config = HttpConfig {
1442            allowed_origins: Some(vec!["http://allowed.com".to_string()]),
1443            ..Default::default()
1444        };
1445
1446        let state = Arc::new(HttpState {
1447            handler: Arc::new(Box::new(mock_handler)),
1448            config,
1449            sessions: Arc::new(RwLock::new(HashMap::new())),
1450        });
1451
1452        let query = PostQuery { session_id: None };
1453        let mut headers = create_test_headers();
1454        headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1455        let body = json!({
1456            "jsonrpc": "2.0",
1457            "method": "ping",
1458            "params": {},
1459            "id": 1
1460        })
1461        .to_string();
1462
1463        let result = handle_post(State(state), Query(query), headers, body).await;
1464        assert!(result.is_err());
1465        assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
1466    }
1467
1468    #[tokio::test]
1469    async fn test_handle_post_auth_failure() {
1470        let config = HttpConfig {
1471            require_auth: true,
1472            valid_tokens: vec!["valid-token".to_string()],
1473            ..Default::default()
1474        };
1475
1476        let state = Arc::new(HttpState {
1477            handler: Arc::new(Box::new(mock_handler)),
1478            config,
1479            sessions: Arc::new(RwLock::new(HashMap::new())),
1480        });
1481
1482        let query = PostQuery { session_id: None };
1483        let mut headers = create_test_headers();
1484        headers.insert(AUTHORIZATION, "Bearer invalid-token".parse().unwrap());
1485        let body = json!({
1486            "jsonrpc": "2.0",
1487            "method": "ping",
1488            "params": {},
1489            "id": 1
1490        })
1491        .to_string();
1492
1493        let result = handle_post(State(state), Query(query), headers, body).await;
1494        assert!(result.is_err());
1495        assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
1496    }
1497
1498    #[tokio::test]
1499    async fn test_handle_post_message_validation_failure() {
1500        let config = HttpConfig {
1501            validate_messages: true,
1502            max_message_size: 10, // Very small limit
1503            ..Default::default()
1504        };
1505
1506        let state = Arc::new(HttpState {
1507            handler: Arc::new(Box::new(mock_handler)),
1508            config,
1509            sessions: Arc::new(RwLock::new(HashMap::new())),
1510        });
1511
1512        let query = PostQuery { session_id: None };
1513        let headers = create_test_headers();
1514        let body = json!({
1515            "jsonrpc": "2.0",
1516            "method": "this_is_a_very_long_method_name_that_exceeds_the_size_limit",
1517            "params": {},
1518            "id": 1
1519        })
1520        .to_string();
1521
1522        let result = handle_post(State(state), Query(query), headers, body).await;
1523        assert!(result.is_err());
1524        assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1525    }
1526
1527    #[tokio::test]
1528    async fn test_handle_post_streamable_http_mode() {
1529        let state = create_test_state();
1530        let query = PostQuery { session_id: None };
1531        let mut headers = create_test_headers();
1532        headers.insert(
1533            "accept",
1534            "text/event-stream, application/json".parse().unwrap(),
1535        );
1536        let body = json!({
1537            "jsonrpc": "2.0",
1538            "method": "ping",
1539            "params": {},
1540            "id": 1
1541        })
1542        .to_string();
1543
1544        let result = handle_post(State(state), Query(query), headers, body).await;
1545        assert!(result.is_ok());
1546
1547        let response = result.unwrap();
1548        assert_eq!(response.status(), StatusCode::OK);
1549        assert!(
1550            response
1551                .headers()
1552                .get("Content-Type")
1553                .unwrap()
1554                .to_str()
1555                .unwrap()
1556                .contains("application/json")
1557        );
1558        assert!(response.headers().contains_key("Mcp-Session-Id"));
1559    }
1560
1561    #[tokio::test]
1562    async fn test_handle_post_sse_mode() {
1563        let state = create_test_state();
1564        let query = PostQuery { session_id: None };
1565        let mut headers = create_test_headers();
1566        headers.insert("accept", "text/event-stream".parse().unwrap());
1567        let body = json!({
1568            "jsonrpc": "2.0",
1569            "method": "ping",
1570            "params": {},
1571            "id": 1
1572        })
1573        .to_string();
1574
1575        let result = handle_post(State(state), Query(query), headers, body).await;
1576        assert!(result.is_ok());
1577
1578        let response = result.unwrap();
1579        assert_eq!(response.status(), StatusCode::NO_CONTENT);
1580        assert!(response.headers().contains_key("Mcp-Session-Id"));
1581    }
1582
1583    #[tokio::test]
1584    async fn test_handle_post_notification_response() {
1585        let state = Arc::new(HttpState {
1586            handler: Arc::new(Box::new(mock_notification_handler)),
1587            config: HttpConfig::default(),
1588            sessions: Arc::new(RwLock::new(HashMap::new())),
1589        });
1590
1591        let query = PostQuery { session_id: None };
1592        let headers = create_test_headers();
1593        let body = json!({
1594            "jsonrpc": "2.0",
1595            "method": "notification",
1596            "params": {}
1597        })
1598        .to_string();
1599
1600        let result = handle_post(State(state), Query(query), headers, body).await;
1601        assert!(result.is_ok());
1602
1603        let response = result.unwrap();
1604        assert_eq!(response.status(), StatusCode::NO_CONTENT);
1605    }
1606
1607    #[tokio::test]
1608    async fn test_handle_post_processing_error() {
1609        let state = Arc::new(HttpState {
1610            handler: Arc::new(Box::new(mock_error_handler)),
1611            config: HttpConfig::default(),
1612            sessions: Arc::new(RwLock::new(HashMap::new())),
1613        });
1614
1615        let query = PostQuery { session_id: None };
1616        let mut headers = create_test_headers();
1617        headers.insert("accept", "application/json".parse().unwrap());
1618        let body = json!({
1619            "jsonrpc": "2.0",
1620            "method": "unknown_method",
1621            "params": {},
1622            "id": 1
1623        })
1624        .to_string();
1625
1626        let result = handle_post(State(state), Query(query), headers, body).await;
1627        assert!(result.is_ok());
1628
1629        let response = result.unwrap();
1630        assert_eq!(response.status(), StatusCode::OK);
1631
1632        // Response should contain error information
1633        let body_str = response.body();
1634        assert!(body_str.contains("error"));
1635    }
1636
1637    #[tokio::test]
1638    async fn test_handle_post_session_id_from_header() {
1639        let state = create_test_state();
1640        let query = PostQuery { session_id: None };
1641        let mut headers = create_test_headers();
1642        headers.insert("Mcp-Session-Id", "header-session-123".parse().unwrap());
1643        let body = json!({
1644            "jsonrpc": "2.0",
1645            "method": "ping",
1646            "params": {},
1647            "id": 1
1648        })
1649        .to_string();
1650
1651        let result = handle_post(State(state.clone()), Query(query), headers, body).await;
1652        assert!(result.is_ok());
1653
1654        // Verify session was created with the header session ID
1655        let sessions = state.sessions.read().await;
1656        assert!(sessions.contains_key("header-session-123"));
1657    }
1658
1659    // === Handle SSE Tests ===
1660
1661    #[tokio::test]
1662    async fn test_handle_sse_basic() {
1663        let state = create_test_state();
1664        let query = SseQuery {
1665            session_id: Some("sse-test-session".to_string()),
1666            last_event_id: None,
1667            transport_type: None,
1668            url: None,
1669        };
1670        let headers = create_test_headers();
1671        let uri = "http://localhost:3000/sse?sessionId=sse-test-session"
1672            .parse()
1673            .unwrap();
1674
1675        let result = handle_sse(uri, State(state.clone()), headers, Query(query)).await;
1676        assert!(result.is_ok());
1677
1678        let response = result.unwrap();
1679        assert_eq!(response.status(), StatusCode::OK);
1680        assert!(response.headers().contains_key("Mcp-Session-Id"));
1681        assert_eq!(
1682            response
1683                .headers()
1684                .get("content-type")
1685                .unwrap()
1686                .to_str()
1687                .unwrap(),
1688            "text/event-stream"
1689        );
1690
1691        // Verify session was created
1692        let sessions = state.sessions.read().await;
1693        assert!(sessions.contains_key("sse-test-session"));
1694    }
1695
1696    #[tokio::test]
1697    async fn test_handle_sse_origin_validation_failure() {
1698        let config = HttpConfig {
1699            allowed_origins: Some(vec!["http://allowed.com".to_string()]),
1700            ..Default::default()
1701        };
1702
1703        let state = Arc::new(HttpState {
1704            handler: Arc::new(Box::new(mock_handler)),
1705            config,
1706            sessions: Arc::new(RwLock::new(HashMap::new())),
1707        });
1708
1709        let query = SseQuery {
1710            session_id: None,
1711            last_event_id: None,
1712            transport_type: None,
1713            url: None,
1714        };
1715        let mut headers = create_test_headers();
1716        headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
1717        let uri = "http://localhost:3000/sse".parse().unwrap();
1718
1719        let result = handle_sse(uri, State(state), headers, Query(query)).await;
1720        assert!(result.is_err());
1721        assert_eq!(result.unwrap_err(), StatusCode::FORBIDDEN);
1722    }
1723
1724    #[tokio::test]
1725    async fn test_handle_sse_auth_failure() {
1726        let config = HttpConfig {
1727            require_auth: true,
1728            valid_tokens: vec!["valid-token".to_string()],
1729            ..Default::default()
1730        };
1731
1732        let state = Arc::new(HttpState {
1733            handler: Arc::new(Box::new(mock_handler)),
1734            config,
1735            sessions: Arc::new(RwLock::new(HashMap::new())),
1736        });
1737
1738        let query = SseQuery {
1739            session_id: None,
1740            last_event_id: None,
1741            transport_type: None,
1742            url: None,
1743        };
1744        let headers = create_test_headers();
1745        let uri = "http://localhost:3000/sse".parse().unwrap();
1746
1747        let result = handle_sse(uri, State(state), headers, Query(query)).await;
1748        assert!(result.is_err());
1749        assert_eq!(result.unwrap_err(), StatusCode::UNAUTHORIZED);
1750    }
1751
1752    // === Handle Health Tests ===
1753
1754    #[tokio::test]
1755    async fn test_handle_health() {
1756        let result = handle_health().await;
1757        assert_eq!(result, "OK");
1758    }
1759
1760    // === Transport Trait Implementation Tests ===
1761
1762    #[tokio::test]
1763    async fn test_transport_start_invalid_address() {
1764        let config = HttpConfig {
1765            host: "invalid-host-name-that-does-not-exist".to_string(),
1766            port: 0,
1767            ..Default::default()
1768        };
1769        let mut transport = HttpTransport::with_config(config);
1770
1771        let result = transport.start(Box::new(mock_handler)).await;
1772        assert!(result.is_err());
1773        assert!(result.unwrap_err().to_string().contains("Invalid address"));
1774    }
1775
1776    #[tokio::test]
1777    async fn test_transport_start_and_stop() {
1778        let mut transport = HttpTransport::new(0); // Use port 0 for OS-assigned port
1779
1780        // Initially not running
1781        assert!(!transport.is_initialized());
1782        assert!(!transport.is_running());
1783
1784        // Start transport
1785        let result = transport.start(Box::new(mock_handler)).await;
1786        assert!(result.is_ok());
1787        assert!(transport.is_initialized());
1788        assert!(transport.is_running());
1789
1790        // Health check should pass
1791        assert!(transport.health_check().await.is_ok());
1792
1793        // Stop transport
1794        let result = transport.stop().await;
1795        assert!(result.is_ok());
1796        assert!(!transport.is_running());
1797    }
1798
1799    #[tokio::test]
1800    async fn test_transport_health_check_not_running() {
1801        let transport = HttpTransport::new(3000);
1802        let result = transport.health_check().await;
1803        assert!(result.is_err());
1804        assert!(
1805            result
1806                .unwrap_err()
1807                .to_string()
1808                .contains("HTTP transport not running")
1809        );
1810    }
1811
1812    // === Integration Tests ===
1813
1814    #[tokio::test]
1815    async fn test_full_session_lifecycle() {
1816        let state = create_test_state();
1817
1818        // Create session
1819        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1820        assert!(!session_id.is_empty());
1821
1822        // Update activity
1823        HttpTransport::update_session_activity(state.clone(), &session_id).await;
1824
1825        // Send a message through the session
1826        let message = "test message";
1827        {
1828            let sessions = state.sessions.read().await;
1829            let session = sessions.get(&session_id).unwrap();
1830            let result = session.event_sender.send(message.to_string());
1831            assert!(result.is_ok());
1832        }
1833
1834        // Clean up sessions (recent session should remain)
1835        HttpTransport::cleanup_sessions(state.clone()).await;
1836        {
1837            let sessions = state.sessions.read().await;
1838            assert!(sessions.contains_key(&session_id));
1839        }
1840    }
1841
1842    #[tokio::test]
1843    async fn test_multiple_sessions() {
1844        let state = create_test_state();
1845
1846        // Create multiple sessions
1847        let session1 =
1848            HttpTransport::ensure_session(state.clone(), Some("session-1".to_string())).await;
1849        let session2 =
1850            HttpTransport::ensure_session(state.clone(), Some("session-2".to_string())).await;
1851        let session3 = HttpTransport::ensure_session(state.clone(), None).await;
1852
1853        assert_eq!(session1, "session-1");
1854        assert_eq!(session2, "session-2");
1855        assert!(!session3.is_empty());
1856        assert_ne!(session3, session1);
1857        assert_ne!(session3, session2);
1858
1859        // Verify all sessions exist
1860        let sessions = state.sessions.read().await;
1861        assert_eq!(sessions.len(), 3);
1862        assert!(sessions.contains_key(&session1));
1863        assert!(sessions.contains_key(&session2));
1864        assert!(sessions.contains_key(&session3));
1865    }
1866
1867    #[tokio::test]
1868    async fn test_message_format_variations() {
1869        let state = create_test_state();
1870        let query = PostQuery { session_id: None };
1871        let headers = create_test_headers();
1872
1873        // Test wrapped format
1874        let wrapped_body = json!({
1875            "message": {
1876                "jsonrpc": "2.0",
1877                "method": "test",
1878                "params": {"key": "value"},
1879                "id": 1
1880            }
1881        })
1882        .to_string();
1883
1884        let result = handle_post(
1885            State(state.clone()),
1886            Query(query.clone()),
1887            headers.clone(),
1888            wrapped_body,
1889        )
1890        .await;
1891        assert!(result.is_ok());
1892
1893        // Test direct format
1894        let direct_body = json!({
1895            "jsonrpc": "2.0",
1896            "method": "test",
1897            "params": {"key": "value"},
1898            "id": 2
1899        })
1900        .to_string();
1901
1902        let result = handle_post(State(state), Query(query), headers, direct_body).await;
1903        assert!(result.is_ok());
1904    }
1905
1906    #[tokio::test]
1907    async fn test_error_handling_edge_cases() {
1908        let state = create_test_state();
1909        let query = PostQuery { session_id: None };
1910        let headers = create_test_headers();
1911
1912        // Test malformed JSON-RPC (missing required fields)
1913        let invalid_jsonrpc = json!({
1914            "jsonrpc": "1.0", // Wrong version
1915            "method": "test"
1916            // Missing id
1917        })
1918        .to_string();
1919
1920        let result = handle_post(State(state), Query(query), headers, invalid_jsonrpc).await;
1921        assert!(result.is_err());
1922        assert_eq!(result.unwrap_err(), StatusCode::BAD_REQUEST);
1923    }
1924
1925    // === Configuration Edge Cases ===
1926
1927    #[test]
1928    fn test_config_extreme_values() {
1929        let config = HttpConfig {
1930            port: 65535,
1931            host: "0.0.0.0".to_string(),
1932            max_message_size: 0,
1933            enable_cors: true,
1934            allowed_origins: Some(vec![]),
1935            validate_messages: false,
1936            session_timeout_secs: 0,
1937            require_auth: false,
1938            valid_tokens: vec![],
1939        };
1940
1941        let transport = HttpTransport::with_config(config);
1942        assert_eq!(transport.config.port, 65535);
1943        assert_eq!(transport.config.max_message_size, 0);
1944        assert_eq!(transport.config.session_timeout_secs, 0);
1945        assert!(
1946            transport
1947                .config
1948                .allowed_origins
1949                .as_ref()
1950                .unwrap()
1951                .is_empty()
1952        );
1953    }
1954
1955    #[test]
1956    fn test_session_info_timing() {
1957        let now = std::time::Instant::now();
1958        let (tx, rx) = broadcast::channel(1024);
1959
1960        let session = SessionInfo {
1961            id: "timing-test".to_string(),
1962            created_at: now,
1963            last_activity: now,
1964            event_sender: tx,
1965            _keepalive_receiver: Arc::new(Mutex::new(rx)),
1966        };
1967
1968        assert!(session.created_at <= std::time::Instant::now());
1969        assert!(session.last_activity <= std::time::Instant::now());
1970    }
1971
1972    // === Broadcast Channel Edge Cases ===
1973
1974    #[tokio::test]
1975    async fn test_broadcast_channel_receiver_drop() {
1976        let state = create_test_state();
1977
1978        // Create session and get receiver
1979        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
1980        let receiver = {
1981            let sessions = state.sessions.read().await;
1982            sessions.get(&session_id).unwrap().event_sender.subscribe()
1983        };
1984
1985        // Drop the receiver
1986        drop(receiver);
1987
1988        // Simulate initialized transport
1989        let transport = HttpTransport {
1990            config: HttpConfig::default(),
1991            state: Some((*state).clone()),
1992            server_handle: None,
1993        };
1994
1995        // Broadcasting should still work (might log warnings but not fail)
1996        let result = transport.broadcast_message("test after drop").await;
1997        assert!(result.is_ok());
1998    }
1999
2000    #[tokio::test]
2001    async fn test_session_channel_capacity() {
2002        let state = create_test_state();
2003        let session_id = HttpTransport::ensure_session(state.clone(), None).await;
2004
2005        // Get sender and fill the channel beyond capacity
2006        let sender = {
2007            let sessions = state.sessions.read().await;
2008            sessions.get(&session_id).unwrap().event_sender.clone()
2009        };
2010
2011        // Send many messages to test channel behavior
2012        for i in 0..2000 {
2013            // More than the 1024 capacity
2014            let _ = sender.send(format!("message-{i}"));
2015        }
2016
2017        // This should not crash the test
2018        // Channel is tested for capacity behavior
2019    }
2020}