claude_api/tool_dispatch/
typed.rs1#![cfg(feature = "schemars-tools")]
35
36use std::future::Future;
37use std::marker::PhantomData;
38
39use async_trait::async_trait;
40
41use crate::tool_dispatch::registry::ToolRegistry;
42use crate::tool_dispatch::tool::{Tool, ToolError};
43
44fn generate_schema_for<A: schemars::JsonSchema>() -> serde_json::Value {
46 let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<A>();
47 serde_json::to_value(schema).expect("RootSchema is always JSON-serializable")
48}
49
50pub struct TypedTool<A, F, Fut>
56where
57 A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
58 F: Fn(A) -> Fut + Send + Sync + 'static,
59 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
60{
61 name: String,
62 schema: serde_json::Value,
63 description: Option<String>,
64 handler: F,
65 _phantom: PhantomData<fn(A) -> Fut>,
66}
67
68impl<A, F, Fut> TypedTool<A, F, Fut>
69where
70 A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
71 F: Fn(A) -> Fut + Send + Sync + 'static,
72 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
73{
74 pub fn new(name: impl Into<String>, handler: F) -> Self {
76 Self {
77 name: name.into(),
78 schema: generate_schema_for::<A>(),
79 description: None,
80 handler,
81 _phantom: PhantomData,
82 }
83 }
84
85 #[must_use]
87 pub fn with_description(mut self, description: impl Into<String>) -> Self {
88 self.description = Some(description.into());
89 self
90 }
91}
92
93#[async_trait]
94impl<A, F, Fut> Tool for TypedTool<A, F, Fut>
95where
96 A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
97 F: Fn(A) -> Fut + Send + Sync + 'static,
98 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
99{
100 fn name(&self) -> &str {
101 &self.name
102 }
103
104 fn description(&self) -> Option<&str> {
105 self.description.as_deref()
106 }
107
108 fn schema(&self) -> serde_json::Value {
109 self.schema.clone()
110 }
111
112 async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
113 let args = serde_json::from_value::<A>(input).map_err(|e| {
114 ToolError::invalid_input(format!("input did not match schema for {}: {e}", self.name))
115 })?;
116 (self.handler)(args).await
117 }
118}
119
120impl ToolRegistry {
121 pub fn register_typed<A, F, Fut>(&mut self, name: impl Into<String>, handler: F) -> &mut Self
128 where
129 A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
130 F: Fn(A) -> Fut + Send + Sync + 'static,
131 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
132 {
133 let tool = TypedTool::<A, F, Fut>::new(name, handler);
134 self.register_tool(tool)
135 }
136
137 pub fn register_typed_described<A, F, Fut>(
139 &mut self,
140 name: impl Into<String>,
141 description: impl Into<String>,
142 handler: F,
143 ) -> &mut Self
144 where
145 A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
146 F: Fn(A) -> Fut + Send + Sync + 'static,
147 Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
148 {
149 let tool = TypedTool::<A, F, Fut>::new(name, handler).with_description(description);
150 self.register_tool(tool)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::messages::tools::Tool as MessagesTool;
158 use schemars::JsonSchema;
159 use serde::Deserialize;
160 use serde_json::json;
161
162 #[derive(JsonSchema, Deserialize)]
163 struct WeatherArgs {
164 city: String,
165 #[serde(default)]
166 units: Option<String>,
167 }
168
169 #[tokio::test]
170 async fn register_typed_dispatches_with_typed_args() {
171 let mut registry = ToolRegistry::new();
172 registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
173 Ok(json!({
174 "city": args.city,
175 "units": args.units.unwrap_or_else(|| "F".into())
176 }))
177 });
178
179 let result = registry
180 .dispatch("weather", json!({"city": "Paris"}))
181 .await
182 .unwrap();
183 assert_eq!(result["city"], "Paris");
184 assert_eq!(result["units"], "F");
185 }
186
187 #[tokio::test]
188 async fn register_typed_passes_optional_fields_through() {
189 let mut registry = ToolRegistry::new();
190 registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
191 Ok(json!({"city": args.city, "units": args.units}))
192 });
193 let result = registry
194 .dispatch("weather", json!({"city": "Tokyo", "units": "C"}))
195 .await
196 .unwrap();
197 assert_eq!(result["units"], "C");
198 }
199
200 #[tokio::test]
201 async fn register_typed_returns_invalid_input_on_missing_required_field() {
202 let mut registry = ToolRegistry::new();
203 registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
204 Ok(json!({"city": args.city}))
205 });
206 let err = registry.dispatch("weather", json!({})).await.unwrap_err();
207 let ToolError::InvalidInput(msg) = err else {
208 panic!("expected InvalidInput");
209 };
210 assert!(msg.contains("city"), "{msg}");
211 }
212
213 #[tokio::test]
214 async fn register_typed_returns_invalid_input_on_wrong_field_type() {
215 let mut registry = ToolRegistry::new();
216 registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
217 Ok(json!({"city": args.city}))
218 });
219 let err = registry
221 .dispatch("weather", json!({"city": 42}))
222 .await
223 .unwrap_err();
224 assert!(matches!(err, ToolError::InvalidInput(_)));
225 }
226
227 #[test]
228 fn register_typed_generates_schema_from_args_type() {
229 let mut registry = ToolRegistry::new();
230 registry.register_typed::<WeatherArgs, _, _>("weather", |_args| async move {
231 Ok(serde_json::Value::Null)
232 });
233
234 let tools = registry.to_messages_tools();
235 let MessagesTool::Custom(ct) = &tools[0] else {
236 panic!("expected Custom");
237 };
238 assert!(ct.input_schema.is_object(), "schema must be a JSON object");
239 let serialized = ct.input_schema.to_string();
241 assert!(
242 serialized.contains("\"city\""),
243 "schema must mention city: {serialized}"
244 );
245 }
246
247 #[test]
248 fn register_typed_described_attaches_description_to_messages_tools() {
249 let mut registry = ToolRegistry::new();
250 registry.register_typed_described::<WeatherArgs, _, _>(
251 "weather",
252 "Get the weather for a city.",
253 |args| async move { Ok(json!({"city": args.city})) },
254 );
255 let tools = registry.to_messages_tools();
256 let MessagesTool::Custom(ct) = &tools[0] else {
257 panic!("expected Custom");
258 };
259 assert_eq!(
260 ct.description.as_deref(),
261 Some("Get the weather for a city.")
262 );
263 }
264
265 #[tokio::test]
266 async fn typed_and_closure_tools_coexist_in_one_registry() {
267 let mut registry = ToolRegistry::new();
268 registry
269 .register_typed::<WeatherArgs, _, _>("weather", |args| async move {
270 Ok(json!({"city": args.city}))
271 })
272 .register("echo", json!({"type": "object"}), |input| async move {
273 Ok(input)
274 });
275
276 assert_eq!(registry.len(), 2);
277 let r1 = registry
278 .dispatch("weather", json!({"city": "Berlin"}))
279 .await
280 .unwrap();
281 let r2 = registry.dispatch("echo", json!({"x": 1})).await.unwrap();
282 assert_eq!(r1["city"], "Berlin");
283 assert_eq!(r2["x"], 1);
284 }
285}