claude_api/tool_dispatch/
registry.rs1use std::collections::HashMap;
20use std::future::Future;
21use std::marker::PhantomData;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25
26use crate::messages::tools::{CustomTool, Tool as MessagesTool};
27use crate::tool_dispatch::tool::{Tool, ToolError};
28
29#[derive(Default)]
35pub struct ToolRegistry {
36 tools: HashMap<String, Arc<dyn Tool>>,
37}
38
39impl ToolRegistry {
40 #[must_use]
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn register_tool<T: Tool>(&mut self, tool: T) -> &mut Self {
50 let name = tool.name().to_owned();
51 self.tools.insert(name, Arc::new(tool));
52 self
53 }
54
55 pub fn register<F, Fut>(
75 &mut self,
76 name: impl Into<String>,
77 schema: serde_json::Value,
78 handler: F,
79 ) -> &mut Self
80 where
81 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
82 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
83 {
84 let name = name.into();
85 let tool = FnTool::new(name.clone(), schema, handler);
86 self.tools.insert(name, Arc::new(tool));
87 self
88 }
89
90 pub fn register_described<F, Fut>(
92 &mut self,
93 name: impl Into<String>,
94 description: impl Into<String>,
95 schema: serde_json::Value,
96 handler: F,
97 ) -> &mut Self
98 where
99 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
100 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
101 {
102 let name = name.into();
103 let mut tool = FnTool::new(name.clone(), schema, handler);
104 tool.description = Some(description.into());
105 self.tools.insert(name, Arc::new(tool));
106 self
107 }
108
109 #[must_use]
111 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
112 self.tools.get(name)
113 }
114
115 #[must_use]
117 pub fn contains(&self, name: &str) -> bool {
118 self.tools.contains_key(name)
119 }
120
121 #[must_use]
123 pub fn len(&self) -> usize {
124 self.tools.len()
125 }
126
127 #[must_use]
129 pub fn is_empty(&self) -> bool {
130 self.tools.is_empty()
131 }
132
133 pub fn names(&self) -> impl Iterator<Item = &str> {
135 self.tools.keys().map(String::as_str)
136 }
137
138 #[must_use]
142 pub fn to_messages_tools(&self) -> Vec<MessagesTool> {
143 self.tools
144 .values()
145 .map(|t| {
146 let mut ct = CustomTool::new(t.name(), t.schema());
147 if let Some(desc) = t.description() {
148 ct = ct.description(desc);
149 }
150 MessagesTool::Custom(ct)
151 })
152 .collect()
153 }
154
155 pub async fn dispatch(
160 &self,
161 name: &str,
162 input: serde_json::Value,
163 ) -> Result<serde_json::Value, ToolError> {
164 let tool = self.tools.get(name).ok_or_else(|| ToolError::Unknown {
165 name: name.to_owned(),
166 })?;
167 tool.invoke(input).await
168 }
169}
170
171impl std::fmt::Debug for ToolRegistry {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("ToolRegistry")
175 .field("tools", &self.tools.keys().collect::<Vec<_>>())
176 .finish()
177 }
178}
179
180pub struct FnTool<F, Fut>
184where
185 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
186 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
187{
188 name: String,
189 schema: serde_json::Value,
190 description: Option<String>,
191 handler: F,
192 _phantom: PhantomData<fn() -> Fut>,
193}
194
195impl<F, Fut> FnTool<F, Fut>
196where
197 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
198 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
199{
200 pub fn new(name: impl Into<String>, schema: serde_json::Value, handler: F) -> Self {
202 Self {
203 name: name.into(),
204 schema,
205 description: None,
206 handler,
207 _phantom: PhantomData,
208 }
209 }
210
211 #[must_use]
213 pub fn with_description(mut self, description: impl Into<String>) -> Self {
214 self.description = Some(description.into());
215 self
216 }
217}
218
219#[async_trait]
220impl<F, Fut> Tool for FnTool<F, Fut>
221where
222 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
223 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
224{
225 fn name(&self) -> &str {
226 &self.name
227 }
228
229 fn description(&self) -> Option<&str> {
230 self.description.as_deref()
231 }
232
233 fn schema(&self) -> serde_json::Value {
234 self.schema.clone()
235 }
236
237 async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
238 (self.handler)(input).await
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::messages::tools::Tool as MessagesTool;
246 use pretty_assertions::assert_eq;
247 use serde_json::{Value, json};
248
249 fn echo_schema() -> Value {
250 json!({"type": "object", "properties": {"text": {"type": "string"}}})
251 }
252
253 struct UpperTool;
255
256 #[async_trait]
257 impl Tool for UpperTool {
258 #[allow(clippy::unnecessary_literal_bound)]
260 fn name(&self) -> &str {
261 "upper"
262 }
263 fn schema(&self) -> Value {
264 json!({"type": "object", "properties": {"text": {"type": "string"}}})
265 }
266 async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
267 let s = input
268 .get("text")
269 .and_then(Value::as_str)
270 .ok_or_else(|| ToolError::invalid_input("missing 'text'"))?;
271 Ok(json!({"upper": s.to_uppercase()}))
272 }
273 }
274
275 #[tokio::test]
276 async fn register_and_dispatch_closure_tool() {
277 let mut registry = ToolRegistry::new();
278 registry.register("echo", echo_schema(), |input| async move { Ok(input) });
279
280 assert!(registry.contains("echo"));
281 assert_eq!(registry.len(), 1);
282
283 let result = registry
284 .dispatch("echo", json!({"text": "hi"}))
285 .await
286 .unwrap();
287 assert_eq!(result, json!({"text": "hi"}));
288 }
289
290 #[tokio::test]
291 async fn register_and_dispatch_trait_tool() {
292 let mut registry = ToolRegistry::new();
293 registry.register_tool(UpperTool);
294
295 let result = registry
296 .dispatch("upper", json!({"text": "rust"}))
297 .await
298 .unwrap();
299 assert_eq!(result, json!({"upper": "RUST"}));
300 }
301
302 #[tokio::test]
303 async fn closure_and_trait_tools_coexist() {
304 let mut registry = ToolRegistry::new();
305 registry
306 .register_tool(UpperTool)
307 .register("echo", echo_schema(), |input| async move { Ok(input) });
308
309 assert_eq!(registry.len(), 2);
310 let names: std::collections::HashSet<_> = registry.names().collect();
311 assert!(names.contains("upper"));
312 assert!(names.contains("echo"));
313
314 let r1 = registry
315 .dispatch("upper", json!({"text": "ok"}))
316 .await
317 .unwrap();
318 let r2 = registry
319 .dispatch("echo", json!({"text": "ok"}))
320 .await
321 .unwrap();
322 assert_eq!(r1, json!({"upper": "OK"}));
323 assert_eq!(r2, json!({"text": "ok"}));
324 }
325
326 #[tokio::test]
327 async fn dispatch_unknown_returns_unknown_error() {
328 let registry = ToolRegistry::new();
329 let err = registry.dispatch("nope", json!({})).await.unwrap_err();
330 let ToolError::Unknown { name } = err else {
331 panic!("expected Unknown variant");
332 };
333 assert_eq!(name, "nope");
334 }
335
336 #[tokio::test]
337 async fn dispatch_propagates_invalid_input_error_from_tool() {
338 let mut registry = ToolRegistry::new();
339 registry.register_tool(UpperTool);
340 let err = registry.dispatch("upper", json!({})).await.unwrap_err();
341 let ToolError::InvalidInput(msg) = err else {
342 panic!("expected InvalidInput");
343 };
344 assert!(msg.contains("'text'"));
345 }
346
347 #[tokio::test]
348 async fn duplicate_register_replaces_previous_entry() {
349 let mut registry = ToolRegistry::new();
350 registry.register("dup", echo_schema(), |_| async move {
351 Ok(json!({"version": "first"}))
352 });
353 registry.register("dup", echo_schema(), |_| async move {
354 Ok(json!({"version": "second"}))
355 });
356 assert_eq!(registry.len(), 1);
357 let r = registry.dispatch("dup", json!({})).await.unwrap();
358 assert_eq!(r, json!({"version": "second"}));
359 }
360
361 #[test]
362 fn to_messages_tools_includes_name_schema_and_description() {
363 let mut registry = ToolRegistry::new();
364 registry.register_tool(UpperTool).register_described(
365 "echo",
366 "Returns its input verbatim.",
367 echo_schema(),
368 |input| async move { Ok(input) },
369 );
370
371 let tools = registry.to_messages_tools();
372 assert_eq!(tools.len(), 2);
373
374 let mut by_name: std::collections::HashMap<String, MessagesTool> =
376 std::collections::HashMap::new();
377 for t in tools {
378 let MessagesTool::Custom(ct) = &t else {
379 panic!("expected custom variant");
380 };
381 by_name.insert(ct.name.clone(), t);
382 }
383
384 let MessagesTool::Custom(echo) = by_name.get("echo").unwrap() else {
385 panic!("expected echo Custom");
386 };
387 assert_eq!(
388 echo.description.as_deref(),
389 Some("Returns its input verbatim.")
390 );
391 assert!(echo.input_schema.is_object());
392
393 let MessagesTool::Custom(upper) = by_name.get("upper").unwrap() else {
394 panic!("expected upper Custom");
395 };
396 assert_eq!(upper.description, None); }
398
399 #[tokio::test]
400 async fn registry_works_through_dyn_dispatch() {
401 let mut registry = ToolRegistry::new();
404 registry.register_tool(UpperTool);
405
406 let tool: &Arc<dyn Tool> = registry.get("upper").unwrap();
407 let r = tool.invoke(json!({"text": "abc"})).await.unwrap();
408 assert_eq!(r, json!({"upper": "ABC"}));
409 }
410
411 #[test]
412 fn debug_impl_lists_tool_names() {
413 let mut registry = ToolRegistry::new();
414 registry.register_tool(UpperTool);
415 let dbg = format!("{registry:?}");
416 assert!(dbg.contains("upper"), "{dbg}");
417 }
418
419 #[test]
420 fn registry_is_send_and_sync() {
421 fn assert_send_sync<T: Send + Sync>() {}
422 assert_send_sync::<ToolRegistry>();
423 }
424}