use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::sync::Mutex;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
use synwire_core::BoxFuture;
use synwire_core::agents::error::AgentError;
use synwire_core::mcp::traits::{
McpConnectionState, McpServerStatus, McpToolDescriptor, McpTransport,
};
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
struct Inner {
stream: Option<WsStream>,
state: McpConnectionState,
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("state", &self.state)
.field("stream", &self.stream.is_some())
.finish()
}
}
#[derive(Debug)]
pub struct WebSocketMcpTransport {
name: String,
url: String,
auth_token: Option<String>,
state: Arc<Mutex<Inner>>,
next_id: AtomicU64,
calls_succeeded: AtomicU64,
calls_failed: AtomicU64,
}
impl WebSocketMcpTransport {
#[must_use]
pub fn new(
name: impl Into<String>,
url: impl Into<String>,
auth_token: Option<String>,
) -> Self {
Self {
name: name.into(),
url: url.into(),
auth_token,
state: Arc::new(Mutex::new(Inner {
stream: None,
state: McpConnectionState::Disconnected,
})),
next_id: AtomicU64::new(1),
calls_succeeded: AtomicU64::new(0),
calls_failed: AtomicU64::new(0),
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn url(&self) -> &str {
&self.url
}
fn next_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
#[allow(clippy::significant_drop_tightening)]
async fn rpc(&self, method: &str, params: Value) -> Result<Value, AgentError> {
let id = self.next_id();
let request = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let text = serde_json::to_string(&request)
.map_err(|e| AgentError::Vfs(format!("JSON serialization error: {e}")))?;
let mut guard = self.state.lock().await;
let stream = guard.stream.as_mut().ok_or_else(|| {
AgentError::Vfs(format!(
"WebSocket MCP server '{}' not connected",
self.name
))
})?;
stream
.send(Message::Text(text))
.await
.map_err(|e| AgentError::Vfs(format!("WebSocket send error: {e}")))?;
let raw = stream.next().await.ok_or_else(|| {
AgentError::Vfs(format!(
"WebSocket MCP server '{}' closed connection",
self.name
))
})?;
let msg = raw.map_err(|e| AgentError::Vfs(format!("WebSocket receive error: {e}")))?;
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => String::from_utf8(b)
.map_err(|e| AgentError::Vfs(format!("Invalid UTF-8 in WebSocket message: {e}")))?,
Message::Close(_) => {
return Err(AgentError::Vfs(format!(
"WebSocket MCP server '{}' sent close frame",
self.name
)));
}
_ => {
return Err(AgentError::Vfs(format!(
"Unexpected WebSocket message type from server '{}'",
self.name
)));
}
};
let response: Value = serde_json::from_str(&text)
.map_err(|e| AgentError::Vfs(format!("Failed to parse JSON-RPC response: {e}")))?;
if let Some(error) = response.get("error") {
return Err(AgentError::Vfs(format!(
"MCP JSON-RPC error from server '{}': {}",
self.name, error
)));
}
Ok(response["result"].clone())
}
async fn open_connection(&self) -> Result<WsStream, AgentError> {
let mut request = self.url.as_str();
let _ = &self.auth_token; let url_str = self.url.clone();
let (stream, _response) = connect_async(request).await.map_err(|e| {
AgentError::Vfs(format!("WebSocket connection to '{url_str}' failed: {e}"))
})?;
let _ = &mut request;
Ok(stream)
}
}
impl McpTransport for WebSocketMcpTransport {
fn connect(&self) -> BoxFuture<'_, Result<(), AgentError>> {
Box::pin(async move {
let mut guard = self.state.lock().await;
if guard.state == McpConnectionState::Connected {
return Ok(());
}
guard.state = McpConnectionState::Connecting;
drop(guard);
let stream = self.open_connection().await?;
let mut guard = self.state.lock().await;
guard.stream = Some(stream);
guard.state = McpConnectionState::Connected;
drop(guard);
tracing::info!(server = %self.name, url = %self.url, "WebSocket MCP server connected");
Ok(())
})
}
fn reconnect(&self) -> BoxFuture<'_, Result<(), AgentError>> {
Box::pin(async move {
{
let mut guard = self.state.lock().await;
guard.state = McpConnectionState::Reconnecting;
guard.stream = None;
}
tracing::info!(server = %self.name, "Reconnecting WebSocket MCP server");
self.connect().await
})
}
fn status(&self) -> BoxFuture<'_, McpServerStatus> {
Box::pin(async move {
let guard = self.state.lock().await;
McpServerStatus {
name: self.name.clone(),
state: guard.state,
calls_succeeded: self.calls_succeeded.load(Ordering::Relaxed),
calls_failed: self.calls_failed.load(Ordering::Relaxed),
enabled: guard.state == McpConnectionState::Connected,
}
})
}
fn list_tools(&self) -> BoxFuture<'_, Result<Vec<McpToolDescriptor>, AgentError>> {
Box::pin(async move {
let result = self.rpc("tools/list", serde_json::json!({})).await?;
let tools = result["tools"].as_array().ok_or_else(|| {
AgentError::Vfs(format!(
"MCP server '{}' returned invalid tools/list response",
self.name
))
})?;
let descriptors = tools
.iter()
.filter_map(|t| {
let name = t["name"].as_str()?;
Some(McpToolDescriptor {
name: name.to_owned(),
description: t["description"].as_str().unwrap_or_default().to_owned(),
input_schema: t["inputSchema"].clone(),
})
})
.collect();
Ok(descriptors)
})
}
fn call_tool(
&self,
tool_name: &str,
arguments: Value,
) -> BoxFuture<'_, Result<Value, AgentError>> {
let tool_name = tool_name.to_owned();
Box::pin(async move {
let result = self
.rpc(
"tools/call",
serde_json::json!({ "name": tool_name, "arguments": arguments }),
)
.await;
match &result {
Ok(_) => {
let _ = self.calls_succeeded.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {
let _ = self.calls_failed.fetch_add(1, Ordering::Relaxed);
}
}
result
})
}
fn disconnect(&self) -> BoxFuture<'_, Result<(), AgentError>> {
Box::pin(async move {
let mut guard = self.state.lock().await;
if let Some(mut stream) = guard.stream.take() {
drop(guard);
let _ = stream.close(None).await;
}
let mut guard = self.state.lock().await;
guard.state = McpConnectionState::Shutdown;
drop(guard);
tracing::info!(server = %self.name, "WebSocket MCP server disconnected");
Ok(())
})
}
}