use rmcp::{
ServiceExt,
model::CallToolRequestParams,
service::{Peer, RoleClient, RunningService},
transport::{StreamableHttpClientTransport, TokioChildProcess},
};
use serde_json::Value;
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
use tracing::{error, instrument};
use crate::raw::shared::{FunctionDefinition, ToolDefinition as RawTool, ToolKind as ToolType};
use crate::tool_trait::Tool;
#[derive(Debug, thiserror::Error)]
pub enum McpError {
#[error("failed to spawn MCP server process: {0}")]
Spawn(#[from] std::io::Error),
#[error("MCP client initialisation failed: {0}")]
Init(String),
#[error("failed to list tools from MCP server: {0}")]
ListTools(String),
#[error("MCP tool call failed: {0}")]
Call(String),
}
#[derive(Clone)]
pub struct McpTool {
tools: Vec<RawTool>,
peer: Arc<Peer<RoleClient>>,
_service: Arc<dyn std::any::Any + Send + Sync>,
name: Option<String>,
max_output_chars: Option<usize>,
max_content_items: Option<usize>,
call_timeout: Option<Duration>,
}
const DEFAULT_MAX_OUTPUT_CHARS: usize = 8_000;
const DEFAULT_MAX_CONTENT_ITEMS: usize = 50;
impl McpTool {
#[instrument(skip(args), fields(program = program.as_ref()))]
pub async fn stdio(
program: impl AsRef<str>,
args: &[impl AsRef<str>],
) -> Result<Self, McpError> {
let mut cmd = Command::new(program.as_ref());
for arg in args {
cmd.arg(arg.as_ref());
}
let transport = TokioChildProcess::new(cmd)?;
Self::from_service(
().serve(transport)
.await
.map_err(|e| McpError::Init(e.to_string()))?,
)
.await
}
#[instrument(skip(args, cwd), fields(program = program.as_ref()))]
pub async fn stdio_with_cwd(
program: impl AsRef<str>,
args: &[impl AsRef<str>],
cwd: impl AsRef<std::path::Path>,
) -> Result<Self, McpError> {
let mut cmd = Command::new(program.as_ref());
for arg in args {
cmd.arg(arg.as_ref());
}
cmd.current_dir(cwd);
let transport = TokioChildProcess::new(cmd)?;
Self::from_service(
().serve(transport)
.await
.map_err(|e| McpError::Init(e.to_string()))?,
)
.await
}
#[instrument(fields(url = url.as_ref()))]
pub async fn http(url: impl AsRef<str>) -> Result<Self, McpError> {
let transport = StreamableHttpClientTransport::from_uri(url.as_ref());
Self::from_service(
().serve(transport)
.await
.map_err(|e| McpError::Init(e.to_string()))?,
)
.await
}
#[cfg(feature = "mcp")]
pub async fn from_transport<T, E, A>(transport: T) -> Result<Self, McpError>
where
T: rmcp::transport::IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
use rmcp::ServiceExt;
Self::from_service(
().serve(transport)
.await
.map_err(|e| McpError::Init(e.to_string()))?,
)
.await
}
async fn from_service<S>(running: RunningService<RoleClient, S>) -> Result<Self, McpError>
where
S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
{
let peer = running.peer().clone();
let tools = Self::fetch_tools(&peer).await?;
Ok(Self {
tools,
peer: Arc::new(peer),
_service: Arc::new(running),
name: None,
max_output_chars: Some(DEFAULT_MAX_OUTPUT_CHARS),
max_content_items: Some(DEFAULT_MAX_CONTENT_ITEMS),
call_timeout: None,
})
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_max_output_chars(mut self, max: usize) -> Self {
self.max_output_chars = Some(max);
self
}
pub fn with_max_content_items(mut self, max: usize) -> Self {
self.max_content_items = Some(max);
self
}
pub fn with_output_limits(mut self, max_chars: usize, max_items: usize) -> Self {
self.max_output_chars = Some(max_chars);
self.max_content_items = Some(max_items);
self
}
pub fn without_output_limits(mut self) -> Self {
self.max_output_chars = None;
self.max_content_items = None;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.call_timeout = Some(timeout);
self
}
pub fn without_timeout(mut self) -> Self {
self.call_timeout = None;
self
}
pub fn output_limits(&self) -> (Option<usize>, Option<usize>) {
(self.max_output_chars, self.max_content_items)
}
async fn fetch_tools(peer: &Peer<RoleClient>) -> Result<Vec<RawTool>, McpError> {
let mcp_tools = peer
.list_all_tools()
.await
.map_err(|e| McpError::ListTools(e.to_string()))?;
let tools = mcp_tools
.into_iter()
.map(|mcp_tool| {
let parameters = Value::Object(mcp_tool.input_schema.as_ref().clone());
RawTool {
kind: ToolType::Function,
function: FunctionDefinition {
name: mcp_tool.name.to_string(),
description: mcp_tool.description.as_deref().map(str::to_string),
parameters,
strict: None,
},
}
})
.collect();
Ok(tools)
}
}
#[async_trait::async_trait]
impl Tool for McpTool {
fn raw_tools(&self) -> Vec<RawTool> {
match &self.name {
None => self.tools.clone(),
Some(prefix) => self
.tools
.iter()
.map(|t| {
let mut t = t.clone();
t.function.name = format!("{}__{}", prefix, t.function.name);
t
})
.collect(),
}
}
async fn call(
&self,
name: &str,
args: Value,
) -> futures::stream::BoxStream<'static, crate::tool_trait::ToolOutput> {
let real_name = match &self.name {
Some(prefix) => {
let pfx = format!("{}__", prefix);
name.strip_prefix(&pfx).unwrap_or(name)
}
None => name,
};
let arguments = args.as_object().cloned().map(|m| m.into_iter().collect());
let owned_name: std::borrow::Cow<'static, str> = real_name.to_string().into();
let params = match arguments {
Some(args) => CallToolRequestParams::new(owned_name).with_arguments(args),
None => CallToolRequestParams::new(owned_name),
};
let timeout_opt = self.call_timeout;
let max_content_items = self.max_content_items;
let max_output_chars = self.max_output_chars;
let name_str = name.to_string();
let peer = self.peer.clone();
let join = tokio::spawn(async move {
if let Some(timeout) = timeout_opt {
match tokio::time::timeout(timeout, peer.call_tool(params)).await {
Ok(res) => res,
Err(_) => Err(rmcp::ServiceError::Timeout { timeout }),
}
} else {
peer.call_tool(params).await
}
});
let result_stream = async_stream::stream! {
let result = match join.await {
Ok(r) => r,
Err(e) => {
error!(tool = %name_str, "MCP tool call task panicked: {e}");
yield crate::tool_trait::ToolOutput::Result(vec![crate::request::Content::text(format!("{{\"error\":\"{e}\"}}"))]);
return;
}
};
match result {
Ok(result) => {
let mut contents: Vec<Value> = result
.content
.into_iter()
.filter_map(|item| serde_json::to_value(item).ok())
.collect();
if let Some(max_items) = max_content_items
&& contents.len() > max_items {
contents.truncate(max_items);
}
if contents.len() > 1 {
let last = contents.pop().unwrap();
for item in contents {
let text = serde_json::to_string(&item).unwrap_or_default();
yield crate::tool_trait::ToolOutput::Progress(text);
}
let result_value = last;
let json_string = serde_json::to_string(&result_value).unwrap_or_default();
let text = if let Some(max_chars) = max_output_chars {
if json_string.len() > max_chars {
let mut limit = max_chars.saturating_sub(40);
limit = json_string.floor_char_boundary(limit);
format!("{}...<truncated {} chars>", &json_string[..limit], json_string.len())
} else {
json_string
}
} else {
json_string
};
yield crate::tool_trait::ToolOutput::Result(vec![crate::request::Content::text(text)]);
} else {
let result_value = match contents.len() {
0 => serde_json::json!({ "result": null }),
_ => contents.into_iter().next().unwrap(),
};
let json_string = serde_json::to_string(&result_value).unwrap_or_default();
let text = if let Some(max_chars) = max_output_chars {
if json_string.len() > max_chars {
let mut limit = max_chars.saturating_sub(40);
limit = json_string.floor_char_boundary(limit);
format!("{}...<truncated {} chars>", &json_string[..limit], json_string.len())
} else {
json_string
}
} else {
json_string
};
yield crate::tool_trait::ToolOutput::Result(vec![crate::request::Content::text(text)]);
}
}
Err(e) => {
error!(tool = %name_str, error = %e, "MCP tool call failed");
yield crate::tool_trait::ToolOutput::Result(vec![crate::request::Content::text(format!("{{\"error\":\"{e}\"}}"))]);
}
}
};
use futures::StreamExt;
result_stream.boxed()
}
}