mcp_kit/transport/
websocket.rs1use std::net::SocketAddr;
23use std::sync::Arc;
24
25use axum::{
26 extract::{
27 ws::{Message, WebSocket, WebSocketUpgrade},
28 State,
29 },
30 response::IntoResponse,
31 routing::get,
32 Router,
33};
34use futures_util::{SinkExt, StreamExt};
35use tokio::sync::mpsc;
36use tracing::{debug, error, info, warn};
37
38use crate::error::McpError;
39use crate::protocol::JsonRpcMessage;
40use crate::server::{session::Session, McpServer, NotificationSender};
41
42pub trait ServeWebSocketExt {
44 fn serve_websocket(
46 self,
47 addr: impl Into<SocketAddr> + Send,
48 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
49
50 fn serve_websocket_with_buffer(
52 self,
53 addr: impl Into<SocketAddr> + Send,
54 buffer_size: usize,
55 ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
56}
57
58impl ServeWebSocketExt for McpServer {
59 async fn serve_websocket(self, addr: impl Into<SocketAddr> + Send) -> Result<(), McpError> {
60 self.serve_websocket_with_buffer(addr, 32).await
61 }
62
63 async fn serve_websocket_with_buffer(
64 self,
65 addr: impl Into<SocketAddr> + Send,
66 buffer_size: usize,
67 ) -> Result<(), McpError> {
68 let addr = addr.into();
69 let state = WebSocketState {
70 server: Arc::new(self),
71 buffer_size,
72 };
73
74 let app = Router::new()
75 .route("/ws", get(ws_handler))
76 .route("/mcp", get(ws_handler))
77 .route("/health", get(|| async { "OK" }))
78 .with_state(state);
79
80 info!("Starting WebSocket MCP server on {}", addr);
81
82 let listener = tokio::net::TcpListener::bind(addr).await?;
83 axum::serve(listener, app).await?;
84
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct WebSocketState {
91 server: Arc<McpServer>,
92 buffer_size: usize,
93}
94
95async fn ws_handler(
96 ws: WebSocketUpgrade,
97 State(state): State<WebSocketState>,
98) -> impl IntoResponse {
99 ws.on_upgrade(move |socket| handle_socket(socket, state))
100}
101
102async fn handle_socket(socket: WebSocket, state: WebSocketState) {
103 let mut session = Session::new();
104 let session_id = session.id.clone();
105 info!(session_id = %session_id, "WebSocket client connected");
106
107 let (_notifier, mut notification_rx) = NotificationSender::channel(state.buffer_size);
109
110 let (mut ws_tx, mut ws_rx) = socket.split();
111
112 let (tx, mut rx) = mpsc::channel::<String>(state.buffer_size);
114 let tx_for_notifications = tx.clone();
115
116 let notification_task = tokio::spawn(async move {
118 while let Some(notification) = notification_rx.recv().await {
119 let msg = JsonRpcMessage::Notification(notification);
120 if let Ok(json) = serde_json::to_string(&msg) {
121 if tx_for_notifications.send(json).await.is_err() {
122 break;
123 }
124 }
125 }
126 });
127
128 let send_task = tokio::spawn(async move {
130 while let Some(msg) = rx.recv().await {
131 if ws_tx.send(Message::Text(msg)).await.is_err() {
132 break;
133 }
134 }
135 });
136
137 let server = state.server.clone();
139 while let Some(msg) = ws_rx.next().await {
140 match msg {
141 Ok(Message::Text(text)) => {
142 debug!(session_id = %session_id, "Received message");
143 match serde_json::from_str::<JsonRpcMessage>(&text) {
144 Ok(request) => {
145 if let Some(response) = server.handle_message(request, &mut session).await {
146 match serde_json::to_string(&response) {
147 Ok(json) => {
148 if tx.send(json).await.is_err() {
149 error!(session_id = %session_id, "Failed to send response");
150 break;
151 }
152 }
153 Err(e) => {
154 error!("Failed to serialize response: {}", e);
155 }
156 }
157 }
158 }
159 Err(e) => {
160 warn!(session_id = %session_id, error = %e, "Invalid JSON-RPC message");
161 }
162 }
163 }
164 Ok(Message::Binary(_)) => {
165 warn!(session_id = %session_id, "Received binary message (not supported)");
166 }
167 Ok(Message::Ping(_)) => {
168 debug!(session_id = %session_id, "Received ping");
169 }
171 Ok(Message::Pong(_)) => {
172 debug!(session_id = %session_id, "Received pong");
173 }
174 Ok(Message::Close(_)) => {
175 info!(session_id = %session_id, "Client disconnected");
176 break;
177 }
178 Err(e) => {
179 error!(session_id = %session_id, error = %e, "WebSocket error");
180 break;
181 }
182 }
183 }
184
185 notification_task.abort();
187 send_task.abort();
188 info!(session_id = %session_id, "WebSocket session ended");
189}
190
191pub struct WebSocketTransport {
193 server: Arc<McpServer>,
194 buffer_size: usize,
195}
196
197impl WebSocketTransport {
198 pub fn new(server: McpServer, buffer_size: usize) -> Self {
200 Self {
201 server: Arc::new(server),
202 buffer_size,
203 }
204 }
205
206 pub async fn handle_connection(&self, socket: WebSocket) {
208 let mut session = Session::new();
209 let session_id = session.id.clone();
210 info!(session_id = %session_id, "WebSocket client connected");
211
212 let (_notifier, mut notification_rx) = NotificationSender::channel(self.buffer_size);
213 let (mut ws_tx, mut ws_rx) = socket.split();
214
215 let (response_tx, mut response_rx) = mpsc::channel::<String>(self.buffer_size);
217 let tx_for_notifications = response_tx.clone();
218
219 let notification_task = tokio::spawn(async move {
221 while let Some(notification) = notification_rx.recv().await {
222 let msg = JsonRpcMessage::Notification(notification);
223 if let Ok(json) = serde_json::to_string(&msg) {
224 if tx_for_notifications.send(json).await.is_err() {
225 break;
226 }
227 }
228 }
229 });
230
231 let send_task = tokio::spawn(async move {
233 while let Some(msg) = response_rx.recv().await {
234 if ws_tx.send(Message::Text(msg)).await.is_err() {
235 break;
236 }
237 }
238 });
239
240 let server = self.server.clone();
242 while let Some(msg) = ws_rx.next().await {
243 match msg {
244 Ok(Message::Text(text)) => match serde_json::from_str::<JsonRpcMessage>(&text) {
245 Ok(request) => {
246 if let Some(response) = server.handle_message(request, &mut session).await {
247 if let Ok(json) = serde_json::to_string(&response) {
248 let _ = response_tx.send(json).await;
249 }
250 }
251 }
252 Err(e) => {
253 warn!(session_id = %session_id, error = %e, "Invalid message");
254 }
255 },
256 Ok(Message::Close(_)) | Err(_) => break,
257 _ => {}
258 }
259 }
260
261 notification_task.abort();
262 send_task.abort();
263 info!(session_id = %session_id, "WebSocket session ended");
264 }
265}