Skip to main content

mcp_kit/transport/
sse.rs

1/// SSE (Server-Sent Events) + HTTP POST transport for MCP.
2///
3///   GET  /sse          → opens SSE stream, client receives messages from server
4///   POST /message      → client sends JSON-RPC messages to server
5use std::{convert::Infallible, sync::Arc};
6
7use crate::server::{core::McpServer, session::Session};
8use crate::{error::McpResult, protocol::JsonRpcMessage};
9use axum::{
10    extract::{Query, State as AxumState},
11    http::StatusCode,
12    response::{sse::Event, IntoResponse, Response, Sse},
13    routing::{get, post},
14    Json as AxumJson, Router as AxumRouter,
15};
16use dashmap::DashMap;
17use futures_util::stream;
18use std::future::Future;
19use tokio::sync::mpsc;
20use tracing::{error, info};
21use uuid::Uuid;
22
23// ─── Shared SSE state ─────────────────────────────────────────────────────────
24
25type SessionTx = mpsc::Sender<JsonRpcMessage>;
26type SessionData = (SessionTx, Arc<tokio::sync::Mutex<Session>>);
27
28#[derive(Clone)]
29pub struct SseState {
30    pub server: Arc<McpServer>,
31    pub sessions: Arc<DashMap<String, SessionData>>,
32}
33
34// ─── SseTransport ─────────────────────────────────────────────────────────────
35
36pub struct SseTransport {
37    server: McpServer,
38    addr: std::net::SocketAddr,
39}
40
41impl SseTransport {
42    pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
43        Self {
44            server,
45            addr: addr.into(),
46        }
47    }
48
49    pub async fn serve(self) -> McpResult<()> {
50        let state = SseState {
51            server: Arc::new(self.server),
52            sessions: Arc::new(DashMap::new()),
53        };
54
55        let app = AxumRouter::new()
56            .route("/sse", get(sse_handler))
57            .route("/message", post(message_handler))
58            .with_state(state);
59
60        info!(addr = %self.addr, "SSE transport listening");
61
62        let listener = tokio::net::TcpListener::bind(self.addr)
63            .await
64            .map_err(crate::error::McpError::Io)?;
65
66        axum::serve(listener, app)
67            .await
68            .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
69
70        Ok(())
71    }
72}
73
74// ─── GET /sse ─────────────────────────────────────────────────────────────────
75
76async fn sse_handler(AxumState(state): AxumState<SseState>) -> Response {
77    let session_id = Uuid::new_v4().to_string();
78    let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
79    let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
80
81    state
82        .sessions
83        .insert(session_id.clone(), (tx, session.clone()));
84    info!(session_id = %session_id, "New SSE connection");
85
86    let sid = session_id.clone();
87    let init_event = Event::default()
88        .event("endpoint")
89        .data(format!("/message?sessionId={sid}"));
90
91    let stream = stream::unfold(
92        (rx, session_id.clone(), state.clone()),
93        |(rx, sid, state)| async move {
94            let mut rx = rx;
95            match rx.recv().await {
96                Some(msg) => {
97                    let data = serde_json::to_string(&msg).unwrap_or_default();
98                    let event = Event::default().event("message").data(data);
99                    Some((Ok::<_, Infallible>(event), (rx, sid, state)))
100                }
101                None => {
102                    state.sessions.remove(&sid);
103                    None
104                }
105            }
106        },
107    );
108
109    let combined = futures_util::StreamExt::chain(
110        stream::once(async move { Ok::<_, Infallible>(init_event) }),
111        stream,
112    );
113
114    Sse::new(combined)
115        .keep_alive(axum::response::sse::KeepAlive::default())
116        .into_response()
117}
118
119// ─── POST /message ────────────────────────────────────────────────────────────
120
121#[derive(serde::Deserialize)]
122struct MessageQuery {
123    #[serde(rename = "sessionId")]
124    session_id: String,
125}
126
127async fn message_handler(
128    AxumState(state): AxumState<SseState>,
129    Query(query): Query<MessageQuery>,
130    AxumJson(msg): AxumJson<JsonRpcMessage>,
131) -> impl IntoResponse {
132    let entry = state.sessions.get(&query.session_id);
133    let Some(entry) = entry else {
134        return (StatusCode::NOT_FOUND, "Session not found").into_response();
135    };
136
137    let (tx, session_arc) = entry.value().clone();
138    drop(entry);
139
140    let mut session = session_arc.lock().await;
141    let server = state.server.clone();
142
143    match server.handle_message(msg, &mut session).await {
144        Some(response) => {
145            if tx.send(response).await.is_err() {
146                error!(session_id = %query.session_id, "Failed to send SSE response");
147            }
148            StatusCode::OK.into_response()
149        }
150        None => StatusCode::ACCEPTED.into_response(),
151    }
152}
153
154/// Extension trait that adds `.serve_sse()` to `McpServer`.
155pub trait ServeSseExt {
156    fn serve_sse(
157        self,
158        addr: impl Into<std::net::SocketAddr>,
159    ) -> impl Future<Output = McpResult<()>> + Send;
160}
161
162impl ServeSseExt for McpServer {
163    fn serve_sse(
164        self,
165        addr: impl Into<std::net::SocketAddr>,
166    ) -> impl Future<Output = McpResult<()>> + Send {
167        SseTransport::new(self, addr).serve()
168    }
169}