mcp_sdk_rs/transport/
websocket.rs1use async_trait::async_trait;
14use futures::{
15 stream::{SplitSink, SplitStream},
16 Sink, SinkExt, Stream, StreamExt,
17};
18use std::{fmt::Display, pin::Pin, sync::Arc};
19use tokio::io::{AsyncRead, AsyncWrite};
20use tokio::sync::Mutex;
21use tokio_tungstenite::{
22 connect_async,
23 tungstenite::{error::Error as WsError, protocol::CloseFrame, protocol::Message as WsMessage},
24 WebSocketStream,
25};
26use url::Url;
27
28use crate::{
29 error::Error,
30 transport::{Message, Transport},
31};
32
33type WebSocketConnection<S> = WebSocketStream<S>;
35
36pub struct WebSocketTransport<S> {
41 read_connection: Arc<Mutex<SplitStream<WebSocketConnection<S>>>>,
42 write_connection: Arc<Mutex<SplitSink<WebSocketConnection<S>, WsMessage>>>,
43}
44
45impl<S> WebSocketTransport<S>
46where
47 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
48{
49 pub fn from_stream(stream: WebSocketConnection<S>) -> Self {
57 let (w, r) = stream.split();
58 Self {
59 read_connection: Arc::new(Mutex::new(r)),
60 write_connection: Arc::new(Mutex::new(w)),
61 }
62 }
63
64 fn convert_to_ws_message(message: &Message) -> Result<WsMessage, Error> {
66 let json =
67 serde_json::to_string(message).map_err(|e| Error::Serialization(e.to_string()))?;
68 Ok(WsMessage::Text(json))
69 }
70
71 fn parse_ws_message(ws_message: WsMessage) -> Result<Message, Error> {
73 match ws_message {
74 WsMessage::Text(text) => {
75 serde_json::from_str(&text).map_err(|e| Error::Serialization(e.to_string()))
76 }
77 WsMessage::Binary(_) => Err(Error::Transport(
78 "Binary messages not supported".to_string(),
79 )),
80 WsMessage::Ping(_) => Ok(Message::Notification(crate::protocol::Notification {
81 jsonrpc: crate::protocol::JSONRPC_VERSION.to_string(),
82 method: "ping".to_string(),
83 params: None,
84 })),
85 WsMessage::Pong(_) => Ok(Message::Notification(crate::protocol::Notification {
86 jsonrpc: crate::protocol::JSONRPC_VERSION.to_string(),
87 method: "pong".to_string(),
88 params: None,
89 })),
90 WsMessage::Close(_) => Err(Error::Transport("Connection closed".to_string())),
91 WsMessage::Frame(_) => Err(Error::Transport("Raw frames not supported".to_string())),
92 }
93 }
94
95 async fn handle_ws_message<T, E>(
97 connection: &mut T,
98 message: WsMessage,
99 ) -> Result<Option<Message>, Error>
100 where
101 T: Sink<WsMessage, Error = E> + Unpin,
102 E: Display,
103 {
104 match message {
105 WsMessage::Ping(data) => {
106 connection
108 .send(WsMessage::Pong(data))
109 .await
110 .map_err(|e| Error::Transport(e.to_string()))?;
111 Ok(None)
112 }
113 WsMessage::Pong(_) => {
114 Ok(None)
116 }
117 _ => Self::parse_ws_message(message).map(Some),
118 }
119 }
120}
121
122#[async_trait]
123impl<S> Transport for WebSocketTransport<S>
124where
125 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
126{
127 async fn send(&self, message: Message) -> Result<(), Error> {
128 let ws_message = Self::convert_to_ws_message(&message)?;
129 let mut connection = self.write_connection.lock().await;
130 connection
131 .send(ws_message)
132 .await
133 .map_err(|e| Error::Transport(e.to_string()))
134 }
135
136 fn receive(&self) -> Pin<Box<dyn Stream<Item = Result<Message, Error>> + Send>> {
137 let read_connection = self.read_connection.clone();
138 let write_connection = self.write_connection.clone();
139
140 Box::pin(futures::stream::unfold(
141 read_connection,
142 move |read_connection| {
143 let read_connection = read_connection.clone();
144 let write_connection = write_connection.clone();
145 async move {
146 loop {
147 let mut guard = read_connection.lock().await;
148 match guard.next().await {
149 Some(Ok(ws_message)) => {
150 drop(guard);
151 let mut guard = write_connection.lock().await;
152 match Self::handle_ws_message(&mut *guard, ws_message).await {
153 Ok(Some(message)) => {
154 return Some((Ok(message), read_connection.clone()))
155 }
156 Ok(None) => continue, Err(e) => return Some((Err(e), read_connection.clone())),
158 }
159 }
160 Some(Err(e)) => {
161 return Some((
162 Err(Error::Transport(e.to_string())),
163 read_connection.clone(),
164 ))
165 }
166 None => return None,
167 }
168 }
169 }
170 },
171 ))
172 }
173
174 async fn close(&self) -> Result<(), Error> {
175 let mut connection = self.write_connection.lock().await;
176 connection
178 .send(WsMessage::Close(Some(CloseFrame {
179 code: 1000u16.into(), reason: "Client initiated close".into(),
181 })))
182 .await
183 .map_err(|e| Error::Transport(e.to_string()))?;
184 drop(connection);
185
186 let mut connection = self.read_connection.lock().await;
187 while let Some(msg) = connection.next().await {
189 match msg {
190 Ok(WsMessage::Close(_)) => break,
191 Ok(_) => continue,
192 Err(e) => {
193 if matches!(e, WsError::ConnectionClosed) {
194 break;
195 }
196 return Err(Error::Transport(e.to_string()));
197 }
198 }
199 }
200
201 Ok(())
202 }
203}
204
205impl WebSocketTransport<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
207 pub async fn new(url: &str) -> Result<Self, Error> {
219 let url = Url::parse(url).map_err(|e| Error::Transport(e.to_string()))?;
220
221 let (ws_stream, _) = connect_async(url)
222 .await
223 .map_err(|e| Error::Transport(e.to_string()))?;
224 let (w, r) = ws_stream.split();
225 Ok(Self {
226 read_connection: Arc::new(Mutex::new(r)),
227 write_connection: Arc::new(Mutex::new(w)),
228 })
229 }
230}