1use std::sync::Arc;
14
15use async_trait::async_trait;
16use serde_json::Value;
17
18use rig_compose::registry::{KernelError, ToolRegistry};
19use rig_compose::tool::{Tool, ToolSchema};
20
21#[async_trait]
25pub trait McpTransport: Send + Sync {
26 fn endpoint(&self) -> &str;
29
30 async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError>;
33
34 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError>;
37}
38
39pub struct McpTool {
43 transport: Arc<dyn McpTransport>,
44 schema: ToolSchema,
45}
46
47impl McpTool {
48 pub fn new(transport: Arc<dyn McpTransport>, schema: ToolSchema) -> Self {
49 Self { transport, schema }
50 }
51
52 pub async fn from_transport(
57 transport: Arc<dyn McpTransport>,
58 ) -> Result<Vec<Arc<dyn Tool>>, KernelError> {
59 let schemas = transport.list_tools().await?;
60 Ok(schemas
61 .into_iter()
62 .map(|schema| {
63 let t: Arc<dyn Tool> = Arc::new(McpTool {
64 transport: transport.clone(),
65 schema,
66 });
67 t
68 })
69 .collect())
70 }
71}
72
73#[async_trait]
74impl Tool for McpTool {
75 fn schema(&self) -> ToolSchema {
76 self.schema.clone()
77 }
78
79 fn name(&self) -> rig_compose::tool::ToolName {
80 self.schema.name.clone()
81 }
82
83 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
84 self.transport.call_tool(&self.schema.name, args).await
85 }
86}
87
88pub struct LoopbackTransport {
100 endpoint: String,
101 registry: ToolRegistry,
102}
103
104impl LoopbackTransport {
105 pub fn new(endpoint: impl Into<String>, registry: ToolRegistry) -> Self {
106 Self {
107 endpoint: endpoint.into(),
108 registry,
109 }
110 }
111}
112
113#[async_trait]
114impl McpTransport for LoopbackTransport {
115 fn endpoint(&self) -> &str {
116 &self.endpoint
117 }
118
119 async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
120 Ok(self.registry.schemas())
124 }
125
126 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
127 self.registry.invoke(name, args).await
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use rig_compose::tool::LocalTool;
135 use serde_json::json;
136
137 fn make_registry() -> ToolRegistry {
138 let reg = ToolRegistry::new();
139 reg.register(Arc::new(LocalTool::new(
140 ToolSchema {
141 name: "math.add".into(),
142 description: "add two ints".into(),
143 args_schema: json!({"type": "object"}),
144 result_schema: json!({"type": "integer"}),
145 },
146 |args| async move {
147 let a = args["a"].as_i64().unwrap_or(0);
148 let b = args["b"].as_i64().unwrap_or(0);
149 Ok(json!(a + b))
150 },
151 )));
152 reg
153 }
154
155 #[tokio::test]
156 async fn loopback_transport_round_trip() {
157 let server = make_registry();
158 let transport: Arc<dyn McpTransport> =
159 Arc::new(LoopbackTransport::new("loopback://test", server));
160
161 let schemas = transport.list_tools().await.unwrap();
162 assert_eq!(schemas.len(), 1);
163 assert_eq!(schemas[0].name, "math.add");
164
165 let result = transport
166 .call_tool("math.add", json!({"a": 2, "b": 3}))
167 .await
168 .unwrap();
169 assert_eq!(result, json!(5));
170 }
171
172 #[tokio::test]
173 async fn mcp_tool_indistinguishable_from_local() {
174 let server = make_registry();
179 let transport: Arc<dyn McpTransport> =
180 Arc::new(LoopbackTransport::new("loopback://test", server));
181
182 let client = ToolRegistry::new();
183 for tool in McpTool::from_transport(transport).await.unwrap() {
184 client.register(tool);
185 }
186
187 let out = client
188 .invoke("math.add", json!({"a": 10, "b": 32}))
189 .await
190 .unwrap();
191 assert_eq!(out, json!(42));
192 }
193}