use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use rmcp::ServiceExt;
use rmcp::model::{CallToolRequestParams, CallToolResult, Content, Tool as McpToolDef};
use rmcp::service::{Peer, RoleClient, RunningService};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use serde_json::Value;
use tokio::sync::RwLock;
use tracing::{debug, info};
use crate::error::ToolError;
use crate::tool::{BoxedTool, DynTool, ToolDefinition};
#[derive(Debug)]
pub struct StdioBuilder {
command: String,
args: Vec<String>,
envs: HashMap<String, String>,
working_dir: Option<PathBuf>,
name: Option<String>,
}
impl StdioBuilder {
#[must_use]
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.envs.insert(key.into(), value.into());
self
}
#[must_use]
pub fn envs(
mut self,
vars: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
for (k, v) in vars {
self.envs.insert(k.into(), v.into());
}
self
}
#[must_use]
pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_dir = Some(dir.into());
self
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub async fn connect(self) -> crate::Result<McpServer> {
info!(
command = %self.command,
args = ?self.args,
envs = ?self.envs.keys().collect::<Vec<_>>(),
"Connecting to MCP server via stdio",
);
let mut cmd = tokio::process::Command::new(&self.command);
cmd.args(&self.args);
if !self.envs.is_empty() {
cmd.envs(&self.envs);
}
if let Some(ref dir) = self.working_dir {
cmd.current_dir(dir);
}
let transport = TokioChildProcess::new(cmd).map_err(|e| {
crate::error::AgentError::runtime(format!(
"Failed to spawn MCP server process '{}': {e}",
self.command,
))
})?;
let service = ().serve(transport).await.map_err(|e| {
crate::error::AgentError::runtime(format!(
"Failed to initialize MCP connection to '{}': {e}",
self.command,
))
})?;
let name = self
.name
.unwrap_or_else(|| format!("stdio:{}", self.command));
info!(name = %name, "MCP server connected");
Ok(McpServer {
service: Arc::new(RwLock::new(service)),
cached_tools: Arc::new(RwLock::new(None)),
name,
})
}
}
#[derive(Debug)]
pub struct HttpBuilder {
url: String,
bearer_token: Option<String>,
name: Option<String>,
}
impl HttpBuilder {
#[must_use]
pub fn bearer_auth(mut self, token: impl Into<String>) -> Self {
self.bearer_token = Some(token.into());
self
}
#[must_use]
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub async fn connect(self) -> crate::Result<McpServer> {
info!(url = %self.url, "Connecting to MCP server via HTTP");
let mut config = StreamableHttpClientTransportConfig::with_uri(self.url.clone());
if let Some(token) = self.bearer_token {
config = config.auth_header(token);
}
let transport = StreamableHttpClientTransport::from_config(config);
let service: RunningService<RoleClient, ()> = ().serve(transport).await.map_err(|e| {
crate::error::AgentError::runtime(format!(
"Failed to initialize MCP connection to '{}': {e}",
self.url,
))
})?;
let name = self.name.unwrap_or_else(|| format!("http:{}", self.url));
info!(name = %name, "MCP server connected");
Ok(McpServer {
service: Arc::new(RwLock::new(service)),
cached_tools: Arc::new(RwLock::new(None)),
name,
})
}
}
pub struct McpServer {
service: Arc<RwLock<RunningService<RoleClient, ()>>>,
cached_tools: Arc<RwLock<Option<Vec<McpToolDef>>>>,
name: String,
}
impl std::fmt::Debug for McpServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpServer")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl McpServer {
pub fn stdio(
command: impl AsRef<str>,
args: impl IntoIterator<Item = impl AsRef<str>>,
) -> StdioBuilder {
StdioBuilder {
command: command.as_ref().to_owned(),
args: args.into_iter().map(|a| a.as_ref().to_owned()).collect(),
envs: HashMap::new(),
working_dir: None,
name: None,
}
}
pub fn http(url: impl Into<String>) -> HttpBuilder {
HttpBuilder {
url: url.into(),
bearer_token: None,
name: None,
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
pub async fn peer(&self) -> Peer<RoleClient> {
let svc = self.service.read().await;
svc.peer().clone()
}
pub async fn list_tools(&self) -> crate::Result<Vec<McpToolDef>> {
{
let cache = self.cached_tools.read().await;
if let Some(ref tools) = *cache {
return Ok(tools.clone());
}
}
self.refresh_tools().await
}
pub async fn refresh_tools(&self) -> crate::Result<Vec<McpToolDef>> {
let svc = self.service.read().await;
let tools = svc.peer().list_all_tools().await.map_err(|e| {
crate::error::AgentError::runtime(format!(
"Failed to list tools from MCP server '{}': {e}",
self.name
))
})?;
debug!(
server = %self.name,
count = tools.len(),
"Discovered MCP tools",
);
let mut cache = self.cached_tools.write().await;
*cache = Some(tools.clone());
Ok(tools)
}
pub async fn call_tool(
&self,
name: impl Into<Cow<'static, str>>,
arguments: Value,
) -> Result<String, ToolError> {
let tool_name = name.into();
let args_obj = match arguments {
Value::Object(map) => Some(map),
Value::Null => None,
other => {
return Err(ToolError::InvalidArguments(format!(
"MCP tool arguments must be a JSON object, got: {other}"
)));
}
};
let svc = self.service.read().await;
let result: CallToolResult = svc
.peer()
.call_tool(CallToolRequestParams {
meta: None,
name: tool_name.clone(),
arguments: args_obj,
task: None,
})
.await
.map_err(|e| {
ToolError::Execution(format!("MCP tool '{tool_name}' call failed: {e}"))
})?;
if result.is_error == Some(true) {
let text = extract_text_from_contents(&result.content);
return Err(ToolError::Execution(format!(
"MCP tool '{tool_name}' returned error: {text}"
)));
}
Ok(extract_text_from_contents(&result.content))
}
pub async fn tools(&self) -> crate::Result<Vec<BoxedTool>> {
let mcp_tools = self.list_tools().await?;
let server = Arc::new(self.clone_inner());
Ok(mcp_tools
.into_iter()
.map(|t| -> BoxedTool {
Box::new(McpTool {
server: Arc::clone(&server),
tool_def: t,
})
})
.collect())
}
pub async fn close(&self) -> crate::Result<()> {
let mut svc = self.service.write().await;
svc.close().await.map_err(|e| {
crate::error::AgentError::runtime(format!(
"Failed to close MCP server '{}': {e}",
self.name
))
})?;
info!(server = %self.name, "MCP server connection closed");
Ok(())
}
fn clone_inner(&self) -> Self {
Self {
service: Arc::clone(&self.service),
cached_tools: Arc::clone(&self.cached_tools),
name: self.name.clone(),
}
}
}
struct McpTool {
server: Arc<McpServer>,
tool_def: McpToolDef,
}
impl std::fmt::Debug for McpTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpTool")
.field("name", &self.tool_def.name)
.field("server", &self.server.name)
.finish()
}
}
#[async_trait]
impl DynTool for McpTool {
fn name(&self) -> &str {
&self.tool_def.name
}
fn description(&self) -> String {
self.tool_def
.description
.as_deref()
.unwrap_or("MCP tool")
.to_owned()
}
fn definition(&self) -> ToolDefinition {
let params = Value::Object(self.tool_def.input_schema.as_ref().clone());
ToolDefinition::new(self.tool_def.name.as_ref(), self.description(), params)
}
async fn call_json(&self, args: Value) -> Result<Value, ToolError> {
let text = self
.server
.call_tool(self.tool_def.name.clone(), args)
.await?;
Ok(Value::String(text))
}
}
fn extract_text_from_contents(contents: &[Content]) -> String {
let mut output = String::new();
for content in contents {
if let Some(text) = content.as_text() {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&text.text);
}
}
if output.is_empty() {
serde_json::to_string(contents).unwrap_or_default()
} else {
output
}
}