agent_client_protocol/mcp_server/
registry.rs1use std::{collections::HashSet, sync::Arc};
4
5use futures::future::BoxFuture;
6use rustc_hash::FxHashMap;
7use schemars::{JsonSchema, generate::SchemaSettings};
8use serde_json::{Map, Value};
9
10use crate::{Error, Role};
11
12use super::{McpConnectionTo, McpTool};
13
14pub type McpToolSchema = Map<String, Value>;
16
17#[derive(Clone, Debug)]
22pub enum EnabledTools {
23 DenyList(HashSet<String>),
25 AllowList(HashSet<String>),
27}
28
29impl Default for EnabledTools {
30 fn default() -> Self {
31 EnabledTools::DenyList(HashSet::new())
32 }
33}
34
35impl EnabledTools {
36 #[must_use]
38 pub fn is_enabled(&self, name: &str) -> bool {
39 match self {
40 EnabledTools::DenyList(deny) => !deny.contains(name),
41 EnabledTools::AllowList(allow) => allow.contains(name),
42 }
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct McpToolMetadata {
49 name: String,
50 title: Option<String>,
51 description: String,
52 input_schema: Arc<McpToolSchema>,
53 output_schema: Option<Arc<McpToolSchema>>,
54}
55
56impl McpToolMetadata {
57 fn from_tool<R: Role, M: McpTool<R>>(tool: &M) -> Self {
58 Self {
59 name: tool.name(),
60 title: tool.title(),
61 description: tool.description(),
62 input_schema: schema_for_type::<M::Input>(),
63 output_schema: schema_for_output::<M::Output>(),
64 }
65 }
66
67 #[must_use]
69 pub fn name(&self) -> &str {
70 &self.name
71 }
72
73 #[must_use]
75 pub fn title(&self) -> Option<&str> {
76 self.title.as_deref()
77 }
78
79 #[must_use]
81 pub fn description(&self) -> &str {
82 &self.description
83 }
84
85 #[must_use]
87 pub fn input_schema(&self) -> &Arc<McpToolSchema> {
88 &self.input_schema
89 }
90
91 #[must_use]
93 pub fn output_schema(&self) -> Option<&Arc<McpToolSchema>> {
94 self.output_schema.as_ref()
95 }
96}
97
98pub struct RegisteredMcpTool<Counterpart: Role> {
100 metadata: McpToolMetadata,
101 tool: Arc<dyn ErasedMcpTool<Counterpart>>,
102}
103
104impl<Counterpart: Role + std::fmt::Debug> std::fmt::Debug for RegisteredMcpTool<Counterpart> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("RegisteredMcpTool")
107 .field("metadata", &self.metadata)
108 .field("has_structured_output", &self.has_structured_output())
109 .finish_non_exhaustive()
110 }
111}
112
113impl<Counterpart: Role> RegisteredMcpTool<Counterpart> {
114 fn new(tool: impl McpTool<Counterpart> + 'static) -> Self {
115 let metadata = McpToolMetadata::from_tool(&tool);
116 Self {
117 metadata,
118 tool: make_erased_mcp_tool(tool),
119 }
120 }
121
122 #[must_use]
124 pub fn metadata(&self) -> &McpToolMetadata {
125 &self.metadata
126 }
127
128 #[must_use]
130 pub fn name(&self) -> &str {
131 self.metadata.name()
132 }
133
134 #[must_use]
136 pub fn has_structured_output(&self) -> bool {
137 self.metadata.output_schema().is_some()
138 }
139
140 pub fn call_tool(
142 &self,
143 input: Value,
144 connection: McpConnectionTo<Counterpart>,
145 ) -> BoxFuture<'_, Result<Value, Error>> {
146 self.tool.call_tool(input, connection)
147 }
148}
149
150#[derive(Debug)]
152pub struct McpToolRegistry<Counterpart: Role> {
153 instructions: Option<String>,
154 tool_indices: FxHashMap<String, usize>,
155 tools: Vec<RegisteredMcpTool<Counterpart>>,
156 enabled_tools: EnabledTools,
157}
158
159impl<Counterpart: Role> Default for McpToolRegistry<Counterpart> {
160 fn default() -> Self {
161 Self {
162 instructions: None,
163 tool_indices: FxHashMap::default(),
164 tools: Vec::new(),
165 enabled_tools: EnabledTools::default(),
166 }
167 }
168}
169
170impl<Counterpart: Role> McpToolRegistry<Counterpart> {
171 pub fn set_instructions(&mut self, instructions: impl ToString) {
173 self.instructions = Some(instructions.to_string());
174 }
175
176 #[must_use]
178 pub fn instructions(&self) -> Option<&str> {
179 self.instructions.as_deref()
180 }
181
182 pub fn register_tool(&mut self, tool: impl McpTool<Counterpart> + 'static) {
184 let registered_tool = RegisteredMcpTool::new(tool);
185 let name = registered_tool.name().to_string();
186
187 if let Some(&index) = self.tool_indices.get(&name) {
188 self.tools[index] = registered_tool;
189 } else {
190 self.tool_indices.insert(name, self.tools.len());
191 self.tools.push(registered_tool);
192 }
193 }
194
195 pub fn tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
197 self.tools.iter()
198 }
199
200 pub fn enabled_tools(&self) -> impl Iterator<Item = &RegisteredMcpTool<Counterpart>> {
202 self.tools
203 .iter()
204 .filter(|tool| self.enabled_tools.is_enabled(tool.name()))
205 }
206
207 #[must_use]
209 pub fn tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
210 self.tool_indices
211 .get(name)
212 .and_then(|&index| self.tools.get(index))
213 }
214
215 #[must_use]
217 pub fn enabled_tool(&self, name: &str) -> Option<&RegisteredMcpTool<Counterpart>> {
218 self.tool(name)
219 .filter(|tool| self.enabled_tools.is_enabled(tool.name()))
220 }
221
222 #[must_use]
224 pub fn contains_tool(&self, name: &str) -> bool {
225 self.tool_indices.contains_key(name)
226 }
227
228 pub fn disable_all_tools(&mut self) {
231 self.enabled_tools = EnabledTools::AllowList(HashSet::new());
232 }
233
234 pub fn enable_all_tools(&mut self) {
237 self.enabled_tools = EnabledTools::DenyList(HashSet::new());
238 }
239
240 pub fn disable_tool(&mut self, name: &str) -> Result<(), Error> {
244 if !self.contains_tool(name) {
245 return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
246 }
247 match &mut self.enabled_tools {
248 EnabledTools::DenyList(deny) => {
249 deny.insert(name.to_string());
250 }
251 EnabledTools::AllowList(allow) => {
252 allow.remove(name);
253 }
254 }
255 Ok(())
256 }
257
258 pub fn enable_tool(&mut self, name: &str) -> Result<(), Error> {
262 if !self.contains_tool(name) {
263 return Err(Error::invalid_request().data(format!("unknown tool: {name}")));
264 }
265 match &mut self.enabled_tools {
266 EnabledTools::DenyList(deny) => {
267 deny.remove(name);
268 }
269 EnabledTools::AllowList(allow) => {
270 allow.insert(name.to_string());
271 }
272 }
273 Ok(())
274 }
275}
276
277trait ErasedMcpTool<Counterpart: Role>: Send + Sync {
279 fn call_tool(
280 &self,
281 input: Value,
282 connection: McpConnectionTo<Counterpart>,
283 ) -> BoxFuture<'_, Result<Value, Error>>;
284}
285
286fn make_erased_mcp_tool<R, M>(tool: M) -> Arc<dyn ErasedMcpTool<R>>
287where
288 R: Role,
289 M: McpTool<R> + 'static,
290{
291 struct ErasedMcpToolImpl<M> {
292 tool: M,
293 }
294
295 impl<R, M> ErasedMcpTool<R> for ErasedMcpToolImpl<M>
296 where
297 R: Role,
298 M: McpTool<R>,
299 {
300 fn call_tool(
301 &self,
302 input: Value,
303 context: McpConnectionTo<R>,
304 ) -> BoxFuture<'_, Result<Value, Error>> {
305 Box::pin(async move {
306 let input = serde_json::from_value(input).map_err(crate::util::internal_error)?;
307 serde_json::to_value(self.tool.call_tool(input, context).await?)
308 .map_err(crate::util::internal_error)
309 })
310 }
311 }
312
313 Arc::new(ErasedMcpToolImpl { tool })
314}
315
316fn schema_for_type<T: JsonSchema>() -> Arc<McpToolSchema> {
317 let settings = SchemaSettings::draft2020_12();
318 let generator = settings.into_generator();
319 let schema = generator.into_root_schema_for::<T>();
320 let object = serde_json::to_value(schema).expect("failed to serialize schema");
321 let Value::Object(object) = object else {
322 panic!("Schema serialization produced non-object value: expected JSON object");
323 };
324 Arc::new(object)
325}
326
327fn schema_for_output<T: JsonSchema>() -> Option<Arc<McpToolSchema>> {
328 let schema = schema_for_type::<T>();
329 match schema.get("type") {
330 Some(Value::String(t)) if t == "object" => Some(schema),
331 _ => None,
332 }
333}