Skip to main content

rig_mcp/
transport.rs

1//! Model Context Protocol transport abstraction.
2//!
3//! The kernel is transport-agnostic: a [`Tool`] is a typed async
4//! function regardless of whether it runs in-process or behind a remote
5//! MCP server. This module defines the trait that real MCP transports
6//! (stdio, http+SSE, websocket) implement, plus an [`McpTool`] adapter
7//! that turns any transport into a kernel [`Tool`].
8//!
9//! A concrete [`LoopbackTransport`] is included so the abstraction can be
10//! exercised end-to-end in tests without an external MCP crate. Production
11//! transports (`rmcp`, custom stdio, etc.) plug in by implementing
12//! [`McpTransport`] — no kernel changes required.
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use serde_json::Value;
17
18use rig_compose::registry::{KernelError, ToolRegistry};
19use rig_compose::tool::{Tool, ToolSchema};
20
21/// Bidirectional MCP transport. Real implementations layer JSON-RPC
22/// framing, capability negotiation, and reconnection on top of this; the
23/// kernel sees only `list_tools` + `call_tool`.
24#[async_trait]
25pub trait McpTransport: Send + Sync {
26    /// Stable identifier for this transport instance (typically the
27    /// server URI or stdio command).
28    fn endpoint(&self) -> &str;
29
30    /// Discover the tools exposed by the remote endpoint. Called at
31    /// registration time; the returned schemas are authoritative.
32    async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError>;
33
34    /// Invoke a named tool. Implementations MUST round-trip the result
35    /// JSON without modification so callers can rely on schema fidelity.
36    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError>;
37}
38
39/// Kernel-facing wrapper that exposes one tool from a remote MCP server
40/// as a local [`Tool`]. Skills cannot tell `McpTool` apart from a local
41/// `rig_compose::Tool` implementation.
42pub struct McpTool {
43    transport: Arc<dyn McpTransport>,
44    schema: ToolSchema,
45}
46
47impl McpTool {
48    pub fn new(transport: Arc<dyn McpTransport>, schema: ToolSchema) -> Self {
49        Self { transport, schema }
50    }
51
52    /// Discover all tools exposed by `transport` and wrap each as an
53    /// [`McpTool`]. Register the returned vec with a
54    /// `rig_compose::registry::ToolRegistry` to merge them into a global
55    /// registry.
56    pub async fn from_transport(
57        transport: Arc<dyn McpTransport>,
58    ) -> Result<Vec<Arc<dyn Tool>>, KernelError> {
59        let schemas = transport.list_tools().await?;
60        Ok(schemas
61            .into_iter()
62            .map(|schema| {
63                let t: Arc<dyn Tool> = Arc::new(McpTool {
64                    transport: transport.clone(),
65                    schema,
66                });
67                t
68            })
69            .collect())
70    }
71}
72
73#[async_trait]
74impl Tool for McpTool {
75    fn schema(&self) -> ToolSchema {
76        self.schema.clone()
77    }
78
79    fn name(&self) -> rig_compose::tool::ToolName {
80        self.schema.name.clone()
81    }
82
83    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
84        self.transport.call_tool(&self.schema.name, args).await
85    }
86}
87
88// =============================================================================
89// LoopbackTransport — in-process transport over a local ToolRegistry
90// =============================================================================
91
92/// Pure-Rust transport that round-trips calls through a local
93/// [`ToolRegistry`]. Useful for testing the MCP composition story without
94/// spawning an external process.
95///
96/// `LoopbackTransport` also doubles as the building block for
97/// `McpToolServer`-style exports in a future commit: any registry can be
98/// wrapped in a transport and then attached to a real MCP server crate.
99pub struct LoopbackTransport {
100    endpoint: String,
101    registry: ToolRegistry,
102}
103
104impl LoopbackTransport {
105    pub fn new(endpoint: impl Into<String>, registry: ToolRegistry) -> Self {
106        Self {
107            endpoint: endpoint.into(),
108            registry,
109        }
110    }
111}
112
113#[async_trait]
114impl McpTransport for LoopbackTransport {
115    fn endpoint(&self) -> &str {
116        &self.endpoint
117    }
118
119    async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
120        // The registry doesn't expose iteration directly to keep the
121        // trait surface small; for the loopback transport we walk the
122        // inner DashMap via a public helper added below.
123        Ok(self.registry.schemas())
124    }
125
126    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
127        self.registry.invoke(name, args).await
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use rig_compose::tool::LocalTool;
135    use serde_json::json;
136
137    fn make_registry() -> ToolRegistry {
138        let reg = ToolRegistry::new();
139        reg.register(Arc::new(LocalTool::new(
140            ToolSchema {
141                name: "math.add".into(),
142                description: "add two ints".into(),
143                args_schema: json!({"type": "object"}),
144                result_schema: json!({"type": "integer"}),
145            },
146            |args| async move {
147                let a = args["a"].as_i64().unwrap_or(0);
148                let b = args["b"].as_i64().unwrap_or(0);
149                Ok(json!(a + b))
150            },
151        )));
152        reg
153    }
154
155    #[tokio::test]
156    async fn loopback_transport_round_trip() {
157        let server = make_registry();
158        let transport: Arc<dyn McpTransport> =
159            Arc::new(LoopbackTransport::new("loopback://test", server));
160
161        let schemas = transport.list_tools().await.unwrap();
162        assert_eq!(schemas.len(), 1);
163        assert_eq!(schemas[0].name, "math.add");
164
165        let result = transport
166            .call_tool("math.add", json!({"a": 2, "b": 3}))
167            .await
168            .unwrap();
169        assert_eq!(result, json!(5));
170    }
171
172    #[tokio::test]
173    async fn mcp_tool_indistinguishable_from_local() {
174        // Register the local tool on a server-side registry, expose it
175        // via loopback, and re-register the wrapped McpTool on a client
176        // registry. Calls through the client registry must produce the
177        // same result as direct local invocation.
178        let server = make_registry();
179        let transport: Arc<dyn McpTransport> =
180            Arc::new(LoopbackTransport::new("loopback://test", server));
181
182        let client = ToolRegistry::new();
183        for tool in McpTool::from_transport(transport).await.unwrap() {
184            client.register(tool);
185        }
186
187        let out = client
188            .invoke("math.add", json!({"a": 10, "b": 32}))
189            .await
190            .unwrap();
191        assert_eq!(out, json!(42));
192    }
193}