use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use rig_compose::registry::{KernelError, ToolRegistry};
use rig_compose::tool::{Tool, ToolSchema};
#[async_trait]
pub trait McpTransport: Send + Sync {
fn endpoint(&self) -> &str;
async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError>;
async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError>;
}
pub struct McpTool {
transport: Arc<dyn McpTransport>,
schema: ToolSchema,
}
impl McpTool {
pub fn new(transport: Arc<dyn McpTransport>, schema: ToolSchema) -> Self {
Self { transport, schema }
}
pub async fn from_transport(
transport: Arc<dyn McpTransport>,
) -> Result<Vec<Arc<dyn Tool>>, KernelError> {
let schemas = transport.list_tools().await?;
Ok(schemas
.into_iter()
.map(|schema| {
let t: Arc<dyn Tool> = Arc::new(McpTool {
transport: transport.clone(),
schema,
});
t
})
.collect())
}
}
#[async_trait]
impl Tool for McpTool {
fn schema(&self) -> ToolSchema {
self.schema.clone()
}
fn name(&self) -> rig_compose::tool::ToolName {
self.schema.name.clone()
}
async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
self.transport.call_tool(&self.schema.name, args).await
}
}
pub struct LoopbackTransport {
endpoint: String,
registry: ToolRegistry,
}
impl LoopbackTransport {
pub fn new(endpoint: impl Into<String>, registry: ToolRegistry) -> Self {
Self {
endpoint: endpoint.into(),
registry,
}
}
}
#[async_trait]
impl McpTransport for LoopbackTransport {
fn endpoint(&self) -> &str {
&self.endpoint
}
async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
Ok(self.registry.schemas())
}
async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
self.registry.invoke(name, args).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig_compose::tool::LocalTool;
use serde_json::json;
fn make_registry() -> ToolRegistry {
let reg = ToolRegistry::new();
reg.register(Arc::new(LocalTool::new(
ToolSchema {
name: "math.add".into(),
description: "add two ints".into(),
args_schema: json!({"type": "object"}),
result_schema: json!({"type": "integer"}),
},
|args| async move {
let a = args["a"].as_i64().unwrap_or(0);
let b = args["b"].as_i64().unwrap_or(0);
Ok(json!(a + b))
},
)));
reg
}
#[tokio::test]
async fn loopback_transport_round_trip() {
let server = make_registry();
let transport: Arc<dyn McpTransport> =
Arc::new(LoopbackTransport::new("loopback://test", server));
let schemas = transport.list_tools().await.unwrap();
assert_eq!(schemas.len(), 1);
assert_eq!(schemas[0].name, "math.add");
let result = transport
.call_tool("math.add", json!({"a": 2, "b": 3}))
.await
.unwrap();
assert_eq!(result, json!(5));
}
#[tokio::test]
async fn mcp_tool_indistinguishable_from_local() {
let server = make_registry();
let transport: Arc<dyn McpTransport> =
Arc::new(LoopbackTransport::new("loopback://test", server));
let client = ToolRegistry::new();
for tool in McpTool::from_transport(transport).await.unwrap() {
client.register(tool);
}
let out = client
.invoke("math.add", json!({"a": 10, "b": 32}))
.await
.unwrap();
assert_eq!(out, json!(42));
}
}