1use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::{TextFormat, TextOptions};
6use crate::schema::mcp_schema;
7use crate::tool::{Tool, ToolCodec};
8use futures::future::BoxFuture;
9use schemars::Schema;
10use serde_json::Value;
11use std::any::TypeId;
12use std::collections::{HashMap, HashSet};
13use std::marker::PhantomData;
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub struct FormattedResult {
19 pub data: Value,
21 pub text: Option<String>,
24}
25
26pub trait ErasedTool: Send + Sync {
28 fn name(&self) -> &'static str;
30
31 fn description(&self) -> &'static str;
33
34 fn input_schema(&self) -> Schema;
36
37 fn output_schema(&self) -> Option<Schema>;
39
40 fn call_json(
42 &self,
43 args: Value,
44 ctx: &ToolContext,
45 ) -> BoxFuture<'static, Result<Value, ToolError>>;
46
47 fn call_json_formatted(
53 &self,
54 args: Value,
55 ctx: &ToolContext,
56 text_opts: &TextOptions,
57 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
58
59 fn type_id(&self) -> TypeId;
61}
62
63pub struct ToolRegistry {
65 map: HashMap<String, Arc<dyn ErasedTool>>,
66 by_type: HashMap<TypeId, String>,
67}
68
69impl ToolRegistry {
70 pub fn builder() -> ToolRegistryBuilder {
72 ToolRegistryBuilder::default()
73 }
74
75 pub fn list_names(&self) -> Vec<String> {
77 self.map.keys().cloned().collect()
78 }
79
80 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
82 self.map.get(name)
83 }
84
85 pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> ToolRegistry {
89 let allowed: HashSet<&str> = names.into_iter().collect();
90
91 let mut map = HashMap::new();
93 for (k, v) in &self.map {
94 if allowed.contains(k.as_str()) {
95 map.insert(k.clone(), v.clone());
96 }
97 }
98
99 let mut by_type = HashMap::new();
102 for (type_id, name) in &self.by_type {
103 if allowed.contains(name.as_str()) {
104 by_type.insert(*type_id, name.clone());
105 }
106 }
107
108 ToolRegistry { map, by_type }
109 }
110
111 pub async fn dispatch_json(
113 &self,
114 name: &str,
115 args: Value,
116 ctx: &ToolContext,
117 ) -> Result<Value, ToolError> {
118 let entry = self
119 .map
120 .get(name)
121 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
122 entry.call_json(args, ctx).await
123 }
124
125 pub async fn dispatch_json_formatted(
131 &self,
132 name: &str,
133 args: Value,
134 ctx: &ToolContext,
135 text_opts: &TextOptions,
136 ) -> Result<FormattedResult, ToolError> {
137 let entry = self
138 .map
139 .get(name)
140 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
141 entry.call_json_formatted(args, ctx, text_opts).await
142 }
143
144 pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
148 let type_id = TypeId::of::<T>();
149 self.by_type.get(&type_id).ok_or_else(|| {
150 ToolError::invalid_input(format!(
151 "Tool type not registered: {}",
152 std::any::type_name::<T>()
153 ))
154 })?;
155 Ok(ToolHandle {
156 _marker: PhantomData,
157 })
158 }
159
160 pub fn contains(&self, name: &str) -> bool {
162 self.map.contains_key(name)
163 }
164
165 pub fn len(&self) -> usize {
167 self.map.len()
168 }
169
170 pub fn is_empty(&self) -> bool {
172 self.map.is_empty()
173 }
174
175 pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
180 self.map.values().cloned().collect()
181 }
182
183 pub fn merge_all(regs: impl IntoIterator<Item = ToolRegistry>) -> ToolRegistry {
189 let mut builder = ToolRegistry::builder();
190 for reg in regs {
191 for erased in reg.iter_erased() {
192 builder = builder.register_erased(erased);
193 }
194 }
195 builder.finish()
196 }
197}
198
199#[derive(Default)]
201pub struct ToolRegistryBuilder {
202 items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
203}
204
205impl ToolRegistryBuilder {
206 pub fn register<T, C>(mut self, tool: T) -> Self
215 where
216 T: Tool + Clone + 'static,
217 C: ToolCodec<T> + 'static,
218 T::Output: TextFormat,
219 {
220 struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
221 tool: T,
222 _codec: PhantomData<C>,
223 }
224
225 impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
226 where
227 T::Output: TextFormat,
228 {
229 fn name(&self) -> &'static str {
230 T::NAME
231 }
232
233 fn description(&self) -> &'static str {
234 T::DESCRIPTION
235 }
236
237 fn input_schema(&self) -> Schema {
238 mcp_schema::cached_schema_for::<C::WireIn>()
239 .as_ref()
240 .clone()
241 }
242
243 fn output_schema(&self) -> Option<Schema> {
244 match mcp_schema::cached_output_schema_for::<C::WireOut>() {
245 Ok(arc) => Some(arc.as_ref().clone()),
246 Err(_) => None,
247 }
248 }
249
250 fn call_json(
251 &self,
252 args: Value,
253 ctx: &ToolContext,
254 ) -> BoxFuture<'static, Result<Value, ToolError>> {
255 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
256 let ctx = ctx.clone();
257 let tool = self.tool.clone();
258
259 match wire_in {
260 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
261 Ok(wire) => match C::decode(wire) {
262 Err(e) => Box::pin(async move { Err(e) }),
263 Ok(native_in) => {
264 let fut = tool.call(native_in, &ctx);
265 Box::pin(async move {
266 let out = fut.await?;
267 let wired = C::encode(out)?;
268 serde_json::to_value(wired)
269 .map_err(|e| ToolError::internal(e.to_string()))
270 })
271 }
272 },
273 }
274 }
275
276 fn call_json_formatted(
277 &self,
278 args: Value,
279 ctx: &ToolContext,
280 text_opts: &TextOptions,
281 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
282 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
283 let ctx = ctx.clone();
284 let tool = self.tool.clone();
285 let text_opts = text_opts.clone();
286
287 match wire_in {
288 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
289 Ok(wire) => match C::decode(wire) {
290 Err(e) => Box::pin(async move { Err(e) }),
291 Ok(native_in) => {
292 let fut = tool.call(native_in, &ctx);
293 Box::pin(async move {
294 let out = fut.await?;
295 let text = out.fmt_text(&text_opts);
297 let wired = C::encode(out)?;
299 let data = serde_json::to_value(&wired)
300 .map_err(|e| ToolError::internal(e.to_string()))?;
301 Ok(FormattedResult {
302 data,
303 text: Some(text),
304 })
305 })
306 }
307 },
308 }
309 }
310
311 fn type_id(&self) -> TypeId {
312 TypeId::of::<T>()
313 }
314 }
315
316 let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
317 tool,
318 _codec: PhantomData,
319 });
320 self.items
321 .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
322 self
323 }
324
325 pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
330 let name = erased.name().to_string();
331 let type_id = erased.type_id();
332 self.items.push((name, type_id, erased));
333 self
334 }
335
336 pub fn finish(self) -> ToolRegistry {
338 let mut map = HashMap::new();
339 let mut by_type = HashMap::new();
340 for (name, type_id, erased) in self.items {
341 by_type.insert(type_id, name.clone());
342 map.insert(name, erased);
343 }
344 ToolRegistry { map, by_type }
345 }
346}
347
348pub struct ToolHandle<T: Tool> {
352 _marker: PhantomData<T>,
353}
354
355impl<T: Tool> ToolHandle<T> {
356 pub async fn call(
358 &self,
359 tool: &T,
360 input: T::Input,
361 ctx: &ToolContext,
362 ) -> Result<T::Output, ToolError> {
363 tool.call(input, ctx).await
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[derive(Clone)]
372 struct TestTool;
373
374 impl Tool for TestTool {
375 type Input = String;
376 type Output = String;
377 const NAME: &'static str = "test_tool";
378 const DESCRIPTION: &'static str = "A test tool";
379
380 fn call(
381 &self,
382 input: Self::Input,
383 _ctx: &ToolContext,
384 ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
385 Box::pin(async move { Ok(format!("Hello, {}!", input)) })
386 }
387 }
388
389 #[test]
390 fn test_registry_builder() {
391 let registry = ToolRegistry::builder()
392 .register::<TestTool, ()>(TestTool)
393 .finish();
394
395 assert!(registry.contains("test_tool"));
396 assert_eq!(registry.len(), 1);
397 assert!(!registry.is_empty());
398 }
399
400 #[test]
401 fn test_registry_list_names() {
402 let registry = ToolRegistry::builder()
403 .register::<TestTool, ()>(TestTool)
404 .finish();
405
406 let names = registry.list_names();
407 assert_eq!(names, vec!["test_tool"]);
408 }
409
410 #[test]
411 fn test_registry_subset() {
412 let registry = ToolRegistry::builder()
413 .register::<TestTool, ()>(TestTool)
414 .finish();
415
416 let subset = registry.subset(["test_tool"]);
417 assert!(subset.contains("test_tool"));
418
419 let empty_subset = registry.subset(["nonexistent"]);
420 assert!(empty_subset.is_empty());
421 }
422
423 #[test]
424 fn test_tool_handle() {
425 let registry = ToolRegistry::builder()
426 .register::<TestTool, ()>(TestTool)
427 .finish();
428
429 let handle = registry.handle::<TestTool>();
430 assert!(handle.is_ok());
431 }
432
433 #[tokio::test]
434 async fn test_dispatch_json_formatted() {
435 let registry = ToolRegistry::builder()
436 .register::<TestTool, ()>(TestTool)
437 .finish();
438
439 let ctx = ToolContext::default();
440 let args = serde_json::json!("World");
441 let opts = TextOptions::default();
442
443 let result = registry
444 .dispatch_json_formatted("test_tool", args, &ctx, &opts)
445 .await;
446
447 assert!(result.is_ok());
448 let formatted = result.unwrap();
449 assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
450 assert!(formatted.text.is_some());
451 assert!(formatted.text.unwrap().contains("Hello, World!"));
453 }
454
455 #[tokio::test]
456 async fn test_dispatch_json_formatted_unknown_tool() {
457 let registry = ToolRegistry::builder()
458 .register::<TestTool, ()>(TestTool)
459 .finish();
460
461 let ctx = ToolContext::default();
462 let args = serde_json::json!("test");
463 let opts = TextOptions::default();
464
465 let result = registry
466 .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
467 .await;
468
469 assert!(result.is_err());
470 }
471
472 #[test]
473 fn test_iter_erased() {
474 let registry = ToolRegistry::builder()
475 .register::<TestTool, ()>(TestTool)
476 .finish();
477
478 let erased = registry.iter_erased();
479 assert_eq!(erased.len(), 1);
480 assert_eq!(erased[0].name(), "test_tool");
481 }
482
483 #[test]
484 fn test_register_erased_roundtrip() {
485 let r1 = ToolRegistry::builder()
487 .register::<TestTool, ()>(TestTool)
488 .finish();
489
490 let erased = r1.iter_erased().into_iter().next().unwrap();
492 let r2 = ToolRegistry::builder().register_erased(erased).finish();
493
494 assert_eq!(r2.len(), 1);
496 assert!(r2.contains("test_tool"));
497 assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
498 }
499
500 #[test]
501 fn test_merge_all_combines_registries() {
502 let r1 = ToolRegistry::builder()
504 .register::<TestTool, ()>(TestTool)
505 .finish();
506 let r2 = ToolRegistry::builder()
507 .register::<TestTool, ()>(TestTool)
508 .finish();
509
510 let merged = ToolRegistry::merge_all(vec![r1, r2]);
512
513 assert_eq!(merged.len(), 1);
515 assert!(merged.contains("test_tool"));
516 }
517
518 #[test]
519 fn test_merge_all_empty() {
520 let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
521 assert!(merged.is_empty());
522 }
523
524 #[test]
525 fn test_merge_all_preserves_subset() {
526 let r1 = ToolRegistry::builder()
527 .register::<TestTool, ()>(TestTool)
528 .finish();
529
530 let merged = ToolRegistry::merge_all(vec![r1]);
531 let subset = merged.subset(["test_tool"]);
532
533 assert_eq!(subset.len(), 1);
534 assert!(subset.contains("test_tool"));
535 }
536}