pulseengine_mcp_transport/
streamable_http.rs1use crate::{RequestHandler, Transport, TransportError};
7use async_trait::async_trait;
8use axum::{
9 Json, Router,
10 extract::{Query, State},
11 http::{HeaderMap, StatusCode},
12 response::IntoResponse,
13 routing::{get, post},
14};
15use serde::Deserialize;
16use serde_json::Value;
17use std::{collections::HashMap, net::SocketAddr, sync::Arc};
18use tokio::sync::RwLock;
19use tower::ServiceBuilder;
20use tower_http::cors::CorsLayer;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24#[derive(Debug, Clone)]
26pub struct StreamableHttpConfig {
27 pub port: u16,
28 pub host: String,
29 pub enable_cors: bool,
30}
31
32impl Default for StreamableHttpConfig {
33 fn default() -> Self {
34 Self {
35 port: 3001,
36 host: "127.0.0.1".to_string(),
37 enable_cors: true,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44struct SessionInfo {
45 #[allow(dead_code)]
46 id: String,
47 #[allow(dead_code)]
48 created_at: std::time::Instant,
49}
50
51#[derive(Clone)]
53struct AppState {
54 handler: Arc<RequestHandler>,
55 sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
56}
57
58#[derive(Debug, Deserialize)]
60struct StreamQuery {
61 #[serde(rename = "sessionId")]
62 session_id: Option<String>,
63}
64
65pub struct StreamableHttpTransport {
67 config: StreamableHttpConfig,
68 server_handle: Option<tokio::task::JoinHandle<()>>,
69}
70
71impl StreamableHttpTransport {
72 pub fn new(port: u16) -> Self {
73 Self {
74 config: StreamableHttpConfig {
75 port,
76 ..Default::default()
77 },
78 server_handle: None,
79 }
80 }
81
82 pub fn config(&self) -> &StreamableHttpConfig {
84 &self.config
85 }
86
87 async fn ensure_session(state: &AppState, session_id: Option<String>) -> String {
89 if let Some(id) = session_id {
90 let sessions = state.sessions.read().await;
92 if sessions.contains_key(&id) {
93 return id;
94 }
95 drop(sessions);
97 let session = SessionInfo {
98 id: id.clone(),
99 created_at: std::time::Instant::now(),
100 };
101 let mut sessions = state.sessions.write().await;
102 sessions.insert(id.clone(), session);
103 info!("Created session with provided ID: {}", id);
104 return id;
105 }
106
107 let id = Uuid::new_v4().to_string();
109 let session = SessionInfo {
110 id: id.clone(),
111 created_at: std::time::Instant::now(),
112 };
113
114 let mut sessions = state.sessions.write().await;
115 sessions.insert(id.clone(), session);
116 info!("Created new session: {}", id);
117
118 id
119 }
120}
121
122async fn handle_messages(
124 State(state): State<Arc<AppState>>,
125 headers: HeaderMap,
126 body: String,
127) -> impl IntoResponse {
128 debug!("Received POST /messages: {}", body);
129
130 let session_id = headers
132 .get("Mcp-Session-Id")
133 .and_then(|v| v.to_str().ok())
134 .map(|s| s.to_string());
135
136 let session_id = StreamableHttpTransport::ensure_session(&state, session_id).await;
137
138 let request: Value = match serde_json::from_str(&body) {
140 Ok(v) => v,
141 Err(e) => {
142 warn!("Failed to parse request: {}", e);
143 return (
144 StatusCode::BAD_REQUEST,
145 Json(serde_json::json!({
146 "jsonrpc": "2.0",
147 "error": {
148 "code": -32700,
149 "message": "Parse error"
150 },
151 "id": null
152 })),
153 )
154 .into_response();
155 }
156 };
157
158 let mcp_request: pulseengine_mcp_protocol::Request =
160 match serde_json::from_value(request.clone()) {
161 Ok(r) => r,
162 Err(e) => {
163 warn!("Invalid request format: {}", e);
164 return (
165 StatusCode::BAD_REQUEST,
166 Json(serde_json::json!({
167 "jsonrpc": "2.0",
168 "error": {
169 "code": -32600,
170 "message": "Invalid request"
171 },
172 "id": request.get("id").cloned().unwrap_or(Value::Null)
173 })),
174 )
175 .into_response();
176 }
177 };
178
179 let response = (state.handler)(mcp_request).await;
181
182 let mut headers = HeaderMap::new();
184 headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
185 debug!("Sending response with session ID: {}", session_id);
186
187 (StatusCode::OK, headers, Json(response)).into_response()
188}
189
190async fn handle_sse(
192 State(state): State<Arc<AppState>>,
193 Query(query): Query<StreamQuery>,
194) -> impl IntoResponse {
195 info!("SSE connection request: {:?}", query);
196
197 let session_id = StreamableHttpTransport::ensure_session(&state, query.session_id).await;
202
203 let response = serde_json::json!({
206 "type": "connection",
207 "status": "connected",
208 "sessionId": session_id,
209 "transport": "streamable-http"
210 });
211
212 let mut headers = HeaderMap::new();
214 headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
215 debug!("SSE response with session ID: {}", session_id);
216
217 (StatusCode::OK, headers, Json(response))
218}
219
220#[async_trait]
221impl Transport for StreamableHttpTransport {
222 async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
223 info!(
224 "Starting Streamable HTTP transport on {}:{}",
225 self.config.host, self.config.port
226 );
227
228 let state = Arc::new(AppState {
229 handler: Arc::new(handler),
230 sessions: Arc::new(RwLock::new(HashMap::new())),
231 });
232
233 let app = Router::new()
235 .route("/messages", post(handle_messages))
236 .route("/sse", get(handle_sse))
237 .route("/", get(|| async { "MCP Streamable HTTP Server" }))
238 .layer(ServiceBuilder::new().layer(if self.config.enable_cors {
239 CorsLayer::permissive()
240 } else {
241 CorsLayer::new()
242 }))
243 .with_state(state);
244
245 let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
247 .parse()
248 .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
249
250 let listener = tokio::net::TcpListener::bind(addr)
251 .await
252 .map_err(|e| TransportError::Connection(format!("Failed to bind: {e}")))?;
253
254 info!("Streamable HTTP transport listening on {}", addr);
255 info!("Endpoints:");
256 info!(" POST http://{}/messages - MCP messages", addr);
257 info!(" GET http://{}/sse - Session establishment", addr);
258
259 let server_handle = tokio::spawn(async move {
260 if let Err(e) = axum::serve(listener, app).await {
261 tracing::error!("Server error: {}", e);
262 }
263 });
264
265 self.server_handle = Some(server_handle);
266 Ok(())
267 }
268
269 async fn stop(&mut self) -> Result<(), TransportError> {
270 if let Some(handle) = self.server_handle.take() {
271 handle.abort();
272 }
273 Ok(())
274 }
275
276 async fn health_check(&self) -> Result<(), TransportError> {
277 if self.server_handle.is_some() {
278 Ok(())
279 } else {
280 Err(TransportError::Connection("Not running".to_string()))
281 }
282 }
283}