1use std::{convert::Infallible, sync::Arc};
6
7use crate::server::{core::McpServer, session::Session};
8use crate::{error::McpResult, protocol::JsonRpcMessage};
9use axum::{
10 extract::{Query, State as AxumState},
11 http::StatusCode,
12 response::{sse::Event, IntoResponse, Response, Sse},
13 routing::{get, post},
14 Json as AxumJson, Router as AxumRouter,
15};
16use dashmap::DashMap;
17use futures_util::stream;
18use std::future::Future;
19use tokio::sync::mpsc;
20use tracing::{error, info};
21use uuid::Uuid;
22
23#[cfg(feature = "auth")]
24use crate::auth::{AuthenticatedIdentity, DynAuthProvider};
25#[cfg(feature = "auth")]
26use crate::transport::auth_layer::{auth_middleware, AuthMiddlewareState};
27
28type SessionTx = mpsc::Sender<JsonRpcMessage>;
31type SessionData = (SessionTx, Arc<tokio::sync::Mutex<Session>>);
32
33#[derive(Clone)]
34pub struct SseState {
35 pub server: Arc<McpServer>,
36 pub sessions: Arc<DashMap<String, SessionData>>,
37 #[cfg(feature = "auth")]
38 pub auth: Option<AuthMiddlewareState>,
39}
40
41pub struct SseTransport {
44 server: McpServer,
45 addr: std::net::SocketAddr,
46 #[cfg(feature = "auth")]
47 auth: Option<AuthMiddlewareState>,
48}
49
50impl SseTransport {
51 pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
52 Self {
53 server,
54 addr: addr.into(),
55 #[cfg(feature = "auth")]
56 auth: None,
57 }
58 }
59
60 #[cfg(feature = "auth")]
63 pub fn with_auth(mut self, provider: DynAuthProvider) -> Self {
64 self.auth = Some(AuthMiddlewareState {
65 provider,
66 require_auth: true,
67 });
68 self
69 }
70
71 #[cfg(feature = "auth")]
74 pub fn with_optional_auth(mut self, provider: DynAuthProvider) -> Self {
75 self.auth = Some(AuthMiddlewareState {
76 provider,
77 require_auth: false,
78 });
79 self
80 }
81
82 pub async fn serve(self) -> McpResult<()> {
83 let state = SseState {
84 server: Arc::new(self.server),
85 sessions: Arc::new(DashMap::new()),
86 #[cfg(feature = "auth")]
87 auth: self.auth,
88 };
89
90 let app = build_router(state);
91
92 info!(addr = %self.addr, "SSE transport listening");
93
94 let listener = tokio::net::TcpListener::bind(self.addr)
95 .await
96 .map_err(crate::error::McpError::Io)?;
97
98 axum::serve(listener, app)
99 .await
100 .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
101
102 Ok(())
103 }
104}
105
106pub(crate) fn build_router(state: SseState) -> AxumRouter {
107 let routes = AxumRouter::new()
108 .route("/sse", get(sse_handler))
109 .route("/message", post(message_handler));
110
111 #[cfg(feature = "auth")]
112 if let Some(auth_state) = state.auth.clone() {
113 return routes
114 .route_layer(axum::middleware::from_fn_with_state(
115 auth_state,
116 auth_middleware,
117 ))
118 .with_state(state);
119 }
120
121 routes.with_state(state)
122}
123
124async fn sse_handler(
127 AxumState(state): AxumState<SseState>,
128 #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
129) -> Response {
130 let session_id = Uuid::new_v4().to_string();
131 let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
132 let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
133
134 #[cfg(feature = "auth")]
135 if let Some(axum::Extension(id)) = identity {
136 session.lock().await.identity = Some((*id).clone());
137 }
138
139 state
140 .sessions
141 .insert(session_id.clone(), (tx, session.clone()));
142 info!(session_id = %session_id, "New SSE connection");
143
144 let sid = session_id.clone();
145 let init_event = Event::default()
146 .event("endpoint")
147 .data(format!("/message?sessionId={sid}"));
148
149 let stream = stream::unfold(
150 (rx, session_id.clone(), state.clone()),
151 |(rx, sid, state)| async move {
152 let mut rx = rx;
153 match rx.recv().await {
154 Some(msg) => {
155 let data = serde_json::to_string(&msg).unwrap_or_default();
156 let event = Event::default().event("message").data(data);
157 Some((Ok::<_, Infallible>(event), (rx, sid, state)))
158 }
159 None => {
160 state.sessions.remove(&sid);
161 None
162 }
163 }
164 },
165 );
166
167 let combined = futures_util::StreamExt::chain(
168 stream::once(async move { Ok::<_, Infallible>(init_event) }),
169 stream,
170 );
171
172 Sse::new(combined)
173 .keep_alive(axum::response::sse::KeepAlive::default())
174 .into_response()
175}
176
177#[derive(serde::Deserialize)]
180struct MessageQuery {
181 #[serde(rename = "sessionId")]
182 session_id: String,
183}
184
185async fn message_handler(
186 AxumState(state): AxumState<SseState>,
187 Query(query): Query<MessageQuery>,
188 #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
189 AxumJson(msg): AxumJson<JsonRpcMessage>,
190) -> impl IntoResponse {
191 let entry = state.sessions.get(&query.session_id);
192 let Some(entry) = entry else {
193 return (StatusCode::NOT_FOUND, "Session not found").into_response();
194 };
195
196 let (tx, session_arc) = entry.value().clone();
197 drop(entry);
198
199 let mut session = session_arc.lock().await;
200
201 #[cfg(feature = "auth")]
203 if let Some(axum::Extension(id)) = identity {
204 session.identity = Some((*id).clone());
205 }
206
207 let server = state.server.clone();
208
209 match server.handle_message(msg, &mut session).await {
210 Some(response) => {
211 if tx.send(response).await.is_err() {
212 error!(session_id = %query.session_id, "Failed to send SSE response");
213 }
214 StatusCode::OK.into_response()
215 }
216 None => StatusCode::ACCEPTED.into_response(),
217 }
218}
219
220pub trait ServeSseExt {
222 fn serve_sse(
223 self,
224 addr: impl Into<std::net::SocketAddr>,
225 ) -> impl Future<Output = McpResult<()>> + Send;
226}
227
228impl ServeSseExt for McpServer {
229 fn serve_sse(
230 self,
231 addr: impl Into<std::net::SocketAddr>,
232 ) -> impl Future<Output = McpResult<()>> + Send {
233 #[cfg(feature = "auth")]
234 {
235 let transport = SseTransport::new(self.clone(), addr);
236 let transport = match (self.auth_provider, self.require_auth) {
237 (Some(provider), true) => transport.with_auth(provider),
238 (Some(provider), false) => transport.with_optional_auth(provider),
239 (None, _) => transport,
240 };
241 transport.serve()
242 }
243 #[cfg(not(feature = "auth"))]
244 SseTransport::new(self, addr).serve()
245 }
246}