Skip to main content

cortexai_mcp/
sse_server.rs

1//! SSE Server Transport for MCP
2//!
3//! Implements HTTP-based MCP server transport using Server-Sent Events (SSE)
4//! for server-to-client communication and POST requests for client-to-server.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────┐
10//! │                     SSE Server Transport                     │
11//! ├─────────────────────────────────────────────────────────────┤
12//! │                                                              │
13//! │  GET /sse ─────────────► SSE Stream (server→client)         │
14//! │                          - Sends endpoint event first       │
15//! │                          - Then streams responses           │
16//! │                                                              │
17//! │  POST /message ────────► JSON-RPC Request (client→server)   │
18//! │                          - Returns immediately              │
19//! │                          - Response sent via SSE            │
20//! │                                                              │
21//! └─────────────────────────────────────────────────────────────┘
22//! ```
23//!
24//! # Example
25//!
26//! ```rust,ignore
27//! use cortexai_mcp::{McpServer, SseServerConfig};
28//!
29//! let server = McpServer::builder()
30//!     .name("my-server")
31//!     .add_tool(my_tool)
32//!     .build();
33//!
34//! // Run with SSE transport on port 3000
35//! server.run_sse(SseServerConfig::default()).await?;
36//! ```
37
38use axum::{
39    extract::State,
40    http::{header, Method},
41    response::sse::{Event, KeepAlive, Sse},
42    routing::{get, post},
43    Json, Router,
44};
45use futures::stream::Stream;
46use http::StatusCode;
47use parking_lot::RwLock;
48use std::collections::HashMap;
49use std::convert::Infallible;
50use std::net::SocketAddr;
51use std::sync::Arc;
52use tokio::sync::{broadcast, mpsc};
53use tower_http::cors::{Any, CorsLayer};
54use tracing::{debug, error, info};
55use uuid::Uuid;
56
57use crate::error::McpError;
58use crate::protocol::JsonRpcResponse;
59use crate::server::McpServer;
60
61// =============================================================================
62// SSE Server Configuration
63// =============================================================================
64
65/// Configuration for the SSE server transport
66#[derive(Debug, Clone)]
67pub struct SseServerConfig {
68    /// Host to bind to
69    pub host: String,
70    /// Port to bind to
71    pub port: u16,
72    /// Path for SSE endpoint
73    pub sse_path: String,
74    /// Path for message endpoint
75    pub message_path: String,
76    /// Enable CORS
77    pub enable_cors: bool,
78    /// Keep-alive interval in seconds
79    pub keep_alive_secs: u64,
80}
81
82impl Default for SseServerConfig {
83    fn default() -> Self {
84        Self {
85            host: "127.0.0.1".to_string(),
86            port: 3000,
87            sse_path: "/sse".to_string(),
88            message_path: "/message".to_string(),
89            enable_cors: true,
90            keep_alive_secs: 30,
91        }
92    }
93}
94
95impl SseServerConfig {
96    /// Create config for localhost on specified port
97    pub fn localhost(port: u16) -> Self {
98        Self {
99            port,
100            ..Default::default()
101        }
102    }
103
104    /// Create config that binds to all interfaces
105    pub fn public(port: u16) -> Self {
106        Self {
107            host: "0.0.0.0".to_string(),
108            port,
109            ..Default::default()
110        }
111    }
112}
113
114// =============================================================================
115// SSE Server State
116// =============================================================================
117
118/// Internal state for SSE server
119struct SseServerState {
120    /// Reference to the MCP server
121    mcp_server: Arc<McpServer>,
122    /// Configuration
123    config: SseServerConfig,
124    /// Active SSE sessions: session_id -> response sender
125    sessions: RwLock<HashMap<String, mpsc::Sender<JsonRpcResponse>>>,
126    /// Broadcast channel for shutdown
127    shutdown_tx: broadcast::Sender<()>,
128}
129
130impl SseServerState {
131    fn new(
132        mcp_server: Arc<McpServer>,
133        config: SseServerConfig,
134        shutdown_tx: broadcast::Sender<()>,
135    ) -> Self {
136        Self {
137            mcp_server,
138            config,
139            sessions: RwLock::new(HashMap::new()),
140            shutdown_tx,
141        }
142    }
143
144    fn register_session(&self, session_id: String, sender: mpsc::Sender<JsonRpcResponse>) {
145        self.sessions.write().insert(session_id, sender);
146    }
147
148    fn unregister_session(&self, session_id: &str) {
149        self.sessions.write().remove(session_id);
150    }
151
152    fn get_session_sender(&self, session_id: &str) -> Option<mpsc::Sender<JsonRpcResponse>> {
153        self.sessions.read().get(session_id).cloned()
154    }
155}
156
157// =============================================================================
158// SSE Server Extension for McpServer
159// =============================================================================
160
161impl McpServer {
162    /// Run the server with SSE transport
163    ///
164    /// This starts an HTTP server with:
165    /// - GET /sse - SSE endpoint for server-to-client streaming
166    /// - POST /message - Endpoint for client-to-server requests
167    pub async fn run_sse(self: Arc<Self>, config: SseServerConfig) -> Result<(), McpError> {
168        let (shutdown_tx, _) = broadcast::channel::<()>(1);
169        let state = Arc::new(SseServerState::new(
170            self.clone(),
171            config.clone(),
172            shutdown_tx,
173        ));
174
175        let mut app = Router::new()
176            .route(&config.sse_path, get(handle_sse))
177            .route(&config.message_path, post(handle_message))
178            .with_state(state.clone());
179
180        if config.enable_cors {
181            let cors = CorsLayer::new()
182                .allow_origin(Any)
183                .allow_methods([Method::GET, Method::POST])
184                .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
185            app = app.layer(cors);
186        }
187
188        let addr: SocketAddr = format!("{}:{}", config.host, config.port)
189            .parse()
190            .map_err(|e| McpError::Transport(format!("Invalid address: {}", e)))?;
191
192        info!(
193            "Starting MCP SSE server on http://{}{}",
194            addr, config.sse_path
195        );
196        info!("Message endpoint: http://{}{}", addr, config.message_path);
197
198        let listener = tokio::net::TcpListener::bind(addr)
199            .await
200            .map_err(|e| McpError::Transport(format!("Failed to bind: {}", e)))?;
201
202        axum::serve(listener, app)
203            .await
204            .map_err(|e| McpError::Transport(format!("Server error: {}", e)))?;
205
206        Ok(())
207    }
208
209    /// Run the server with SSE transport and return the router
210    /// (for embedding in existing Axum applications)
211    pub fn sse_router(self: Arc<Self>, config: SseServerConfig) -> Router {
212        let (shutdown_tx, _) = broadcast::channel::<()>(1);
213        let state = Arc::new(SseServerState::new(
214            self.clone(),
215            config.clone(),
216            shutdown_tx,
217        ));
218
219        let mut router = Router::new()
220            .route(&config.sse_path, get(handle_sse))
221            .route(&config.message_path, post(handle_message))
222            .with_state(state);
223
224        if config.enable_cors {
225            let cors = CorsLayer::new()
226                .allow_origin(Any)
227                .allow_methods([Method::GET, Method::POST])
228                .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
229            router = router.layer(cors);
230        }
231
232        router
233    }
234}
235
236// =============================================================================
237// Request/Response Types
238// =============================================================================
239
240/// SSE endpoint event - tells client where to POST messages
241#[derive(Debug, serde::Serialize)]
242struct EndpointEvent {
243    endpoint: String,
244}
245
246// =============================================================================
247// HTTP Handlers
248// =============================================================================
249
250/// Handle SSE connection
251async fn handle_sse(
252    State(state): State<Arc<SseServerState>>,
253) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
254    let session_id = Uuid::new_v4().to_string();
255    let (tx, mut rx) = mpsc::channel::<JsonRpcResponse>(100);
256
257    state.register_session(session_id.clone(), tx);
258
259    let config = state.config.clone();
260    let state_clone = state.clone();
261    let session_id_clone = session_id.clone();
262
263    info!("New SSE session: {}", session_id);
264
265    let stream = async_stream::stream! {
266        // First, send the endpoint event
267        let endpoint = EndpointEvent {
268            endpoint: format!("{}?sessionId={}", config.message_path, session_id_clone),
269        };
270        let endpoint_json = serde_json::to_string(&endpoint).unwrap();
271        yield Ok(Event::default().event("endpoint").data(endpoint_json));
272
273        debug!("Sent endpoint event for session {}", session_id_clone);
274
275        // Then stream responses
276        let mut shutdown_rx = state_clone.shutdown_tx.subscribe();
277        loop {
278            tokio::select! {
279                Some(response) = rx.recv() => {
280                    match serde_json::to_string(&response) {
281                        Ok(json) => {
282                            debug!("Sending SSE message: {}", json);
283                            yield Ok(Event::default().event("message").data(json));
284                        }
285                        Err(e) => {
286                            error!("Failed to serialize response: {}", e);
287                        }
288                    }
289                }
290                _ = shutdown_rx.recv() => {
291                    info!("SSE session {} shutting down", session_id_clone);
292                    break;
293                }
294            }
295        }
296
297        state_clone.unregister_session(&session_id_clone);
298        info!("SSE session {} closed", session_id_clone);
299    };
300
301    Sse::new(stream).keep_alive(
302        KeepAlive::new()
303            .interval(std::time::Duration::from_secs(state.config.keep_alive_secs))
304            .text("ping"),
305    )
306}
307
308/// Query parameters for message endpoint
309#[derive(Debug, Default, serde::Deserialize)]
310struct MessageQuery {
311    #[serde(rename = "sessionId")]
312    session_id: Option<String>,
313}
314
315/// Handle incoming JSON-RPC message
316async fn handle_message(
317    State(state): State<Arc<SseServerState>>,
318    axum::extract::Query(query): axum::extract::Query<MessageQuery>,
319    Json(body): Json<serde_json::Value>,
320) -> (StatusCode, Json<serde_json::Value>) {
321    let session_id = query.session_id;
322
323    debug!(
324        "Received message for session {:?}: {}",
325        session_id,
326        serde_json::to_string_pretty(&body).unwrap_or_default()
327    );
328
329    // Parse as JSON-RPC request
330    let request = match serde_json::from_value::<crate::protocol::JsonRpcRequest>(body.clone()) {
331        Ok(req) => req,
332        Err(e) => {
333            error!("Failed to parse JSON-RPC request: {}", e);
334            return (
335                StatusCode::BAD_REQUEST,
336                Json(serde_json::json!({
337                    "jsonrpc": "2.0",
338                    "id": null,
339                    "error": {
340                        "code": -32700,
341                        "message": format!("Parse error: {}", e)
342                    }
343                })),
344            );
345        }
346    };
347
348    // Handle the request
349    let response = state.mcp_server.handle_request(request).await;
350
351    // If we have a session, send via SSE; otherwise return directly
352    if let Some(ref sid) = session_id {
353        if let Some(sender) = state.get_session_sender(sid) {
354            if sender.send(response.clone()).await.is_ok() {
355                // Return accepted - response will come via SSE
356                return (
357                    StatusCode::ACCEPTED,
358                    Json(serde_json::json!({"status": "accepted"})),
359                );
360            }
361        }
362    }
363
364    // No session or send failed - return response directly
365    let response_json = serde_json::to_value(&response).unwrap_or_default();
366    (StatusCode::OK, Json(response_json))
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::server::FnTool;
373    use serde_json::json;
374
375    fn create_test_server() -> Arc<McpServer> {
376        McpServer::builder()
377            .name("test-sse-server")
378            .version("1.0.0")
379            .add_tool(FnTool::new(
380                "echo",
381                "Echoes input",
382                json!({
383                    "type": "object",
384                    "properties": {
385                        "message": {"type": "string"}
386                    }
387                }),
388                |args| {
389                    let msg = args["message"].as_str().unwrap_or("no message");
390                    Ok(json!({"echoed": msg}))
391                },
392            ))
393            .build()
394    }
395
396    #[test]
397    fn test_sse_config_default() {
398        let config = SseServerConfig::default();
399        assert_eq!(config.host, "127.0.0.1");
400        assert_eq!(config.port, 3000);
401        assert_eq!(config.sse_path, "/sse");
402        assert_eq!(config.message_path, "/message");
403        assert!(config.enable_cors);
404    }
405
406    #[test]
407    fn test_sse_config_localhost() {
408        let config = SseServerConfig::localhost(8080);
409        assert_eq!(config.host, "127.0.0.1");
410        assert_eq!(config.port, 8080);
411    }
412
413    #[test]
414    fn test_sse_config_public() {
415        let config = SseServerConfig::public(9000);
416        assert_eq!(config.host, "0.0.0.0");
417        assert_eq!(config.port, 9000);
418    }
419
420    #[tokio::test]
421    async fn test_sse_router_creation() {
422        let server = create_test_server();
423        let config = SseServerConfig::default();
424        let _router = server.sse_router(config);
425        // Router created successfully
426    }
427
428    #[tokio::test]
429    async fn test_session_registration() {
430        let server = create_test_server();
431        let (shutdown_tx, _) = broadcast::channel::<()>(1);
432        let state = SseServerState::new(server, SseServerConfig::default(), shutdown_tx);
433
434        let (tx, _rx) = mpsc::channel::<JsonRpcResponse>(10);
435        state.register_session("test-session".to_string(), tx);
436
437        assert!(state.get_session_sender("test-session").is_some());
438        assert!(state.get_session_sender("nonexistent").is_none());
439
440        state.unregister_session("test-session");
441        assert!(state.get_session_sender("test-session").is_none());
442    }
443}