use std::time::Duration;
use rmcp::{
RoleClient, ServiceError, ServiceExt,
model::{
CallToolRequestParams, ClientCapabilities, ClientInfo, Implementation,
InitializeRequestParams, Tool,
},
service::RunningService,
transport::StreamableHttpClientTransport,
};
use serde_json::Value;
use crate::base::error::{TestError, TestsResult, ToolCallError};
pub struct McpClient {
client: RunningService<RoleClient, InitializeRequestParams>,
}
impl McpClient {
pub async fn new(server_url: &str) -> TestsResult<Self> {
let transport = StreamableHttpClientTransport::from_uri(server_url);
let client_info = ClientInfo::new(
ClientCapabilities::default(),
Implementation::new("vibe-tests", "0.1.0"),
);
let client = client_info
.serve(transport)
.await
.map_err(|e| TestError::Setup(format!("Failed to connect MCP: {}", e)))?;
Ok(Self { client })
}
pub async fn list_tools(&self) -> TestsResult<Vec<Tool>> {
let response = self
.client
.list_tools(None)
.await
.map_err(|e| TestError::Setup(format!("Failed to list tools: {}", e)))?;
Ok(response.tools)
}
pub async fn call_tool(
&self,
name: String,
args: Value,
timeout: Duration,
) -> TestsResult<String> {
let params = if let Some(obj) = args.as_object() {
CallToolRequestParams::new(name.clone()).with_arguments(obj.clone())
} else {
CallToolRequestParams::new(name.clone())
};
let result = tokio::time::timeout(timeout, self.client.call_tool(params))
.await
.map_err(|_| {
TestError::Timeout(format!(
"Tool '{}' call timed out after {:?}",
name, timeout
))
})?;
match result {
Ok(r) => {
let content = r
.content
.into_iter()
.next()
.and_then(|c| c.as_text().map(|t| t.text.clone()))
.unwrap_or(String::new());
Ok(content)
}
Err(e) => {
tracing::error!("Tool call failed: {}", e);
match e {
ServiceError::McpError(e) => Err(TestError::ToolCall(ToolCallError {
tool: Some(name.clone()),
args: Some(serde_json::to_string(&args).unwrap_or_default()),
code: e.code.0,
})),
_ => Err(TestError::ToolCall(ToolCallError {
tool: Some(name.clone()),
args: Some(serde_json::to_string(&args).unwrap_or_default()),
code: -1,
})),
}
}
}
}
}