use super::types::{McpConfig, McpError, McpToolInfo};
use crate::tool::{Capability, Tool, ToolDefinition};
use crate::tool_error::ToolError;
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use tracing::{error, info};
#[derive(Debug, Serialize)]
struct JsonRpcRequest {
jsonrpc: String,
method: String,
params: Value,
id: u64,
}
#[derive(Debug, Deserialize)]
struct JsonRpcResponse {
_jsonrpc: String,
result: Option<Value>,
error: Option<JsonRpcError>,
id: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcError {
_code: i32,
message: String,
#[allow(dead_code)]
data: Option<Value>,
}
enum McpCommand {
Call {
method: String,
params: Value,
resp_tx: oneshot::Sender<Result<Value, McpError>>,
},
Shutdown,
}
pub struct McpClient {
server_url: String,
command_tx: mpsc::Sender<McpCommand>,
connected: Arc<RwLock<bool>>,
tools_cache: Arc<RwLock<Option<Vec<McpToolInfo>>>>,
}
impl McpClient {
pub async fn connect(url: &str, config: McpConfig) -> Result<Self, McpError> {
let is_localhost = url.contains("localhost")
|| url.contains("127.0.0.1")
|| url.contains("[::1]")
|| url.contains("0.0.0.0");
if config.require_tls
&& !is_localhost
&& !url.starts_with("wss://")
&& !url.starts_with("https://")
{
return Err(McpError::TlsRequired);
}
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| McpError::ConnectionFailed(e.to_string()))?;
info!(url = url, "Connected to MCP server");
let (command_tx, mut command_rx) = mpsc::channel::<McpCommand>(32);
let connected = Arc::new(RwLock::new(true));
let connected_clone = connected.clone();
tokio::spawn(async move {
let (mut ws_tx, mut ws_rx) = ws_stream.split();
let mut pending_requests: HashMap<u64, oneshot::Sender<Result<Value, McpError>>> =
HashMap::new();
let mut next_id = 1u64;
loop {
tokio::select! {
Some(cmd) = command_rx.recv() => {
match cmd {
McpCommand::Call { method, params, resp_tx } => {
let id = next_id;
next_id += 1;
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
method,
params,
id,
};
let json = serde_json::to_string(&req).unwrap();
if let Err(e) = ws_tx.send(Message::Text(json)).await {
error!("WS send failed: {}", e);
let _ = resp_tx.send(Err(McpError::ConnectionFailed(e.to_string())));
break;
}
pending_requests.insert(id, resp_tx);
}
McpCommand::Shutdown => break,
}
}
Some(msg) = ws_rx.next() => {
match msg {
Ok(Message::Text(text)) => {
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&text) {
if let Some(id) = resp.id {
if let Some(tx) = pending_requests.remove(&id) {
if let Some(err) = resp.error {
let _ = tx.send(Err(McpError::ExecutionFailed(err.message)));
} else {
let _ = tx.send(Ok(resp.result.unwrap_or(Value::Null)));
}
}
}
}
}
Ok(Message::Close(_)) => {
info!("MCP server closed connection");
break;
}
Err(e) => {
error!("WS read error: {}", e);
break;
}
_ => {}
}
}
}
}
*connected_clone.write().await = false;
});
let client = Self {
server_url: url.to_string(),
command_tx,
connected,
tools_cache: Arc::new(RwLock::new(None)),
};
client.initialize().await?;
Ok(client)
}
async fn initialize(&self) -> Result<(), McpError> {
let params = serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "vex-client",
"version": "0.1.5"
}
});
self.call_raw("initialize", params).await?;
Ok(())
}
async fn call_raw(&self, method: &str, params: Value) -> Result<Value, McpError> {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(McpCommand::Call {
method: method.to_string(),
params,
resp_tx,
})
.await
.map_err(|_| McpError::ConnectionFailed("Channel closed".into()))?;
resp_rx
.await
.map_err(|_| McpError::ConnectionFailed("Response channel closed".into()))?
}
pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
if let Some(ref tools) = *self.tools_cache.read().await {
return Ok(tools.clone());
}
let resp = self.call_raw("tools/list", Value::Null).await?;
let tools: Vec<McpToolInfo> = serde_json::from_value(resp["tools"].clone())
.map_err(|e| McpError::Serialization(e.to_string()))?;
*self.tools_cache.write().await = Some(tools.clone());
Ok(tools)
}
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
let params = serde_json::json!({
"name": name,
"arguments": args
});
self.call_raw("tools/call", params).await
}
pub fn server_url(&self) -> &str {
&self.server_url
}
pub async fn is_connected(&self) -> bool {
*self.connected.read().await
}
pub async fn disconnect(&self) {
let _ = self.command_tx.send(McpCommand::Shutdown).await;
}
}
pub struct McpToolAdapter {
client: Arc<McpClient>,
info: McpToolInfo,
definition: ToolDefinition,
}
impl McpToolAdapter {
pub fn new(client: Arc<McpClient>, info: McpToolInfo) -> Self {
let name: &'static str = Box::leak(info.name.clone().into_boxed_str());
let description: &'static str = Box::leak(info.description.clone().into_boxed_str());
let parameters: &'static str = Box::leak(
serde_json::to_string(&info.input_schema)
.unwrap_or_default()
.into_boxed_str(),
);
let definition = ToolDefinition::new(name, description, parameters);
Self {
client,
info,
definition,
}
}
}
#[async_trait]
impl Tool for McpToolAdapter {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
fn capabilities(&self) -> Vec<Capability> {
vec![Capability::Network]
}
fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_secs(30)
}
async fn execute(&self, args: Value) -> Result<Value, ToolError> {
self.client
.call_tool(&self.info.name, args)
.await
.map_err(|e| ToolError::execution_failed(&self.info.name, e.to_string()))
}
}