use crate::tool::{Tool, ToolBox, ToolError};
use anyhow::Result as AnyhowResult;
use async_trait::async_trait;
use log::trace;
use mcp_client_rs::{
client::{Client, ClientBuilder},
MessageContent,
};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
pub struct McpToolBox {
client: Arc<Client>,
tools: Vec<Tool>,
}
impl McpToolBox {
pub async fn new(
cmd: &str,
args: impl IntoIterator<Item = impl AsRef<str>>,
envs: Option<HashMap<String, String>>,
) -> AnyhowResult<Self> {
trace!("McpToolBox::new for cmd: {cmd}");
let mut builder = ClientBuilder::new(cmd).args(args);
if let Some(envs) = envs {
for (k, v) in envs {
builder = builder.env(&k, &v);
}
}
let client = builder.spawn_and_initialize().await?;
trace!("McpToolBox::new for client initialized");
let mut tools = vec![];
for tool_desc in client.list_tools().await?.tools {
tools.push(Tool {
name: tool_desc.name,
description: Some(tool_desc.description),
schema: Some(tool_desc.input_schema),
});
}
Ok(Self {
client: Arc::new(client),
tools,
})
}
}
#[async_trait]
impl ToolBox for McpToolBox {
fn tools_definitions(&self) -> Result<Vec<Tool>, ToolError> {
Ok(self.tools.clone())
}
async fn call_tool(&self, tool_name: String, arguments: Value) -> Result<String, ToolError> {
let call_result = self
.client
.call_tool(&tool_name, arguments)
.await
.map_err(anyhow::Error::new)?;
let msg = call_result
.content
.iter()
.filter_map(|msg| match msg {
MessageContent::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
Ok(msg)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result as AnyhowResult;
use serde_json::json;
async fn create_test_toolbox() -> AnyhowResult<McpToolBox> {
McpToolBox::new("uvx", ["mcp-server-time", "--local-timezone", "UTC"], None).await
}
#[tokio::test]
async fn test_new_and_tools_definitions() -> AnyhowResult<()> {
let mcp_tools = create_test_toolbox().await?;
let tool_defs = mcp_tools.tools_definitions()?;
assert!(tool_defs.len() >= 2);
let get_time_tool = tool_defs.iter().find(|t| t.name == "get_current_time");
assert!(
get_time_tool.is_some(),
"Expected tool 'get_current_time' not found"
);
assert_eq!(get_time_tool.unwrap().name, "get_current_time");
assert!(get_time_tool.unwrap().description.is_some());
assert!(get_time_tool.unwrap().schema.is_some());
let convert_time_tool = tool_defs.iter().find(|t| t.name == "convert_time");
assert!(
convert_time_tool.is_some(),
"Expected tool 'convert_time' not found"
);
assert_eq!(convert_time_tool.unwrap().name, "convert_time");
assert!(convert_time_tool.unwrap().description.is_some());
assert!(convert_time_tool.unwrap().schema.is_some());
Ok(())
}
#[tokio::test]
async fn test_call_tool_convert_time() -> AnyhowResult<()> {
let mcp_tools = create_test_toolbox().await?;
let arguments = json!({
"source_timezone": "Europe/Warsaw",
"target_timezone": "America/New_York",
"time": "12:00"
});
let result = mcp_tools
.call_tool("convert_time".to_string(), arguments)
.await?;
assert!(!result.is_empty());
Ok(())
}
#[tokio::test]
async fn test_call_tool_invalid_tool() -> AnyhowResult<()> {
let mcp_tools = create_test_toolbox().await?;
let arguments = json!({});
let result = mcp_tools
.call_tool("non_existent_tool".to_string(), arguments)
.await;
assert!(result.is_err());
Ok(())
}
}