Skip to main content

agent_client_protocol/mcp_server/
registry.rs

1//! Runtime-neutral MCP tool registration and dispatch.
2
3use std::{collections::HashSet, sync::Arc};
4
5use futures::future::BoxFuture;
6use rustc_hash::FxHashMap;
7use schemars::{JsonSchema, generate::SchemaSettings};
8use serde_json::{Map, Value};
9
10use crate::{Error, Role};
11
12use super::{McpConnectionTo, McpTool};
13
14/// JSON Schema object used to describe MCP tool inputs and outputs.
15pub type McpToolSchema = Map<String, Value>;
16
17/// Tracks which tools are enabled.
18///
19/// - `DenyList`: All tools enabled except those in the set (default)
20/// - `AllowList`: Only tools in the set are enabled
21#[derive(Clone, Debug)]
22pub enum EnabledTools {
23    /// All tools enabled except those in the deny set.
24    DenyList(HashSet<String>),
25    /// Only tools in the allow set are enabled.
26    AllowList(HashSet<String>),
27}
28
29impl Default for EnabledTools {
30    fn default() -> Self {
31        EnabledTools::DenyList(HashSet::new())
32    }
33}
34
35impl EnabledTools {
36    /// Check if a tool is enabled.
37    #[must_use]
38    pub fn is_enabled(&self, name: &str) -> bool {
39        match self {
40            EnabledTools::DenyList(deny) => !deny.contains(name),
41            EnabledTools::AllowList(allow) => allow.contains(name),
42        }
43    }
44}
45
46/// Runtime-neutral metadata for an MCP tool.
47#[derive(Clone, Debug)]
48pub struct McpToolMetadata {
49    name: String,
50    title: Option<String>,
51    description: String,
52    input_schema: Arc<McpToolSchema>,
53    output_schema: Option<Arc<McpToolSchema>>,
54}
55
56impl McpToolMetadata {
57    fn from_tool<R: Role, M: McpTool<R>>(tool: &M) -> Self {
58        Self {
59            name: tool.name(),
60            title: tool.title(),
61            description: tool.description(),
62            input_schema: schema_for_type::<M::Input>(),
63            output_schema: schema_for_output::<M::Output>(),
64        }
65    }
66
67    /// The tool name.
68    #[must_use]
69    pub fn name(&self) -> &str {
70        &self.name
71    }
72
73    /// A human-readable title for the tool.
74    #[must_use]
75    pub fn title(&self) -> Option<&str> {
76        self.title.as_deref()
77    }
78
79    /// A description of what the tool does.
80    #[must_use]
81    pub fn description(&self) -> &str {
82        &self.description
83    }
84
85    /// JSON Schema object defining the expected parameters for the tool.
86    #[must_use]
87    pub fn input_schema(&self) -> &Arc<McpToolSchema> {
88        &self.input_schema
89    }
90
91    /// Optional JSON Schema object defining the structure of the tool's output.
92    #[must_use]
93    pub fn output_schema(&self) -> Option<&Arc<McpToolSchema>> {
94        self.output_schema.as_ref()
95    }
96}
97
98/// A registered MCP tool that can be dispatched with erased JSON values.
99pub struct RegisteredMcpTool<Counterpart: Role> {
100    metadata: McpToolMetadata,
101    tool: Arc<dyn ErasedMcpTool<Counterpart>>,
102}
103
104impl<Counterpart: Role + std::fmt::Debug> std::fmt::Debug for RegisteredMcpTool<Counterpart> {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        f.debug_struct("RegisteredMcpTool")
107            .field("metadata", &self.metadata)
108            .field("has_structured_output", &self.has_structured_output())
109            .finish_non_exhaustive()
110    }
111}
112
113impl<Counterpart: Role> RegisteredMcpTool<Counterpart> {
114    fn new(tool: impl McpTool<Counterpart> + 'static) -> Self {
115        let metadata = McpToolMetadata::from_tool(&tool);
116        Self {
117            metadata,
118            tool: make_erased_mcp_tool(tool),
119        }
120    }
121
122    /// Tool metadata.
123    #[must_use]
124    pub fn metadata(&self) -> &McpToolMetadata {
125        &self.metadata
126    }
127
128    /// The tool name.
129    #[must_use]
130    pub fn name(&self) -> &str {
131        self.metadata.name()
132    }
133
134    /// Whether the tool returns structured output.
135    #[must_use]
136    pub fn has_structured_output(&self) -> bool {
137        self.metadata.output_schema().is_some()
138    }
139
140    /// Invoke the registered tool using JSON input and output values.
141    pub fn call_tool(
142        &self,
143        input: Value,
144        connection: McpConnectionTo<Counterpart>,
145    ) -> BoxFuture<'_, Result<Value, Error>> {
146        self.tool.call_tool(input, connection)
147    }
148}
149
150/// Runtime-neutral registry for MCP tools.
151#[derive(Debug)]
152pub struct McpToolRegistry<Counterpart: Role> {
153    instructions: Option<String>,
154    tool_indices: FxHashMap<String, usize>,
155    tools: Vec<RegisteredMcpTool<Counterpart>>,
156    enabled_tools: EnabledTools,
157}
158
159impl<Counterpart: Role> Default for McpToolRegistry<Counterpart> {
160    fn default() -> Self {
161        Self {
162            instructions: None,
163            tool_indices: FxHashMap::default(),
164            tools: Vec::new(),
165            enabled_tools: EnabledTools::default(),
166        }
167    }
168}
169
170impl<Counterpart: Role> McpToolRegistry<Counterpart> {
171    /// Set the server instructions that are provided to the client.
172    pub fn set_instructions(&mut self, instructions: impl ToString) {
173        self.instructions = Some(instructions.to_string());
174    }
175
176    /// Server instructions provided to the client.
177    #[must_use]
178    pub fn instructions(&self) -> Option<&str> {
179        self.instructions.as_deref()
180    }
181
182    /// Register a tool.
183    pub fn register_tool(&mut self, tool: impl McpTool<Counterpart> + 'static) {
184        let registered_tool = RegisteredMcpTool::new(tool);
185        let name = registered_tool.name().to_string();
186
187        if let Some(&index) = self.tool_indices.get(&name) {
188            self.tools[index] = registered_tool;
189        } else {
190            self.tool_indices.insert(name, self.tools.len());
191            self.tools.push(registered_tool);
192        }
193    }
194
195    /// Return all registered tools in registration order.
196    pub fn tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
197        self.tools.iter()
198    }
199
200    /// Return enabled registered tools in registration order.
201    pub fn enabled_tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
202        self.tools
203            .iter()
204            .filter(|tool| self.enabled_tools.is_enabled(tool.name()))
205    }
206
207    /// Return a registered tool by name, even if it is disabled.
208    #[must_use]
209    pub fn tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
210        self.tool_indices
211            .get(name)
212            .and_then(|&index| self.tools.get(index))
213    }
214
215    /// Return an enabled tool by name.
216    #[must_use]
217    pub fn enabled_tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
218        self.tool(name)
219            .filter(|tool| self.enabled_tools.is_enabled(tool.name()))
220    }
221
222    /// Check whether a tool is registered.
223    #[must_use]
224    pub fn contains_tool(&self, name: &str) -> bool {
225        self.tool_indices.contains_key(name)
226    }
227
228    /// Disable all tools. After calling this, only tools explicitly enabled
229    /// with [`enable_tool`](Self::enable_tool) will be available.
230    pub fn disable_all_tools(&mut self) {
231        self.enabled_tools = EnabledTools::AllowList(HashSet::new());
232    }
233
234    /// Enable all tools. After calling this, all tools will be available
235    /// except those explicitly disabled with [`disable_tool`](Self::disable_tool).
236    pub fn enable_all_tools(&mut self) {
237        self.enabled_tools = EnabledTools::DenyList(HashSet::new());
238    }
239
240    /// Disable a specific tool by name.
241    ///
242    /// Returns an error if the tool is not registered.
243    pub fn disable_tool(&mut self, name: &str) -> Result<(), Error> {
244        if !self.contains_tool(name) {
245            return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
246        }
247        match &mut self.enabled_tools {
248            EnabledTools::DenyList(deny) => {
249                deny.insert(name.to_string());
250            }
251            EnabledTools::AllowList(allow) => {
252                allow.remove(name);
253            }
254        }
255        Ok(())
256    }
257
258    /// Enable a specific tool by name.
259    ///
260    /// Returns an error if the tool is not registered.
261    pub fn enable_tool(&mut self, name: &str) -> Result<(), Error> {
262        if !self.contains_tool(name) {
263            return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
264        }
265        match &mut self.enabled_tools {
266            EnabledTools::DenyList(deny) => {
267                deny.remove(name);
268            }
269            EnabledTools::AllowList(allow) => {
270                allow.insert(name.to_string());
271            }
272        }
273        Ok(())
274    }
275}
276
277/// Erased version of the MCP tool trait that is dyn-compatible.
278trait ErasedMcpTool<Counterpart: Role>: Send + Sync {
279    fn call_tool(
280        &self,
281        input: Value,
282        connection: McpConnectionTo<Counterpart>,
283    ) -> BoxFuture<'_, Result<Value, Error>>;
284}
285
286fn make_erased_mcp_tool<R, M>(tool: M) -> Arc<dyn ErasedMcpTool<R>>
287where
288    R: Role,
289    M: McpTool<R> + 'static,
290{
291    struct ErasedMcpToolImpl<M> {
292        tool: M,
293    }
294
295    impl<R, M> ErasedMcpTool<R> for ErasedMcpToolImpl<M>
296    where
297        R: Role,
298        M: McpTool<R>,
299    {
300        fn call_tool(
301            &self,
302            input: Value,
303            context: McpConnectionTo<R>,
304        ) -> BoxFuture<'_, Result<Value, Error>> {
305            Box::pin(async move {
306                let input = serde_json::from_value(input).map_err(crate::util::internal_error)?;
307                serde_json::to_value(self.tool.call_tool(input, context).await?)
308                    .map_err(crate::util::internal_error)
309            })
310        }
311    }
312
313    Arc::new(ErasedMcpToolImpl { tool })
314}
315
316fn schema_for_type<T: JsonSchema>() -> Arc<McpToolSchema> {
317    let settings = SchemaSettings::draft2020_12();
318    let generator = settings.into_generator();
319    let schema = generator.into_root_schema_for::<T>();
320    let object = serde_json::to_value(schema).expect("failed to serialize schema");
321    let Value::Object(object) = object else {
322        panic!("Schema serialization produced non-object value: expected JSON object");
323    };
324    Arc::new(object)
325}
326
327fn schema_for_output<T: JsonSchema>() -> Option<Arc<McpToolSchema>> {
328    let schema = schema_for_type::<T>();
329    match schema.get("type") {
330        Some(Value::String(t)) if t == "object" => Some(schema),
331        _ => None,
332    }
333}