use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::RwLock;
use super::types::{McpConfig, McpError, McpToolInfo};
use crate::tool::{Capability, Tool, ToolDefinition};
use crate::tool_error::ToolError;
pub struct McpClient {
server_url: String,
config: McpConfig,
tools_cache: Arc<RwLock<Option<Vec<McpToolInfo>>>>,
connected: Arc<RwLock<bool>>,
}
impl McpClient {
pub async fn connect(url: &str, config: McpConfig) -> Result<Self, McpError> {
let is_localhost = url.contains("localhost") || url.contains("127.0.0.1");
if config.require_tls && !is_localhost {
if !url.starts_with("wss://") && !url.starts_with("https://") {
return Err(McpError::TlsRequired);
}
}
let client = Self {
server_url: url.to_string(),
config,
tools_cache: Arc::new(RwLock::new(None)),
connected: Arc::new(RwLock::new(false)),
};
*client.connected.write().await = true;
Ok(client)
}
pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
{
let cache = self.tools_cache.read().await;
if let Some(ref tools) = *cache {
return Ok(tools.clone());
}
}
let tools = Vec::new();
*self.tools_cache.write().await = Some(tools.clone());
Ok(tools)
}
pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
if !*self.connected.read().await {
return Err(McpError::ConnectionFailed("Not connected".into()));
}
Ok(serde_json::json!({
"tool": name,
"args": args,
"result": null,
"status": "mcp_call_placeholder"
}))
}
pub fn server_url(&self) -> &str {
&self.server_url
}
pub async fn is_connected(&self) -> bool {
*self.connected.read().await
}
pub async fn disconnect(&self) {
*self.connected.write().await = false;
*self.tools_cache.write().await = None;
}
}
pub struct McpToolAdapter {
client: Arc<McpClient>,
info: McpToolInfo,
definition: ToolDefinition,
}
impl McpToolAdapter {
pub fn new(client: Arc<McpClient>, info: McpToolInfo) -> Self {
let name: &'static str = Box::leak(info.name.clone().into_boxed_str());
let description: &'static str = Box::leak(info.description.clone().into_boxed_str());
let parameters: &'static str = Box::leak(
serde_json::to_string(&info.input_schema)
.unwrap_or_default()
.into_boxed_str(),
);
let definition = ToolDefinition::new(name, description, parameters);
Self {
client,
info,
definition,
}
}
pub fn info(&self) -> &McpToolInfo {
&self.info
}
}
#[async_trait]
impl Tool for McpToolAdapter {
fn definition(&self) -> &ToolDefinition {
&self.definition
}
fn capabilities(&self) -> Vec<Capability> {
vec![Capability::Network]
}
fn timeout(&self) -> Duration {
Duration::from_secs(30)
}
async fn execute(&self, args: Value) -> Result<Value, ToolError> {
self.client
.call_tool(&self.info.name, args)
.await
.map_err(|e| ToolError::execution_failed(&self.info.name, e.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_connect_localhost() {
let config = McpConfig::default().allow_insecure();
let client = McpClient::connect("ws://localhost:8080", config).await;
assert!(client.is_ok());
}
#[tokio::test]
async fn test_connect_tls_required() {
let config = McpConfig::default(); let result = McpClient::connect("ws://remote.server:8080", config).await;
assert!(matches!(result, Err(McpError::TlsRequired)));
}
#[tokio::test]
async fn test_connect_tls_allowed() {
let config = McpConfig::default();
let result = McpClient::connect("wss://remote.server:8080", config).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_list_tools_empty() {
let config = McpConfig::default().allow_insecure();
let client = McpClient::connect("ws://localhost:8080", config)
.await
.unwrap();
let tools = client.list_tools().await.unwrap();
assert!(tools.is_empty());
}
#[tokio::test]
async fn test_call_tool() {
let config = McpConfig::default().allow_insecure();
let client = McpClient::connect("ws://localhost:8080", config)
.await
.unwrap();
let result = client
.call_tool("test_tool", serde_json::json!({"arg": "value"}))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_disconnect() {
let config = McpConfig::default().allow_insecure();
let client = McpClient::connect("ws://localhost:8080", config)
.await
.unwrap();
assert!(client.is_connected().await);
client.disconnect().await;
assert!(!client.is_connected().await);
}
}