Skip to main content

rig_mcp/
stdio.rs

1//! Stdio MCP transport, backed by the official [`rmcp`] SDK.
2//!
3//! This module bridges between [`rig_compose`]'s transport-agnostic
4//! [`Tool`](rig_compose::tool::Tool) surface and rmcp's spec-compliant
5//! MCP implementation. Everything spec-related (JSON-RPC framing,
6//! capability negotiation, version handshakes) is delegated to rmcp;
7//! we only translate at the seam.
8//!
9//! Public surface (kept stable across the rmcp migration):
10//!
11//! * [`StdioTransport::spawn`] — spawn a child binary and speak MCP
12//!   over its stdio. Implements [`McpTransport`] so the resulting
13//!   handle is interchangeable with any other transport.
14//! * [`serve_stdio`] — expose a [`ToolRegistry`] as an MCP server on
15//!   the current process's stdin/stdout. Intended for `--mcp-serve`
16//!   style CLI flags.
17
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use serde_json::{Value, json};
22use tokio::process::Command;
23
24use rmcp::model::{
25    CallToolRequestParams, CallToolResult, Content, Implementation, ListToolsResult,
26    PaginatedRequestParams, ProtocolVersion, ServerCapabilities, ServerInfo, Tool as RmcpTool,
27};
28use rmcp::service::{Peer, RequestContext, RoleClient, RoleServer, RunningService, ServiceExt};
29use rmcp::transport::{ConfigureCommandExt, TokioChildProcess, stdio as rmcp_stdio};
30use rmcp::{ErrorData as McpError, ServerHandler};
31
32use crate::transport::McpTransport;
33use rig_compose::registry::{KernelError, ToolRegistry};
34use rig_compose::tool::ToolSchema;
35
36// =============================================================================
37// Server side: expose a ToolRegistry as an rmcp ServerHandler
38// =============================================================================
39
40/// Adapter that wears [`ServerHandler`] over a [`ToolRegistry`]. Every
41/// `tools/list` is answered from `registry.schemas()`; every
42/// `tools/call` dispatches to `registry.invoke()`. No prompts,
43/// resources, or sampling are advertised — clients see a tools-only
44/// server.
45#[derive(Clone)]
46struct RegistryServer {
47    registry: Arc<ToolRegistry>,
48    info: ServerInfo,
49}
50
51impl RegistryServer {
52    fn new(registry: Arc<ToolRegistry>) -> Self {
53        // rmcp's `Implementation` and `ServerInfo` are `#[non_exhaustive]`,
54        // so we can't use a struct literal. Build via `Default::default`
55        // and assign field-by-field.
56        #[allow(clippy::field_reassign_with_default)]
57        let server_info = {
58            let mut s = Implementation::default();
59            s.name = env!("CARGO_PKG_NAME").to_string();
60            s.version = env!("CARGO_PKG_VERSION").to_string();
61            s
62        };
63        #[allow(clippy::field_reassign_with_default)]
64        let info = {
65            let mut i = ServerInfo::default();
66            i.protocol_version = ProtocolVersion::default();
67            i.capabilities = ServerCapabilities::builder().enable_tools().build();
68            i.server_info = server_info;
69            i
70        };
71        Self { registry, info }
72    }
73}
74
75fn schema_to_rmcp_tool(s: ToolSchema) -> RmcpTool {
76    let input_obj = match s.args_schema {
77        Value::Object(map) => map,
78        _ => Default::default(),
79    };
80    let output_obj = match s.result_schema {
81        Value::Object(map) if !map.is_empty() => Some(Arc::new(map)),
82        _ => None,
83    };
84    #[allow(clippy::field_reassign_with_default)]
85    {
86        let mut tool = RmcpTool::default();
87        tool.name = s.name.into();
88        tool.description = Some(s.description.into());
89        tool.input_schema = Arc::new(input_obj);
90        tool.output_schema = output_obj;
91        tool
92    }
93}
94
95impl ServerHandler for RegistryServer {
96    fn get_info(&self) -> ServerInfo {
97        self.info.clone()
98    }
99
100    async fn list_tools(
101        &self,
102        _request: Option<PaginatedRequestParams>,
103        _context: RequestContext<RoleServer>,
104    ) -> Result<ListToolsResult, McpError> {
105        let tools = self
106            .registry
107            .schemas()
108            .into_iter()
109            .map(schema_to_rmcp_tool)
110            .collect();
111        Ok(ListToolsResult {
112            tools,
113            next_cursor: None,
114            meta: None,
115        })
116    }
117
118    async fn call_tool(
119        &self,
120        request: CallToolRequestParams,
121        _context: RequestContext<RoleServer>,
122    ) -> Result<CallToolResult, McpError> {
123        let name = request.name.to_string();
124        let args = request
125            .arguments
126            .map(Value::Object)
127            .unwrap_or_else(|| json!({}));
128        match self.registry.invoke(&name, args).await {
129            Ok(value) => Ok(CallToolResult::structured(value)),
130            Err(e) => Ok(CallToolResult::error(vec![Content::text(e.to_string())])),
131        }
132    }
133}
134
135/// Serve `registry` over stdin/stdout using rmcp's spec-compliant stdio
136/// transport. Returns when the peer disconnects.
137pub async fn serve_stdio(registry: ToolRegistry) -> Result<(), KernelError> {
138    let server = RegistryServer::new(Arc::new(registry));
139    let service = server
140        .serve(rmcp_stdio())
141        .await
142        .map_err(|e| KernelError::ToolFailed(format!("mcp.serve: {e}")))?;
143    service
144        .waiting()
145        .await
146        .map_err(|e| KernelError::ToolFailed(format!("mcp.serve: {e}")))?;
147    Ok(())
148}
149
150// =============================================================================
151// Client side: spawn a child process and speak MCP over its stdio
152// =============================================================================
153
154/// Production stdio MCP client. Wraps an [`rmcp`] running service so
155/// that callers see only the [`McpTransport`] trait.
156///
157/// The cloneable [`Peer`] is cached at construction time so every
158/// `list_tools` / `call_tool` is a lock-free dispatch into rmcp.
159/// Concurrent calls fan out without serialising on a transport-level
160/// mutex; rmcp itself multiplexes the underlying stdio channel.
161pub struct StdioTransport {
162    endpoint: String,
163    peer: Peer<RoleClient>,
164    /// Keeps the rmcp service task alive for the lifetime of the
165    /// transport. Held but never read — dropping the transport drops
166    /// the service, which closes the child's stdio.
167    _service: Arc<RunningService<RoleClient, ()>>,
168}
169
170impl StdioTransport {
171    /// Spawn `program` with `args` and connect over its stdio.
172    ///
173    /// `endpoint` is a free-form identifier surfaced via
174    /// [`McpTransport::endpoint`]; it has no protocol meaning.
175    pub async fn spawn(
176        endpoint: impl Into<String>,
177        program: impl AsRef<std::ffi::OsStr>,
178        args: &[&str],
179    ) -> Result<Self, KernelError> {
180        let program = program.as_ref().to_owned();
181        let argv: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
182        let cmd = Command::new(&program).configure(|c| {
183            c.args(&argv);
184        });
185        let transport = TokioChildProcess::new(cmd)
186            .map_err(|e| KernelError::ToolFailed(format!("mcp.spawn: {e}")))?;
187        let service = ()
188            .serve(transport)
189            .await
190            .map_err(|e| KernelError::ToolFailed(format!("mcp.connect: {e}")))?;
191        let peer = service.peer().clone();
192        Ok(Self {
193            endpoint: endpoint.into(),
194            peer,
195            _service: Arc::new(service),
196        })
197    }
198}
199
200#[async_trait]
201impl McpTransport for StdioTransport {
202    fn endpoint(&self) -> &str {
203        &self.endpoint
204    }
205
206    async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
207        let tools = self
208            .peer
209            .list_all_tools()
210            .await
211            .map_err(|e| KernelError::ToolFailed(format!("tools/list: {e}")))?;
212        Ok(tools.into_iter().map(rmcp_tool_to_schema).collect())
213    }
214
215    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
216        let arguments = match args {
217            Value::Object(map) => Some(map),
218            Value::Null => None,
219            other => {
220                return Err(KernelError::InvalidArgument(format!(
221                    "tools/call requires an object or null arguments, got {other}"
222                )));
223            }
224        };
225        let params = {
226            #[allow(clippy::field_reassign_with_default)]
227            let mut p = CallToolRequestParams::default();
228            p.name = name.to_string().into();
229            p.arguments = arguments;
230            p
231        };
232        let result = self
233            .peer
234            .call_tool(params)
235            .await
236            .map_err(|e| KernelError::ToolFailed(format!("tools/call: {e}")))?;
237
238        if result.is_error.unwrap_or(false) {
239            let msg = result
240                .content
241                .iter()
242                .find_map(|c| c.as_text().map(|t| t.text.clone()))
243                .unwrap_or_else(|| "tool returned error".to_string());
244            return Err(KernelError::ToolFailed(msg));
245        }
246
247        // Prefer typed structured content; fall back to first text block parsed
248        // as JSON, then to the raw text wrapped in a string Value.
249        if let Some(v) = result.structured_content {
250            return Ok(v);
251        }
252        if let Some(text) = result
253            .content
254            .iter()
255            .find_map(|c| c.as_text().map(|t| t.text.clone()))
256        {
257            if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
258                return Ok(parsed);
259            }
260            return Ok(Value::String(text));
261        }
262        Ok(Value::Null)
263    }
264}
265
266fn rmcp_tool_to_schema(t: RmcpTool) -> ToolSchema {
267    ToolSchema {
268        name: t.name.to_string(),
269        description: t.description.map(|d| d.to_string()).unwrap_or_default(),
270        args_schema: Value::Object((*t.input_schema).clone()),
271        result_schema: t
272            .output_schema
273            .map(|s| Value::Object((*s).clone()))
274            .unwrap_or(Value::Null),
275    }
276}
277
278// =============================================================================
279// Tests — round-trip a registry through a real spawn() of the test bin
280// =============================================================================
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use rig_compose::tool::LocalTool;
286    use serde_json::json;
287    use std::sync::Arc;
288
289    fn echo_registry() -> ToolRegistry {
290        let reg = ToolRegistry::new();
291        reg.register(Arc::new(LocalTool::new(
292            ToolSchema {
293                name: "math.mul".into(),
294                description: "multiply".into(),
295                args_schema: json!({"type": "object"}),
296                result_schema: json!({"type": "integer"}),
297            },
298            |args: Value| async move {
299                let a = args["a"].as_i64().unwrap_or(0);
300                let b = args["b"].as_i64().unwrap_or(0);
301                Ok(json!(a * b))
302            },
303        )));
304        reg
305    }
306
307    /// Verify `serve_stdio` actually constructs a working server. We
308    /// don't drive the wire here — that's covered by the `mcp_serve_cli`
309    /// tests in azreal which spawn the real binary. This is a smoke
310    /// test that the rmcp wiring compiles and the registry can be
311    /// observed through the same `Tool` trait used by skills.
312    #[tokio::test]
313    async fn registry_server_round_trip_via_tool_trait() {
314        let registry = echo_registry();
315        let tool = registry.get("math.mul").unwrap();
316        let out = tool.invoke(json!({"a": 6, "b": 7})).await.unwrap();
317        assert_eq!(out, json!(42));
318    }
319}