1use std::sync::Arc;
19
20use async_trait::async_trait;
21use serde_json::{Value, json};
22use tokio::process::Command;
23use tracing::{Instrument, field};
24
25use rmcp::model::{
26 CallToolRequestParams, CallToolResult, Content, Implementation, ListToolsResult,
27 PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, Tool as RmcpTool,
28};
29use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt};
30use rmcp::transport::{ConfigureCommandExt, TokioChildProcess, stdio as rmcp_stdio};
31use rmcp::{ErrorData as McpError, ServerHandler};
32
33use crate::transport::McpTransport;
34use rig_compose::registry::{KernelError, ToolRegistry};
35use rig_compose::tool::ToolSchema;
36
37#[derive(Clone)]
47struct RegistryServer {
48 registry: Arc<ToolRegistry>,
49 info: ServerInfo,
50}
51
52impl RegistryServer {
53 fn new(registry: Arc<ToolRegistry>) -> Self {
54 #[allow(clippy::field_reassign_with_default)]
58 let server_info = {
59 let mut s = Implementation::default();
60 s.name = env!("CARGO_PKG_NAME").to_string();
61 s.version = env!("CARGO_PKG_VERSION").to_string();
62 s
63 };
64 #[allow(clippy::field_reassign_with_default)]
65 let info = {
66 let mut i = ServerInfo::default();
67 i.protocol_version = ProtocolVersion::default();
68 i.capabilities = ServerCapabilities::builder().enable_tools().build();
69 i.server_info = server_info;
70 i
71 };
72 Self { registry, info }
73 }
74}
75
76fn schema_to_rmcp_tool(s: ToolSchema) -> RmcpTool {
77 let input_obj = match s.args_schema {
78 Value::Object(map) => map,
79 _ => Default::default(),
80 };
81 let output_obj = match s.result_schema {
82 Value::Object(map) if !map.is_empty() => Some(Arc::new(map)),
83 _ => None,
84 };
85 #[allow(clippy::field_reassign_with_default)]
86 {
87 let mut tool = RmcpTool::default();
88 tool.name = s.name.into();
89 tool.description = Some(s.description.into());
90 tool.input_schema = Arc::new(input_obj);
91 tool.output_schema = output_obj;
92 tool
93 }
94}
95
96impl ServerHandler for RegistryServer {
97 fn get_info(&self) -> ServerInfo {
98 self.info.clone()
99 }
100
101 async fn list_tools(
102 &self,
103 _request: Option<PaginatedRequestParams>,
104 _context: RequestContext<RoleServer>,
105 ) -> Result<ListToolsResult, McpError> {
106 let span = tracing::info_span!(
107 "mcp.stdio_server.list_tools",
108 mcp.transport = "stdio_server",
109 mcp.tool_count = field::Empty,
110 );
111 let span_for_record = span.clone();
112
113 async move {
114 let tools: Vec<_> = self
115 .registry
116 .schemas()
117 .into_iter()
118 .map(schema_to_rmcp_tool)
119 .collect();
120 span_for_record.record("mcp.tool_count", tools.len() as u64);
121 Ok(ListToolsResult {
122 tools,
123 next_cursor: None,
124 meta: None,
125 })
126 }
127 .instrument(span)
128 .await
129 }
130
131 async fn call_tool(
132 &self,
133 request: CallToolRequestParams,
134 _context: RequestContext<RoleServer>,
135 ) -> Result<CallToolResult, McpError> {
136 let name = request.name.to_string();
137 let span = tracing::info_span!(
138 "mcp.stdio_server.call_tool",
139 mcp.transport = "stdio_server",
140 mcp.tool_name = %name,
141 mcp.error = field::Empty,
142 );
143 let span_for_record = span.clone();
144
145 async move {
146 let args = request
147 .arguments
148 .map(Value::Object)
149 .unwrap_or_else(|| json!({}));
150 match self.registry.invoke(&name, args).await {
151 Ok(value) => Ok(CallToolResult::structured(value)),
152 Err(e) => {
153 span_for_record.record("mcp.error", e.to_string());
154 Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
155 }
156 }
157 }
158 .instrument(span)
159 .await
160 }
161}
162
163pub async fn serve_stdio(registry: ToolRegistry) -> Result<(), KernelError> {
166 let span = tracing::info_span!(
167 "mcp.stdio.serve",
168 mcp.transport = "stdio",
169 mcp.error = field::Empty,
170 );
171 let span_for_record = span.clone();
172
173 async move {
174 let server = RegistryServer::new(Arc::new(registry));
175 let service = server.serve(rmcp_stdio()).await.map_err(|e| {
176 let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
177 span_for_record.record("mcp.error", error.to_string());
178 error
179 })?;
180 service.waiting().await.map_err(|e| {
181 let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
182 span_for_record.record("mcp.error", error.to_string());
183 error
184 })?;
185 Ok(())
186 }
187 .instrument(span)
188 .await
189}
190
191pub struct StdioTransport {
203 endpoint: String,
204 peer: Peer<RoleClient>,
205 _service: Arc<RunningService<RoleClient, ()>>,
209}
210
211impl StdioTransport {
212 pub async fn spawn(
217 endpoint: impl Into<String>,
218 program: impl AsRef<std::ffi::OsStr>,
219 args: &[&str],
220 ) -> Result<Self, KernelError> {
221 let endpoint = endpoint.into();
222 let program = program.as_ref().to_owned();
223 let program_name = program.to_string_lossy().to_string();
224 let argv: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
225 let span = tracing::info_span!(
226 "mcp.stdio.spawn",
227 mcp.transport = "stdio",
228 mcp.endpoint = %endpoint,
229 mcp.program = %program_name,
230 mcp.arg_count = argv.len() as u64,
231 mcp.error = field::Empty,
232 );
233 let span_for_record = span.clone();
234
235 async move {
236 let cmd = Command::new(&program).configure(|c| {
237 c.args(&argv);
238 });
239 let transport = TokioChildProcess::new(cmd).map_err(|e| {
240 let error = KernelError::ToolFailed(format!("mcp.spawn: {e}"));
241 span_for_record.record("mcp.error", error.to_string());
242 error
243 })?;
244 let service = ().serve(transport).await.map_err(|e| {
245 let error = KernelError::ToolFailed(format!("mcp.connect: {e}"));
246 span_for_record.record("mcp.error", error.to_string());
247 error
248 })?;
249 let peer = service.peer().clone();
250 Ok(Self {
251 endpoint,
252 peer,
253 _service: Arc::new(service),
254 })
255 }
256 .instrument(span)
257 .await
258 }
259}
260
261#[async_trait]
262impl McpTransport for StdioTransport {
263 fn endpoint(&self) -> &str {
264 &self.endpoint
265 }
266
267 async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
268 let span = tracing::info_span!(
269 "mcp.stdio.list_tools",
270 mcp.transport = "stdio",
271 mcp.endpoint = %self.endpoint,
272 mcp.tool_count = field::Empty,
273 mcp.error = field::Empty,
274 );
275 let span_for_record = span.clone();
276
277 async move {
278 let tools = self.peer.list_all_tools().await.map_err(|e| {
279 let error = KernelError::ToolFailed(format!("tools/list: {e}"));
280 span_for_record.record("mcp.error", error.to_string());
281 error
282 })?;
283 span_for_record.record("mcp.tool_count", tools.len() as u64);
284 Ok(tools.into_iter().map(rmcp_tool_to_schema).collect())
285 }
286 .instrument(span)
287 .await
288 }
289
290 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
291 let span = tracing::info_span!(
292 "mcp.stdio.call_tool",
293 mcp.transport = "stdio",
294 mcp.endpoint = %self.endpoint,
295 mcp.tool_name = %name,
296 mcp.error = field::Empty,
297 );
298 let span_for_record = span.clone();
299
300 async move {
301 let arguments = match args {
302 Value::Object(map) => Some(map),
303 Value::Null => None,
304 other => {
305 let error = KernelError::InvalidArgument(format!(
306 "tools/call requires an object or null arguments, got {other}"
307 ));
308 span_for_record.record("mcp.error", error.to_string());
309 return Err(error);
310 }
311 };
312 let params = {
313 #[allow(clippy::field_reassign_with_default)]
314 let mut p = CallToolRequestParams::default();
315 p.name = name.to_string().into();
316 p.arguments = arguments;
317 p
318 };
319 let result = self.peer.call_tool(params).await.map_err(|e| {
320 let error = KernelError::ToolFailed(format!("tools/call: {e}"));
321 span_for_record.record("mcp.error", error.to_string());
322 error
323 })?;
324
325 if result.is_error.unwrap_or(false) {
326 let msg = result
327 .content
328 .iter()
329 .find_map(|c| c.as_text().map(|t| t.text.clone()))
330 .unwrap_or_else(|| "tool returned error".to_string());
331 let error = KernelError::ToolFailed(msg);
332 span_for_record.record("mcp.error", error.to_string());
333 return Err(error);
334 }
335
336 if let Some(v) = result.structured_content {
339 return Ok(v);
340 }
341 if let Some(text) = result
342 .content
343 .iter()
344 .find_map(|c| c.as_text().map(|t| t.text.clone()))
345 {
346 if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
347 return Ok(parsed);
348 }
349 return Ok(Value::String(text));
350 }
351 Ok(Value::Null)
352 }
353 .instrument(span)
354 .await
355 }
356}
357
358fn rmcp_tool_to_schema(t: RmcpTool) -> ToolSchema {
359 ToolSchema {
360 name: t.name.to_string(),
361 description: t.description.map(|d| d.to_string()).unwrap_or_default(),
362 args_schema: Value::Object((*t.input_schema).clone()),
363 result_schema: t
364 .output_schema
365 .map(|s| Value::Object((*s).clone()))
366 .unwrap_or(Value::Null),
367 }
368}
369
370#[cfg(test)]
375mod tests {
376 use super::*;
377 use rig_compose::tool::LocalTool;
378 use serde_json::json;
379 use std::sync::Arc;
380
381 fn echo_registry() -> ToolRegistry {
382 let reg = ToolRegistry::new();
383 reg.register(Arc::new(LocalTool::new(
384 ToolSchema {
385 name: "math.mul".into(),
386 description: "multiply".into(),
387 args_schema: json!({"type": "object"}),
388 result_schema: json!({"type": "integer"}),
389 },
390 |args: Value| async move {
391 let a = args["a"].as_i64().unwrap_or(0);
392 let b = args["b"].as_i64().unwrap_or(0);
393 Ok(json!(a * b))
394 },
395 )));
396 reg
397 }
398
399 #[tokio::test]
405 async fn registry_server_round_trip_via_tool_trait() {
406 let registry = echo_registry();
407 let tool = registry.get("math.mul").unwrap();
408 let out = tool.invoke(json!({"a": 6, "b": 7})).await.unwrap();
409 assert_eq!(out, json!(42));
410 }
411}