Skip to main content

mcp_kit/transport/
sse.rs

1/// SSE (Server-Sent Events) + HTTP POST transport for MCP.
2///
3///   GET  /sse          → opens SSE stream, client receives messages from server
4///   POST /message      → client sends JSON-RPC messages to server
5use std::{convert::Infallible, sync::Arc};
6
7use crate::server::{core::McpServer, session::Session};
8use crate::{error::McpResult, protocol::JsonRpcMessage};
9use axum::{
10    extract::{Query, State as AxumState},
11    http::StatusCode,
12    response::{sse::Event, IntoResponse, Response, Sse},
13    routing::{get, post},
14    Json as AxumJson, Router as AxumRouter,
15};
16use dashmap::DashMap;
17use futures_util::stream;
18use std::future::Future;
19use tokio::sync::mpsc;
20use tracing::{error, info};
21use uuid::Uuid;
22
23#[cfg(feature = "auth")]
24use crate::auth::{AuthenticatedIdentity, DynAuthProvider};
25#[cfg(feature = "auth")]
26use crate::transport::auth_layer::{auth_middleware, AuthMiddlewareState};
27
28// ─── Shared SSE state ─────────────────────────────────────────────────────────
29
30type SessionTx = mpsc::Sender<JsonRpcMessage>;
31type SessionData = (SessionTx, Arc<tokio::sync::Mutex<Session>>);
32
33#[derive(Clone)]
34pub struct SseState {
35    pub server: Arc<McpServer>,
36    pub sessions: Arc<DashMap<String, SessionData>>,
37    #[cfg(feature = "auth")]
38    pub auth: Option<AuthMiddlewareState>,
39}
40
41// ─── SseTransport ─────────────────────────────────────────────────────────────
42
43pub struct SseTransport {
44    server: McpServer,
45    addr: std::net::SocketAddr,
46    #[cfg(feature = "auth")]
47    auth: Option<AuthMiddlewareState>,
48}
49
50impl SseTransport {
51    pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
52        Self {
53            server,
54            addr: addr.into(),
55            #[cfg(feature = "auth")]
56            auth: None,
57        }
58    }
59
60    /// Require authentication on all requests using the given provider.
61    /// Requests with no or invalid credentials receive HTTP 401.
62    #[cfg(feature = "auth")]
63    pub fn with_auth(mut self, provider: DynAuthProvider) -> Self {
64        self.auth = Some(AuthMiddlewareState {
65            provider,
66            require_auth: true,
67        });
68        self
69    }
70
71    /// Accept an auth provider but allow unauthenticated requests through.
72    /// Authenticated requests will have an identity; unauthenticated ones will not.
73    #[cfg(feature = "auth")]
74    pub fn with_optional_auth(mut self, provider: DynAuthProvider) -> Self {
75        self.auth = Some(AuthMiddlewareState {
76            provider,
77            require_auth: false,
78        });
79        self
80    }
81
82    pub async fn serve(self) -> McpResult<()> {
83        let state = SseState {
84            server: Arc::new(self.server),
85            sessions: Arc::new(DashMap::new()),
86            #[cfg(feature = "auth")]
87            auth: self.auth,
88        };
89
90        let app = build_router(state);
91
92        info!(addr = %self.addr, "SSE transport listening");
93
94        let listener = tokio::net::TcpListener::bind(self.addr)
95            .await
96            .map_err(crate::error::McpError::Io)?;
97
98        axum::serve(listener, app)
99            .await
100            .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
101
102        Ok(())
103    }
104}
105
106pub(crate) fn build_router(state: SseState) -> AxumRouter {
107    let routes = AxumRouter::new()
108        .route("/sse", get(sse_handler))
109        .route("/message", post(message_handler));
110
111    #[cfg(feature = "auth")]
112    if let Some(auth_state) = state.auth.clone() {
113        return routes
114            .route_layer(axum::middleware::from_fn_with_state(
115                auth_state,
116                auth_middleware,
117            ))
118            .with_state(state);
119    }
120
121    routes.with_state(state)
122}
123
124// ─── GET /sse ─────────────────────────────────────────────────────────────────
125
126async fn sse_handler(
127    AxumState(state): AxumState<SseState>,
128    #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
129) -> Response {
130    let session_id = Uuid::new_v4().to_string();
131    let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
132    let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
133
134    #[cfg(feature = "auth")]
135    if let Some(axum::Extension(id)) = identity {
136        session.lock().await.identity = Some((*id).clone());
137    }
138
139    state
140        .sessions
141        .insert(session_id.clone(), (tx, session.clone()));
142    info!(session_id = %session_id, "New SSE connection");
143
144    let sid = session_id.clone();
145    let init_event = Event::default()
146        .event("endpoint")
147        .data(format!("/message?sessionId={sid}"));
148
149    let stream = stream::unfold(
150        (rx, session_id.clone(), state.clone()),
151        |(rx, sid, state)| async move {
152            let mut rx = rx;
153            match rx.recv().await {
154                Some(msg) => {
155                    let data = serde_json::to_string(&msg).unwrap_or_default();
156                    let event = Event::default().event("message").data(data);
157                    Some((Ok::<_, Infallible>(event), (rx, sid, state)))
158                }
159                None => {
160                    state.sessions.remove(&sid);
161                    None
162                }
163            }
164        },
165    );
166
167    let combined = futures_util::StreamExt::chain(
168        stream::once(async move { Ok::<_, Infallible>(init_event) }),
169        stream,
170    );
171
172    Sse::new(combined)
173        .keep_alive(axum::response::sse::KeepAlive::default())
174        .into_response()
175}
176
177// ─── POST /message ────────────────────────────────────────────────────────────
178
179#[derive(serde::Deserialize)]
180struct MessageQuery {
181    #[serde(rename = "sessionId")]
182    session_id: String,
183}
184
185async fn message_handler(
186    AxumState(state): AxumState<SseState>,
187    Query(query): Query<MessageQuery>,
188    #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
189    AxumJson(msg): AxumJson<JsonRpcMessage>,
190) -> impl IntoResponse {
191    let entry = state.sessions.get(&query.session_id);
192    let Some(entry) = entry else {
193        return (StatusCode::NOT_FOUND, "Session not found").into_response();
194    };
195
196    let (tx, session_arc) = entry.value().clone();
197    drop(entry);
198
199    let mut session = session_arc.lock().await;
200
201    // Refresh the identity on every POST so it reflects the current request's auth.
202    #[cfg(feature = "auth")]
203    if let Some(axum::Extension(id)) = identity {
204        session.identity = Some((*id).clone());
205    }
206
207    let server = state.server.clone();
208
209    match server.handle_message(msg, &mut session).await {
210        Some(response) => {
211            if tx.send(response).await.is_err() {
212                error!(session_id = %query.session_id, "Failed to send SSE response");
213            }
214            StatusCode::OK.into_response()
215        }
216        None => StatusCode::ACCEPTED.into_response(),
217    }
218}
219
220/// Extension trait that adds `.serve_sse()` to `McpServer`.
221pub trait ServeSseExt {
222    fn serve_sse(
223        self,
224        addr: impl Into<std::net::SocketAddr>,
225    ) -> impl Future<Output = McpResult<()>> + Send;
226}
227
228impl ServeSseExt for McpServer {
229    fn serve_sse(
230        self,
231        addr: impl Into<std::net::SocketAddr>,
232    ) -> impl Future<Output = McpResult<()>> + Send {
233        #[cfg(feature = "auth")]
234        {
235            let transport = SseTransport::new(self.clone(), addr);
236            let transport = match (self.auth_provider, self.require_auth) {
237                (Some(provider), true) => transport.with_auth(provider),
238                (Some(provider), false) => transport.with_optional_auth(provider),
239                (None, _) => transport,
240            };
241            transport.serve()
242        }
243        #[cfg(not(feature = "auth"))]
244        SseTransport::new(self, addr).serve()
245    }
246}