Skip to main content

claude_api/tool_dispatch/
typed.rs

1//! Schemars-driven typed tool registration.
2//!
3//! Layers on top of [`ToolRegistry`] and the
4//! [`Tool`] trait: instead of a raw `Fn(Value) -> Future` handler, you
5//! supply a typed `Fn(Args) -> Future` where `Args: JsonSchema +
6//! DeserializeOwned`. The schema is derived automatically; the registry
7//! deserializes the model's input into `Args` before invoking, returning
8//! [`ToolError::InvalidInput`] when the shape doesn't match.
9//!
10//! Gated on the `schemars-tools` feature.
11//!
12//! ```
13//! use claude_api::tool_dispatch::ToolRegistry;
14//! use schemars::JsonSchema;
15//! use serde::Deserialize;
16//! use serde_json::json;
17//!
18//! #[derive(JsonSchema, Deserialize)]
19//! struct WeatherArgs {
20//!     city: String,
21//!     #[serde(default)]
22//!     units: Option<String>,
23//! }
24//!
25//! let mut registry = ToolRegistry::new();
26//! registry.register_typed::<WeatherArgs, _, _>(
27//!     "get_weather",
28//!     |args| async move {
29//!         Ok(json!({"city": args.city, "units": args.units.unwrap_or_default()}))
30//!     },
31//! );
32//! ```
33
34#![cfg(feature = "schemars-tools")]
35
36use std::future::Future;
37use std::marker::PhantomData;
38
39use async_trait::async_trait;
40
41use crate::tool_dispatch::registry::ToolRegistry;
42use crate::tool_dispatch::tool::{Tool, ToolError};
43
44/// Generate a JSON Schema for the given type via [`schemars`].
45fn generate_schema_for<A: schemars::JsonSchema>() -> serde_json::Value {
46    let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<A>();
47    serde_json::to_value(schema).expect("RootSchema is always JSON-serializable")
48}
49
50/// Typed-input adapter that implements [`Tool`] for a handler taking a
51/// `JsonSchema`-deriving struct.
52///
53/// Constructed implicitly by [`ToolRegistry::register_typed`] /
54/// [`ToolRegistry::register_typed_described`]; rarely instantiated directly.
55pub struct TypedTool<A, F, Fut>
56where
57    A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
58    F: Fn(A) -> Fut + Send + Sync + 'static,
59    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
60{
61    name: String,
62    schema: serde_json::Value,
63    description: Option<String>,
64    handler: F,
65    _phantom: PhantomData<fn(A) -> Fut>,
66}
67
68impl<A, F, Fut> TypedTool<A, F, Fut>
69where
70    A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
71    F: Fn(A) -> Fut + Send + Sync + 'static,
72    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
73{
74    /// Build a typed tool. The schema is derived from `A` automatically.
75    pub fn new(name: impl Into<String>, handler: F) -> Self {
76        Self {
77            name: name.into(),
78            schema: generate_schema_for::<A>(),
79            description: None,
80            handler,
81            _phantom: PhantomData,
82        }
83    }
84
85    /// Attach a description.
86    #[must_use]
87    pub fn with_description(mut self, description: impl Into<String>) -> Self {
88        self.description = Some(description.into());
89        self
90    }
91}
92
93#[async_trait]
94impl<A, F, Fut> Tool for TypedTool<A, F, Fut>
95where
96    A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
97    F: Fn(A) -> Fut + Send + Sync + 'static,
98    Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
99{
100    fn name(&self) -> &str {
101        &self.name
102    }
103
104    fn description(&self) -> Option<&str> {
105        self.description.as_deref()
106    }
107
108    fn schema(&self) -> serde_json::Value {
109        self.schema.clone()
110    }
111
112    async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError> {
113        let args = serde_json::from_value::<A>(input).map_err(|e| {
114            ToolError::invalid_input(format!("input did not match schema for {}: {e}", self.name))
115        })?;
116        (self.handler)(args).await
117    }
118}
119
120impl ToolRegistry {
121    /// Register a tool with a typed input struct.
122    ///
123    /// The schema is generated from `A` via [`schemars`], and the model's
124    /// raw `Value` input is deserialized into `A` before the handler runs.
125    /// Deserialization failures surface as [`ToolError::InvalidInput`] so
126    /// the model can self-correct.
127    pub fn register_typed<A, F, Fut>(&mut self, name: impl Into<String>, handler: F) -> &mut Self
128    where
129        A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
130        F: Fn(A) -> Fut + Send + Sync + 'static,
131        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
132    {
133        let tool = TypedTool::<A, F, Fut>::new(name, handler);
134        self.register_tool(tool)
135    }
136
137    /// Like [`Self::register_typed`] but also attaches a description.
138    pub fn register_typed_described<A, F, Fut>(
139        &mut self,
140        name: impl Into<String>,
141        description: impl Into<String>,
142        handler: F,
143    ) -> &mut Self
144    where
145        A: schemars::JsonSchema + serde::de::DeserializeOwned + Send + Sync + 'static,
146        F: Fn(A) -> Fut + Send + Sync + 'static,
147        Fut: Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
148    {
149        let tool = TypedTool::<A, F, Fut>::new(name, handler).with_description(description);
150        self.register_tool(tool)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::messages::tools::Tool as MessagesTool;
158    use schemars::JsonSchema;
159    use serde::Deserialize;
160    use serde_json::json;
161
162    #[derive(JsonSchema, Deserialize)]
163    struct WeatherArgs {
164        city: String,
165        #[serde(default)]
166        units: Option<String>,
167    }
168
169    #[tokio::test]
170    async fn register_typed_dispatches_with_typed_args() {
171        let mut registry = ToolRegistry::new();
172        registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
173            Ok(json!({
174                "city": args.city,
175                "units": args.units.unwrap_or_else(|| "F".into())
176            }))
177        });
178
179        let result = registry
180            .dispatch("weather", json!({"city": "Paris"}))
181            .await
182            .unwrap();
183        assert_eq!(result["city"], "Paris");
184        assert_eq!(result["units"], "F");
185    }
186
187    #[tokio::test]
188    async fn register_typed_passes_optional_fields_through() {
189        let mut registry = ToolRegistry::new();
190        registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
191            Ok(json!({"city": args.city, "units": args.units}))
192        });
193        let result = registry
194            .dispatch("weather", json!({"city": "Tokyo", "units": "C"}))
195            .await
196            .unwrap();
197        assert_eq!(result["units"], "C");
198    }
199
200    #[tokio::test]
201    async fn register_typed_returns_invalid_input_on_missing_required_field() {
202        let mut registry = ToolRegistry::new();
203        registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
204            Ok(json!({"city": args.city}))
205        });
206        let err = registry.dispatch("weather", json!({})).await.unwrap_err();
207        let ToolError::InvalidInput(msg) = err else {
208            panic!("expected InvalidInput");
209        };
210        assert!(msg.contains("city"), "{msg}");
211    }
212
213    #[tokio::test]
214    async fn register_typed_returns_invalid_input_on_wrong_field_type() {
215        let mut registry = ToolRegistry::new();
216        registry.register_typed::<WeatherArgs, _, _>("weather", |args| async move {
217            Ok(json!({"city": args.city}))
218        });
219        // city should be string, not number
220        let err = registry
221            .dispatch("weather", json!({"city": 42}))
222            .await
223            .unwrap_err();
224        assert!(matches!(err, ToolError::InvalidInput(_)));
225    }
226
227    #[test]
228    fn register_typed_generates_schema_from_args_type() {
229        let mut registry = ToolRegistry::new();
230        registry.register_typed::<WeatherArgs, _, _>("weather", |_args| async move {
231            Ok(serde_json::Value::Null)
232        });
233
234        let tools = registry.to_messages_tools();
235        let MessagesTool::Custom(ct) = &tools[0] else {
236            panic!("expected Custom");
237        };
238        assert!(ct.input_schema.is_object(), "schema must be a JSON object");
239        // Schema should describe the city field.
240        let serialized = ct.input_schema.to_string();
241        assert!(
242            serialized.contains("\"city\""),
243            "schema must mention city: {serialized}"
244        );
245    }
246
247    #[test]
248    fn register_typed_described_attaches_description_to_messages_tools() {
249        let mut registry = ToolRegistry::new();
250        registry.register_typed_described::<WeatherArgs, _, _>(
251            "weather",
252            "Get the weather for a city.",
253            |args| async move { Ok(json!({"city": args.city})) },
254        );
255        let tools = registry.to_messages_tools();
256        let MessagesTool::Custom(ct) = &tools[0] else {
257            panic!("expected Custom");
258        };
259        assert_eq!(
260            ct.description.as_deref(),
261            Some("Get the weather for a city.")
262        );
263    }
264
265    #[tokio::test]
266    async fn typed_and_closure_tools_coexist_in_one_registry() {
267        let mut registry = ToolRegistry::new();
268        registry
269            .register_typed::<WeatherArgs, _, _>("weather", |args| async move {
270                Ok(json!({"city": args.city}))
271            })
272            .register("echo", json!({"type": "object"}), |input| async move {
273                Ok(input)
274            });
275
276        assert_eq!(registry.len(), 2);
277        let r1 = registry
278            .dispatch("weather", json!({"city": "Berlin"}))
279            .await
280            .unwrap();
281        let r2 = registry.dispatch("echo", json!({"x": 1})).await.unwrap();
282        assert_eq!(r1["city"], "Berlin");
283        assert_eq!(r2["x"], 1);
284    }
285}