1use std::sync::Arc;
19
20use async_trait::async_trait;
21use serde_json::{Value, json};
22use tokio::process::Command;
23
24use rmcp::model::{
25 CallToolRequestParams, CallToolResult, Content, Implementation, ListToolsResult,
26 PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, Tool as RmcpTool,
27};
28use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt};
29use rmcp::transport::{ConfigureCommandExt, TokioChildProcess, stdio as rmcp_stdio};
30use rmcp::{ErrorData as McpError, ServerHandler};
31
32use crate::transport::McpTransport;
33use rig_compose::registry::{KernelError, ToolRegistry};
34use rig_compose::tool::ToolSchema;
35
36#[derive(Clone)]
46struct RegistryServer {
47 registry: Arc<ToolRegistry>,
48 info: ServerInfo,
49}
50
51impl RegistryServer {
52 fn new(registry: Arc<ToolRegistry>) -> Self {
53 #[allow(clippy::field_reassign_with_default)]
57 let server_info = {
58 let mut s = Implementation::default();
59 s.name = env!("CARGO_PKG_NAME").to_string();
60 s.version = env!("CARGO_PKG_VERSION").to_string();
61 s
62 };
63 #[allow(clippy::field_reassign_with_default)]
64 let info = {
65 let mut i = ServerInfo::default();
66 i.protocol_version = ProtocolVersion::default();
67 i.capabilities = ServerCapabilities::builder().enable_tools().build();
68 i.server_info = server_info;
69 i
70 };
71 Self { registry, info }
72 }
73}
74
75fn schema_to_rmcp_tool(s: ToolSchema) -> RmcpTool {
76 let input_obj = match s.args_schema {
77 Value::Object(map) => map,
78 _ => Default::default(),
79 };
80 let output_obj = match s.result_schema {
81 Value::Object(map) if !map.is_empty() => Some(Arc::new(map)),
82 _ => None,
83 };
84 #[allow(clippy::field_reassign_with_default)]
85 {
86 let mut tool = RmcpTool::default();
87 tool.name = s.name.into();
88 tool.description = Some(s.description.into());
89 tool.input_schema = Arc::new(input_obj);
90 tool.output_schema = output_obj;
91 tool
92 }
93}
94
95impl ServerHandler for RegistryServer {
96 fn get_info(&self) -> ServerInfo {
97 self.info.clone()
98 }
99
100 async fn list_tools(
101 &self,
102 _request: Option<PaginatedRequestParams>,
103 _context: RequestContext<RoleServer>,
104 ) -> Result<ListToolsResult, McpError> {
105 let tools = self
106 .registry
107 .schemas()
108 .into_iter()
109 .map(schema_to_rmcp_tool)
110 .collect();
111 Ok(ListToolsResult {
112 tools,
113 next_cursor: None,
114 meta: None,
115 })
116 }
117
118 async fn call_tool(
119 &self,
120 request: CallToolRequestParams,
121 _context: RequestContext<RoleServer>,
122 ) -> Result<CallToolResult, McpError> {
123 let name = request.name.to_string();
124 let args = request
125 .arguments
126 .map(Value::Object)
127 .unwrap_or_else(|| json!({}));
128 match self.registry.invoke(&name, args).await {
129 Ok(value) => Ok(CallToolResult::structured(value)),
130 Err(e) => Ok(CallToolResult::error(vec![Content::text(e.to_string())])),
131 }
132 }
133}
134
135pub async fn serve_stdio(registry: ToolRegistry) -> Result<(), KernelError> {
138 let server = RegistryServer::new(Arc::new(registry));
139 let service = server
140 .serve(rmcp_stdio())
141 .await
142 .map_err(|e| KernelError::ToolFailed(format!("mcp.serve: {e}")))?;
143 service
144 .waiting()
145 .await
146 .map_err(|e| KernelError::ToolFailed(format!("mcp.serve: {e}")))?;
147 Ok(())
148}
149
150pub struct StdioTransport {
162 endpoint: String,
163 peer: Peer<RoleClient>,
164 _service: Arc<RunningService<RoleClient, ()>>,
168}
169
170impl StdioTransport {
171 pub async fn spawn(
176 endpoint: impl Into<String>,
177 program: impl AsRef<std::ffi::OsStr>,
178 args: &[&str],
179 ) -> Result<Self, KernelError> {
180 let program = program.as_ref().to_owned();
181 let argv: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
182 let cmd = Command::new(&program).configure(|c| {
183 c.args(&argv);
184 });
185 let transport = TokioChildProcess::new(cmd)
186 .map_err(|e| KernelError::ToolFailed(format!("mcp.spawn: {e}")))?;
187 let service = ()
188 .serve(transport)
189 .await
190 .map_err(|e| KernelError::ToolFailed(format!("mcp.connect: {e}")))?;
191 let peer = service.peer().clone();
192 Ok(Self {
193 endpoint: endpoint.into(),
194 peer,
195 _service: Arc::new(service),
196 })
197 }
198}
199
200#[async_trait]
201impl McpTransport for StdioTransport {
202 fn endpoint(&self) -> &str {
203 &self.endpoint
204 }
205
206 async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
207 let tools = self
208 .peer
209 .list_all_tools()
210 .await
211 .map_err(|e| KernelError::ToolFailed(format!("tools/list: {e}")))?;
212 Ok(tools.into_iter().map(rmcp_tool_to_schema).collect())
213 }
214
215 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
216 let arguments = match args {
217 Value::Object(map) => Some(map),
218 Value::Null => None,
219 other => {
220 return Err(KernelError::InvalidArgument(format!(
221 "tools/call requires an object or null arguments, got {other}"
222 )));
223 }
224 };
225 let params = {
226 #[allow(clippy::field_reassign_with_default)]
227 let mut p = CallToolRequestParams::default();
228 p.name = name.to_string().into();
229 p.arguments = arguments;
230 p
231 };
232 let result = self
233 .peer
234 .call_tool(params)
235 .await
236 .map_err(|e| KernelError::ToolFailed(format!("tools/call: {e}")))?;
237
238 if result.is_error.unwrap_or(false) {
239 let msg = result
240 .content
241 .iter()
242 .find_map(|c| c.as_text().map(|t| t.text.clone()))
243 .unwrap_or_else(|| "tool returned error".to_string());
244 return Err(KernelError::ToolFailed(msg));
245 }
246
247 if let Some(v) = result.structured_content {
250 return Ok(v);
251 }
252 if let Some(text) = result
253 .content
254 .iter()
255 .find_map(|c| c.as_text().map(|t| t.text.clone()))
256 {
257 if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
258 return Ok(parsed);
259 }
260 return Ok(Value::String(text));
261 }
262 Ok(Value::Null)
263 }
264}
265
266fn rmcp_tool_to_schema(t: RmcpTool) -> ToolSchema {
267 ToolSchema {
268 name: t.name.to_string(),
269 description: t.description.map(|d| d.to_string()).unwrap_or_default(),
270 args_schema: Value::Object((*t.input_schema).clone()),
271 result_schema: t
272 .output_schema
273 .map(|s| Value::Object((*s).clone()))
274 .unwrap_or(Value::Null),
275 }
276}
277
278#[cfg(test)]
283mod tests {
284 use super::*;
285 use rig_compose::tool::LocalTool;
286 use serde_json::json;
287 use std::sync::Arc;
288
289 fn echo_registry() -> ToolRegistry {
290 let reg = ToolRegistry::new();
291 reg.register(Arc::new(LocalTool::new(
292 ToolSchema {
293 name: "math.mul".into(),
294 description: "multiply".into(),
295 args_schema: json!({"type": "object"}),
296 result_schema: json!({"type": "integer"}),
297 },
298 |args: Value| async move {
299 let a = args["a"].as_i64().unwrap_or(0);
300 let b = args["b"].as_i64().unwrap_or(0);
301 Ok(json!(a * b))
302 },
303 )));
304 reg
305 }
306
307 #[tokio::test]
313 async fn registry_server_round_trip_via_tool_trait() {
314 let registry = echo_registry();
315 let tool = registry.get("math.mul").unwrap();
316 let out = tool.invoke(json!({"a": 6, "b": 7})).await.unwrap();
317 assert_eq!(out, json!(42));
318 }
319}