1use async_trait::async_trait;
7use futures_util::{
8 sink::SinkExt,
9 stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14 net::{TcpListener, TcpStream},
15 sync::{broadcast, mpsc, Mutex, RwLock},
16 time::timeout,
17};
18use tokio_tungstenite::{
19 accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27pub struct WebSocketClientTransport {
36 ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
37 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
38 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
39 config: TransportConfig,
40 state: Arc<RwLock<ConnectionState>>,
41 url: String,
42 message_handler: Option<tokio::task::JoinHandle<()>>,
43}
44
45impl WebSocketClientTransport {
46 pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
54 Self::with_config(url, TransportConfig::default()).await
55 }
56
57 pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
66 let url_str = url.as_ref();
67
68 let _url_parsed = Url::parse(url_str)
70 .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {}", e)))?;
71
72 tracing::debug!("Connecting to WebSocket: {}", url_str);
73
74 let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
76
77 let (ws_stream, _) = timeout(connect_timeout, connect_async(url_str))
78 .await
79 .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
80 .map_err(|e| McpError::WebSocket(format!("Failed to connect: {}", e)))?;
81
82 let (ws_sender, ws_receiver) = ws_stream.split();
83
84 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
85 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
86 let state = Arc::new(RwLock::new(ConnectionState::Connected));
87
88 let message_handler = tokio::spawn(Self::handle_messages(
90 ws_receiver,
91 pending_requests.clone(),
92 notification_sender,
93 state.clone(),
94 ));
95
96 Ok(Self {
97 ws_sender: Some(ws_sender),
98 pending_requests,
99 notification_receiver: Some(notification_receiver),
100 config,
101 state,
102 url: url_str.to_string(),
103 message_handler: Some(message_handler),
104 })
105 }
106
107 async fn handle_messages(
108 mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
109 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
110 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111 state: Arc<RwLock<ConnectionState>>,
112 ) {
113 while let Some(message) = ws_receiver.next().await {
114 match message {
115 Ok(Message::Text(text)) => {
116 tracing::trace!("Received WebSocket message: {}", text);
117
118 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
120 let mut pending = pending_requests.lock().await;
121 if let Some(sender) = pending.remove(&response.id) {
122 if let Err(_) = sender.send(response) {
123 tracing::warn!("Failed to send response to waiting request");
124 }
125 } else {
126 tracing::warn!(
127 "Received response for unknown request ID: {:?}",
128 response.id
129 );
130 }
131 }
132 else if let Ok(notification) =
134 serde_json::from_str::<JsonRpcNotification>(&text)
135 {
136 if let Err(_) = notification_sender.send(notification) {
137 tracing::debug!("Notification receiver dropped");
138 break;
139 }
140 } else {
141 tracing::warn!("Failed to parse WebSocket message: {}", text);
142 }
143 }
144 Ok(Message::Close(_)) => {
145 tracing::info!("WebSocket connection closed");
146 *state.write().await = ConnectionState::Disconnected;
147 break;
148 }
149 Ok(Message::Ping(_data)) => {
150 tracing::trace!("Received WebSocket ping");
151 }
153 Ok(Message::Pong(_)) => {
154 tracing::trace!("Received WebSocket pong");
155 }
156 Ok(Message::Binary(_)) => {
157 tracing::warn!("Received unexpected binary WebSocket message");
158 }
159 Ok(Message::Frame(_)) => {
160 tracing::trace!("Received WebSocket frame (internal)");
161 }
163 Err(e) => {
164 tracing::error!("WebSocket error: {}", e);
165 *state.write().await = ConnectionState::Error(e.to_string());
166 break;
167 }
168 }
169 }
170
171 tracing::debug!("WebSocket message handler exiting");
172 }
173
174 async fn send_message(&mut self, message: Message) -> McpResult<()> {
175 if let Some(ref mut sender) = self.ws_sender {
176 sender
177 .send(message)
178 .await
179 .map_err(|e| McpError::WebSocket(format!("Failed to send message: {}", e)))?;
180 } else {
181 return Err(McpError::WebSocket("WebSocket not connected".to_string()));
182 }
183 Ok(())
184 }
185}
186
187#[async_trait]
188impl Transport for WebSocketClientTransport {
189 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
190 let (sender, receiver) = tokio::sync::oneshot::channel();
191
192 {
194 let mut pending = self.pending_requests.lock().await;
195 pending.insert(request.id.clone(), sender);
196 }
197
198 let request_text =
200 serde_json::to_string(&request).map_err(|e| McpError::Serialization(e.to_string()))?;
201
202 tracing::trace!("Sending WebSocket request: {}", request_text);
203
204 self.send_message(Message::Text(request_text.into()))
205 .await?;
206
207 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
209
210 let response = timeout(timeout_duration, receiver)
211 .await
212 .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
213 .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
214
215 Ok(response)
216 }
217
218 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
219 let notification_text = serde_json::to_string(¬ification)
220 .map_err(|e| McpError::Serialization(e.to_string()))?;
221
222 tracing::trace!("Sending WebSocket notification: {}", notification_text);
223
224 self.send_message(Message::Text(notification_text.into()))
225 .await
226 }
227
228 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
229 if let Some(ref mut receiver) = self.notification_receiver {
230 match receiver.try_recv() {
231 Ok(notification) => Ok(Some(notification)),
232 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
233 Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
234 "Notification channel disconnected".to_string(),
235 )),
236 }
237 } else {
238 Ok(None)
239 }
240 }
241
242 async fn close(&mut self) -> McpResult<()> {
243 tracing::debug!("Closing WebSocket connection");
244
245 *self.state.write().await = ConnectionState::Closing;
246
247 if let Some(ref mut sender) = self.ws_sender {
249 let _ = sender.send(Message::Close(None)).await;
250 }
251
252 if let Some(handle) = self.message_handler.take() {
254 handle.abort();
255 }
256
257 self.ws_sender = None;
258 self.notification_receiver = None;
259
260 *self.state.write().await = ConnectionState::Disconnected;
261
262 Ok(())
263 }
264
265 fn is_connected(&self) -> bool {
266 self.ws_sender.is_some()
268 }
269
270 fn connection_info(&self) -> String {
271 format!("WebSocket transport (url: {})", self.url)
272 }
273}
274
275struct WebSocketConnection {
281 sender: SplitSink<WebSocketStream<TcpStream>, Message>,
282 _id: String, }
284
285pub struct WebSocketServerTransport {
290 bind_addr: String,
291 config: TransportConfig, clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
293 request_handler: Arc<
294 RwLock<
295 Option<
296 Arc<
297 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
298 + Send
299 + Sync,
300 >,
301 >,
302 >,
303 >,
304 server_handle: Option<tokio::task::JoinHandle<()>>,
305 running: Arc<RwLock<bool>>,
306 shutdown_sender: Option<broadcast::Sender<()>>,
307}
308
309impl WebSocketServerTransport {
310 pub fn new<S: Into<String>>(bind_addr: S) -> Self {
318 Self::with_config(bind_addr, TransportConfig::default())
319 }
320
321 pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
330 let (shutdown_sender, _) = broadcast::channel(1);
331
332 Self {
333 bind_addr: bind_addr.into(),
334 config,
335 clients: Arc::new(RwLock::new(HashMap::new())),
336 request_handler: Arc::new(RwLock::new(None)),
337 server_handle: None,
338 running: Arc::new(RwLock::new(false)),
339 shutdown_sender: Some(shutdown_sender),
340 }
341 }
342
343 pub async fn set_request_handler<F>(&mut self, handler: F)
348 where
349 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
350 + Send
351 + Sync
352 + 'static,
353 {
354 let mut request_handler = self.request_handler.write().await;
355 *request_handler = Some(Arc::new(handler));
356 }
357
358 pub fn config(&self) -> &TransportConfig {
360 &self.config
361 }
362
363 pub fn max_message_size(&self) -> Option<usize> {
365 self.config.max_message_size
366 }
367
368 async fn handle_client_connection(
369 stream: TcpStream,
370 clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
371 request_handler: Arc<
372 RwLock<
373 Option<
374 Arc<
375 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
376 + Send
377 + Sync,
378 >,
379 >,
380 >,
381 >,
382 mut shutdown_receiver: broadcast::Receiver<()>,
383 ) {
384 let client_id = uuid::Uuid::new_v4().to_string();
385
386 let ws_stream = match accept_async(stream).await {
387 Ok(ws) => ws,
388 Err(e) => {
389 tracing::error!("Failed to accept WebSocket connection: {}", e);
390 return;
391 }
392 };
393
394 tracing::info!("New WebSocket client connected: {}", client_id);
395
396 let (ws_sender, mut ws_receiver) = ws_stream.split();
397
398 {
400 let mut clients_guard = clients.write().await;
401 clients_guard.insert(
402 client_id.clone(),
403 WebSocketConnection {
404 sender: ws_sender,
405 _id: client_id.clone(),
406 },
407 );
408 }
409
410 loop {
412 tokio::select! {
413 message = ws_receiver.next() => {
414 match message {
415 Some(Ok(Message::Text(text))) => {
416 tracing::trace!("Received message from {}: {}", client_id, text);
417
418 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
420 let handler_guard = request_handler.read().await;
421 if let Some(ref handler) = *handler_guard {
422 let response_rx = handler(request.clone());
423 drop(handler_guard);
424
425 match response_rx.await {
426 Ok(response) => {
427 let response_text = match serde_json::to_string(&response) {
428 Ok(text) => text,
429 Err(e) => {
430 tracing::error!("Failed to serialize response: {}", e);
431 continue;
432 }
433 };
434
435 let mut clients_guard = clients.write().await;
437 if let Some(client) = clients_guard.get_mut(&client_id) {
438 if let Err(e) = client.sender.send(Message::Text(response_text.into())).await {
439 tracing::error!("Failed to send response to client {}: {}", client_id, e);
440 break;
441 }
442 }
443 }
444 Err(_) => {
445 tracing::error!("Request handler channel closed for client {}", client_id);
446 }
447 }
448 } else {
449 tracing::warn!("No request handler configured for client {}", client_id);
450 }
451 }
452 else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
454 tracing::trace!("Received notification from client {}", client_id);
455 } else {
457 tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
458 }
459 }
460 Some(Ok(Message::Close(_))) => {
461 tracing::info!("Client {} disconnected", client_id);
462 break;
463 }
464 Some(Ok(Message::Ping(data))) => {
465 tracing::trace!("Received ping from client {}", client_id);
466 let mut clients_guard = clients.write().await;
467 if let Some(client) = clients_guard.get_mut(&client_id) {
468 if let Err(e) = client.sender.send(Message::Pong(data)).await {
469 tracing::error!("Failed to send pong to client {}: {}", client_id, e);
470 break;
471 }
472 }
473 }
474 Some(Ok(Message::Pong(_))) => {
475 tracing::trace!("Received pong from client {}", client_id);
476 }
477 Some(Ok(Message::Binary(_))) => {
478 tracing::warn!("Received unexpected binary message from client {}", client_id);
479 }
480 Some(Ok(Message::Frame(_))) => {
481 tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
482 }
484 Some(Err(e)) => {
485 tracing::error!("WebSocket error for client {}: {}", client_id, e);
486 break;
487 }
488 None => {
489 tracing::info!("WebSocket stream ended for client {}", client_id);
490 break;
491 }
492 }
493 }
494 _ = shutdown_receiver.recv() => {
495 tracing::info!("Shutting down connection for client {}", client_id);
496 break;
497 }
498 }
499 }
500
501 {
503 let mut clients_guard = clients.write().await;
504 clients_guard.remove(&client_id);
505 }
506
507 tracing::info!("Client {} connection handler exiting", client_id);
508 }
509}
510
511#[async_trait]
512impl ServerTransport for WebSocketServerTransport {
513 async fn start(&mut self) -> McpResult<()> {
514 tracing::info!("Starting WebSocket server on {}", self.bind_addr);
515
516 let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
517 McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
518 })?;
519
520 let clients = self.clients.clone();
521 let request_handler = self.request_handler.clone();
522 let running = self.running.clone();
523 let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
524
525 *running.write().await = true;
526
527 let server_handle = tokio::spawn(async move {
528 let mut shutdown_receiver = shutdown_sender.subscribe();
529
530 loop {
531 tokio::select! {
532 result = listener.accept() => {
533 match result {
534 Ok((stream, addr)) => {
535 tracing::debug!("New connection from: {}", addr);
536
537 tokio::spawn(Self::handle_client_connection(
538 stream,
539 clients.clone(),
540 request_handler.clone(),
541 shutdown_sender.subscribe(),
542 ));
543 }
544 Err(e) => {
545 tracing::error!("Failed to accept connection: {}", e);
546 }
547 }
548 }
549 _ = shutdown_receiver.recv() => {
550 tracing::info!("WebSocket server shutting down");
551 break;
552 }
553 }
554 }
555 });
556
557 self.server_handle = Some(server_handle);
558
559 tracing::info!(
560 "WebSocket server started successfully on {}",
561 self.bind_addr
562 );
563 Ok(())
564 }
565
566 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
567 let handler_guard = self.request_handler.read().await;
568
569 if let Some(ref handler) = *handler_guard {
570 let response_rx = handler(request);
571 drop(handler_guard);
572
573 match response_rx.await {
574 Ok(response) => Ok(response),
575 Err(_) => Err(McpError::WebSocket(
576 "Request handler channel closed".to_string(),
577 )),
578 }
579 } else {
580 Ok(JsonRpcResponse {
581 jsonrpc: "2.0".to_string(),
582 id: request.id,
583 result: None,
584 })
585 }
586 }
587
588 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
589 let notification_text = serde_json::to_string(¬ification)
590 .map_err(|e| McpError::Serialization(e.to_string()))?;
591
592 let mut clients_guard = self.clients.write().await;
593 let mut disconnected_clients = Vec::new();
594
595 for (client_id, client) in clients_guard.iter_mut() {
596 if let Err(e) = client
597 .sender
598 .send(Message::Text(notification_text.clone().into()))
599 .await
600 {
601 tracing::error!("Failed to send notification to client {}: {}", client_id, e);
602 disconnected_clients.push(client_id.clone());
603 }
604 }
605
606 for client_id in disconnected_clients {
608 clients_guard.remove(&client_id);
609 }
610
611 Ok(())
612 }
613
614 async fn stop(&mut self) -> McpResult<()> {
615 tracing::info!("Stopping WebSocket server");
616
617 *self.running.write().await = false;
618
619 if let Some(ref sender) = self.shutdown_sender {
621 let _ = sender.send(());
622 }
623
624 if let Some(handle) = self.server_handle.take() {
626 handle.abort();
627 }
628
629 let mut clients_guard = self.clients.write().await;
631 for (client_id, client) in clients_guard.iter_mut() {
632 tracing::debug!("Closing connection for client {}", client_id);
633 let _ = client.sender.send(Message::Close(None)).await;
634 }
635 clients_guard.clear();
636
637 Ok(())
638 }
639
640 fn is_running(&self) -> bool {
641 self.server_handle.is_some()
643 }
644
645 fn server_info(&self) -> String {
646 format!("WebSocket server transport (bind: {})", self.bind_addr)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[test]
655 fn test_websocket_server_creation() {
656 let transport = WebSocketServerTransport::new("127.0.0.1:0");
657 assert_eq!(transport.bind_addr, "127.0.0.1:0");
658 assert!(!transport.is_running());
659 }
660
661 #[test]
662 fn test_websocket_server_with_config() {
663 let mut config = TransportConfig::default();
664 config.max_message_size = Some(64 * 1024);
665
666 let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
667 assert_eq!(transport.bind_addr, "0.0.0.0:9090");
668 assert_eq!(transport.config.max_message_size, Some(64 * 1024));
669 }
670
671 #[tokio::test]
672 async fn test_websocket_client_invalid_url() {
673 let result = WebSocketClientTransport::new("invalid-url").await;
674 assert!(result.is_err());
675
676 if let Err(McpError::WebSocket(msg)) = result {
677 assert!(msg.contains("Invalid WebSocket URL"));
678 } else {
679 panic!("Expected WebSocket error");
680 }
681 }
682
683 #[tokio::test]
684 async fn test_websocket_client_connection_info() {
685 let url = "ws://localhost:9999/test";
687 if let Ok(transport) = WebSocketClientTransport::new(url).await {
688 let info = transport.connection_info();
689 assert!(info.contains("localhost:9999"));
690 }
691 }
693}