1use crate::context::ToolContext;
4use crate::error::ToolError;
5use crate::fmt::TextFormat;
6use crate::fmt::TextOptions;
7use crate::schema::mcp_schema;
8use crate::tool::Tool;
9use crate::tool::ToolCodec;
10use futures::future::BoxFuture;
11use schemars::Schema;
12use serde_json::Value;
13use std::any::TypeId;
14use std::collections::HashMap;
15use std::collections::HashSet;
16use std::marker::PhantomData;
17use std::sync::Arc;
18
19#[derive(Debug, Clone)]
21pub struct FormattedResult {
22 pub data: Value,
24 pub text: Option<String>,
27}
28
29pub trait ErasedTool: Send + Sync {
31 fn name(&self) -> &'static str;
33
34 fn description(&self) -> &'static str;
36
37 fn input_schema(&self) -> Schema;
39
40 fn output_schema(&self) -> Option<Schema>;
42
43 fn call_json(
45 &self,
46 args: Value,
47 ctx: &ToolContext,
48 ) -> BoxFuture<'static, Result<Value, ToolError>>;
49
50 fn call_json_formatted(
56 &self,
57 args: Value,
58 ctx: &ToolContext,
59 text_opts: &TextOptions,
60 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>>;
61
62 fn type_id(&self) -> TypeId;
64}
65
66pub struct ToolRegistry {
68 map: HashMap<String, Arc<dyn ErasedTool>>,
69 by_type: HashMap<TypeId, String>,
70}
71
72impl ToolRegistry {
73 pub fn builder() -> ToolRegistryBuilder {
75 ToolRegistryBuilder::default()
76 }
77
78 pub fn list_names(&self) -> Vec<String> {
80 self.map.keys().cloned().collect()
81 }
82
83 pub fn get(&self, name: &str) -> Option<&Arc<dyn ErasedTool>> {
85 self.map.get(name)
86 }
87
88 pub fn subset<'a>(&self, names: impl IntoIterator<Item = &'a str>) -> ToolRegistry {
92 let allowed: HashSet<&str> = names.into_iter().collect();
93
94 let mut map = HashMap::new();
96 for (k, v) in &self.map {
97 if allowed.contains(k.as_str()) {
98 map.insert(k.clone(), v.clone());
99 }
100 }
101
102 let mut by_type = HashMap::new();
105 for (type_id, name) in &self.by_type {
106 if allowed.contains(name.as_str()) {
107 by_type.insert(*type_id, name.clone());
108 }
109 }
110
111 ToolRegistry { map, by_type }
112 }
113
114 pub async fn dispatch_json(
116 &self,
117 name: &str,
118 args: Value,
119 ctx: &ToolContext,
120 ) -> Result<Value, ToolError> {
121 let entry = self
122 .map
123 .get(name)
124 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
125 entry.call_json(args, ctx).await
126 }
127
128 pub async fn dispatch_json_formatted(
134 &self,
135 name: &str,
136 args: Value,
137 ctx: &ToolContext,
138 text_opts: &TextOptions,
139 ) -> Result<FormattedResult, ToolError> {
140 let entry = self
141 .map
142 .get(name)
143 .ok_or_else(|| ToolError::invalid_input(format!("Unknown tool: {}", name)))?;
144 entry.call_json_formatted(args, ctx, text_opts).await
145 }
146
147 pub fn handle<T: Tool>(&self) -> Result<ToolHandle<T>, ToolError> {
151 let type_id = TypeId::of::<T>();
152 self.by_type.get(&type_id).ok_or_else(|| {
153 ToolError::invalid_input(format!(
154 "Tool type not registered: {}",
155 std::any::type_name::<T>()
156 ))
157 })?;
158 Ok(ToolHandle {
159 _marker: PhantomData,
160 })
161 }
162
163 pub fn contains(&self, name: &str) -> bool {
165 self.map.contains_key(name)
166 }
167
168 pub fn len(&self) -> usize {
170 self.map.len()
171 }
172
173 pub fn is_empty(&self) -> bool {
175 self.map.is_empty()
176 }
177
178 pub fn iter_erased(&self) -> Vec<Arc<dyn ErasedTool>> {
183 self.map.values().cloned().collect()
184 }
185
186 pub fn merge_all(regs: impl IntoIterator<Item = ToolRegistry>) -> ToolRegistry {
192 let mut builder = ToolRegistry::builder();
193 for reg in regs {
194 for erased in reg.iter_erased() {
195 builder = builder.register_erased(erased);
196 }
197 }
198 builder.finish()
199 }
200}
201
202#[derive(Default)]
204pub struct ToolRegistryBuilder {
205 items: Vec<(String, TypeId, Arc<dyn ErasedTool>)>,
206}
207
208impl ToolRegistryBuilder {
209 pub fn register<T, C>(mut self, tool: T) -> Self
218 where
219 T: Tool + Clone + 'static,
220 C: ToolCodec<T> + 'static,
221 T::Output: TextFormat,
222 {
223 struct Impl<T: Tool + Clone, C: ToolCodec<T>> {
224 tool: T,
225 _codec: PhantomData<C>,
226 }
227
228 impl<T: Tool + Clone, C: ToolCodec<T>> ErasedTool for Impl<T, C>
229 where
230 T::Output: TextFormat,
231 {
232 fn name(&self) -> &'static str {
233 T::NAME
234 }
235
236 fn description(&self) -> &'static str {
237 T::DESCRIPTION
238 }
239
240 fn input_schema(&self) -> Schema {
241 mcp_schema::cached_schema_for::<C::WireIn>()
242 .as_ref()
243 .clone()
244 }
245
246 fn output_schema(&self) -> Option<Schema> {
247 match mcp_schema::cached_output_schema_for::<C::WireOut>() {
248 Ok(arc) => Some(arc.as_ref().clone()),
249 Err(_) => None,
250 }
251 }
252
253 fn call_json(
254 &self,
255 args: Value,
256 ctx: &ToolContext,
257 ) -> BoxFuture<'static, Result<Value, ToolError>> {
258 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
259 let ctx = ctx.clone();
260 let tool = self.tool.clone();
261
262 match wire_in {
263 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
264 Ok(wire) => match C::decode(wire) {
265 Err(e) => Box::pin(async move { Err(e) }),
266 Ok(native_in) => {
267 let fut = tool.call(native_in, &ctx);
268 Box::pin(async move {
269 let out = fut.await?;
270 let wired = C::encode(out)?;
271 serde_json::to_value(wired)
272 .map_err(|e| ToolError::internal(e.to_string()))
273 })
274 }
275 },
276 }
277 }
278
279 fn call_json_formatted(
280 &self,
281 args: Value,
282 ctx: &ToolContext,
283 text_opts: &TextOptions,
284 ) -> BoxFuture<'static, Result<FormattedResult, ToolError>> {
285 let wire_in: Result<C::WireIn, _> = serde_json::from_value(args);
286 let ctx = ctx.clone();
287 let tool = self.tool.clone();
288 let text_opts = text_opts.clone();
289
290 match wire_in {
291 Err(e) => Box::pin(async move { Err(ToolError::invalid_input(e.to_string())) }),
292 Ok(wire) => match C::decode(wire) {
293 Err(e) => Box::pin(async move { Err(e) }),
294 Ok(native_in) => {
295 let fut = tool.call(native_in, &ctx);
296 Box::pin(async move {
297 let out = fut.await?;
298 let text = out.fmt_text(&text_opts);
300 let wired = C::encode(out)?;
302 let data = serde_json::to_value(&wired)
303 .map_err(|e| ToolError::internal(e.to_string()))?;
304 Ok(FormattedResult {
305 data,
306 text: Some(text),
307 })
308 })
309 }
310 },
311 }
312 }
313
314 fn type_id(&self) -> TypeId {
315 TypeId::of::<T>()
316 }
317 }
318
319 let erased: Arc<dyn ErasedTool> = Arc::new(Impl::<T, C> {
320 tool,
321 _codec: PhantomData,
322 });
323 self.items
324 .push((T::NAME.to_string(), TypeId::of::<T>(), erased));
325 self
326 }
327
328 pub fn register_erased(mut self, erased: Arc<dyn ErasedTool>) -> Self {
333 let name = erased.name().to_string();
334 let type_id = erased.type_id();
335 self.items.push((name, type_id, erased));
336 self
337 }
338
339 pub fn finish(self) -> ToolRegistry {
341 let mut map = HashMap::new();
342 let mut by_type = HashMap::new();
343 for (name, type_id, erased) in self.items {
344 by_type.insert(type_id, name.clone());
345 map.insert(name, erased);
346 }
347 ToolRegistry { map, by_type }
348 }
349}
350
351pub struct ToolHandle<T: Tool> {
355 _marker: PhantomData<T>,
356}
357
358impl<T: Tool> ToolHandle<T> {
359 pub async fn call(
361 &self,
362 tool: &T,
363 input: T::Input,
364 ctx: &ToolContext,
365 ) -> Result<T::Output, ToolError> {
366 tool.call(input, ctx).await
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[derive(Clone)]
375 struct TestTool;
376
377 impl Tool for TestTool {
378 type Input = String;
379 type Output = String;
380 const NAME: &'static str = "test_tool";
381 const DESCRIPTION: &'static str = "A test tool";
382
383 fn call(
384 &self,
385 input: Self::Input,
386 _ctx: &ToolContext,
387 ) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
388 Box::pin(async move { Ok(format!("Hello, {}!", input)) })
389 }
390 }
391
392 #[test]
393 fn test_registry_builder() {
394 let registry = ToolRegistry::builder()
395 .register::<TestTool, ()>(TestTool)
396 .finish();
397
398 assert!(registry.contains("test_tool"));
399 assert_eq!(registry.len(), 1);
400 assert!(!registry.is_empty());
401 }
402
403 #[test]
404 fn test_registry_list_names() {
405 let registry = ToolRegistry::builder()
406 .register::<TestTool, ()>(TestTool)
407 .finish();
408
409 let names = registry.list_names();
410 assert_eq!(names, vec!["test_tool"]);
411 }
412
413 #[test]
414 fn test_registry_subset() {
415 let registry = ToolRegistry::builder()
416 .register::<TestTool, ()>(TestTool)
417 .finish();
418
419 let subset = registry.subset(["test_tool"]);
420 assert!(subset.contains("test_tool"));
421
422 let empty_subset = registry.subset(["nonexistent"]);
423 assert!(empty_subset.is_empty());
424 }
425
426 #[test]
427 fn test_tool_handle() {
428 let registry = ToolRegistry::builder()
429 .register::<TestTool, ()>(TestTool)
430 .finish();
431
432 let handle = registry.handle::<TestTool>();
433 assert!(handle.is_ok());
434 }
435
436 #[tokio::test]
437 async fn test_dispatch_json_formatted() {
438 let registry = ToolRegistry::builder()
439 .register::<TestTool, ()>(TestTool)
440 .finish();
441
442 let ctx = ToolContext::default();
443 let args = serde_json::json!("World");
444 let opts = TextOptions::default();
445
446 let result = registry
447 .dispatch_json_formatted("test_tool", args, &ctx, &opts)
448 .await;
449
450 assert!(result.is_ok());
451 let formatted = result.unwrap();
452 assert_eq!(formatted.data, serde_json::json!("Hello, World!"));
453 assert!(formatted.text.is_some());
454 assert!(formatted.text.unwrap().contains("Hello, World!"));
456 }
457
458 #[tokio::test]
459 async fn test_dispatch_json_formatted_unknown_tool() {
460 let registry = ToolRegistry::builder()
461 .register::<TestTool, ()>(TestTool)
462 .finish();
463
464 let ctx = ToolContext::default();
465 let args = serde_json::json!("test");
466 let opts = TextOptions::default();
467
468 let result = registry
469 .dispatch_json_formatted("nonexistent", args, &ctx, &opts)
470 .await;
471
472 assert!(result.is_err());
473 }
474
475 #[test]
476 fn test_iter_erased() {
477 let registry = ToolRegistry::builder()
478 .register::<TestTool, ()>(TestTool)
479 .finish();
480
481 let erased = registry.iter_erased();
482 assert_eq!(erased.len(), 1);
483 assert_eq!(erased[0].name(), "test_tool");
484 }
485
486 #[test]
487 fn test_register_erased_roundtrip() {
488 let r1 = ToolRegistry::builder()
490 .register::<TestTool, ()>(TestTool)
491 .finish();
492
493 let erased = r1.iter_erased().into_iter().next().unwrap();
495 let r2 = ToolRegistry::builder().register_erased(erased).finish();
496
497 assert_eq!(r2.len(), 1);
499 assert!(r2.contains("test_tool"));
500 assert_eq!(r2.get("test_tool").unwrap().name(), "test_tool");
501 }
502
503 #[test]
504 fn test_merge_all_combines_registries() {
505 let r1 = ToolRegistry::builder()
507 .register::<TestTool, ()>(TestTool)
508 .finish();
509 let r2 = ToolRegistry::builder()
510 .register::<TestTool, ()>(TestTool)
511 .finish();
512
513 let merged = ToolRegistry::merge_all(vec![r1, r2]);
515
516 assert_eq!(merged.len(), 1);
518 assert!(merged.contains("test_tool"));
519 }
520
521 #[test]
522 fn test_merge_all_empty() {
523 let merged = ToolRegistry::merge_all(Vec::<ToolRegistry>::new());
524 assert!(merged.is_empty());
525 }
526
527 #[test]
528 fn test_merge_all_preserves_subset() {
529 let r1 = ToolRegistry::builder()
530 .register::<TestTool, ()>(TestTool)
531 .finish();
532
533 let merged = ToolRegistry::merge_all(vec![r1]);
534 let subset = merged.subset(["test_tool"]);
535
536 assert_eq!(subset.len(), 1);
537 assert!(subset.contains("test_tool"));
538 }
539}