1use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::TextFormat;
6use crate::fmt::TextOptions;
7use crate::schema::mcp_schema;
8use crate::tool::Tool;
9use crate::tool::ToolCodec;
10use futures::future::BoxFuture;
11use schemars::Schema;
12use serde_json::Value;
13use std::any::TypeId;
14use std::collections::HashMap;
15use std::collections::HashSet;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19#[derive(Debug, Clone)]
21pub struct FormattedResult {
22 pub data: Value,
24 pub text: Option<String>,
27}
28
29pub trait ErasedTool: Send + Sync {
31 fn name(&self) -> &'static str;
33
34 fn description(&self) -> &'static str;
36
37 fn input_schema(&self) -> Schema;
39
40 fn output_schema(&self) -> Option<Schema>;
42
43 fn call_json(
45 &self,
46 args: Value,
47 ctx: &ToolContext,
48 ) -> BoxFuture<'static, Result<Value, ToolError>>;
49
50 fn call_json_formatted(
56 &self,
57 args: Value,
58 ctx: &ToolContext,
59 text_opts: &TextOptions,
60 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
61
62 fn type_id(&self) -> TypeId;
64}
65
66pub struct ToolRegistry {
68 map: HashMap<String, Arc<dyn ErasedTool>>,
69 by_type: HashMap<TypeId, String>,
70}
71
72impl ToolRegistry {
73 pub fn builder() -> ToolRegistryBuilder {
75 ToolRegistryBuilder::default()
76 }
77
78 pub fn list_names(&self) -> Vec<String> {
80 self.map.keys().cloned().collect()
81 }
82
83 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
85 self.map.get(name)
86 }
87
88 #[must_use]
92 pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> Self {
93 let allowed: HashSet<&str> = names.into_iter().collect();
94
95 let mut map = HashMap::new();
97 for (k, v) in &self.map {
98 if allowed.contains(k.as_str()) {
99 map.insert(k.clone(), Arc::clone(v));
100 }
101 }
102
103 let mut by_type = HashMap::new();
106 for (type_id, name) in &self.by_type {
107 if allowed.contains(name.as_str()) {
108 by_type.insert(*type_id, name.clone());
109 }
110 }
111
112 Self { map, by_type }
113 }
114
115 pub async fn dispatch_json(
117 &self,
118 name: &str,
119 args: Value,
120 ctx: &ToolContext,
121 ) -> Result<Value, ToolError> {
122 let entry = self
123 .map
124 .get(name)
125 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {name}")))?;
126 entry.call_json(args, ctx).await
127 }
128
129 pub async fn dispatch_json_formatted(
135 &self,
136 name: &str,
137 args: Value,
138 ctx: &ToolContext,
139 text_opts: &TextOptions,
140 ) -> Result<FormattedResult, ToolError> {
141 let entry = self
142 .map
143 .get(name)
144 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {name}")))?;
145 entry.call_json_formatted(args, ctx, text_opts).await
146 }
147
148 pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
152 let type_id = TypeId::of::<T>();
153 self.by_type.get(&type_id).ok_or_else(|| {
154 ToolError::invalid_input(format!(
155 "Tool type not registered: {}",
156 std::any::type_name::<T>()
157 ))
158 })?;
159 Ok(ToolHandle {
160 _marker: PhantomData,
161 })
162 }
163
164 pub fn contains(&self, name: &str) -> bool {
166 self.map.contains_key(name)
167 }
168
169 pub fn len(&self) -> usize {
171 self.map.len()
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.map.is_empty()
177 }
178
179 pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
184 self.map.values().cloned().collect()
185 }
186
187 pub fn merge_all(regs: impl IntoIterator<Item = Self>) -> Self {
193 let mut builder = Self::builder();
194 for reg in regs {
195 for erased in reg.iter_erased() {
196 builder = builder.register_erased(erased);
197 }
198 }
199 builder.finish()
200 }
201}
202
203#[derive(Default)]
205pub struct ToolRegistryBuilder {
206 items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
207}
208
209impl ToolRegistryBuilder {
210 #[must_use]
219 pub fn register<T, C>(mut self, tool: T) -> Self
220 where
221 T: Tool + Clone + 'static,
222 C: ToolCodec<T> + 'static,
223 T::Output: TextFormat,
224 {
225 struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
226 tool: T,
227 _codec: PhantomData<C>,
228 }
229
230 impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
231 where
232 T::Output: TextFormat,
233 {
234 fn name(&self) -> &'static str {
235 T::NAME
236 }
237
238 fn description(&self) -> &'static str {
239 T::DESCRIPTION
240 }
241
242 fn input_schema(&self) -> Schema {
243 mcp_schema::cached_schema_for::<C::WireIn>()
244 .as_ref()
245 .clone()
246 }
247
248 fn output_schema(&self) -> Option<Schema> {
249 match mcp_schema::cached_output_schema_for::<C::WireOut>() {
250 Ok(arc) => Some(arc.as_ref().clone()),
251 Err(_) => None,
252 }
253 }
254
255 fn call_json(
256 &self,
257 args: Value,
258 ctx: &ToolContext,
259 ) -> BoxFuture<'static, Result<Value, ToolError>> {
260 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
261 let ctx = ctx.clone();
262 let tool = self.tool.clone();
263
264 match wire_in {
265 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
266 Ok(wire) => match C::decode(wire) {
267 Err(e) => Box::pin(async move { Err(e) }),
268 Ok(native_in) => {
269 let fut = tool.call(native_in, &ctx);
270 Box::pin(async move {
271 let out = fut.await?;
272 let wired = C::encode(out)?;
273 serde_json::to_value(wired)
274 .map_err(|e| ToolError::internal(e.to_string()))
275 })
276 }
277 },
278 }
279 }
280
281 fn call_json_formatted(
282 &self,
283 args: Value,
284 ctx: &ToolContext,
285 text_opts: &TextOptions,
286 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
287 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
288 let ctx = ctx.clone();
289 let tool = self.tool.clone();
290 let text_opts = text_opts.clone();
291
292 match wire_in {
293 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
294 Ok(wire) => match C::decode(wire) {
295 Err(e) => Box::pin(async move { Err(e) }),
296 Ok(native_in) => {
297 let fut = tool.call(native_in, &ctx);
298 Box::pin(async move {
299 let out = fut.await?;
300 let text = out.fmt_text(&text_opts);
302 let wired = C::encode(out)?;
304 let data = serde_json::to_value(&wired)
305 .map_err(|e| ToolError::internal(e.to_string()))?;
306 Ok(FormattedResult {
307 data,
308 text: Some(text),
309 })
310 })
311 }
312 },
313 }
314 }
315
316 fn type_id(&self) -> TypeId {
317 TypeId::of::<T>()
318 }
319 }
320
321 let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
322 tool,
323 _codec: PhantomData,
324 });
325 self.items
326 .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
327 self
328 }
329
330 #[must_use]
335 pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
336 let name = erased.name().to_string();
337 let type_id = erased.type_id();
338 self.items.push((name, type_id, erased));
339 self
340 }
341
342 pub fn finish(self) -> ToolRegistry {
344 let mut map = HashMap::new();
345 let mut by_type = HashMap::new();
346 for (name, type_id, erased) in self.items {
347 by_type.insert(type_id, name.clone());
348 map.insert(name, erased);
349 }
350 ToolRegistry { map, by_type }
351 }
352}
353
354pub struct ToolHandle<T: Tool> {
358 _marker: PhantomData<T>,
359}
360
361impl<T: Tool> ToolHandle<T> {
362 pub async fn call(
364 &self,
365 tool: &T,
366 input: T::Input,
367 ctx: &ToolContext,
368 ) -> Result<T::Output, ToolError> {
369 tool.call(input, ctx).await
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[derive(Clone)]
378 struct TestTool;
379
380 impl Tool for TestTool {
381 type Input = String;
382 type Output = String;
383 const NAME: &'static str = "test_tool";
384 const DESCRIPTION: &'static str = "A test tool";
385
386 fn call(
387 &self,
388 input: Self::Input,
389 _ctx: &ToolContext,
390 ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
391 Box::pin(async move { Ok(format!("Hello, {input}!")) })
392 }
393 }
394
395 #[test]
396 fn test_registry_builder() {
397 let registry = ToolRegistry::builder()
398 .register::<TestTool, ()>(TestTool)
399 .finish();
400
401 assert!(registry.contains("test_tool"));
402 assert_eq!(registry.len(), 1);
403 assert!(!registry.is_empty());
404 }
405
406 #[test]
407 fn test_registry_list_names() {
408 let registry = ToolRegistry::builder()
409 .register::<TestTool, ()>(TestTool)
410 .finish();
411
412 let names = registry.list_names();
413 assert_eq!(names, vec!["test_tool"]);
414 }
415
416 #[test]
417 fn test_registry_subset() {
418 let registry = ToolRegistry::builder()
419 .register::<TestTool, ()>(TestTool)
420 .finish();
421
422 let subset = registry.subset(["test_tool"]);
423 assert!(subset.contains("test_tool"));
424
425 let empty_subset = registry.subset(["nonexistent"]);
426 assert!(empty_subset.is_empty());
427 }
428
429 #[test]
430 fn test_tool_handle() {
431 let registry = ToolRegistry::builder()
432 .register::<TestTool, ()>(TestTool)
433 .finish();
434
435 let handle = registry.handle::<TestTool>();
436 assert!(handle.is_ok());
437 }
438
439 #[tokio::test]
440 async fn test_dispatch_json_formatted() {
441 let registry = ToolRegistry::builder()
442 .register::<TestTool, ()>(TestTool)
443 .finish();
444
445 let ctx = ToolContext::default();
446 let args = serde_json::json!("World");
447 let opts = TextOptions::default();
448
449 let result = registry
450 .dispatch_json_formatted("test_tool", args, &ctx, &opts)
451 .await;
452
453 assert!(result.is_ok());
454 let formatted = result.unwrap();
455 assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
456 assert!(formatted.text.is_some());
457 assert!(formatted.text.unwrap().contains("Hello, World!"));
459 }
460
461 #[tokio::test]
462 async fn test_dispatch_json_formatted_unknown_tool() {
463 let registry = ToolRegistry::builder()
464 .register::<TestTool, ()>(TestTool)
465 .finish();
466
467 let ctx = ToolContext::default();
468 let args = serde_json::json!("test");
469 let opts = TextOptions::default();
470
471 let result = registry
472 .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
473 .await;
474
475 assert!(result.is_err());
476 }
477
478 #[test]
479 fn test_iter_erased() {
480 let registry = ToolRegistry::builder()
481 .register::<TestTool, ()>(TestTool)
482 .finish();
483
484 let erased = registry.iter_erased();
485 assert_eq!(erased.len(), 1);
486 assert_eq!(erased[0].name(), "test_tool");
487 }
488
489 #[test]
490 fn test_register_erased_roundtrip() {
491 let r1 = ToolRegistry::builder()
493 .register::<TestTool, ()>(TestTool)
494 .finish();
495
496 let erased = r1.iter_erased().into_iter().next().unwrap();
498 let r2 = ToolRegistry::builder().register_erased(erased).finish();
499
500 assert_eq!(r2.len(), 1);
502 assert!(r2.contains("test_tool"));
503 assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
504 }
505
506 #[test]
507 fn test_merge_all_combines_registries() {
508 let r1 = ToolRegistry::builder()
510 .register::<TestTool, ()>(TestTool)
511 .finish();
512 let r2 = ToolRegistry::builder()
513 .register::<TestTool, ()>(TestTool)
514 .finish();
515
516 let merged = ToolRegistry::merge_all(vec![r1, r2]);
518
519 assert_eq!(merged.len(), 1);
521 assert!(merged.contains("test_tool"));
522 }
523
524 #[test]
525 fn test_merge_all_empty() {
526 let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
527 assert!(merged.is_empty());
528 }
529
530 #[test]
531 fn test_merge_all_preserves_subset() {
532 let r1 = ToolRegistry::builder()
533 .register::<TestTool, ()>(TestTool)
534 .finish();
535
536 let merged = ToolRegistry::merge_all(vec![r1]);
537 let subset = merged.subset(["test_tool"]);
538
539 assert_eq!(subset.len(), 1);
540 assert!(subset.contains("test_tool"));
541 }
542}