claude_api/tool_dispatch/
registry.rs1use std::collections::HashMap;
16use std::future::Future;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21
22use crate::messages::tools::{CustomTool, Tool as MessagesTool};
23use crate::tool_dispatch::tool::{Tool, ToolError};
24
25#[derive(Default)]
31pub struct ToolRegistry {
32 tools: HashMap<String, Arc<dyn Tool>>,
33}
34
35impl ToolRegistry {
36 #[must_use]
38 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn register_tool<T: Tool>(&mut self, tool: T) -> &mut Self {
46 let name = tool.name().to_owned();
47 self.tools.insert(name, Arc::new(tool));
48 self
49 }
50
51 pub fn register<F, Fut>(
71 &mut self,
72 name: impl Into<String>,
73 schema: serde_json::Value,
74 handler: F,
75 ) -> &mut Self
76 where
77 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
78 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
79 {
80 let name = name.into();
81 let tool = FnTool::new(name.clone(), schema, handler);
82 self.tools.insert(name, Arc::new(tool));
83 self
84 }
85
86 pub fn register_described<F, Fut>(
88 &mut self,
89 name: impl Into<String>,
90 description: impl Into<String>,
91 schema: serde_json::Value,
92 handler: F,
93 ) -> &mut Self
94 where
95 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
96 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
97 {
98 let name = name.into();
99 let mut tool = FnTool::new(name.clone(), schema, handler);
100 tool.description = Some(description.into());
101 self.tools.insert(name, Arc::new(tool));
102 self
103 }
104
105 #[must_use]
107 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
108 self.tools.get(name)
109 }
110
111 #[must_use]
113 pub fn contains(&self, name: &str) -> bool {
114 self.tools.contains_key(name)
115 }
116
117 #[must_use]
119 pub fn len(&self) -> usize {
120 self.tools.len()
121 }
122
123 #[must_use]
125 pub fn is_empty(&self) -> bool {
126 self.tools.is_empty()
127 }
128
129 pub fn names(&self) -> impl Iterator<Item = &str> {
131 self.tools.keys().map(String::as_str)
132 }
133
134 #[must_use]
138 pub fn to_messages_tools(&self) -> Vec<MessagesTool> {
139 self.tools
140 .values()
141 .map(|t| {
142 let mut ct = CustomTool::new(t.name(), t.schema());
143 if let Some(desc) = t.description() {
144 ct = ct.description(desc);
145 }
146 MessagesTool::Custom(ct)
147 })
148 .collect()
149 }
150
151 pub async fn dispatch(
156 &self,
157 name: &str,
158 input: serde_json::Value,
159 ) -> Result<serde_json::Value, ToolError> {
160 let tool = self.tools.get(name).ok_or_else(|| ToolError::Unknown {
161 name: name.to_owned(),
162 })?;
163 tool.invoke(input).await
164 }
165}
166
167impl std::fmt::Debug for ToolRegistry {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("ToolRegistry")
171 .field("tools", &self.tools.keys().collect::<Vec<_>>())
172 .finish()
173 }
174}
175
176pub struct FnTool<F, Fut>
180where
181 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
182 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
183{
184 name: String,
185 schema: serde_json::Value,
186 description: Option<String>,
187 handler: F,
188 _phantom: PhantomData<fn() -> Fut>,
189}
190
191impl<F, Fut> FnTool<F, Fut>
192where
193 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
194 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
195{
196 pub fn new(name: impl Into<String>, schema: serde_json::Value, handler: F) -> Self {
198 Self {
199 name: name.into(),
200 schema,
201 description: None,
202 handler,
203 _phantom: PhantomData,
204 }
205 }
206
207 #[must_use]
209 pub fn with_description(mut self, description: impl Into<String>) -> Self {
210 self.description = Some(description.into());
211 self
212 }
213}
214
215#[async_trait]
216impl<F, Fut> Tool for FnTool<F, Fut>
217where
218 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
219 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
220{
221 fn name(&self) -> &str {
222 &self.name
223 }
224
225 fn description(&self) -> Option<&str> {
226 self.description.as_deref()
227 }
228
229 fn schema(&self) -> serde_json::Value {
230 self.schema.clone()
231 }
232
233 async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
234 (self.handler)(input).await
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::messages::tools::Tool as MessagesTool;
242 use pretty_assertions::assert_eq;
243 use serde_json::{Value, json};
244
245 fn echo_schema() -> Value {
246 json!({"type": "object", "properties": {"text": {"type": "string"}}})
247 }
248
249 struct UpperTool;
251
252 #[async_trait]
253 impl Tool for UpperTool {
254 #[allow(clippy::unnecessary_literal_bound)]
256 fn name(&self) -> &str {
257 "upper"
258 }
259 fn schema(&self) -> Value {
260 json!({"type": "object", "properties": {"text": {"type": "string"}}})
261 }
262 async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
263 let s = input
264 .get("text")
265 .and_then(Value::as_str)
266 .ok_or_else(|| ToolError::invalid_input("missing 'text'"))?;
267 Ok(json!({"upper": s.to_uppercase()}))
268 }
269 }
270
271 #[tokio::test]
272 async fn register_and_dispatch_closure_tool() {
273 let mut registry = ToolRegistry::new();
274 registry.register("echo", echo_schema(), |input| async move { Ok(input) });
275
276 assert!(registry.contains("echo"));
277 assert_eq!(registry.len(), 1);
278
279 let result = registry
280 .dispatch("echo", json!({"text": "hi"}))
281 .await
282 .unwrap();
283 assert_eq!(result, json!({"text": "hi"}));
284 }
285
286 #[tokio::test]
287 async fn register_and_dispatch_trait_tool() {
288 let mut registry = ToolRegistry::new();
289 registry.register_tool(UpperTool);
290
291 let result = registry
292 .dispatch("upper", json!({"text": "rust"}))
293 .await
294 .unwrap();
295 assert_eq!(result, json!({"upper": "RUST"}));
296 }
297
298 #[tokio::test]
299 async fn closure_and_trait_tools_coexist() {
300 let mut registry = ToolRegistry::new();
301 registry
302 .register_tool(UpperTool)
303 .register("echo", echo_schema(), |input| async move { Ok(input) });
304
305 assert_eq!(registry.len(), 2);
306 let names: std::collections::HashSet<_> = registry.names().collect();
307 assert!(names.contains("upper"));
308 assert!(names.contains("echo"));
309
310 let r1 = registry
311 .dispatch("upper", json!({"text": "ok"}))
312 .await
313 .unwrap();
314 let r2 = registry
315 .dispatch("echo", json!({"text": "ok"}))
316 .await
317 .unwrap();
318 assert_eq!(r1, json!({"upper": "OK"}));
319 assert_eq!(r2, json!({"text": "ok"}));
320 }
321
322 #[tokio::test]
323 async fn dispatch_unknown_returns_unknown_error() {
324 let registry = ToolRegistry::new();
325 let err = registry.dispatch("nope", json!({})).await.unwrap_err();
326 let ToolError::Unknown { name } = err else {
327 panic!("expected Unknown variant");
328 };
329 assert_eq!(name, "nope");
330 }
331
332 #[tokio::test]
333 async fn dispatch_propagates_invalid_input_error_from_tool() {
334 let mut registry = ToolRegistry::new();
335 registry.register_tool(UpperTool);
336 let err = registry.dispatch("upper", json!({})).await.unwrap_err();
337 let ToolError::InvalidInput(msg) = err else {
338 panic!("expected InvalidInput");
339 };
340 assert!(msg.contains("'text'"));
341 }
342
343 #[tokio::test]
344 async fn duplicate_register_replaces_previous_entry() {
345 let mut registry = ToolRegistry::new();
346 registry.register("dup", echo_schema(), |_| async move {
347 Ok(json!({"version": "first"}))
348 });
349 registry.register("dup", echo_schema(), |_| async move {
350 Ok(json!({"version": "second"}))
351 });
352 assert_eq!(registry.len(), 1);
353 let r = registry.dispatch("dup", json!({})).await.unwrap();
354 assert_eq!(r, json!({"version": "second"}));
355 }
356
357 #[test]
358 fn to_messages_tools_includes_name_schema_and_description() {
359 let mut registry = ToolRegistry::new();
360 registry.register_tool(UpperTool).register_described(
361 "echo",
362 "Returns its input verbatim.",
363 echo_schema(),
364 |input| async move { Ok(input) },
365 );
366
367 let tools = registry.to_messages_tools();
368 assert_eq!(tools.len(), 2);
369
370 let mut by_name: std::collections::HashMap<String, MessagesTool> =
372 std::collections::HashMap::new();
373 for t in tools {
374 let MessagesTool::Custom(ct) = &t else {
375 panic!("expected custom variant");
376 };
377 by_name.insert(ct.name.clone(), t);
378 }
379
380 let MessagesTool::Custom(echo) = by_name.get("echo").unwrap() else {
381 panic!("expected echo Custom");
382 };
383 assert_eq!(
384 echo.description.as_deref(),
385 Some("Returns its input verbatim.")
386 );
387 assert!(echo.input_schema.is_object());
388
389 let MessagesTool::Custom(upper) = by_name.get("upper").unwrap() else {
390 panic!("expected upper Custom");
391 };
392 assert_eq!(upper.description, None); }
394
395 #[tokio::test]
396 async fn registry_works_through_dyn_dispatch() {
397 let mut registry = ToolRegistry::new();
400 registry.register_tool(UpperTool);
401
402 let tool: &Arc<dyn Tool> = registry.get("upper").unwrap();
403 let r = tool.invoke(json!({"text": "abc"})).await.unwrap();
404 assert_eq!(r, json!({"upper": "ABC"}));
405 }
406
407 #[test]
408 fn debug_impl_lists_tool_names() {
409 let mut registry = ToolRegistry::new();
410 registry.register_tool(UpperTool);
411 let dbg = format!("{registry:?}");
412 assert!(dbg.contains("upper"), "{dbg}");
413 }
414
415 #[test]
416 fn registry_is_send_and_sync() {
417 fn assert_send_sync<T: Send + Sync>() {}
418 assert_send_sync::<ToolRegistry>();
419 }
420}