use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{Value, json};
use tokio::process::Command;
use tracing::{Instrument, field};
use rmcp::model::{
CallToolRequestParams, CallToolResult, Content, Implementation, ListToolsResult,
PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, Tool as RmcpTool,
};
use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt};
use rmcp::transport::{ConfigureCommandExt, TokioChildProcess, stdio as rmcp_stdio};
use rmcp::{ErrorData as McpError, ServerHandler};
use crate::transport::McpTransport;
use rig_compose::registry::{KernelError, ToolRegistry};
use rig_compose::tool::ToolSchema;
#[derive(Clone)]
struct RegistryServer {
registry: Arc<ToolRegistry>,
info: ServerInfo,
}
impl RegistryServer {
fn new(registry: Arc<ToolRegistry>) -> Self {
#[allow(clippy::field_reassign_with_default)]
let server_info = {
let mut s = Implementation::default();
s.name = env!("CARGO_PKG_NAME").to_string();
s.version = env!("CARGO_PKG_VERSION").to_string();
s
};
#[allow(clippy::field_reassign_with_default)]
let info = {
let mut i = ServerInfo::default();
i.protocol_version = ProtocolVersion::default();
i.capabilities = ServerCapabilities::builder().enable_tools().build();
i.server_info = server_info;
i
};
Self { registry, info }
}
}
fn schema_to_rmcp_tool(s: ToolSchema) -> RmcpTool {
let input_obj = match s.args_schema {
Value::Object(map) => map,
_ => Default::default(),
};
let output_obj = match s.result_schema {
Value::Object(map) if !map.is_empty() => Some(Arc::new(map)),
_ => None,
};
#[allow(clippy::field_reassign_with_default)]
{
let mut tool = RmcpTool::default();
tool.name = s.name.into();
tool.description = Some(s.description.into());
tool.input_schema = Arc::new(input_obj);
tool.output_schema = output_obj;
tool
}
}
impl ServerHandler for RegistryServer {
fn get_info(&self) -> ServerInfo {
self.info.clone()
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_context: RequestContext<RoleServer>,
) -> Result<ListToolsResult, McpError> {
let span = tracing::info_span!(
"mcp.stdio_server.list_tools",
mcp.transport = "stdio_server",
mcp.tool_count = field::Empty,
);
let span_for_record = span.clone();
async move {
let tools: Vec<_> = self
.registry
.schemas()
.into_iter()
.map(schema_to_rmcp_tool)
.collect();
span_for_record.record("mcp.tool_count", tools.len() as u64);
Ok(ListToolsResult {
tools,
next_cursor: None,
meta: None,
})
}
.instrument(span)
.await
}
async fn call_tool(
&self,
request: CallToolRequestParams,
_context: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
let name = request.name.to_string();
let span = tracing::info_span!(
"mcp.stdio_server.call_tool",
mcp.transport = "stdio_server",
mcp.tool_name = %name,
mcp.error = field::Empty,
);
let span_for_record = span.clone();
async move {
let args = request
.arguments
.map(Value::Object)
.unwrap_or_else(|| json!({}));
match self.registry.invoke(&name, args).await {
Ok(value) => Ok(CallToolResult::structured(value)),
Err(e) => {
span_for_record.record("mcp.error", e.to_string());
Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
}
}
}
.instrument(span)
.await
}
}
pub async fn serve_stdio(registry: ToolRegistry) -> Result<(), KernelError> {
let span = tracing::info_span!(
"mcp.stdio.serve",
mcp.transport = "stdio",
mcp.error = field::Empty,
);
let span_for_record = span.clone();
async move {
let server = RegistryServer::new(Arc::new(registry));
let service = server.serve(rmcp_stdio()).await.map_err(|e| {
let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
service.waiting().await.map_err(|e| {
let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
Ok(())
}
.instrument(span)
.await
}
pub struct StdioTransport {
endpoint: String,
peer: Peer<RoleClient>,
_service: Arc<RunningService<RoleClient, ()>>,
}
impl StdioTransport {
pub async fn spawn(
endpoint: impl Into<String>,
program: impl AsRef<std::ffi::OsStr>,
args: &[&str],
) -> Result<Self, KernelError> {
let endpoint = endpoint.into();
let program = program.as_ref().to_owned();
let program_name = program.to_string_lossy().to_string();
let argv: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
let span = tracing::info_span!(
"mcp.stdio.spawn",
mcp.transport = "stdio",
mcp.endpoint = %endpoint,
mcp.program = %program_name,
mcp.arg_count = argv.len() as u64,
mcp.error = field::Empty,
);
let span_for_record = span.clone();
async move {
let cmd = Command::new(&program).configure(|c| {
c.args(&argv);
});
let transport = TokioChildProcess::new(cmd).map_err(|e| {
let error = KernelError::ToolFailed(format!("mcp.spawn: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
let service = ().serve(transport).await.map_err(|e| {
let error = KernelError::ToolFailed(format!("mcp.connect: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
let peer = service.peer().clone();
Ok(Self {
endpoint,
peer,
_service: Arc::new(service),
})
}
.instrument(span)
.await
}
}
#[async_trait]
impl McpTransport for StdioTransport {
fn endpoint(&self) -> &str {
&self.endpoint
}
async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
let span = tracing::info_span!(
"mcp.stdio.list_tools",
mcp.transport = "stdio",
mcp.endpoint = %self.endpoint,
mcp.tool_count = field::Empty,
mcp.error = field::Empty,
);
let span_for_record = span.clone();
async move {
let tools = self.peer.list_all_tools().await.map_err(|e| {
let error = KernelError::ToolFailed(format!("tools/list: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
span_for_record.record("mcp.tool_count", tools.len() as u64);
Ok(tools.into_iter().map(rmcp_tool_to_schema).collect())
}
.instrument(span)
.await
}
async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
let span = tracing::info_span!(
"mcp.stdio.call_tool",
mcp.transport = "stdio",
mcp.endpoint = %self.endpoint,
mcp.tool_name = %name,
mcp.error = field::Empty,
);
let span_for_record = span.clone();
async move {
let arguments = match args {
Value::Object(map) => Some(map),
Value::Null => None,
other => {
let error = KernelError::InvalidArgument(format!(
"tools/call requires an object or null arguments, got {other}"
));
span_for_record.record("mcp.error", error.to_string());
return Err(error);
}
};
let params = {
#[allow(clippy::field_reassign_with_default)]
let mut p = CallToolRequestParams::default();
p.name = name.to_string().into();
p.arguments = arguments;
p
};
let result = self.peer.call_tool(params).await.map_err(|e| {
let error = KernelError::ToolFailed(format!("tools/call: {e}"));
span_for_record.record("mcp.error", error.to_string());
error
})?;
if result.is_error.unwrap_or(false) {
let msg = result
.content
.iter()
.find_map(|c| c.as_text().map(|t| t.text.clone()))
.unwrap_or_else(|| "tool returned error".to_string());
let error = KernelError::ToolFailed(msg);
span_for_record.record("mcp.error", error.to_string());
return Err(error);
}
if let Some(v) = result.structured_content {
return Ok(v);
}
if let Some(text) = result
.content
.iter()
.find_map(|c| c.as_text().map(|t| t.text.clone()))
{
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
return Ok(parsed);
}
return Ok(Value::String(text));
}
Ok(Value::Null)
}
.instrument(span)
.await
}
}
fn rmcp_tool_to_schema(t: RmcpTool) -> ToolSchema {
ToolSchema {
name: t.name.to_string(),
description: t.description.map(|d| d.to_string()).unwrap_or_default(),
args_schema: Value::Object((*t.input_schema).clone()),
result_schema: t
.output_schema
.map(|s| Value::Object((*s).clone()))
.unwrap_or(Value::Null),
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig_compose::tool::LocalTool;
use serde_json::json;
use std::sync::Arc;
fn echo_registry() -> ToolRegistry {
let reg = ToolRegistry::new();
reg.register(Arc::new(LocalTool::new(
ToolSchema {
name: "math.mul".into(),
description: "multiply".into(),
args_schema: json!({"type": "object"}),
result_schema: json!({"type": "integer"}),
},
|args: Value| async move {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
Ok(json!(a * b))
},
)));
reg
}
#[tokio::test]
async fn registry_server_round_trip_via_tool_trait() {
let registry = echo_registry();
let tool = registry.get("math.mul").unwrap();
let out = tool.invoke(json!({"a": 6, "b": 7})).await.unwrap();
assert_eq!(out, json!(42));
}
}