Skip to main content

rig_mcp/
stdio.rs

1//! Stdio MCP transport, backed by the official [`rmcp`] SDK.
2//!
3//! This module bridges between [`rig_compose`]'s transport-agnostic
4//! [`Tool`](rig_compose::tool::Tool) surface and rmcp's spec-compliant
5//! MCP implementation. Everything spec-related (JSON-RPC framing,
6//! capability negotiation, version handshakes) is delegated to rmcp;
7//! we only translate at the seam.
8//!
9//! Public surface (kept stable across the rmcp migration):
10//!
11//! * [`StdioTransport::spawn`] — spawn a child binary and speak MCP
12//!   over its stdio. Implements [`McpTransport`] so the resulting
13//!   handle is interchangeable with any other transport.
14//! * [`serve_stdio`] — expose a [`ToolRegistry`] as an MCP server on
15//!   the current process's stdin/stdout. Intended for `--mcp-serve`
16//!   style CLI flags.
17
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use serde_json::{Value, json};
22use tokio::process::Command;
23use tracing::{Instrument, field};
24
25use rmcp::model::{
26    CallToolRequestParams, CallToolResult, Content, Implementation, ListToolsResult,
27    PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, Tool as RmcpTool,
28};
29use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt};
30use rmcp::transport::{ConfigureCommandExt, TokioChildProcess, stdio as rmcp_stdio};
31use rmcp::{ErrorData as McpError, ServerHandler};
32
33use crate::transport::McpTransport;
34use rig_compose::registry::{KernelError, ToolRegistry};
35use rig_compose::tool::ToolSchema;
36
37// =============================================================================
38// Server side: expose a ToolRegistry as an rmcp ServerHandler
39// =============================================================================
40
41/// Adapter that wears [`ServerHandler`] over a [`ToolRegistry`]. Every
42/// `tools/list` is answered from `registry.schemas()`; every
43/// `tools/call` dispatches to `registry.invoke()`. No prompts,
44/// resources, or sampling are advertised — clients see a tools-only
45/// server.
46#[derive(Clone)]
47struct RegistryServer {
48    registry: Arc<ToolRegistry>,
49    info: ServerInfo,
50}
51
52impl RegistryServer {
53    fn new(registry: Arc<ToolRegistry>) -> Self {
54        // rmcp's `Implementation` and `ServerInfo` are `#[non_exhaustive]`,
55        // so we can't use a struct literal. Build via `Default::default`
56        // and assign field-by-field.
57        #[allow(clippy::field_reassign_with_default)]
58        let server_info = {
59            let mut s = Implementation::default();
60            s.name = env!("CARGO_PKG_NAME").to_string();
61            s.version = env!("CARGO_PKG_VERSION").to_string();
62            s
63        };
64        #[allow(clippy::field_reassign_with_default)]
65        let info = {
66            let mut i = ServerInfo::default();
67            i.protocol_version = ProtocolVersion::default();
68            i.capabilities = ServerCapabilities::builder().enable_tools().build();
69            i.server_info = server_info;
70            i
71        };
72        Self { registry, info }
73    }
74}
75
76fn schema_to_rmcp_tool(s: ToolSchema) -> RmcpTool {
77    let input_obj = match s.args_schema {
78        Value::Object(map) => map,
79        _ => Default::default(),
80    };
81    let output_obj = match s.result_schema {
82        Value::Object(map) if !map.is_empty() => Some(Arc::new(map)),
83        _ => None,
84    };
85    #[allow(clippy::field_reassign_with_default)]
86    {
87        let mut tool = RmcpTool::default();
88        tool.name = s.name.into();
89        tool.description = Some(s.description.into());
90        tool.input_schema = Arc::new(input_obj);
91        tool.output_schema = output_obj;
92        tool
93    }
94}
95
96impl ServerHandler for RegistryServer {
97    fn get_info(&self) -> ServerInfo {
98        self.info.clone()
99    }
100
101    async fn list_tools(
102        &self,
103        _request: Option<PaginatedRequestParams>,
104        _context: RequestContext<RoleServer>,
105    ) -> Result<ListToolsResult, McpError> {
106        let span = tracing::info_span!(
107            "mcp.stdio_server.list_tools",
108            mcp.transport = "stdio_server",
109            mcp.tool_count = field::Empty,
110        );
111        let span_for_record = span.clone();
112
113        async move {
114            let tools: Vec<_> = self
115                .registry
116                .schemas()
117                .into_iter()
118                .map(schema_to_rmcp_tool)
119                .collect();
120            span_for_record.record("mcp.tool_count", tools.len() as u64);
121            Ok(ListToolsResult {
122                tools,
123                next_cursor: None,
124                meta: None,
125            })
126        }
127        .instrument(span)
128        .await
129    }
130
131    async fn call_tool(
132        &self,
133        request: CallToolRequestParams,
134        _context: RequestContext<RoleServer>,
135    ) -> Result<CallToolResult, McpError> {
136        let name = request.name.to_string();
137        let span = tracing::info_span!(
138            "mcp.stdio_server.call_tool",
139            mcp.transport = "stdio_server",
140            mcp.tool_name = %name,
141            mcp.error = field::Empty,
142        );
143        let span_for_record = span.clone();
144
145        async move {
146            let args = request
147                .arguments
148                .map(Value::Object)
149                .unwrap_or_else(|| json!({}));
150            match self.registry.invoke(&name, args).await {
151                Ok(value) => Ok(CallToolResult::structured(value)),
152                Err(e) => {
153                    span_for_record.record("mcp.error", e.to_string());
154                    Ok(CallToolResult::error(vec![Content::text(e.to_string())]))
155                }
156            }
157        }
158        .instrument(span)
159        .await
160    }
161}
162
163/// Serve `registry` over stdin/stdout using rmcp's spec-compliant stdio
164/// transport. Returns when the peer disconnects.
165pub async fn serve_stdio(registry: ToolRegistry) -> Result<(), KernelError> {
166    let span = tracing::info_span!(
167        "mcp.stdio.serve",
168        mcp.transport = "stdio",
169        mcp.error = field::Empty,
170    );
171    let span_for_record = span.clone();
172
173    async move {
174        let server = RegistryServer::new(Arc::new(registry));
175        let service = server.serve(rmcp_stdio()).await.map_err(|e| {
176            let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
177            span_for_record.record("mcp.error", error.to_string());
178            error
179        })?;
180        service.waiting().await.map_err(|e| {
181            let error = KernelError::ToolFailed(format!("mcp.serve: {e}"));
182            span_for_record.record("mcp.error", error.to_string());
183            error
184        })?;
185        Ok(())
186    }
187    .instrument(span)
188    .await
189}
190
191// =============================================================================
192// Client side: spawn a child process and speak MCP over its stdio
193// =============================================================================
194
195/// Production stdio MCP client. Wraps an [`rmcp`] running service so
196/// that callers see only the [`McpTransport`] trait.
197///
198/// The cloneable [`Peer`] is cached at construction time so every
199/// `list_tools` / `call_tool` is a lock-free dispatch into rmcp.
200/// Concurrent calls fan out without serialising on a transport-level
201/// mutex; rmcp itself multiplexes the underlying stdio channel.
202pub struct StdioTransport {
203    endpoint: String,
204    peer: Peer<RoleClient>,
205    /// Keeps the rmcp service task alive for the lifetime of the
206    /// transport. Held but never read — dropping the transport drops
207    /// the service, which closes the child's stdio.
208    _service: Arc<RunningService<RoleClient, ()>>,
209}
210
211impl StdioTransport {
212    /// Spawn `program` with `args` and connect over its stdio.
213    ///
214    /// `endpoint` is a free-form identifier surfaced via
215    /// [`McpTransport::endpoint`]; it has no protocol meaning.
216    pub async fn spawn(
217        endpoint: impl Into<String>,
218        program: impl AsRef<std::ffi::OsStr>,
219        args: &[&str],
220    ) -> Result<Self, KernelError> {
221        let endpoint = endpoint.into();
222        let program = program.as_ref().to_owned();
223        let program_name = program.to_string_lossy().to_string();
224        let argv: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
225        let span = tracing::info_span!(
226            "mcp.stdio.spawn",
227            mcp.transport = "stdio",
228            mcp.endpoint = %endpoint,
229            mcp.program = %program_name,
230            mcp.arg_count = argv.len() as u64,
231            mcp.error = field::Empty,
232        );
233        let span_for_record = span.clone();
234
235        async move {
236            let cmd = Command::new(&program).configure(|c| {
237                c.args(&argv);
238            });
239            let transport = TokioChildProcess::new(cmd).map_err(|e| {
240                let error = KernelError::ToolFailed(format!("mcp.spawn: {e}"));
241                span_for_record.record("mcp.error", error.to_string());
242                error
243            })?;
244            let service = ().serve(transport).await.map_err(|e| {
245                let error = KernelError::ToolFailed(format!("mcp.connect: {e}"));
246                span_for_record.record("mcp.error", error.to_string());
247                error
248            })?;
249            let peer = service.peer().clone();
250            Ok(Self {
251                endpoint,
252                peer,
253                _service: Arc::new(service),
254            })
255        }
256        .instrument(span)
257        .await
258    }
259}
260
261#[async_trait]
262impl McpTransport for StdioTransport {
263    fn endpoint(&self) -> &str {
264        &self.endpoint
265    }
266
267    async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
268        let span = tracing::info_span!(
269            "mcp.stdio.list_tools",
270            mcp.transport = "stdio",
271            mcp.endpoint = %self.endpoint,
272            mcp.tool_count = field::Empty,
273            mcp.error = field::Empty,
274        );
275        let span_for_record = span.clone();
276
277        async move {
278            let tools = self.peer.list_all_tools().await.map_err(|e| {
279                let error = KernelError::ToolFailed(format!("tools/list: {e}"));
280                span_for_record.record("mcp.error", error.to_string());
281                error
282            })?;
283            span_for_record.record("mcp.tool_count", tools.len() as u64);
284            Ok(tools.into_iter().map(rmcp_tool_to_schema).collect())
285        }
286        .instrument(span)
287        .await
288    }
289
290    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
291        let span = tracing::info_span!(
292            "mcp.stdio.call_tool",
293            mcp.transport = "stdio",
294            mcp.endpoint = %self.endpoint,
295            mcp.tool_name = %name,
296            mcp.error = field::Empty,
297        );
298        let span_for_record = span.clone();
299
300        async move {
301            let arguments = match args {
302                Value::Object(map) => Some(map),
303                Value::Null => None,
304                other => {
305                    let error = KernelError::InvalidArgument(format!(
306                        "tools/call requires an object or null arguments, got {other}"
307                    ));
308                    span_for_record.record("mcp.error", error.to_string());
309                    return Err(error);
310                }
311            };
312            let params = {
313                #[allow(clippy::field_reassign_with_default)]
314                let mut p = CallToolRequestParams::default();
315                p.name = name.to_string().into();
316                p.arguments = arguments;
317                p
318            };
319            let result = self.peer.call_tool(params).await.map_err(|e| {
320                let error = KernelError::ToolFailed(format!("tools/call: {e}"));
321                span_for_record.record("mcp.error", error.to_string());
322                error
323            })?;
324
325            if result.is_error.unwrap_or(false) {
326                let msg = result
327                    .content
328                    .iter()
329                    .find_map(|c| c.as_text().map(|t| t.text.clone()))
330                    .unwrap_or_else(|| "tool returned error".to_string());
331                let error = KernelError::ToolFailed(msg);
332                span_for_record.record("mcp.error", error.to_string());
333                return Err(error);
334            }
335
336            // Prefer typed structured content; fall back to first text block parsed
337            // as JSON, then to the raw text wrapped in a string Value.
338            if let Some(v) = result.structured_content {
339                return Ok(v);
340            }
341            if let Some(text) = result
342                .content
343                .iter()
344                .find_map(|c| c.as_text().map(|t| t.text.clone()))
345            {
346                if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
347                    return Ok(parsed);
348                }
349                return Ok(Value::String(text));
350            }
351            Ok(Value::Null)
352        }
353        .instrument(span)
354        .await
355    }
356}
357
358fn rmcp_tool_to_schema(t: RmcpTool) -> ToolSchema {
359    ToolSchema {
360        name: t.name.to_string(),
361        description: t.description.map(|d| d.to_string()).unwrap_or_default(),
362        args_schema: Value::Object((*t.input_schema).clone()),
363        result_schema: t
364            .output_schema
365            .map(|s| Value::Object((*s).clone()))
366            .unwrap_or(Value::Null),
367    }
368}
369
370// =============================================================================
371// Tests — round-trip a registry through a real spawn() of the test bin
372// =============================================================================
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use rig_compose::tool::LocalTool;
378    use serde_json::json;
379    use std::sync::Arc;
380
381    fn echo_registry() -> ToolRegistry {
382        let reg = ToolRegistry::new();
383        reg.register(Arc::new(LocalTool::new(
384            ToolSchema {
385                name: "math.mul".into(),
386                description: "multiply".into(),
387                args_schema: json!({"type": "object"}),
388                result_schema: json!({"type": "integer"}),
389            },
390            |args: Value| async move {
391                let a = args["a"].as_i64().unwrap_or(0);
392                let b = args["b"].as_i64().unwrap_or(0);
393                Ok(json!(a * b))
394            },
395        )));
396        reg
397    }
398
399    /// Verify `serve_stdio` actually constructs a working server. We
400    /// don't drive the wire here — that's covered by the `mcp_serve_cli`
401    /// tests in azreal which spawn the real binary. This is a smoke
402    /// test that the rmcp wiring compiles and the registry can be
403    /// observed through the same `Tool` trait used by skills.
404    #[tokio::test]
405    async fn registry_server_round_trip_via_tool_trait() {
406        let registry = echo_registry();
407        let tool = registry.get("math.mul").unwrap();
408        let out = tool.invoke(json!({"a": 6, "b": 7})).await.unwrap();
409        assert_eq!(out, json!(42));
410    }
411}