use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Context, Result};
use serde_json::{Value, json};
use super::protocol::{
JsonRpcNotification, JsonRpcRequest, MCP_PROTOCOL_VERSION, METHOD_INITIALIZE,
METHOD_NOTIFICATIONS_INITIALIZED, METHOD_TOOLS_CALL, METHOD_TOOLS_LIST, McpServerInfo,
McpToolDefinition, McpToolResult,
};
use super::transport::{HttpTransport, McpTransport, StdioTransport};
pub struct McpClient {
transport: McpTransport,
id_counter: AtomicU64,
tools: Vec<McpToolDefinition>,
server_info: Option<McpServerInfo>,
}
impl McpClient {
pub fn connect_stdio(command: &str, args: &[&str]) -> Result<Self> {
let transport = StdioTransport::new(command, args)?;
Ok(Self {
transport: McpTransport::Stdio(Box::new(transport)),
id_counter: AtomicU64::new(1),
tools: Vec::new(),
server_info: None,
})
}
pub fn connect_stdio_with_env(
command: &str,
args: &[&str],
env: &std::collections::HashMap<String, String>,
) -> Result<Self> {
let transport = StdioTransport::with_env(command, args, env)?;
Ok(Self {
transport: McpTransport::Stdio(Box::new(transport)),
id_counter: AtomicU64::new(1),
tools: Vec::new(),
server_info: None,
})
}
pub fn connect_http(url: &str) -> Result<Self> {
Self::connect_http_with_headers(url, &std::collections::HashMap::new())
}
pub fn connect_http_with_headers(
url: &str,
headers: &std::collections::HashMap<String, String>,
) -> Result<Self> {
let transport = if headers.is_empty() {
HttpTransport::new(url)
} else {
HttpTransport::with_headers(url, headers)
};
Ok(Self {
transport: McpTransport::Http(transport),
id_counter: AtomicU64::new(1),
tools: Vec::new(),
server_info: None,
})
}
pub fn pid(&self) -> Option<u32> {
self.transport.pid()
}
pub fn target_url(&self) -> Option<String> {
self.transport.target_url()
}
fn next_id(&self) -> u64 {
self.id_counter.fetch_add(1, Ordering::Relaxed)
}
pub async fn initialize(&mut self) -> Result<McpServerInfo> {
let req = JsonRpcRequest::new(
self.next_id(),
METHOD_INITIALIZE,
Some(json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {
"roots": { "listChanged": false },
"sampling": {}
},
"clientInfo": {
"name": "collet",
"version": env!("CARGO_PKG_VERSION")
}
})),
);
let resp = self
.transport
.send(&req)
.await
.context("MCP initialize failed")?;
if let Some(err) = resp.error {
anyhow::bail!("MCP initialize error {}: {}", err.code, err.message);
}
let result = resp.result.unwrap_or_default();
let server_info_obj = result.get("serverInfo");
let info = McpServerInfo {
name: server_info_obj
.and_then(|s| s.get("name"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
version: server_info_obj
.and_then(|s| s.get("version"))
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
protocol_version: result
.get("protocolVersion")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
capabilities: result.get("capabilities").cloned().unwrap_or(Value::Null),
instructions: result
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
};
let note = JsonRpcNotification::new(METHOD_NOTIFICATIONS_INITIALIZED, None);
if let Err(e) = self.transport.send_notification(¬e).await {
tracing::warn!(error = %e, "MCP notifications/initialized send failed");
}
self.server_info = Some(info.clone());
Ok(info)
}
pub async fn list_tools(&mut self) -> Result<Vec<McpToolDefinition>> {
let req = JsonRpcRequest::new(self.next_id(), METHOD_TOOLS_LIST, Some(json!({})));
let resp = self
.transport
.send(&req)
.await
.context("MCP tools/list failed")?;
if let Some(err) = resp.error {
anyhow::bail!("MCP tools/list error {}: {}", err.code, err.message);
}
let result = resp.result.unwrap_or(json!({"tools": []}));
let tools_value = result.get("tools").cloned().unwrap_or(Value::Array(vec![]));
let tools: Vec<McpToolDefinition> =
serde_json::from_value(tools_value).context("Failed to parse MCP tool list")?;
self.tools = tools.clone();
Ok(tools)
}
pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result<String> {
let req = JsonRpcRequest::new(
self.next_id(),
METHOD_TOOLS_CALL,
Some(json!({
"name": name,
"arguments": arguments,
})),
);
let resp = self
.transport
.send(&req)
.await
.with_context(|| format!("MCP tools/call failed for tool '{name}'"))?;
if let Some(err) = resp.error {
anyhow::bail!(
"MCP tools/call error for '{}': {} — {}",
name,
err.code,
err.message
);
}
let result = resp.result.unwrap_or_default();
let tool_result: McpToolResult =
serde_json::from_value(result).context("Failed to parse McpToolResult")?;
let text: String = tool_result
.content
.iter()
.filter(|c| c.type_ == "text")
.map(|c| c.text.as_str())
.collect::<Vec<_>>()
.join("\n");
if tool_result.is_error {
anyhow::bail!(
"MCP tool '{name}' reported isError=true: {}",
if text.is_empty() {
"<no content>"
} else {
text.as_str()
}
);
}
Ok(text)
}
pub fn tool_definitions(&self) -> Vec<Value> {
self.tools
.iter()
.map(|t| {
json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema,
}
})
})
.collect()
}
pub fn server_instructions(&self) -> Option<&str> {
self.server_info
.as_ref()
.and_then(|i| i.instructions.as_deref())
}
pub fn raw_tools(&self) -> &[McpToolDefinition] {
&self.tools
}
pub async fn shutdown(&mut self) -> Result<()> {
self.transport.shutdown().await.map_err(anyhow::Error::from)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::protocol::McpToolDefinition;
use serde_json::json;
fn client_with_tools(tools: Vec<McpToolDefinition>) -> McpClient {
McpClient {
transport: McpTransport::Http(HttpTransport::new("http://unused")),
id_counter: AtomicU64::new(1),
tools,
server_info: None,
}
}
#[test]
fn test_tool_definitions_basic() {
let client = client_with_tools(vec![McpToolDefinition {
name: "read_file".to_string(),
description: "Read a file".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"path": { "type": "string" }
},
"required": ["path"]
}),
}]);
let defs = client.tool_definitions();
assert_eq!(defs.len(), 1);
let def = &defs[0];
assert_eq!(def["type"], "function");
assert_eq!(def["function"]["name"], "read_file");
assert_eq!(def["function"]["description"], "Read a file");
assert!(def["function"]["parameters"]["properties"]["path"].is_object());
}
#[test]
fn test_tool_definitions_multiple() {
let client = client_with_tools(vec![
McpToolDefinition {
name: "tool_a".to_string(),
description: "First tool".to_string(),
input_schema: json!({"type": "object"}),
},
McpToolDefinition {
name: "tool_b".to_string(),
description: "Second tool".to_string(),
input_schema: json!({"type": "object", "properties": {}}),
},
]);
let defs = client.tool_definitions();
assert_eq!(defs.len(), 2);
assert_eq!(defs[0]["function"]["name"], "tool_a");
assert_eq!(defs[1]["function"]["name"], "tool_b");
}
#[test]
fn test_tool_definitions_empty() {
let client = client_with_tools(vec![]);
let defs = client.tool_definitions();
assert!(defs.is_empty());
}
}