1use async_trait::async_trait;
7use futures_util::{
8 sink::SinkExt,
9 stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14 net::{TcpListener, TcpStream},
15 sync::{Mutex, RwLock, broadcast, mpsc},
16 time::timeout,
17};
18use tokio_tungstenite::{
19 MaybeTlsStream, WebSocketStream, accept_async, connect_async, tungstenite::Message,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27type RequestHandler = Arc<
29 RwLock<
30 Option<
31 Arc<
32 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
33 + Send
34 + Sync,
35 >,
36 >,
37 >,
38>;
39
40pub struct WebSocketClientTransport {
49 ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
50 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
51 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
52 config: TransportConfig,
53 state: Arc<RwLock<ConnectionState>>,
54 url: String,
55 message_handler: Option<tokio::task::JoinHandle<()>>,
56}
57
58impl WebSocketClientTransport {
59 pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
67 Self::with_config(url, TransportConfig::default()).await
68 }
69
70 pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
79 let url_str = url.as_ref();
80
81 let _url_parsed = Url::parse(url_str)
83 .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {e}")))?;
84
85 tracing::debug!("Connecting to WebSocket: {}", url_str);
86
87 let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
89
90 let (ws_stream, _) = timeout(connect_timeout, connect_async(url_str))
91 .await
92 .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
93 .map_err(|e| McpError::WebSocket(format!("Failed to connect: {e}")))?;
94
95 let (ws_sender, ws_receiver) = ws_stream.split();
96
97 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
98 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
99 let state = Arc::new(RwLock::new(ConnectionState::Connected));
100
101 let message_handler = tokio::spawn(Self::handle_messages(
103 ws_receiver,
104 pending_requests.clone(),
105 notification_sender,
106 state.clone(),
107 ));
108
109 Ok(Self {
110 ws_sender: Some(ws_sender),
111 pending_requests,
112 notification_receiver: Some(notification_receiver),
113 config,
114 state,
115 url: url_str.to_string(),
116 message_handler: Some(message_handler),
117 })
118 }
119
120 async fn handle_messages(
121 mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
122 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
123 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
124 state: Arc<RwLock<ConnectionState>>,
125 ) {
126 while let Some(message) = ws_receiver.next().await {
127 match message {
128 Ok(Message::Text(text)) => {
129 tracing::trace!("Received WebSocket message: {}", text);
130
131 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
133 let mut pending = pending_requests.lock().await;
134 if let Some(sender) = pending.remove(&response.id) {
135 if sender.send(response).is_err() {
136 tracing::warn!("Failed to send response to waiting request");
137 }
138 } else {
139 tracing::warn!(
140 "Received response for unknown request ID: {:?}",
141 response.id
142 );
143 }
144 }
145 else if let Ok(notification) =
147 serde_json::from_str::<JsonRpcNotification>(&text)
148 {
149 if notification_sender.send(notification).is_err() {
150 tracing::debug!("Notification receiver dropped");
151 break;
152 }
153 } else {
154 tracing::warn!("Failed to parse WebSocket message: {}", text);
155 }
156 }
157 Ok(Message::Close(_)) => {
158 tracing::info!("WebSocket connection closed");
159 *state.write().await = ConnectionState::Disconnected;
160 break;
161 }
162 Ok(Message::Ping(_data)) => {
163 tracing::trace!("Received WebSocket ping");
164 }
166 Ok(Message::Pong(_)) => {
167 tracing::trace!("Received WebSocket pong");
168 }
169 Ok(Message::Binary(_)) => {
170 tracing::warn!("Received unexpected binary WebSocket message");
171 }
172 Ok(Message::Frame(_)) => {
173 tracing::trace!("Received WebSocket frame (internal)");
174 }
176 Err(e) => {
177 tracing::error!("WebSocket error: {}", e);
178 *state.write().await = ConnectionState::Error(e.to_string());
179 break;
180 }
181 }
182 }
183
184 tracing::debug!("WebSocket message handler exiting");
185 }
186
187 async fn send_message(&mut self, message: Message) -> McpResult<()> {
188 if let Some(ref mut sender) = self.ws_sender {
189 sender
190 .send(message)
191 .await
192 .map_err(|e| McpError::WebSocket(format!("Failed to send message: {e}")))?;
193 } else {
194 return Err(McpError::WebSocket("WebSocket not connected".to_string()));
195 }
196 Ok(())
197 }
198}
199
200#[async_trait]
201impl Transport for WebSocketClientTransport {
202 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
203 let (sender, receiver) = tokio::sync::oneshot::channel();
204
205 {
207 let mut pending = self.pending_requests.lock().await;
208 pending.insert(request.id.clone(), sender);
209 }
210
211 let request_text =
213 serde_json::to_string(&request).map_err(|e| McpError::Serialization(e.to_string()))?;
214
215 tracing::trace!("Sending WebSocket request: {}", request_text);
216
217 self.send_message(Message::Text(request_text.into()))
218 .await?;
219
220 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
222
223 let response = timeout(timeout_duration, receiver)
224 .await
225 .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
226 .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
227
228 Ok(response)
229 }
230
231 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
232 let notification_text = serde_json::to_string(¬ification)
233 .map_err(|e| McpError::Serialization(e.to_string()))?;
234
235 tracing::trace!("Sending WebSocket notification: {}", notification_text);
236
237 self.send_message(Message::Text(notification_text.into()))
238 .await
239 }
240
241 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
242 if let Some(ref mut receiver) = self.notification_receiver {
243 match receiver.try_recv() {
244 Ok(notification) => Ok(Some(notification)),
245 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
246 Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
247 "Notification channel disconnected".to_string(),
248 )),
249 }
250 } else {
251 Ok(None)
252 }
253 }
254
255 async fn close(&mut self) -> McpResult<()> {
256 tracing::debug!("Closing WebSocket connection");
257
258 *self.state.write().await = ConnectionState::Closing;
259
260 if let Some(ref mut sender) = self.ws_sender {
262 let _ = sender.send(Message::Close(None)).await;
263 }
264
265 if let Some(handle) = self.message_handler.take() {
267 handle.abort();
268 }
269
270 self.ws_sender = None;
271 self.notification_receiver = None;
272
273 *self.state.write().await = ConnectionState::Disconnected;
274
275 Ok(())
276 }
277
278 fn is_connected(&self) -> bool {
279 self.ws_sender.is_some()
281 }
282
283 fn connection_info(&self) -> String {
284 format!("WebSocket transport (url: {})", self.url)
285 }
286}
287
288struct WebSocketConnection {
294 sender: SplitSink<WebSocketStream<TcpStream>, Message>,
295 _id: String, }
297
298pub struct WebSocketServerTransport {
303 bind_addr: String,
304 config: TransportConfig, clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
306 request_handler: RequestHandler,
307 server_handle: Option<tokio::task::JoinHandle<()>>,
308 running: Arc<RwLock<bool>>,
309 shutdown_sender: Option<broadcast::Sender<()>>,
310}
311
312impl WebSocketServerTransport {
313 pub fn new<S: Into<String>>(bind_addr: S) -> Self {
321 Self::with_config(bind_addr, TransportConfig::default())
322 }
323
324 pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
333 let (shutdown_sender, _) = broadcast::channel(1);
334
335 Self {
336 bind_addr: bind_addr.into(),
337 config,
338 clients: Arc::new(RwLock::new(HashMap::new())),
339 request_handler: Arc::new(RwLock::new(None)),
340 server_handle: None,
341 running: Arc::new(RwLock::new(false)),
342 shutdown_sender: Some(shutdown_sender),
343 }
344 }
345
346 pub async fn set_request_handler<F>(&mut self, handler: F)
351 where
352 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
353 + Send
354 + Sync
355 + 'static,
356 {
357 let mut request_handler = self.request_handler.write().await;
358 *request_handler = Some(Arc::new(handler));
359 }
360
361 pub fn config(&self) -> &TransportConfig {
363 &self.config
364 }
365
366 pub fn max_message_size(&self) -> Option<usize> {
368 self.config.max_message_size
369 }
370
371 async fn handle_client_connection(
372 stream: TcpStream,
373 clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
374 request_handler: RequestHandler,
375 mut shutdown_receiver: broadcast::Receiver<()>,
376 ) {
377 let client_id = uuid::Uuid::new_v4().to_string();
378
379 let ws_stream = match accept_async(stream).await {
380 Ok(ws) => ws,
381 Err(e) => {
382 tracing::error!("Failed to accept WebSocket connection: {}", e);
383 return;
384 }
385 };
386
387 tracing::info!("New WebSocket client connected: {}", client_id);
388
389 let (ws_sender, mut ws_receiver) = ws_stream.split();
390
391 {
393 let mut clients_guard = clients.write().await;
394 clients_guard.insert(
395 client_id.clone(),
396 WebSocketConnection {
397 sender: ws_sender,
398 _id: client_id.clone(),
399 },
400 );
401 }
402
403 loop {
405 tokio::select! {
406 message = ws_receiver.next() => {
407 match message {
408 Some(Ok(Message::Text(text))) => {
409 tracing::trace!("Received message from {}: {}", client_id, text);
410
411 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
413 let handler_guard = request_handler.read().await;
414 if let Some(ref handler) = *handler_guard {
415 let response_rx = handler(request.clone());
416 drop(handler_guard);
417
418 match response_rx.await {
419 Ok(response) => {
420 let response_text = match serde_json::to_string(&response) {
421 Ok(text) => text,
422 Err(e) => {
423 tracing::error!("Failed to serialize response: {}", e);
424 continue;
425 }
426 };
427
428 let mut clients_guard = clients.write().await;
430 if let Some(client) = clients_guard.get_mut(&client_id) {
431 if let Err(e) = client.sender.send(Message::Text(response_text.into())).await {
432 tracing::error!("Failed to send response to client {}: {}", client_id, e);
433 break;
434 }
435 }
436 }
437 Err(_) => {
438 tracing::error!("Request handler channel closed for client {}", client_id);
439 }
440 }
441 } else {
442 tracing::warn!("No request handler configured for client {}", client_id);
443 }
444 }
445 else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
447 tracing::trace!("Received notification from client {}", client_id);
448 } else {
450 tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
451 }
452 }
453 Some(Ok(Message::Close(_))) => {
454 tracing::info!("Client {} disconnected", client_id);
455 break;
456 }
457 Some(Ok(Message::Ping(data))) => {
458 tracing::trace!("Received ping from client {}", client_id);
459 let mut clients_guard = clients.write().await;
460 if let Some(client) = clients_guard.get_mut(&client_id) {
461 if let Err(e) = client.sender.send(Message::Pong(data)).await {
462 tracing::error!("Failed to send pong to client {}: {}", client_id, e);
463 break;
464 }
465 }
466 }
467 Some(Ok(Message::Pong(_))) => {
468 tracing::trace!("Received pong from client {}", client_id);
469 }
470 Some(Ok(Message::Binary(_))) => {
471 tracing::warn!("Received unexpected binary message from client {}", client_id);
472 }
473 Some(Ok(Message::Frame(_))) => {
474 tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
475 }
477 Some(Err(e)) => {
478 tracing::error!("WebSocket error for client {}: {}", client_id, e);
479 break;
480 }
481 None => {
482 tracing::info!("WebSocket stream ended for client {}", client_id);
483 break;
484 }
485 }
486 }
487 _ = shutdown_receiver.recv() => {
488 tracing::info!("Shutting down connection for client {}", client_id);
489 break;
490 }
491 }
492 }
493
494 {
496 let mut clients_guard = clients.write().await;
497 clients_guard.remove(&client_id);
498 }
499
500 tracing::info!("Client {} connection handler exiting", client_id);
501 }
502}
503
504#[async_trait]
505impl ServerTransport for WebSocketServerTransport {
506 async fn start(&mut self) -> McpResult<()> {
507 tracing::info!("Starting WebSocket server on {}", self.bind_addr);
508
509 let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
510 McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
511 })?;
512
513 let clients = self.clients.clone();
514 let request_handler = self.request_handler.clone();
515 let running = self.running.clone();
516 let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
517
518 *running.write().await = true;
519
520 let server_handle = tokio::spawn(async move {
521 let mut shutdown_receiver = shutdown_sender.subscribe();
522
523 loop {
524 tokio::select! {
525 result = listener.accept() => {
526 match result {
527 Ok((stream, addr)) => {
528 tracing::debug!("New connection from: {}", addr);
529
530 tokio::spawn(Self::handle_client_connection(
531 stream,
532 clients.clone(),
533 request_handler.clone(),
534 shutdown_sender.subscribe(),
535 ));
536 }
537 Err(e) => {
538 tracing::error!("Failed to accept connection: {}", e);
539 }
540 }
541 }
542 _ = shutdown_receiver.recv() => {
543 tracing::info!("WebSocket server shutting down");
544 break;
545 }
546 }
547 }
548 });
549
550 self.server_handle = Some(server_handle);
551
552 tracing::info!(
553 "WebSocket server started successfully on {}",
554 self.bind_addr
555 );
556 Ok(())
557 }
558
559 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
560 let notification_text = serde_json::to_string(¬ification)
561 .map_err(|e| McpError::Serialization(e.to_string()))?;
562
563 let mut clients_guard = self.clients.write().await;
564 let mut disconnected_clients = Vec::new();
565
566 for (client_id, client) in clients_guard.iter_mut() {
567 if let Err(e) = client
568 .sender
569 .send(Message::Text(notification_text.clone().into()))
570 .await
571 {
572 tracing::error!("Failed to send notification to client {}: {}", client_id, e);
573 disconnected_clients.push(client_id.clone());
574 }
575 }
576
577 for client_id in disconnected_clients {
579 clients_guard.remove(&client_id);
580 }
581
582 Ok(())
583 }
584
585 async fn stop(&mut self) -> McpResult<()> {
586 tracing::info!("Stopping WebSocket server");
587
588 *self.running.write().await = false;
589
590 if let Some(ref sender) = self.shutdown_sender {
592 let _ = sender.send(());
593 }
594
595 if let Some(handle) = self.server_handle.take() {
597 handle.abort();
598 }
599
600 let mut clients_guard = self.clients.write().await;
602 for (client_id, client) in clients_guard.iter_mut() {
603 tracing::debug!("Closing connection for client {}", client_id);
604 let _ = client.sender.send(Message::Close(None)).await;
605 }
606 clients_guard.clear();
607
608 Ok(())
609 }
610
611 fn is_running(&self) -> bool {
612 self.server_handle.is_some()
614 }
615
616 fn server_info(&self) -> String {
617 format!("WebSocket server transport (bind: {})", self.bind_addr)
618 }
619
620 fn set_request_handler(&mut self, handler: crate::transport::traits::ServerRequestHandler) {
621 let _ws_handler = Arc::new(move |request: JsonRpcRequest| {
623 let (tx, rx) = tokio::sync::oneshot::channel();
624 let handler_future = handler(request);
625 tokio::spawn(async move {
626 let result = handler_future.await;
627 let _ = tx.send(result.unwrap_or_else(|e| JsonRpcResponse {
628 jsonrpc: "2.0".to_string(),
629 id: serde_json::Value::Null,
630 result: Some(serde_json::json!({
631 "error": {
632 "code": -32603,
633 "message": e.to_string()
634 }
635 })),
636 }));
637 });
638 rx
639 });
640
641 tokio::spawn(async move {
643 });
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_websocket_server_creation() {
655 let transport = WebSocketServerTransport::new("127.0.0.1:0");
656 assert_eq!(transport.bind_addr, "127.0.0.1:0");
657 assert!(!transport.is_running());
658 }
659
660 #[test]
661 fn test_websocket_server_with_config() {
662 let mut config = TransportConfig::default();
663 config.max_message_size = Some(64 * 1024);
664
665 let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
666 assert_eq!(transport.bind_addr, "0.0.0.0:9090");
667 assert_eq!(transport.config.max_message_size, Some(64 * 1024));
668 }
669
670 #[tokio::test]
671 async fn test_websocket_client_invalid_url() {
672 let result = WebSocketClientTransport::new("invalid-url").await;
673 assert!(result.is_err());
674
675 if let Err(McpError::WebSocket(msg)) = result {
676 assert!(msg.contains("Invalid WebSocket URL"));
677 } else {
678 panic!("Expected WebSocket error");
679 }
680 }
681
682 #[tokio::test]
683 async fn test_websocket_client_connection_info() {
684 let url = "ws://localhost:9999/test";
686 if let Ok(transport) = WebSocketClientTransport::new(url).await {
687 let info = transport.connection_info();
688 assert!(info.contains("localhost:9999"));
689 }
690 }
692}