1use async_trait::async_trait;
7use futures_util::{
8 sink::SinkExt,
9 stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14 net::{TcpListener, TcpStream},
15 sync::{broadcast, mpsc, Mutex, RwLock},
16 time::timeout,
17};
18use tokio_tungstenite::{
19 accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27pub struct WebSocketClientTransport {
36 ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
37 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
38 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
39 config: TransportConfig,
40 state: Arc<RwLock<ConnectionState>>,
41 url: String,
42 message_handler: Option<tokio::task::JoinHandle<()>>,
43}
44
45impl WebSocketClientTransport {
46 pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
54 Self::with_config(url, TransportConfig::default()).await
55 }
56
57 pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
66 let url_str = url.as_ref();
67 let url_parsed = Url::parse(url_str)
68 .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {}", e)))?;
69
70 tracing::debug!("Connecting to WebSocket: {}", url_str);
71
72 let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
74
75 let (ws_stream, _) = timeout(connect_timeout, connect_async(&url_parsed))
76 .await
77 .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
78 .map_err(|e| McpError::WebSocket(format!("Failed to connect: {}", e)))?;
79
80 let (ws_sender, ws_receiver) = ws_stream.split();
81
82 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
83 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
84 let state = Arc::new(RwLock::new(ConnectionState::Connected));
85
86 let message_handler = tokio::spawn(Self::handle_messages(
88 ws_receiver,
89 pending_requests.clone(),
90 notification_sender,
91 state.clone(),
92 ));
93
94 Ok(Self {
95 ws_sender: Some(ws_sender),
96 pending_requests,
97 notification_receiver: Some(notification_receiver),
98 config,
99 state,
100 url: url_str.to_string(),
101 message_handler: Some(message_handler),
102 })
103 }
104
105 async fn handle_messages(
106 mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
107 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
108 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
109 state: Arc<RwLock<ConnectionState>>,
110 ) {
111 while let Some(message) = ws_receiver.next().await {
112 match message {
113 Ok(Message::Text(text)) => {
114 tracing::trace!("Received WebSocket message: {}", text);
115
116 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
118 let mut pending = pending_requests.lock().await;
119 if let Some(sender) = pending.remove(&response.id) {
120 if let Err(_) = sender.send(response) {
121 tracing::warn!("Failed to send response to waiting request");
122 }
123 } else {
124 tracing::warn!(
125 "Received response for unknown request ID: {:?}",
126 response.id
127 );
128 }
129 }
130 else if let Ok(notification) =
132 serde_json::from_str::<JsonRpcNotification>(&text)
133 {
134 if let Err(_) = notification_sender.send(notification) {
135 tracing::debug!("Notification receiver dropped");
136 break;
137 }
138 } else {
139 tracing::warn!("Failed to parse WebSocket message: {}", text);
140 }
141 }
142 Ok(Message::Close(_)) => {
143 tracing::info!("WebSocket connection closed");
144 *state.write().await = ConnectionState::Disconnected;
145 break;
146 }
147 Ok(Message::Ping(_data)) => {
148 tracing::trace!("Received WebSocket ping");
149 }
151 Ok(Message::Pong(_)) => {
152 tracing::trace!("Received WebSocket pong");
153 }
154 Ok(Message::Binary(_)) => {
155 tracing::warn!("Received unexpected binary WebSocket message");
156 }
157 Ok(Message::Frame(_)) => {
158 tracing::trace!("Received WebSocket frame (internal)");
159 }
161 Err(e) => {
162 tracing::error!("WebSocket error: {}", e);
163 *state.write().await = ConnectionState::Error(e.to_string());
164 break;
165 }
166 }
167 }
168
169 tracing::debug!("WebSocket message handler exiting");
170 }
171
172 async fn send_message(&mut self, message: Message) -> McpResult<()> {
173 if let Some(ref mut sender) = self.ws_sender {
174 sender
175 .send(message)
176 .await
177 .map_err(|e| McpError::WebSocket(format!("Failed to send message: {}", e)))?;
178 } else {
179 return Err(McpError::WebSocket("WebSocket not connected".to_string()));
180 }
181 Ok(())
182 }
183}
184
185#[async_trait]
186impl Transport for WebSocketClientTransport {
187 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
188 let (sender, receiver) = tokio::sync::oneshot::channel();
189
190 {
192 let mut pending = self.pending_requests.lock().await;
193 pending.insert(request.id.clone(), sender);
194 }
195
196 let request_text =
198 serde_json::to_string(&request).map_err(|e| McpError::Serialization(e))?;
199
200 tracing::trace!("Sending WebSocket request: {}", request_text);
201
202 self.send_message(Message::Text(request_text)).await?;
203
204 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
206
207 let response = timeout(timeout_duration, receiver)
208 .await
209 .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
210 .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
211
212 Ok(response)
213 }
214
215 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
216 let notification_text =
217 serde_json::to_string(¬ification).map_err(|e| McpError::Serialization(e))?;
218
219 tracing::trace!("Sending WebSocket notification: {}", notification_text);
220
221 self.send_message(Message::Text(notification_text)).await
222 }
223
224 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
225 if let Some(ref mut receiver) = self.notification_receiver {
226 match receiver.try_recv() {
227 Ok(notification) => Ok(Some(notification)),
228 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
229 Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
230 "Notification channel disconnected".to_string(),
231 )),
232 }
233 } else {
234 Ok(None)
235 }
236 }
237
238 async fn close(&mut self) -> McpResult<()> {
239 tracing::debug!("Closing WebSocket connection");
240
241 *self.state.write().await = ConnectionState::Closing;
242
243 if let Some(ref mut sender) = self.ws_sender {
245 let _ = sender.send(Message::Close(None)).await;
246 }
247
248 if let Some(handle) = self.message_handler.take() {
250 handle.abort();
251 }
252
253 self.ws_sender = None;
254 self.notification_receiver = None;
255
256 *self.state.write().await = ConnectionState::Disconnected;
257
258 Ok(())
259 }
260
261 fn is_connected(&self) -> bool {
262 self.ws_sender.is_some()
264 }
265
266 fn connection_info(&self) -> String {
267 format!("WebSocket transport (url: {})", self.url)
268 }
269}
270
271struct WebSocketConnection {
277 sender: SplitSink<WebSocketStream<TcpStream>, Message>,
278 id: String,
279}
280
281pub struct WebSocketServerTransport {
286 bind_addr: String,
287 config: TransportConfig,
288 clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
289 request_handler: Arc<
290 RwLock<
291 Option<
292 Arc<
293 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
294 + Send
295 + Sync,
296 >,
297 >,
298 >,
299 >,
300 server_handle: Option<tokio::task::JoinHandle<()>>,
301 running: Arc<RwLock<bool>>,
302 shutdown_sender: Option<broadcast::Sender<()>>,
303}
304
305impl WebSocketServerTransport {
306 pub fn new<S: Into<String>>(bind_addr: S) -> Self {
314 Self::with_config(bind_addr, TransportConfig::default())
315 }
316
317 pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
326 let (shutdown_sender, _) = broadcast::channel(1);
327
328 Self {
329 bind_addr: bind_addr.into(),
330 config,
331 clients: Arc::new(RwLock::new(HashMap::new())),
332 request_handler: Arc::new(RwLock::new(None)),
333 server_handle: None,
334 running: Arc::new(RwLock::new(false)),
335 shutdown_sender: Some(shutdown_sender),
336 }
337 }
338
339 pub async fn set_request_handler<F>(&mut self, handler: F)
344 where
345 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
346 + Send
347 + Sync
348 + 'static,
349 {
350 let mut request_handler = self.request_handler.write().await;
351 *request_handler = Some(Arc::new(handler));
352 }
353
354 async fn handle_client_connection(
355 stream: TcpStream,
356 clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
357 request_handler: Arc<
358 RwLock<
359 Option<
360 Arc<
361 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
362 + Send
363 + Sync,
364 >,
365 >,
366 >,
367 >,
368 mut shutdown_receiver: broadcast::Receiver<()>,
369 ) {
370 let client_id = uuid::Uuid::new_v4().to_string();
371
372 let ws_stream = match accept_async(stream).await {
373 Ok(ws) => ws,
374 Err(e) => {
375 tracing::error!("Failed to accept WebSocket connection: {}", e);
376 return;
377 }
378 };
379
380 tracing::info!("New WebSocket client connected: {}", client_id);
381
382 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
383
384 {
386 let mut clients_guard = clients.write().await;
387 clients_guard.insert(
388 client_id.clone(),
389 WebSocketConnection {
390 sender: ws_sender,
391 id: client_id.clone(),
392 },
393 );
394 }
395
396 loop {
398 tokio::select! {
399 message = ws_receiver.next() => {
400 match message {
401 Some(Ok(Message::Text(text))) => {
402 tracing::trace!("Received message from {}: {}", client_id, text);
403
404 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
406 let handler_guard = request_handler.read().await;
407 if let Some(ref handler) = *handler_guard {
408 let response_rx = handler(request.clone());
409 drop(handler_guard);
410
411 match response_rx.await {
412 Ok(response) => {
413 let response_text = match serde_json::to_string(&response) {
414 Ok(text) => text,
415 Err(e) => {
416 tracing::error!("Failed to serialize response: {}", e);
417 continue;
418 }
419 };
420
421 let mut clients_guard = clients.write().await;
423 if let Some(client) = clients_guard.get_mut(&client_id) {
424 if let Err(e) = client.sender.send(Message::Text(response_text)).await {
425 tracing::error!("Failed to send response to client {}: {}", client_id, e);
426 break;
427 }
428 }
429 }
430 Err(_) => {
431 tracing::error!("Request handler channel closed for client {}", client_id);
432 }
433 }
434 } else {
435 tracing::warn!("No request handler configured for client {}", client_id);
436 }
437 }
438 else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
440 tracing::trace!("Received notification from client {}", client_id);
441 } else {
443 tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
444 }
445 }
446 Some(Ok(Message::Close(_))) => {
447 tracing::info!("Client {} disconnected", client_id);
448 break;
449 }
450 Some(Ok(Message::Ping(data))) => {
451 tracing::trace!("Received ping from client {}", client_id);
452 let mut clients_guard = clients.write().await;
453 if let Some(client) = clients_guard.get_mut(&client_id) {
454 if let Err(e) = client.sender.send(Message::Pong(data)).await {
455 tracing::error!("Failed to send pong to client {}: {}", client_id, e);
456 break;
457 }
458 }
459 }
460 Some(Ok(Message::Pong(_))) => {
461 tracing::trace!("Received pong from client {}", client_id);
462 }
463 Some(Ok(Message::Binary(_))) => {
464 tracing::warn!("Received unexpected binary message from client {}", client_id);
465 }
466 Some(Ok(Message::Frame(_))) => {
467 tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
468 }
470 Some(Err(e)) => {
471 tracing::error!("WebSocket error for client {}: {}", client_id, e);
472 break;
473 }
474 None => {
475 tracing::info!("WebSocket stream ended for client {}", client_id);
476 break;
477 }
478 }
479 }
480 _ = shutdown_receiver.recv() => {
481 tracing::info!("Shutting down connection for client {}", client_id);
482 break;
483 }
484 }
485 }
486
487 {
489 let mut clients_guard = clients.write().await;
490 clients_guard.remove(&client_id);
491 }
492
493 tracing::info!("Client {} connection handler exiting", client_id);
494 }
495}
496
497#[async_trait]
498impl ServerTransport for WebSocketServerTransport {
499 async fn start(&mut self) -> McpResult<()> {
500 tracing::info!("Starting WebSocket server on {}", self.bind_addr);
501
502 let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
503 McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
504 })?;
505
506 let clients = self.clients.clone();
507 let request_handler = self.request_handler.clone();
508 let running = self.running.clone();
509 let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
510
511 *running.write().await = true;
512
513 let server_handle = tokio::spawn(async move {
514 let mut shutdown_receiver = shutdown_sender.subscribe();
515
516 loop {
517 tokio::select! {
518 result = listener.accept() => {
519 match result {
520 Ok((stream, addr)) => {
521 tracing::debug!("New connection from: {}", addr);
522
523 tokio::spawn(Self::handle_client_connection(
524 stream,
525 clients.clone(),
526 request_handler.clone(),
527 shutdown_sender.subscribe(),
528 ));
529 }
530 Err(e) => {
531 tracing::error!("Failed to accept connection: {}", e);
532 }
533 }
534 }
535 _ = shutdown_receiver.recv() => {
536 tracing::info!("WebSocket server shutting down");
537 break;
538 }
539 }
540 }
541 });
542
543 self.server_handle = Some(server_handle);
544
545 tracing::info!(
546 "WebSocket server started successfully on {}",
547 self.bind_addr
548 );
549 Ok(())
550 }
551
552 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
553 let handler_guard = self.request_handler.read().await;
554
555 if let Some(ref handler) = *handler_guard {
556 let response_rx = handler(request);
557 drop(handler_guard);
558
559 match response_rx.await {
560 Ok(response) => Ok(response),
561 Err(_) => Err(McpError::WebSocket(
562 "Request handler channel closed".to_string(),
563 )),
564 }
565 } else {
566 Ok(JsonRpcResponse {
567 jsonrpc: "2.0".to_string(),
568 id: request.id,
569 result: None,
570 error: Some(crate::protocol::types::JsonRpcError {
571 code: crate::protocol::types::METHOD_NOT_FOUND,
572 message: "No request handler configured".to_string(),
573 data: None,
574 }),
575 })
576 }
577 }
578
579 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
580 let notification_text =
581 serde_json::to_string(¬ification).map_err(|e| McpError::Serialization(e))?;
582
583 let mut clients_guard = self.clients.write().await;
584 let mut disconnected_clients = Vec::new();
585
586 for (client_id, client) in clients_guard.iter_mut() {
587 if let Err(e) = client
588 .sender
589 .send(Message::Text(notification_text.clone()))
590 .await
591 {
592 tracing::error!("Failed to send notification to client {}: {}", client_id, e);
593 disconnected_clients.push(client_id.clone());
594 }
595 }
596
597 for client_id in disconnected_clients {
599 clients_guard.remove(&client_id);
600 }
601
602 Ok(())
603 }
604
605 async fn stop(&mut self) -> McpResult<()> {
606 tracing::info!("Stopping WebSocket server");
607
608 *self.running.write().await = false;
609
610 if let Some(ref sender) = self.shutdown_sender {
612 let _ = sender.send(());
613 }
614
615 if let Some(handle) = self.server_handle.take() {
617 handle.abort();
618 }
619
620 let mut clients_guard = self.clients.write().await;
622 for (client_id, client) in clients_guard.iter_mut() {
623 tracing::debug!("Closing connection for client {}", client_id);
624 let _ = client.sender.send(Message::Close(None)).await;
625 }
626 clients_guard.clear();
627
628 Ok(())
629 }
630
631 fn is_running(&self) -> bool {
632 self.server_handle.is_some()
634 }
635
636 fn server_info(&self) -> String {
637 format!("WebSocket server transport (bind: {})", self.bind_addr)
638 }
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use serde_json::json;
645
646 #[test]
647 fn test_websocket_server_creation() {
648 let transport = WebSocketServerTransport::new("127.0.0.1:0");
649 assert_eq!(transport.bind_addr, "127.0.0.1:0");
650 assert!(!transport.is_running());
651 }
652
653 #[test]
654 fn test_websocket_server_with_config() {
655 let mut config = TransportConfig::default();
656 config.max_message_size = Some(64 * 1024);
657
658 let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
659 assert_eq!(transport.bind_addr, "0.0.0.0:9090");
660 assert_eq!(transport.config.max_message_size, Some(64 * 1024));
661 }
662
663 #[tokio::test]
664 async fn test_websocket_client_invalid_url() {
665 let result = WebSocketClientTransport::new("invalid-url").await;
666 assert!(result.is_err());
667
668 if let Err(McpError::WebSocket(msg)) = result {
669 assert!(msg.contains("Invalid WebSocket URL"));
670 } else {
671 panic!("Expected WebSocket error");
672 }
673 }
674
675 #[tokio::test]
676 async fn test_websocket_client_connection_info() {
677 let url = "ws://localhost:9999/test";
679 if let Ok(transport) = WebSocketClientTransport::new(url).await {
680 let info = transport.connection_info();
681 assert!(info.contains("localhost:9999"));
682 }
683 }
685}