1use axum::{
39 extract::State,
40 http::{header, Method},
41 response::sse::{Event, KeepAlive, Sse},
42 routing::{get, post},
43 Json, Router,
44};
45use futures::stream::Stream;
46use http::StatusCode;
47use parking_lot::RwLock;
48use std::collections::HashMap;
49use std::convert::Infallible;
50use std::net::SocketAddr;
51use std::sync::Arc;
52use tokio::sync::{broadcast, mpsc};
53use tower_http::cors::{Any, CorsLayer};
54use tracing::{debug, error, info};
55use uuid::Uuid;
56
57use crate::error::McpError;
58use crate::protocol::JsonRpcResponse;
59use crate::server::McpServer;
60
61#[derive(Debug, Clone)]
67pub struct SseServerConfig {
68 pub host: String,
70 pub port: u16,
72 pub sse_path: String,
74 pub message_path: String,
76 pub enable_cors: bool,
78 pub keep_alive_secs: u64,
80}
81
82impl Default for SseServerConfig {
83 fn default() -> Self {
84 Self {
85 host: "127.0.0.1".to_string(),
86 port: 3000,
87 sse_path: "/sse".to_string(),
88 message_path: "/message".to_string(),
89 enable_cors: true,
90 keep_alive_secs: 30,
91 }
92 }
93}
94
95impl SseServerConfig {
96 pub fn localhost(port: u16) -> Self {
98 Self {
99 port,
100 ..Default::default()
101 }
102 }
103
104 pub fn public(port: u16) -> Self {
106 Self {
107 host: "0.0.0.0".to_string(),
108 port,
109 ..Default::default()
110 }
111 }
112}
113
114struct SseServerState {
120 mcp_server: Arc<McpServer>,
122 config: SseServerConfig,
124 sessions: RwLock<HashMap<String, mpsc::Sender<JsonRpcResponse>>>,
126 shutdown_tx: broadcast::Sender<()>,
128}
129
130impl SseServerState {
131 fn new(
132 mcp_server: Arc<McpServer>,
133 config: SseServerConfig,
134 shutdown_tx: broadcast::Sender<()>,
135 ) -> Self {
136 Self {
137 mcp_server,
138 config,
139 sessions: RwLock::new(HashMap::new()),
140 shutdown_tx,
141 }
142 }
143
144 fn register_session(&self, session_id: String, sender: mpsc::Sender<JsonRpcResponse>) {
145 self.sessions.write().insert(session_id, sender);
146 }
147
148 fn unregister_session(&self, session_id: &str) {
149 self.sessions.write().remove(session_id);
150 }
151
152 fn get_session_sender(&self, session_id: &str) -> Option<mpsc::Sender<JsonRpcResponse>> {
153 self.sessions.read().get(session_id).cloned()
154 }
155}
156
157impl McpServer {
162 pub async fn run_sse(self: Arc<Self>, config: SseServerConfig) -> Result<(), McpError> {
168 let (shutdown_tx, _) = broadcast::channel::<()>(1);
169 let state = Arc::new(SseServerState::new(
170 self.clone(),
171 config.clone(),
172 shutdown_tx,
173 ));
174
175 let mut app = Router::new()
176 .route(&config.sse_path, get(handle_sse))
177 .route(&config.message_path, post(handle_message))
178 .with_state(state.clone());
179
180 if config.enable_cors {
181 let cors = CorsLayer::new()
182 .allow_origin(Any)
183 .allow_methods([Method::GET, Method::POST])
184 .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
185 app = app.layer(cors);
186 }
187
188 let addr: SocketAddr = format!("{}:{}", config.host, config.port)
189 .parse()
190 .map_err(|e| McpError::Transport(format!("Invalid address: {}", e)))?;
191
192 info!(
193 "Starting MCP SSE server on http://{}{}",
194 addr, config.sse_path
195 );
196 info!("Message endpoint: http://{}{}", addr, config.message_path);
197
198 let listener = tokio::net::TcpListener::bind(addr)
199 .await
200 .map_err(|e| McpError::Transport(format!("Failed to bind: {}", e)))?;
201
202 axum::serve(listener, app)
203 .await
204 .map_err(|e| McpError::Transport(format!("Server error: {}", e)))?;
205
206 Ok(())
207 }
208
209 pub fn sse_router(self: Arc<Self>, config: SseServerConfig) -> Router {
212 let (shutdown_tx, _) = broadcast::channel::<()>(1);
213 let state = Arc::new(SseServerState::new(
214 self.clone(),
215 config.clone(),
216 shutdown_tx,
217 ));
218
219 let mut router = Router::new()
220 .route(&config.sse_path, get(handle_sse))
221 .route(&config.message_path, post(handle_message))
222 .with_state(state);
223
224 if config.enable_cors {
225 let cors = CorsLayer::new()
226 .allow_origin(Any)
227 .allow_methods([Method::GET, Method::POST])
228 .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
229 router = router.layer(cors);
230 }
231
232 router
233 }
234}
235
236#[derive(Debug, serde::Serialize)]
242struct EndpointEvent {
243 endpoint: String,
244}
245
246async fn handle_sse(
252 State(state): State<Arc<SseServerState>>,
253) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
254 let session_id = Uuid::new_v4().to_string();
255 let (tx, mut rx) = mpsc::channel::<JsonRpcResponse>(100);
256
257 state.register_session(session_id.clone(), tx);
258
259 let config = state.config.clone();
260 let state_clone = state.clone();
261 let session_id_clone = session_id.clone();
262
263 info!("New SSE session: {}", session_id);
264
265 let stream = async_stream::stream! {
266 let endpoint = EndpointEvent {
268 endpoint: format!("{}?sessionId={}", config.message_path, session_id_clone),
269 };
270 let endpoint_json = serde_json::to_string(&endpoint).unwrap();
271 yield Ok(Event::default().event("endpoint").data(endpoint_json));
272
273 debug!("Sent endpoint event for session {}", session_id_clone);
274
275 let mut shutdown_rx = state_clone.shutdown_tx.subscribe();
277 loop {
278 tokio::select! {
279 Some(response) = rx.recv() => {
280 match serde_json::to_string(&response) {
281 Ok(json) => {
282 debug!("Sending SSE message: {}", json);
283 yield Ok(Event::default().event("message").data(json));
284 }
285 Err(e) => {
286 error!("Failed to serialize response: {}", e);
287 }
288 }
289 }
290 _ = shutdown_rx.recv() => {
291 info!("SSE session {} shutting down", session_id_clone);
292 break;
293 }
294 }
295 }
296
297 state_clone.unregister_session(&session_id_clone);
298 info!("SSE session {} closed", session_id_clone);
299 };
300
301 Sse::new(stream).keep_alive(
302 KeepAlive::new()
303 .interval(std::time::Duration::from_secs(state.config.keep_alive_secs))
304 .text("ping"),
305 )
306}
307
308#[derive(Debug, Default, serde::Deserialize)]
310struct MessageQuery {
311 #[serde(rename = "sessionId")]
312 session_id: Option<String>,
313}
314
315async fn handle_message(
317 State(state): State<Arc<SseServerState>>,
318 axum::extract::Query(query): axum::extract::Query<MessageQuery>,
319 Json(body): Json<serde_json::Value>,
320) -> (StatusCode, Json<serde_json::Value>) {
321 let session_id = query.session_id;
322
323 debug!(
324 "Received message for session {:?}: {}",
325 session_id,
326 serde_json::to_string_pretty(&body).unwrap_or_default()
327 );
328
329 let request = match serde_json::from_value::<crate::protocol::JsonRpcRequest>(body.clone()) {
331 Ok(req) => req,
332 Err(e) => {
333 error!("Failed to parse JSON-RPC request: {}", e);
334 return (
335 StatusCode::BAD_REQUEST,
336 Json(serde_json::json!({
337 "jsonrpc": "2.0",
338 "id": null,
339 "error": {
340 "code": -32700,
341 "message": format!("Parse error: {}", e)
342 }
343 })),
344 );
345 }
346 };
347
348 let response = state.mcp_server.handle_request(request).await;
350
351 if let Some(ref sid) = session_id {
353 if let Some(sender) = state.get_session_sender(sid) {
354 if sender.send(response.clone()).await.is_ok() {
355 return (
357 StatusCode::ACCEPTED,
358 Json(serde_json::json!({"status": "accepted"})),
359 );
360 }
361 }
362 }
363
364 let response_json = serde_json::to_value(&response).unwrap_or_default();
366 (StatusCode::OK, Json(response_json))
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::server::FnTool;
373 use serde_json::json;
374
375 fn create_test_server() -> Arc<McpServer> {
376 McpServer::builder()
377 .name("test-sse-server")
378 .version("1.0.0")
379 .add_tool(FnTool::new(
380 "echo",
381 "Echoes input",
382 json!({
383 "type": "object",
384 "properties": {
385 "message": {"type": "string"}
386 }
387 }),
388 |args| {
389 let msg = args["message"].as_str().unwrap_or("no message");
390 Ok(json!({"echoed": msg}))
391 },
392 ))
393 .build()
394 }
395
396 #[test]
397 fn test_sse_config_default() {
398 let config = SseServerConfig::default();
399 assert_eq!(config.host, "127.0.0.1");
400 assert_eq!(config.port, 3000);
401 assert_eq!(config.sse_path, "/sse");
402 assert_eq!(config.message_path, "/message");
403 assert!(config.enable_cors);
404 }
405
406 #[test]
407 fn test_sse_config_localhost() {
408 let config = SseServerConfig::localhost(8080);
409 assert_eq!(config.host, "127.0.0.1");
410 assert_eq!(config.port, 8080);
411 }
412
413 #[test]
414 fn test_sse_config_public() {
415 let config = SseServerConfig::public(9000);
416 assert_eq!(config.host, "0.0.0.0");
417 assert_eq!(config.port, 9000);
418 }
419
420 #[tokio::test]
421 async fn test_sse_router_creation() {
422 let server = create_test_server();
423 let config = SseServerConfig::default();
424 let _router = server.sse_router(config);
425 }
427
428 #[tokio::test]
429 async fn test_session_registration() {
430 let server = create_test_server();
431 let (shutdown_tx, _) = broadcast::channel::<()>(1);
432 let state = SseServerState::new(server, SseServerConfig::default(), shutdown_tx);
433
434 let (tx, _rx) = mpsc::channel::<JsonRpcResponse>(10);
435 state.register_session("test-session".to_string(), tx);
436
437 assert!(state.get_session_sender("test-session").is_some());
438 assert!(state.get_session_sender("nonexistent").is_none());
439
440 state.unregister_session("test-session");
441 assert!(state.get_session_sender("test-session").is_none());
442 }
443}