pulseengine_mcp_transport/
streamable_http.rs

1//! Streamable HTTP transport implementation for MCP
2//!
3//! This implements the newer streamable-http transport that MCP Inspector expects,
4//! which replaces the deprecated SSE transport.
5
6use crate::{RequestHandler, Transport, TransportError};
7use async_trait::async_trait;
8use axum::{
9    Json, Router,
10    extract::{Query, State},
11    http::{HeaderMap, StatusCode},
12    response::IntoResponse,
13    routing::{get, post},
14};
15use serde::Deserialize;
16use serde_json::Value;
17use std::{collections::HashMap, net::SocketAddr, sync::Arc};
18use tokio::sync::RwLock;
19use tower::ServiceBuilder;
20use tower_http::cors::CorsLayer;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24/// Configuration for Streamable HTTP transport
25#[derive(Debug, Clone)]
26pub struct StreamableHttpConfig {
27    pub port: u16,
28    pub host: String,
29    pub enable_cors: bool,
30}
31
32impl Default for StreamableHttpConfig {
33    fn default() -> Self {
34        Self {
35            port: 3001,
36            host: "127.0.0.1".to_string(),
37            enable_cors: true,
38        }
39    }
40}
41
42/// Session information
43#[derive(Debug, Clone)]
44struct SessionInfo {
45    #[allow(dead_code)]
46    id: String,
47    #[allow(dead_code)]
48    created_at: std::time::Instant,
49}
50
51/// Shared state
52#[derive(Clone)]
53struct AppState {
54    handler: Arc<RequestHandler>,
55    sessions: Arc<RwLock<HashMap<String, SessionInfo>>>,
56}
57
58/// Query parameters for SSE endpoint
59#[derive(Debug, Deserialize)]
60struct StreamQuery {
61    #[serde(rename = "sessionId")]
62    session_id: Option<String>,
63}
64
65/// Streamable HTTP transport
66pub struct StreamableHttpTransport {
67    config: StreamableHttpConfig,
68    server_handle: Option<tokio::task::JoinHandle<()>>,
69}
70
71impl StreamableHttpTransport {
72    pub fn new(port: u16) -> Self {
73        Self {
74            config: StreamableHttpConfig {
75                port,
76                ..Default::default()
77            },
78            server_handle: None,
79        }
80    }
81
82    /// Get the configuration
83    pub fn config(&self) -> &StreamableHttpConfig {
84        &self.config
85    }
86
87    /// Create or get session
88    async fn ensure_session(state: &AppState, session_id: Option<String>) -> String {
89        if let Some(id) = session_id {
90            // Check if session exists
91            let sessions = state.sessions.read().await;
92            if sessions.contains_key(&id) {
93                return id;
94            }
95            // If session doesn't exist, create it with the provided ID
96            drop(sessions);
97            let session = SessionInfo {
98                id: id.clone(),
99                created_at: std::time::Instant::now(),
100            };
101            let mut sessions = state.sessions.write().await;
102            sessions.insert(id.clone(), session);
103            info!("Created session with provided ID: {}", id);
104            return id;
105        }
106
107        // Create new session with generated ID
108        let id = Uuid::new_v4().to_string();
109        let session = SessionInfo {
110            id: id.clone(),
111            created_at: std::time::Instant::now(),
112        };
113
114        let mut sessions = state.sessions.write().await;
115        sessions.insert(id.clone(), session);
116        info!("Created new session: {}", id);
117
118        id
119    }
120}
121
122/// Handle POST requests for client-to-server messages
123async fn handle_messages(
124    State(state): State<Arc<AppState>>,
125    headers: HeaderMap,
126    body: String,
127) -> impl IntoResponse {
128    debug!("Received POST /messages: {}", body);
129
130    // Get or create session
131    let session_id = headers
132        .get("Mcp-Session-Id")
133        .and_then(|v| v.to_str().ok())
134        .map(|s| s.to_string());
135
136    let session_id = StreamableHttpTransport::ensure_session(&state, session_id).await;
137
138    // Parse the request
139    let request: Value = match serde_json::from_str(&body) {
140        Ok(v) => v,
141        Err(e) => {
142            warn!("Failed to parse request: {}", e);
143            return (
144                StatusCode::BAD_REQUEST,
145                Json(serde_json::json!({
146                    "jsonrpc": "2.0",
147                    "error": {
148                        "code": -32700,
149                        "message": "Parse error"
150                    },
151                    "id": null
152                })),
153            )
154                .into_response();
155        }
156    };
157
158    // Convert to MCP Request
159    let mcp_request: pulseengine_mcp_protocol::Request =
160        match serde_json::from_value(request.clone()) {
161            Ok(r) => r,
162            Err(e) => {
163                warn!("Invalid request format: {}", e);
164                return (
165                    StatusCode::BAD_REQUEST,
166                    Json(serde_json::json!({
167                        "jsonrpc": "2.0",
168                        "error": {
169                            "code": -32600,
170                            "message": "Invalid request"
171                        },
172                        "id": request.get("id").cloned().unwrap_or(Value::Null)
173                    })),
174                )
175                    .into_response();
176            }
177        };
178
179    // Process through handler
180    let response = (state.handler)(mcp_request).await;
181
182    // Return JSON response with session header
183    let mut headers = HeaderMap::new();
184    headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
185    debug!("Sending response with session ID: {}", session_id);
186
187    (StatusCode::OK, headers, Json(response)).into_response()
188}
189
190/// Handle SSE requests for server-to-client streaming
191async fn handle_sse(
192    State(state): State<Arc<AppState>>,
193    Query(query): Query<StreamQuery>,
194) -> impl IntoResponse {
195    info!("SSE connection request: {:?}", query);
196
197    // For streamable-http, we need to handle this differently
198    // The client expects an immediate response, not an SSE stream
199
200    // Get or create session
201    let session_id = StreamableHttpTransport::ensure_session(&state, query.session_id).await;
202
203    // Return a simple response indicating the connection is established
204    // This is what MCP Inspector expects for streamable-http
205    let response = serde_json::json!({
206        "type": "connection",
207        "status": "connected",
208        "sessionId": session_id,
209        "transport": "streamable-http"
210    });
211
212    // Include session ID in response header as per MCP spec
213    let mut headers = HeaderMap::new();
214    headers.insert("Mcp-Session-Id", session_id.parse().unwrap());
215    debug!("SSE response with session ID: {}", session_id);
216
217    (StatusCode::OK, headers, Json(response))
218}
219
220#[async_trait]
221impl Transport for StreamableHttpTransport {
222    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
223        info!(
224            "Starting Streamable HTTP transport on {}:{}",
225            self.config.host, self.config.port
226        );
227
228        let state = Arc::new(AppState {
229            handler: Arc::new(handler),
230            sessions: Arc::new(RwLock::new(HashMap::new())),
231        });
232
233        // Build router
234        let app = Router::new()
235            .route("/messages", post(handle_messages))
236            .route("/sse", get(handle_sse))
237            .route("/", get(|| async { "MCP Streamable HTTP Server" }))
238            .layer(ServiceBuilder::new().layer(if self.config.enable_cors {
239                CorsLayer::permissive()
240            } else {
241                CorsLayer::new()
242            }))
243            .with_state(state);
244
245        // Start server
246        let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
247            .parse()
248            .map_err(|e| TransportError::Config(format!("Invalid address: {e}")))?;
249
250        let listener = tokio::net::TcpListener::bind(addr)
251            .await
252            .map_err(|e| TransportError::Connection(format!("Failed to bind: {e}")))?;
253
254        info!("Streamable HTTP transport listening on {}", addr);
255        info!("Endpoints:");
256        info!("  POST http://{}/messages - MCP messages", addr);
257        info!("  GET  http://{}/sse      - Session establishment", addr);
258
259        let server_handle = tokio::spawn(async move {
260            if let Err(e) = axum::serve(listener, app).await {
261                tracing::error!("Server error: {}", e);
262            }
263        });
264
265        self.server_handle = Some(server_handle);
266        Ok(())
267    }
268
269    async fn stop(&mut self) -> Result<(), TransportError> {
270        if let Some(handle) = self.server_handle.take() {
271            handle.abort();
272        }
273        Ok(())
274    }
275
276    async fn health_check(&self) -> Result<(), TransportError> {
277        if self.server_handle.is_some() {
278            Ok(())
279        } else {
280            Err(TransportError::Connection("Not running".to_string()))
281        }
282    }
283}