1use std::{convert::Infallible, sync::Arc};
6
7use crate::server::{core::McpServer, session::Session};
8use crate::{error::McpResult, protocol::JsonRpcMessage};
9use axum::{
10 extract::{Query, State as AxumState},
11 http::StatusCode,
12 response::{sse::Event, IntoResponse, Response, Sse},
13 routing::{get, post},
14 Json as AxumJson, Router as AxumRouter,
15};
16use dashmap::DashMap;
17use futures_util::stream;
18use std::future::Future;
19use tokio::sync::mpsc;
20use tracing::{error, info};
21use uuid::Uuid;
22
23type SessionTx = mpsc::Sender<JsonRpcMessage>;
26type SessionData = (SessionTx, Arc<tokio::sync::Mutex<Session>>);
27
28#[derive(Clone)]
29pub struct SseState {
30 pub server: Arc<McpServer>,
31 pub sessions: Arc<DashMap<String, SessionData>>,
32}
33
34pub struct SseTransport {
37 server: McpServer,
38 addr: std::net::SocketAddr,
39}
40
41impl SseTransport {
42 pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
43 Self {
44 server,
45 addr: addr.into(),
46 }
47 }
48
49 pub async fn serve(self) -> McpResult<()> {
50 let state = SseState {
51 server: Arc::new(self.server),
52 sessions: Arc::new(DashMap::new()),
53 };
54
55 let app = AxumRouter::new()
56 .route("/sse", get(sse_handler))
57 .route("/message", post(message_handler))
58 .with_state(state);
59
60 info!(addr = %self.addr, "SSE transport listening");
61
62 let listener = tokio::net::TcpListener::bind(self.addr)
63 .await
64 .map_err(crate::error::McpError::Io)?;
65
66 axum::serve(listener, app)
67 .await
68 .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
69
70 Ok(())
71 }
72}
73
74async fn sse_handler(AxumState(state): AxumState<SseState>) -> Response {
77 let session_id = Uuid::new_v4().to_string();
78 let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
79 let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
80
81 state
82 .sessions
83 .insert(session_id.clone(), (tx, session.clone()));
84 info!(session_id = %session_id, "New SSE connection");
85
86 let sid = session_id.clone();
87 let init_event = Event::default()
88 .event("endpoint")
89 .data(format!("/message?sessionId={sid}"));
90
91 let stream = stream::unfold(
92 (rx, session_id.clone(), state.clone()),
93 |(rx, sid, state)| async move {
94 let mut rx = rx;
95 match rx.recv().await {
96 Some(msg) => {
97 let data = serde_json::to_string(&msg).unwrap_or_default();
98 let event = Event::default().event("message").data(data);
99 Some((Ok::<_, Infallible>(event), (rx, sid, state)))
100 }
101 None => {
102 state.sessions.remove(&sid);
103 None
104 }
105 }
106 },
107 );
108
109 let combined = futures_util::StreamExt::chain(
110 stream::once(async move { Ok::<_, Infallible>(init_event) }),
111 stream,
112 );
113
114 Sse::new(combined)
115 .keep_alive(axum::response::sse::KeepAlive::default())
116 .into_response()
117}
118
119#[derive(serde::Deserialize)]
122struct MessageQuery {
123 #[serde(rename = "sessionId")]
124 session_id: String,
125}
126
127async fn message_handler(
128 AxumState(state): AxumState<SseState>,
129 Query(query): Query<MessageQuery>,
130 AxumJson(msg): AxumJson<JsonRpcMessage>,
131) -> impl IntoResponse {
132 let entry = state.sessions.get(&query.session_id);
133 let Some(entry) = entry else {
134 return (StatusCode::NOT_FOUND, "Session not found").into_response();
135 };
136
137 let (tx, session_arc) = entry.value().clone();
138 drop(entry);
139
140 let mut session = session_arc.lock().await;
141 let server = state.server.clone();
142
143 match server.handle_message(msg, &mut session).await {
144 Some(response) => {
145 if tx.send(response).await.is_err() {
146 error!(session_id = %query.session_id, "Failed to send SSE response");
147 }
148 StatusCode::OK.into_response()
149 }
150 None => StatusCode::ACCEPTED.into_response(),
151 }
152}
153
154pub trait ServeSseExt {
156 fn serve_sse(
157 self,
158 addr: impl Into<std::net::SocketAddr>,
159 ) -> impl Future<Output = McpResult<()>> + Send;
160}
161
162impl ServeSseExt for McpServer {
163 fn serve_sse(
164 self,
165 addr: impl Into<std::net::SocketAddr>,
166 ) -> impl Future<Output = McpResult<()>> + Send {
167 SseTransport::new(self, addr).serve()
168 }
169}