pulseengine_mcp_transport/
streamable_http.rs1use crate::{RequestHandler, Transport, TransportError};
7use async_trait::async_trait;
8use axum::{
9    extract::{Query, State},
10    http::{HeaderMap, StatusCode},
11    response::IntoResponse,
12    routing::{get, post},
13    Json, Router,
14};
15use serde::Deserialize;
16use serde_json::Value;
17use std::{collections::HashMap, net::SocketAddr, sync::Arc};
18use tokio::sync::RwLock;
19use tower::ServiceBuilder;
20use tower_http::cors::CorsLayer;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24#[derive(Debug, Clone)]
26pub struct StreamableHttpConfig {
27    pub port: u16,
28    pub host: String,
29    pub enable_cors: bool,
30}
31
32impl Default for StreamableHttpConfig {
33    fn default() -> Self {
34        Self {
35            port: 3001,
36            host: "127.0.0.1".to_string(),
37            enable_cors: true,
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
44struct SessionInfo {
45    #[allow(dead_code)]
46    id: String,
47    #[allow(dead_code)]
48    created_at: std::time::Instant,
49}
50
51#[derive(Clone)]
53struct AppState {
54    handler: Arc<RequestHandler>,
55    sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
56}
57
58#[derive(Debug, Deserialize)]
60struct StreamQuery {
61    #[serde(rename = "sessionId")]
62    session_id: Option<String>,
63}
64
65pub struct StreamableHttpTransport {
67    config: StreamableHttpConfig,
68    server_handle: Option<tokio::task::JoinHandle<()>>,
69}
70
71impl StreamableHttpTransport {
72    pub fn new(port: u16) -> Self {
73        Self {
74            config: StreamableHttpConfig {
75                port,
76                ..Default::default()
77            },
78            server_handle: None,
79        }
80    }
81
82    async fn ensure_session(state: &AppState, session_id: Option<String>) -> String {
84        if let Some(id) = session_id {
85            let sessions = state.sessions.read().await;
87            if sessions.contains_key(&id) {
88                return id;
89            }
90        }
91
92        let id = Uuid::new_v4().to_string();
94        let session = SessionInfo {
95            id: id.clone(),
96            created_at: std::time::Instant::now(),
97        };
98
99        let mut sessions = state.sessions.write().await;
100        sessions.insert(id.clone(), session);
101        info!("Created new session: {}", id);
102
103        id
104    }
105}
106
107async fn handle_messages(
109    State(state): State<Arc<AppState>>,
110    headers: HeaderMap,
111    body: String,
112) -> impl IntoResponse {
113    debug!("Received POST /messages: {}", body);
114
115    let session_id = headers
117        .get("Mcp-Session-Id")
118        .and_then(|v| v.to_str().ok())
119        .map(|s| s.to_string());
120
121    let session_id = StreamableHttpTransport::ensure_session(&state, session_id).await;
122
123    let request: Value = match serde_json::from_str(&body) {
125        Ok(v) => v,
126        Err(e) => {
127            warn!("Failed to parse request: {}", e);
128            return (
129                StatusCode::BAD_REQUEST,
130                Json(serde_json::json!({
131                    "jsonrpc": "2.0",
132                    "error": {
133                        "code": -32700,
134                        "message": "Parse error"
135                    },
136                    "id": null
137                })),
138            )
139                .into_response();
140        }
141    };
142
143    let mcp_request: pulseengine_mcp_protocol::Request =
145        match serde_json::from_value(request.clone()) {
146            Ok(r) => r,
147            Err(e) => {
148                warn!("Invalid request format: {}", e);
149                return (
150                    StatusCode::BAD_REQUEST,
151                    Json(serde_json::json!({
152                        "jsonrpc": "2.0",
153                        "error": {
154                            "code": -32600,
155                            "message": "Invalid request"
156                        },
157                        "id": request.get("id").cloned().unwrap_or(Value::Null)
158                    })),
159                )
160                    .into_response();
161            }
162        };
163
164    let response = (state.handler)(mcp_request).await;
166
167    let mut headers = HeaderMap::new();
169    headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
170
171    (StatusCode::OK, headers, Json(response)).into_response()
172}
173
174async fn handle_sse(
176    State(state): State<Arc<AppState>>,
177    Query(query): Query<StreamQuery>,
178) -> impl IntoResponse {
179    info!("SSE connection request: {:?}", query);
180
181    let session_id = StreamableHttpTransport::ensure_session(&state, query.session_id).await;
186
187    let response = serde_json::json!({
190        "type": "connection",
191        "status": "connected",
192        "sessionId": session_id,
193        "transport": "streamable-http"
194    });
195
196    Json(response)
197}
198
199#[async_trait]
200impl Transport for StreamableHttpTransport {
201    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
202        info!(
203            "Starting Streamable HTTP transport on {}:{}",
204            self.config.host, self.config.port
205        );
206
207        let state = Arc::new(AppState {
208            handler: Arc::new(handler),
209            sessions: Arc::new(RwLock::new(HashMap::new())),
210        });
211
212        let app = Router::new()
214            .route("/messages", post(handle_messages))
215            .route("/sse", get(handle_sse))
216            .route("/", get(|| async { "MCP Streamable HTTP Server" }))
217            .layer(ServiceBuilder::new().layer(if self.config.enable_cors {
218                CorsLayer::permissive()
219            } else {
220                CorsLayer::new()
221            }))
222            .with_state(state);
223
224        let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
226            .parse()
227            .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
228
229        let listener = tokio::net::TcpListener::bind(addr)
230            .await
231            .map_err(|e| TransportError::Connection(format!("Failed to bind: {e}")))?;
232
233        info!("Streamable HTTP transport listening on {}", addr);
234        info!("Endpoints:");
235        info!("  POST http://{}/messages - MCP messages", addr);
236        info!("  GET  http://{}/sse      - Session establishment", addr);
237
238        let server_handle = tokio::spawn(async move {
239            if let Err(e) = axum::serve(listener, app).await {
240                tracing::error!("Server error: {}", e);
241            }
242        });
243
244        self.server_handle = Some(server_handle);
245        Ok(())
246    }
247
248    async fn stop(&mut self) -> Result<(), TransportError> {
249        if let Some(handle) = self.server_handle.take() {
250            handle.abort();
251        }
252        Ok(())
253    }
254
255    async fn health_check(&self) -> Result<(), TransportError> {
256        if self.server_handle.is_some() {
257            Ok(())
258        } else {
259            Err(TransportError::Connection("Not running".to_string()))
260        }
261    }
262}