1use std::{convert::Infallible, sync::Arc, time::Duration};
47
48use crate::server::{core::McpServer, session::Session};
49use crate::{error::McpResult, protocol::JsonRpcMessage};
50use axum::{
51 extract::State as AxumState,
52 http::{HeaderMap, HeaderValue, StatusCode},
53 response::{sse::Event, IntoResponse, Response, Sse},
54 routing::{delete, get, post},
55 Json as AxumJson, Router as AxumRouter,
56};
57use dashmap::DashMap;
58use futures_util::stream;
59use std::future::Future;
60use tokio::sync::mpsc;
61use tracing::info;
62use uuid::Uuid;
63
64#[cfg(feature = "auth")]
65use crate::auth::{AuthenticatedIdentity, DynAuthProvider};
66#[cfg(feature = "auth")]
67use crate::transport::auth_layer::{auth_middleware, AuthMiddlewareState};
68
69const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id";
72const SESSION_TIMEOUT_SECS: u64 = 3600; type NotificationTx = mpsc::Sender<JsonRpcMessage>;
77
78pub(crate) struct SessionEntry {
79 session: Arc<tokio::sync::Mutex<Session>>,
80 notification_tx: Option<NotificationTx>,
81 last_active: std::time::Instant,
82}
83
84#[derive(Clone)]
85pub struct StreamableState {
86 pub(crate) server: Arc<McpServer>,
87 pub(crate) sessions: Arc<DashMap<String, SessionEntry>>,
88 #[cfg(feature = "auth")]
89 #[allow(dead_code)]
90 pub(crate) auth: Option<AuthMiddlewareState>,
91}
92
93impl StreamableState {
94 fn get_or_create_session(
95 &self,
96 session_id: Option<&str>,
97 ) -> (String, Arc<tokio::sync::Mutex<Session>>) {
98 if let Some(sid) = session_id {
99 if let Some(mut entry) = self.sessions.get_mut(sid) {
100 entry.last_active = std::time::Instant::now();
101 return (sid.to_string(), entry.session.clone());
102 }
103 }
104
105 let session_id = Uuid::new_v4().to_string();
107 let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
108 self.sessions.insert(
109 session_id.clone(),
110 SessionEntry {
111 session: session.clone(),
112 notification_tx: None,
113 last_active: std::time::Instant::now(),
114 },
115 );
116 info!(session_id = %session_id, "Created new session");
117 (session_id, session)
118 }
119
120 fn set_notification_channel(&self, session_id: &str, tx: NotificationTx) {
121 if let Some(mut entry) = self.sessions.get_mut(session_id) {
122 entry.notification_tx = Some(tx);
123 }
124 }
125
126 fn remove_notification_channel(&self, session_id: &str) {
127 if let Some(mut entry) = self.sessions.get_mut(session_id) {
128 entry.notification_tx = None;
129 }
130 }
131
132 fn remove_session(&self, session_id: &str) {
133 self.sessions.remove(session_id);
134 info!(session_id = %session_id, "Session terminated");
135 }
136
137 fn cleanup_expired_sessions(&self) {
138 let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
139 let now = std::time::Instant::now();
140
141 self.sessions.retain(|sid, entry| {
142 let expired = now.duration_since(entry.last_active) > timeout;
143 if expired {
144 info!(session_id = %sid, "Session expired");
145 }
146 !expired
147 });
148 }
149}
150
151pub struct StreamableTransport {
155 server: McpServer,
156 addr: std::net::SocketAddr,
157 endpoint: String,
158 #[cfg(feature = "auth")]
159 auth: Option<AuthMiddlewareState>,
160}
161
162impl StreamableTransport {
163 pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
165 Self {
166 server,
167 addr: addr.into(),
168 endpoint: "/mcp".to_string(),
169 #[cfg(feature = "auth")]
170 auth: None,
171 }
172 }
173
174 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
176 self.endpoint = endpoint.into();
177 self
178 }
179
180 #[cfg(feature = "auth")]
182 pub fn with_auth(mut self, provider: DynAuthProvider) -> Self {
183 self.auth = Some(AuthMiddlewareState {
184 provider,
185 require_auth: true,
186 });
187 self
188 }
189
190 #[cfg(feature = "auth")]
192 pub fn with_optional_auth(mut self, provider: DynAuthProvider) -> Self {
193 self.auth = Some(AuthMiddlewareState {
194 provider,
195 require_auth: false,
196 });
197 self
198 }
199
200 pub fn build_router(self) -> (AxumRouter, StreamableState) {
202 let state = StreamableState {
203 server: Arc::new(self.server),
204 sessions: Arc::new(DashMap::new()),
205 #[cfg(feature = "auth")]
206 auth: self.auth.clone(),
207 };
208
209 let routes = AxumRouter::new()
210 .route(&self.endpoint, post(handle_post))
211 .route(&self.endpoint, get(handle_get_sse))
212 .route(&self.endpoint, delete(handle_delete));
213
214 #[cfg(feature = "auth")]
215 let routes = if let Some(auth_state) = self.auth {
216 routes.route_layer(axum::middleware::from_fn_with_state(
217 auth_state,
218 auth_middleware,
219 ))
220 } else {
221 routes
222 };
223
224 (routes.with_state(state.clone()), state)
225 }
226
227 pub async fn serve(self) -> McpResult<()> {
229 let addr = self.addr;
230 let (router, state) = self.build_router();
231
232 let cleanup_state = state.clone();
234 tokio::spawn(async move {
235 let mut interval = tokio::time::interval(Duration::from_secs(300)); loop {
237 interval.tick().await;
238 cleanup_state.cleanup_expired_sessions();
239 }
240 });
241
242 info!(addr = %addr, "Streamable HTTP transport listening");
243
244 let listener = tokio::net::TcpListener::bind(addr)
245 .await
246 .map_err(crate::error::McpError::Io)?;
247
248 axum::serve(listener, router)
249 .await
250 .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
251
252 Ok(())
253 }
254}
255
256async fn handle_post(
259 AxumState(state): AxumState<StreamableState>,
260 headers: HeaderMap,
261 #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
262 AxumJson(msg): AxumJson<JsonRpcMessage>,
263) -> Response {
264 let session_id_header = headers
266 .get(MCP_SESSION_ID_HEADER)
267 .and_then(|v| v.to_str().ok());
268
269 let (session_id, session_arc) = state.get_or_create_session(session_id_header);
270
271 let mut session = session_arc.lock().await;
272
273 #[cfg(feature = "auth")]
275 if let Some(axum::Extension(id)) = identity {
276 session.identity = Some((*id).clone());
277 }
278
279 let needs_streaming = matches!(&msg, JsonRpcMessage::Request(req) if
281 req.method == "tools/call" ||
282 req.method == "sampling/createMessage"
283 );
284
285 let server = state.server.clone();
286
287 match server.handle_message(msg, &mut session).await {
289 Some(response) => {
290 if needs_streaming {
291 stream_response(session_id, response)
293 } else {
294 json_response(session_id, response)
296 }
297 }
298 None => {
299 (StatusCode::ACCEPTED, [(MCP_SESSION_ID_HEADER, session_id)]).into_response()
301 }
302 }
303}
304
305fn json_response(session_id: String, response: JsonRpcMessage) -> Response {
306 let mut resp = AxumJson(response).into_response();
307 resp.headers_mut().insert(
308 MCP_SESSION_ID_HEADER,
309 HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
310 );
311 resp
312}
313
314fn stream_response(session_id: String, response: JsonRpcMessage) -> Response {
315 let event = Event::default()
316 .event("message")
317 .data(serde_json::to_string(&response).unwrap_or_default());
318
319 let stream = stream::once(async move { Ok::<_, Infallible>(event) });
320
321 let mut resp = Sse::new(stream)
322 .keep_alive(axum::response::sse::KeepAlive::default())
323 .into_response();
324
325 resp.headers_mut().insert(
326 MCP_SESSION_ID_HEADER,
327 HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
328 );
329
330 resp
331}
332
333async fn handle_get_sse(
336 AxumState(state): AxumState<StreamableState>,
337 headers: HeaderMap,
338 #[cfg(feature = "auth")] _identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
339) -> Response {
340 let Some(session_id) = headers
342 .get(MCP_SESSION_ID_HEADER)
343 .and_then(|v| v.to_str().ok())
344 else {
345 return (
346 StatusCode::BAD_REQUEST,
347 "Mcp-Session-Id header required for SSE",
348 )
349 .into_response();
350 };
351
352 if !state.sessions.contains_key(session_id) {
354 return (StatusCode::NOT_FOUND, "Session not found").into_response();
355 }
356
357 let session_id = session_id.to_string();
358
359 let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
361 state.set_notification_channel(&session_id, tx);
362
363 info!(session_id = %session_id, "SSE stream opened");
364
365 let cleanup_state = state.clone();
366 let session_id_for_cleanup = session_id.clone();
367
368 let stream = stream::unfold(rx, move |mut rx| async move {
369 match rx.recv().await {
370 Some(msg) => {
371 let data = serde_json::to_string(&msg).unwrap_or_default();
372 let event = Event::default().event("message").data(data);
373 Some((Ok::<_, Infallible>(event), rx))
374 }
375 None => None,
376 }
377 });
378
379 tokio::spawn(async move {
381 tokio::time::sleep(Duration::from_secs(SESSION_TIMEOUT_SECS)).await;
383 cleanup_state.remove_notification_channel(&session_id_for_cleanup);
384 info!(session_id = %session_id_for_cleanup, "SSE stream timeout cleanup");
385 });
386
387 let mut resp = Sse::new(stream)
388 .keep_alive(axum::response::sse::KeepAlive::default())
389 .into_response();
390
391 resp.headers_mut().insert(
392 MCP_SESSION_ID_HEADER,
393 HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
394 );
395
396 resp
397}
398
399async fn handle_delete(
402 AxumState(state): AxumState<StreamableState>,
403 headers: HeaderMap,
404) -> Response {
405 let Some(session_id) = headers
406 .get(MCP_SESSION_ID_HEADER)
407 .and_then(|v| v.to_str().ok())
408 else {
409 return (StatusCode::BAD_REQUEST, "Mcp-Session-Id header required").into_response();
410 };
411
412 if state.sessions.contains_key(session_id) {
413 state.remove_session(session_id);
414 StatusCode::NO_CONTENT.into_response()
415 } else {
416 (StatusCode::NOT_FOUND, "Session not found").into_response()
417 }
418}
419
420pub trait ServeStreamableExt {
424 fn serve_streamable(
426 self,
427 addr: impl Into<std::net::SocketAddr>,
428 ) -> impl Future<Output = McpResult<()>> + Send;
429}
430
431impl ServeStreamableExt for McpServer {
432 fn serve_streamable(
433 self,
434 addr: impl Into<std::net::SocketAddr>,
435 ) -> impl Future<Output = McpResult<()>> + Send {
436 #[cfg(feature = "auth")]
437 {
438 let transport = StreamableTransport::new(self.clone(), addr);
439 let transport = match (self.auth_provider, self.require_auth) {
440 (Some(provider), true) => transport.with_auth(provider),
441 (Some(provider), false) => transport.with_optional_auth(provider),
442 (None, _) => transport,
443 };
444 transport.serve()
445 }
446 #[cfg(not(feature = "auth"))]
447 StreamableTransport::new(self, addr).serve()
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_session_creation() {
457 let state = StreamableState {
458 server: Arc::new(McpServer::builder().name("test").version("1.0").build()),
459 sessions: Arc::new(DashMap::new()),
460 #[cfg(feature = "auth")]
461 auth: None,
462 };
463
464 let (sid1, _) = state.get_or_create_session(None);
466 assert!(!sid1.is_empty());
467 assert!(state.sessions.contains_key(&sid1));
468
469 let (sid2, _) = state.get_or_create_session(Some(&sid1));
471 assert_eq!(sid1, sid2);
472 }
473}