Skip to main content

mcp_kit/transport/
websocket.rs

1//! WebSocket transport for MCP servers.
2//!
3//! Provides bidirectional communication over WebSocket connections,
4//! suitable for browser-based clients and real-time applications.
5//!
6//! # Example
7//! ```rust,ignore
8//! use mcp_kit::prelude::*;
9//!
10//! #[tokio::main]
11//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
12//!     let server = McpServer::builder()
13//!         .name("ws-server")
14//!         .version("1.0.0")
15//!         .build();
16//!
17//!     server.serve_websocket("0.0.0.0:3000").await?;
18//!     Ok(())
19//! }
20//! ```
21
22use std::net::SocketAddr;
23use std::sync::Arc;
24
25use axum::{
26    extract::{
27        ws::{Message, WebSocket, WebSocketUpgrade},
28        State,
29    },
30    response::IntoResponse,
31    routing::get,
32    Router,
33};
34use futures_util::{SinkExt, StreamExt};
35use tokio::sync::mpsc;
36use tracing::{debug, error, info, warn};
37
38use crate::error::McpError;
39use crate::protocol::JsonRpcMessage;
40use crate::server::{session::Session, McpServer, NotificationSender};
41
42/// Extension trait for serving MCP over WebSocket.
43pub trait ServeWebSocketExt {
44    /// Start serving MCP over WebSocket on the given address.
45    fn serve_websocket(
46        self,
47        addr: impl Into<SocketAddr> + Send,
48    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
49
50    /// Start serving MCP over WebSocket with a custom notification channel buffer size.
51    fn serve_websocket_with_buffer(
52        self,
53        addr: impl Into<SocketAddr> + Send,
54        buffer_size: usize,
55    ) -> impl std::future::Future<Output = Result<(), McpError>> + Send;
56}
57
58impl ServeWebSocketExt for McpServer {
59    async fn serve_websocket(self, addr: impl Into<SocketAddr> + Send) -> Result<(), McpError> {
60        self.serve_websocket_with_buffer(addr, 32).await
61    }
62
63    async fn serve_websocket_with_buffer(
64        self,
65        addr: impl Into<SocketAddr> + Send,
66        buffer_size: usize,
67    ) -> Result<(), McpError> {
68        let addr = addr.into();
69        let state = WebSocketState {
70            server: Arc::new(self),
71            buffer_size,
72        };
73
74        let app = Router::new()
75            .route("/ws", get(ws_handler))
76            .route("/mcp", get(ws_handler))
77            .route("/health", get(|| async { "OK" }))
78            .with_state(state);
79
80        info!("Starting WebSocket MCP server on {}", addr);
81
82        let listener = tokio::net::TcpListener::bind(addr).await?;
83        axum::serve(listener, app).await?;
84
85        Ok(())
86    }
87}
88
89#[derive(Clone)]
90struct WebSocketState {
91    server: Arc<McpServer>,
92    buffer_size: usize,
93}
94
95async fn ws_handler(
96    ws: WebSocketUpgrade,
97    State(state): State<WebSocketState>,
98) -> impl IntoResponse {
99    ws.on_upgrade(move |socket| handle_socket(socket, state))
100}
101
102async fn handle_socket(socket: WebSocket, state: WebSocketState) {
103    let mut session = Session::new();
104    let session_id = session.id.clone();
105    info!(session_id = %session_id, "WebSocket client connected");
106
107    // Create notification channel for this session
108    let (_notifier, mut notification_rx) = NotificationSender::channel(state.buffer_size);
109
110    let (mut ws_tx, mut ws_rx) = socket.split();
111
112    // Channel for sending messages back to the client
113    let (tx, mut rx) = mpsc::channel::<String>(state.buffer_size);
114    let tx_for_notifications = tx.clone();
115
116    // Task to forward notifications to the client
117    let notification_task = tokio::spawn(async move {
118        while let Some(notification) = notification_rx.recv().await {
119            let msg = JsonRpcMessage::Notification(notification);
120            if let Ok(json) = serde_json::to_string(&msg) {
121                if tx_for_notifications.send(json).await.is_err() {
122                    break;
123                }
124            }
125        }
126    });
127
128    // Task to send messages to WebSocket
129    let send_task = tokio::spawn(async move {
130        while let Some(msg) = rx.recv().await {
131            if ws_tx.send(Message::Text(msg)).await.is_err() {
132                break;
133            }
134        }
135    });
136
137    // Main message loop
138    let server = state.server.clone();
139    while let Some(msg) = ws_rx.next().await {
140        match msg {
141            Ok(Message::Text(text)) => {
142                debug!(session_id = %session_id, "Received message");
143                match serde_json::from_str::<JsonRpcMessage>(&text) {
144                    Ok(request) => {
145                        if let Some(response) = server.handle_message(request, &mut session).await {
146                            match serde_json::to_string(&response) {
147                                Ok(json) => {
148                                    if tx.send(json).await.is_err() {
149                                        error!(session_id = %session_id, "Failed to send response");
150                                        break;
151                                    }
152                                }
153                                Err(e) => {
154                                    error!("Failed to serialize response: {}", e);
155                                }
156                            }
157                        }
158                    }
159                    Err(e) => {
160                        warn!(session_id = %session_id, error = %e, "Invalid JSON-RPC message");
161                    }
162                }
163            }
164            Ok(Message::Binary(_)) => {
165                warn!(session_id = %session_id, "Received binary message (not supported)");
166            }
167            Ok(Message::Ping(_)) => {
168                debug!(session_id = %session_id, "Received ping");
169                // Axum handles pong automatically
170            }
171            Ok(Message::Pong(_)) => {
172                debug!(session_id = %session_id, "Received pong");
173            }
174            Ok(Message::Close(_)) => {
175                info!(session_id = %session_id, "Client disconnected");
176                break;
177            }
178            Err(e) => {
179                error!(session_id = %session_id, error = %e, "WebSocket error");
180                break;
181            }
182        }
183    }
184
185    // Clean up
186    notification_task.abort();
187    send_task.abort();
188    info!(session_id = %session_id, "WebSocket session ended");
189}
190
191/// A more complete WebSocket transport with proper bidirectional communication.
192pub struct WebSocketTransport {
193    server: Arc<McpServer>,
194    buffer_size: usize,
195}
196
197impl WebSocketTransport {
198    /// Create a new WebSocket transport.
199    pub fn new(server: McpServer, buffer_size: usize) -> Self {
200        Self {
201            server: Arc::new(server),
202            buffer_size,
203        }
204    }
205
206    /// Handle a WebSocket connection.
207    pub async fn handle_connection(&self, socket: WebSocket) {
208        let mut session = Session::new();
209        let session_id = session.id.clone();
210        info!(session_id = %session_id, "WebSocket client connected");
211
212        let (_notifier, mut notification_rx) = NotificationSender::channel(self.buffer_size);
213        let (mut ws_tx, mut ws_rx) = socket.split();
214
215        // Channel for sending responses back to the client
216        let (response_tx, mut response_rx) = mpsc::channel::<String>(self.buffer_size);
217        let tx_for_notifications = response_tx.clone();
218
219        // Task to forward notifications
220        let notification_task = tokio::spawn(async move {
221            while let Some(notification) = notification_rx.recv().await {
222                let msg = JsonRpcMessage::Notification(notification);
223                if let Ok(json) = serde_json::to_string(&msg) {
224                    if tx_for_notifications.send(json).await.is_err() {
225                        break;
226                    }
227                }
228            }
229        });
230
231        // Task to send messages to the client
232        let send_task = tokio::spawn(async move {
233            while let Some(msg) = response_rx.recv().await {
234                if ws_tx.send(Message::Text(msg)).await.is_err() {
235                    break;
236                }
237            }
238        });
239
240        // Main message loop
241        let server = self.server.clone();
242        while let Some(msg) = ws_rx.next().await {
243            match msg {
244                Ok(Message::Text(text)) => match serde_json::from_str::<JsonRpcMessage>(&text) {
245                    Ok(request) => {
246                        if let Some(response) = server.handle_message(request, &mut session).await {
247                            if let Ok(json) = serde_json::to_string(&response) {
248                                let _ = response_tx.send(json).await;
249                            }
250                        }
251                    }
252                    Err(e) => {
253                        warn!(session_id = %session_id, error = %e, "Invalid message");
254                    }
255                },
256                Ok(Message::Close(_)) | Err(_) => break,
257                _ => {}
258            }
259        }
260
261        notification_task.abort();
262        send_task.abort();
263        info!(session_id = %session_id, "WebSocket session ended");
264    }
265}