1use crate::{
4    batch::{process_batch, JsonRpcMessage},
5    validation::validate_message_string,
6    RequestHandler, Transport, TransportError,
7};
8use async_trait::async_trait;
9use axum::response::sse::{Event, KeepAlive};
10use axum::{
11    extract::{Query, State},
12    http::{
13        header::{AUTHORIZATION, ORIGIN},
14        HeaderMap, StatusCode,
15    },
16    response::{IntoResponse, Response as AxumResponse, Sse},
17    routing::{get, post},
18    Router,
19};
20use serde::Deserialize;
23use serde_json;
24use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
25use tokio::sync::{broadcast, Mutex, RwLock};
26use tower::ServiceBuilder;
27use tower_http::cors::CorsLayer;
28use tracing::{debug, error, info, warn};
29use uuid::Uuid;
30
31#[derive(Debug, Clone)]
33pub struct HttpConfig {
34    pub port: u16,
36    pub host: String,
38    pub max_message_size: usize,
40    pub enable_cors: bool,
42    pub allowed_origins: Option<Vec<String>>,
44    pub validate_messages: bool,
46    pub session_timeout_secs: u64,
48    pub require_auth: bool,
50    pub valid_tokens: Vec<String>,
52}
53
54impl Default for HttpConfig {
55    fn default() -> Self {
56        Self {
57            port: 3000,
58            host: "127.0.0.1".to_string(),
59            max_message_size: 10 * 1024 * 1024, enable_cors: true,
61            allowed_origins: None,
62            validate_messages: true,
63            session_timeout_secs: 300, require_auth: false,
65            valid_tokens: vec![],
66        }
67    }
68}
69
70#[derive(Clone)]
72struct SessionInfo {
73    #[allow(dead_code)]
74    id: String,
75    #[allow(dead_code)]
76    created_at: std::time::Instant,
77    last_activity: std::time::Instant,
78    event_sender: broadcast::Sender<String>,
79    #[allow(dead_code)]
81    _keepalive_receiver: Arc<Mutex<broadcast::Receiver<String>>>,
82}
83
84#[derive(Clone)]
86struct HttpState {
87    handler: Arc<RequestHandler>,
88    config: HttpConfig,
89    sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
90}
91
92#[derive(Debug, Deserialize)]
94struct SseQuery {
95    #[serde(rename = "sessionId")]
96    session_id: Option<String>,
97    #[serde(rename = "lastEventId")]
98    #[allow(dead_code)]
99    last_event_id: Option<String>,
100    #[serde(rename = "transportType")]
101    #[allow(dead_code)]
102    transport_type: Option<String>,
103    #[allow(dead_code)]
104    url: Option<String>,
105}
106
107pub struct HttpTransport {
116    config: HttpConfig,
117    state: Option<HttpState>,
118    server_handle: Option<tokio::task::JoinHandle<()>>,
119}
120
121impl HttpTransport {
122    pub fn new(port: u16) -> Self {
124        let config = HttpConfig {
125            port,
126            ..Default::default()
127        };
128
129        Self {
130            config,
131            state: None,
132            server_handle: None,
133        }
134    }
135
136    pub async fn broadcast_message(&self, message: &str) -> Result<(), TransportError> {
138        if let Some(ref state) = self.state {
139            let sessions = state.sessions.read().await;
140            for (session_id, session) in sessions.iter() {
141                if let Err(e) = session.event_sender.send(message.to_string()) {
142                    debug!("Failed to send to session {}: {}", session_id, e);
143                }
144            }
145            Ok(())
146        } else {
147            Err(TransportError::Connection(
148                "Transport not started".to_string(),
149            ))
150        }
151    }
152
153    pub fn with_config(config: HttpConfig) -> Self {
155        Self {
156            config,
157            state: None,
158            server_handle: None,
159        }
160    }
161
162    async fn create_session(state: Arc<HttpState>) -> String {
164        let session_id = Uuid::new_v4().to_string();
165        let (tx, keepalive_rx) = broadcast::channel(1024);
168
169        let session_info = SessionInfo {
170            id: session_id.clone(),
171            created_at: std::time::Instant::now(),
172            last_activity: std::time::Instant::now(),
173            event_sender: tx,
174            _keepalive_receiver: Arc::new(Mutex::new(keepalive_rx)),
175        };
176
177        {
178            let mut sessions = state.sessions.write().await;
179            sessions.insert(session_id.clone(), session_info);
180        }
181
182        debug!("Created new session: {}", session_id);
183        session_id
184    }
185
186    async fn update_session_activity(state: Arc<HttpState>, session_id: &str) {
188        let mut sessions = state.sessions.write().await;
189        if let Some(session) = sessions.get_mut(session_id) {
190            session.last_activity = std::time::Instant::now();
191        }
192    }
193
194    async fn cleanup_sessions(state: Arc<HttpState>) {
196        let timeout = Duration::from_secs(state.config.session_timeout_secs);
197        let now = std::time::Instant::now();
198
199        let mut sessions = state.sessions.write().await;
200        sessions.retain(|id, session| {
201            let expired = now.duration_since(session.last_activity) > timeout;
202            if expired {
203                debug!("Removing expired session: {}", id);
204            }
205            !expired
206        });
207    }
208
209    fn validate_origin(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
211        if let Some(allowed_origins) = &config.allowed_origins {
212            if let Some(origin) = headers.get(ORIGIN) {
213                let origin_str = origin
214                    .to_str()
215                    .map_err(|_| TransportError::Protocol("Invalid Origin header".to_string()))?;
216
217                if !allowed_origins.contains(&origin_str.to_string()) {
218                    return Err(TransportError::Protocol(format!(
219                        "Origin not allowed: {origin_str}"
220                    )));
221                }
222            } else {
223                return Err(TransportError::Protocol(
224                    "Missing Origin header".to_string(),
225                ));
226            }
227        }
228
229        Ok(())
230    }
231
232    fn validate_auth(config: &HttpConfig, headers: &HeaderMap) -> Result<(), TransportError> {
234        if !config.require_auth {
235            return Ok(());
236        }
237
238        let auth_header = headers
239            .get(AUTHORIZATION)
240            .ok_or_else(|| TransportError::Protocol("Missing Authorization header".to_string()))?;
241
242        let auth_str = auth_header
243            .to_str()
244            .map_err(|_| TransportError::Protocol("Invalid Authorization header".to_string()))?;
245
246        if let Some(token) = auth_str.strip_prefix("Bearer ") {
247            if config.valid_tokens.contains(&token.to_string()) {
248                Ok(())
249            } else {
250                Err(TransportError::Protocol("Invalid bearer token".to_string()))
251            }
252        } else {
253            Err(TransportError::Protocol(
254                "Invalid Authorization format, expected Bearer token".to_string(),
255            ))
256        }
257    }
258}
259
260#[derive(Debug, Deserialize)]
262struct PostQuery {
263    session_id: Option<String>,
264}
265
266async fn handle_post(
268    State(state): State<Arc<HttpState>>,
269    Query(query): Query<PostQuery>,
270    headers: HeaderMap,
271    body: String,
272) -> Result<AxumResponse<String>, StatusCode> {
273    info!("Received POST request with session query: {:?}", query);
274    debug!("Raw request body: {}", body);
275
276    let request_value: serde_json::Value = match serde_json::from_str(&body) {
278        Ok(v) => v,
279        Err(e) => {
280            warn!("Failed to parse JSON: {}", e);
281            return Err(StatusCode::BAD_REQUEST);
282        }
283    };
284
285    let message = if let Some(wrapped_message) = request_value.get("message") {
287        wrapped_message.clone()
289    } else if request_value.get("jsonrpc").is_some() {
290        request_value
292    } else {
293        warn!("Invalid request format - no 'message' field and no 'jsonrpc' field");
294        return Err(StatusCode::BAD_REQUEST);
295    };
296
297    info!("Request message: {:?}", message);
298
299    if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
301        warn!("Origin validation failed: {}", e);
302        return Err(StatusCode::FORBIDDEN);
303    }
304
305    if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
307        warn!("Authentication failed: {}", e);
308        return Err(StatusCode::UNAUTHORIZED);
309    }
310
311    let session_id = if let Some(id) = query.session_id {
313        id
314    } else if let Some(id) = headers
315        .get("Mcp-Session-Id")
316        .and_then(|v| v.to_str().ok())
317        .map(|s| s.to_string())
318    {
319        id
320    } else {
321        HttpTransport::create_session(state.clone()).await
323    };
324
325    let message_json = serde_json::to_string(&message).map_err(|_| StatusCode::BAD_REQUEST)?;
327
328    if state.config.validate_messages {
329        if let Err(e) = validate_message_string(&message_json, Some(state.config.max_message_size))
330        {
331            warn!("Message validation failed: {}", e);
332            return Err(StatusCode::BAD_REQUEST);
333        }
334    }
335
336    let message = JsonRpcMessage::parse(&message_json).map_err(|_| StatusCode::BAD_REQUEST)?;
338
339    if let Err(e) = message.validate() {
341        warn!("JSON-RPC validation failed: {}", e);
342        return Err(StatusCode::BAD_REQUEST);
343    }
344
345    {
347        HttpTransport::update_session_activity(state.clone(), &session_id).await;
348    }
349
350    match process_batch(message, &state.handler).await {
352        Ok(Some(response_message)) => {
353            let response_json = response_message
354                .to_string()
355                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
356
357            let accept_header = headers
359                .get("accept")
360                .and_then(|v| v.to_str().ok())
361                .unwrap_or("");
362
363            let wants_json_response = if accept_header.starts_with("application/json") {
367                true } else if accept_header.contains("application/json")
369                && accept_header.contains("text/event-stream")
370            {
371                let json_pos = accept_header.find("application/json").unwrap_or(usize::MAX);
373                let sse_pos = accept_header
374                    .find("text/event-stream")
375                    .unwrap_or(usize::MAX);
376                json_pos < sse_pos } else {
378                accept_header.contains("application/json")
379                    && !accept_header.contains("text/event-stream")
380            };
381
382            if wants_json_response {
383                info!("Using Streamable HTTP transport, returning response directly for session: {}, Accept: {}", session_id, accept_header);
385                debug!("Direct response: {}", response_json);
386                Ok(AxumResponse::builder()
387                    .status(StatusCode::OK)
388                    .header("Content-Type", "application/json")
389                    .header("Mcp-Session-Id", session_id)
390                    .body(response_json)
391                    .unwrap())
392            } else {
393                info!(
395                    "Using legacy HTTP+SSE transport for session: {}, Accept: {}",
396                    session_id, accept_header
397                );
398                debug!("Response to send through SSE: {}", response_json);
399
400                let sessions = state.sessions.read().await;
401                info!("Active sessions: {}", sessions.len());
402
403                if let Some(session) = sessions.get(&session_id) {
404                    info!("Found session {}, sending response", session_id);
405                    match session.event_sender.send(response_json.clone()) {
406                        Ok(num_receivers) => {
407                            info!(
408                                "Response sent successfully to {} receivers on session: {}",
409                                num_receivers, session_id
410                            );
411                        }
412                        Err(e) => {
413                            warn!("Failed to send response through SSE: {}", e);
414                        }
415                    }
416                } else {
417                    warn!(
418                        "Session {} not found for response, trying any active session",
419                        session_id
420                    );
421                    let mut sent = false;
423                    for (sid, session) in sessions.iter() {
424                        match session.event_sender.send(response_json.clone()) {
425                            Ok(num_receivers) => {
426                                info!("Response sent successfully to {} receivers on fallback session: {}", num_receivers, sid);
427                                sent = true;
428                                break;
429                            }
430                            Err(e) => {
431                                debug!("Failed to send to session {}: {}", sid, e);
432                            }
433                        }
434                    }
435                    if !sent {
436                        warn!("No active sessions available to send response");
437                    }
438                }
439
440                Ok(AxumResponse::builder()
442                    .status(StatusCode::NO_CONTENT)
443                    .header("Mcp-Session-Id", session_id)
444                    .body("".to_string())
445                    .unwrap())
446            }
447        }
448        Ok(None) => {
449            Ok(AxumResponse::builder()
451                .status(StatusCode::NO_CONTENT)
452                .body("".to_string())
453                .unwrap())
454        }
455        Err(e) => {
456            error!("Failed to process message: {}", e);
457
458            let error_response = pulseengine_mcp_protocol::Response {
460                jsonrpc: "2.0".to_string(),
461                id: serde_json::Value::Null,
462                result: None,
463                error: Some(pulseengine_mcp_protocol::Error::internal_error(
464                    e.to_string(),
465                )),
466            };
467
468            if let Ok(error_json) = serde_json::to_string(&error_response) {
469                let accept_header = headers
471                    .get("accept")
472                    .and_then(|v| v.to_str().ok())
473                    .unwrap_or("");
474                let wants_json_response = accept_header.contains("application/json")
475                    && !accept_header.contains("text/event-stream");
476
477                if wants_json_response {
478                    debug!(
480                        "Using Streamable HTTP transport, returning error directly: {}",
481                        error_json
482                    );
483                    Ok(AxumResponse::builder()
484                        .status(StatusCode::OK)
485                        .header("Content-Type", "application/json")
486                        .header("Mcp-Session-Id", session_id)
487                        .body(error_json)
488                        .unwrap())
489                } else {
490                    debug!(
492                        "Using legacy HTTP+SSE transport, sending error through SSE: {}",
493                        error_json
494                    );
495                    let sessions = state.sessions.read().await;
496                    if let Some(session) = sessions.get(&session_id) {
497                        if let Err(e) = session.event_sender.send(error_json.clone()) {
498                            warn!("Failed to send error through SSE: {}", e);
499                        } else {
500                            debug!(
501                                "Error response sent successfully to session: {}",
502                                session_id
503                            );
504                        }
505                    } else {
506                        warn!(
507                            "Session {} not found for error response, trying any active session",
508                            session_id
509                        );
510                        let mut sent = false;
512                        for (sid, session) in sessions.iter() {
513                            if session.event_sender.send(error_json.clone()).is_ok() {
514                                debug!(
515                                    "Error response sent successfully to fallback session: {}",
516                                    sid
517                                );
518                                sent = true;
519                                break;
520                            }
521                        }
522                        if !sent {
523                            warn!("No active sessions available to send error response");
524                        }
525                    }
526
527                    Ok(AxumResponse::builder()
529                        .status(StatusCode::NO_CONTENT)
530                        .body("".to_string())
531                        .unwrap())
532                }
533            } else {
534                Err(StatusCode::INTERNAL_SERVER_ERROR)
536            }
537        }
538    }
539}
540
541async fn handle_sse(
543    uri: axum::http::Uri,
544    State(state): State<Arc<HttpState>>,
545    headers: HeaderMap,
546    Query(query): Query<SseQuery>,
547) -> Result<axum::response::Response, StatusCode> {
548    info!(
549        "Received SSE request - URI: {}, query string: {:?}, parsed query: {:?}",
550        uri,
551        uri.query(),
552        query
553    );
554    info!("Headers: {:?}", headers);
555
556    if let Err(e) = HttpTransport::validate_origin(&state.config, &headers) {
558        warn!("Origin validation failed: {}", e);
559        return Err(StatusCode::FORBIDDEN);
560    }
561
562    if let Err(e) = HttpTransport::validate_auth(&state.config, &headers) {
564        warn!("Authentication failed: {}", e);
565        return Err(StatusCode::UNAUTHORIZED);
566    }
567
568    let session_id = if let Some(session_id) = query.session_id {
570        let sessions = state.sessions.read().await;
572        if sessions.contains_key(&session_id) {
573            session_id
574        } else {
575            return Err(StatusCode::BAD_REQUEST);
576        }
577    } else {
578        let new_session_id = HttpTransport::create_session(state.clone()).await;
580        info!("Created new SSE session: {}", new_session_id);
581        new_session_id
582    };
583
584    info!("Creating MCP-compliant SSE stream with endpoint event");
586
587    let receiver = {
589        let sessions = state.sessions.read().await;
590        sessions
591            .get(&session_id)
592            .map(|session| session.event_sender.subscribe())
593            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
594    };
595
596    info!("Starting SSE stream for session: {}", session_id);
597
598    let stream = async_stream::stream! {
600        let mut event_counter = 0u64;
601
602        let endpoint_url = format!("/messages?session_id={session_id}");
604        info!("Sending 'endpoint' event for session: {} with URL: {}", session_id, endpoint_url);
605        event_counter += 1;
606        yield Ok::<_, axum::Error>(Event::default()
607            .id(event_counter.to_string())
608            .event("endpoint")
609            .data(endpoint_url));
610
611        let mut receiver = receiver;
613        loop {
614            tokio::select! {
615                Ok(data) = receiver.recv() => {
616                    event_counter += 1;
617                    yield Ok::<_, axum::Error>(Event::default()
618                        .id(event_counter.to_string())
619                        .event("message")
620                        .data(data));
621                }
622                _ = tokio::time::sleep(Duration::from_secs(30)) => {
623                    event_counter += 1;
625                    yield Ok::<_, axum::Error>(Event::default()
626                        .id(event_counter.to_string())
627                        .event("ping")
628                        .data(serde_json::json!({
629                            "type": "ping",
630                            "timestamp": chrono::Utc::now().to_rfc3339()
631                        }).to_string()));
632                }
633            }
634        }
635    };
636
637    let sse = Sse::new(stream).keep_alive(
639        KeepAlive::new()
640            .interval(Duration::from_secs(15))
641            .text("keep-alive"),
642    );
643
644    let mut response = sse.into_response();
646    response.headers_mut().insert(
647        axum::http::header::CACHE_CONTROL,
648        "no-cache".parse().unwrap(),
649    );
650    response.headers_mut().insert(
651        axum::http::header::CONNECTION,
652        "keep-alive".parse().unwrap(),
653    );
654    response
655        .headers_mut()
656        .insert("X-Accel-Buffering", "no".parse().unwrap());
657
658    Ok(response)
659}
660
661async fn handle_health() -> &'static str {
663    "OK"
664}
665
666#[async_trait]
667impl Transport for HttpTransport {
668    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
669        info!(
670            "Starting HTTP transport on {}:{}",
671            self.config.host, self.config.port
672        );
673
674        let state = Arc::new(HttpState {
675            handler: Arc::new(handler),
676            config: self.config.clone(),
677            sessions: Arc::new(RwLock::new(HashMap::new())),
678        });
679
680        let cors = CorsLayer::very_permissive().expose_headers(vec![
682            axum::http::header::HeaderName::from_static("mcp-session-id"),
683            axum::http::header::HeaderName::from_static("content-type"),
684        ]);
685
686        let app = Router::new()
688            .route("/messages", post(handle_post))
689            .route("/sse", get(handle_sse))
690            .route("/health", get(handle_health))
691            .layer(ServiceBuilder::new().layer(cors))
692            .with_state(state.clone());
693
694        let cleanup_state = state.clone();
696        tokio::spawn(async move {
697            let mut interval = tokio::time::interval(Duration::from_secs(60));
698            loop {
699                interval.tick().await;
700                HttpTransport::cleanup_sessions(cleanup_state.clone()).await;
701            }
702        });
703
704        let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
706            .parse()
707            .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
708
709        let listener = tokio::net::TcpListener::bind(addr)
710            .await
711            .map_err(|e| TransportError::Connection(format!("Failed to bind to {addr}: {e}")))?;
712
713        info!("HTTP transport listening on {}", addr);
714        info!("Endpoints:");
715        info!("  POST   http://{}/messages   - MCP messages", addr);
716        info!("  GET    http://{}/sse        - Server-Sent Events", addr);
717        info!("  GET    http://{}/health     - Health check", addr);
718
719        let server_handle = tokio::spawn(async move {
720            if let Err(e) = axum::serve(listener, app).await {
721                error!("HTTP server error: {}", e);
722            }
723        });
724
725        self.state = Some(HttpState {
726            handler: state.handler.clone(),
727            config: state.config.clone(),
728            sessions: state.sessions.clone(),
729        });
730        self.server_handle = Some(server_handle);
731
732        Ok(())
733    }
734
735    async fn stop(&mut self) -> Result<(), TransportError> {
736        info!("Stopping HTTP transport");
737
738        if let Some(handle) = self.server_handle.take() {
739            handle.abort();
740        }
741
742        self.state = None;
743        Ok(())
744    }
745
746    async fn health_check(&self) -> Result<(), TransportError> {
747        if self.state.is_some() {
748            Ok(())
749        } else {
750            Err(TransportError::Connection(
751                "HTTP transport not running".to_string(),
752            ))
753        }
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760    use serde_json::json;
761
762    fn mock_handler(
764        request: pulseengine_mcp_protocol::Request,
765    ) -> std::pin::Pin<
766        Box<dyn std::future::Future<Output = pulseengine_mcp_protocol::Response> + Send>,
767    > {
768        Box::pin(async move {
769            pulseengine_mcp_protocol::Response {
770                jsonrpc: "2.0".to_string(),
771                id: request.id,
772                result: Some(json!({"echo": request.method})),
773                error: None,
774            }
775        })
776    }
777
778    #[test]
779    fn test_http_config() {
780        let config = HttpConfig {
781            port: 8080,
782            host: "0.0.0.0".to_string(),
783            max_message_size: 1024,
784            enable_cors: false,
785            allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
786            validate_messages: true,
787            session_timeout_secs: 600,
788            require_auth: true,
789            valid_tokens: vec!["test-token".to_string()],
790        };
791
792        let transport = HttpTransport::with_config(config.clone());
793        assert_eq!(transport.config.port, 8080);
794        assert_eq!(transport.config.host, "0.0.0.0");
795        assert!(!transport.config.enable_cors);
796        assert!(transport.config.require_auth);
797    }
798
799    #[test]
800    fn test_validate_origin() {
801        let mut config = HttpConfig::default();
802        config.allowed_origins = Some(vec!["http://localhost:3000".to_string()]);
803
804        let mut headers = HeaderMap::new();
805        headers.insert(ORIGIN, "http://localhost:3000".parse().unwrap());
806
807        assert!(HttpTransport::validate_origin(&config, &headers).is_ok());
808
809        headers.insert(ORIGIN, "http://evil.com".parse().unwrap());
810        assert!(HttpTransport::validate_origin(&config, &headers).is_err());
811    }
812
813    #[test]
814    fn test_validate_auth() {
815        let mut config = HttpConfig::default();
816        config.require_auth = true;
817        config.valid_tokens = vec!["valid-token".to_string()];
818
819        let mut headers = HeaderMap::new();
820        headers.insert(AUTHORIZATION, "Bearer valid-token".parse().unwrap());
821
822        assert!(HttpTransport::validate_auth(&config, &headers).is_ok());
823
824        headers.insert(AUTHORIZATION, "Bearer invalid-token".parse().unwrap());
825        assert!(HttpTransport::validate_auth(&config, &headers).is_err());
826
827        headers.remove(AUTHORIZATION);
828        assert!(HttpTransport::validate_auth(&config, &headers).is_err());
829    }
830
831    #[tokio::test]
832    async fn test_session_management() {
833        let config = HttpConfig::default();
834        let state = Arc::new(HttpState {
835            handler: Arc::new(Box::new(mock_handler)),
836            config,
837            sessions: Arc::new(RwLock::new(HashMap::new())),
838        });
839
840        let session_id = HttpTransport::create_session(state.clone()).await;
842        assert!(!session_id.is_empty());
843
844        {
846            let sessions = state.sessions.read().await;
847            assert!(sessions.contains_key(&session_id));
848        }
849
850        HttpTransport::update_session_activity(state.clone(), &session_id).await;
852
853        HttpTransport::cleanup_sessions(state.clone()).await;
855        {
856            let sessions = state.sessions.read().await;
857            assert!(sessions.contains_key(&session_id));
858        }
859    }
860}