claude_api/tool_dispatch/tool.rs
1//! The [`Tool`] async trait and its companion [`ToolError`].
2//!
3//! [`Tool`] is the foundation for every tool-dispatch shape v0.2 ships:
4//! direct trait implementations, closure adapters via `FnTool` (#19), and
5//! schemars-driven typed handlers (#21) all reduce to a `Tool` impl.
6
7use async_trait::async_trait;
8
9/// A tool the model can invoke during generation.
10///
11/// # Contract
12///
13/// - **`name()`** returns a stable identifier. Names must be unique within
14/// a single registry; the model uses them to refer to specific tools.
15/// - **`schema()`** returns a JSON Schema describing the tool's input. The
16/// schema is sent to the model as part of the request and is used to
17/// guide what `invoke` will receive. Must be valid JSON Schema.
18/// - **`invoke(input)`** runs the tool. `input` is the raw `Value` the
19/// model produced and is *not* validated against the schema by the SDK
20/// -- impls should validate themselves and return [`ToolError::InvalidInput`]
21/// on failure. `invoke` may take arbitrary time; the agent-loop runner
22/// in #20 supports per-iteration timeouts.
23///
24/// All methods take `&self` so a single instance can be shared via `Arc`.
25/// The trait is `Send + Sync + 'static` so tools can live in concurrent
26/// contexts.
27///
28/// # Example
29///
30/// ```
31/// use async_trait::async_trait;
32/// use claude_api::tool_dispatch::{Tool, ToolError};
33/// use serde_json::{json, Value};
34///
35/// struct AddTool;
36///
37/// #[async_trait]
38/// impl Tool for AddTool {
39/// fn name(&self) -> &str { "add" }
40/// fn schema(&self) -> Value {
41/// json!({
42/// "type": "object",
43/// "properties": {
44/// "a": {"type": "number"},
45/// "b": {"type": "number"}
46/// },
47/// "required": ["a", "b"]
48/// })
49/// }
50/// async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
51/// let a = input.get("a").and_then(Value::as_f64)
52/// .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
53/// let b = input.get("b").and_then(Value::as_f64)
54/// .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
55/// Ok(json!({"sum": a + b}))
56/// }
57/// }
58/// ```
59#[async_trait]
60pub trait Tool: Send + Sync + 'static {
61 /// Stable identifier the model uses to refer to this tool.
62 fn name(&self) -> &str;
63
64 /// JSON Schema describing the tool's expected input.
65 fn schema(&self) -> serde_json::Value;
66
67 /// Optional human-readable description; helps the model decide when to
68 /// invoke. Default returns `None`.
69 fn description(&self) -> Option<&str> {
70 None
71 }
72
73 /// Run the tool with the model-supplied input and return its result.
74 async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
75}
76
77/// Errors a [`Tool`] implementation can return.
78///
79/// Construct via [`Self::invalid_input`] for caller-side validation
80/// failures (string message, surfaced back to the model) or
81/// [`Self::execution`] to wrap any underlying error type that implements
82/// [`std::error::Error`].
83#[derive(Debug, thiserror::Error)]
84#[non_exhaustive]
85pub enum ToolError {
86 /// The model-supplied input did not satisfy the tool's schema or
87 /// other validation rules. The string is surfaced back to the model
88 /// as the `tool_result` content.
89 #[error("invalid tool input: {0}")]
90 InvalidInput(String),
91
92 /// The tool ran but its underlying operation failed.
93 #[error("tool execution failed: {0}")]
94 Execution(Box<dyn std::error::Error + Send + Sync>),
95
96 /// A registry was asked to dispatch a tool name it doesn't know.
97 /// Surfaced by `ToolRegistry::dispatch` (lands in #19).
98 #[error("no tool registered with name '{name}'")]
99 Unknown {
100 /// Tool name the registry was asked to dispatch.
101 name: String,
102 },
103}
104
105impl ToolError {
106 /// Build an [`InvalidInput`](Self::InvalidInput) error from a message.
107 pub fn invalid_input(msg: impl Into<String>) -> Self {
108 Self::InvalidInput(msg.into())
109 }
110
111 /// Build an [`Execution`](Self::Execution) error wrapping any error type.
112 pub fn execution<E>(err: E) -> Self
113 where
114 E: std::error::Error + Send + Sync + 'static,
115 {
116 Self::Execution(Box::new(err))
117 }
118}
119
120/// Verdict from a [`ToolApprover`] for a single `tool_use` invocation.
121///
122/// Approvers are consulted by [`crate::Client::run`] *before* each tool
123/// dispatch, so users can gate side-effecting tools behind an interactive
124/// confirmation, a policy check, an input rewriter, or a static
125/// allowlist.
126#[non_exhaustive]
127#[derive(Debug, Clone)]
128pub enum ApprovalDecision {
129 /// Proceed with the tool dispatch unchanged.
130 Approve,
131 /// Proceed, but substitute a different `input` (the model's original
132 /// payload is discarded). Useful for sanitizing arguments before the
133 /// tool runs (path scrubbing, scope clamping, etc.).
134 ApproveWithInput(serde_json::Value),
135 /// Skip the tool dispatch entirely and return `value` as the
136 /// `tool_result` content (with no `is_error` flag). Useful for
137 /// stubbing tools in tests or short-circuiting expensive calls when
138 /// the answer is already known.
139 Substitute(serde_json::Value),
140 /// Skip the tool dispatch. The supplied `reason` is returned to the
141 /// model as the `tool_result` content (with `is_error = true`) so
142 /// the model can choose how to recover.
143 Deny(String),
144 /// Abort the entire agent loop. Surfaces as
145 /// [`crate::Error::ToolApprovalStopped`] from `Client::run`.
146 Stop(String),
147}
148
149/// Async-callable predicate consulted before each tool dispatch.
150///
151/// Implement this trait for stateful approvers, or use the closure
152/// adapter [`fn_approver`] /
153/// [`RunOptions::with_approver_fn`](crate::tool_dispatch::runner::RunOptions::with_approver_fn).
154#[async_trait]
155pub trait ToolApprover: Send + Sync + 'static {
156 /// Inspect a pending tool dispatch and return a verdict.
157 async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision;
158}
159
160/// Wrap an async closure into a [`ToolApprover`].
161#[must_use]
162pub fn fn_approver<F, Fut>(handler: F) -> std::sync::Arc<dyn ToolApprover>
163where
164 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
165 Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
166{
167 std::sync::Arc::new(FnApprover { handler })
168}
169
170struct FnApprover<F> {
171 handler: F,
172}
173
174#[async_trait]
175impl<F, Fut> ToolApprover for FnApprover<F>
176where
177 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
178 Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
179{
180 async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision {
181 (self.handler)(tool_name, input).await
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use serde_json::{Value, json};
189 use std::sync::Arc;
190
191 struct AddTool;
192
193 #[async_trait]
194 impl Tool for AddTool {
195 // Trait dictates the return type; clippy can't tell we're returning
196 // a literal versus a stored String, so allow the lint locally.
197 #[allow(clippy::unnecessary_literal_bound)]
198 fn name(&self) -> &str {
199 "add"
200 }
201 #[allow(clippy::unnecessary_literal_bound)]
202 fn description(&self) -> Option<&str> {
203 Some("Add two numbers and return the sum.")
204 }
205 fn schema(&self) -> Value {
206 json!({
207 "type": "object",
208 "properties": {
209 "a": {"type": "number"},
210 "b": {"type": "number"}
211 },
212 "required": ["a", "b"]
213 })
214 }
215 async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
216 let a = input
217 .get("a")
218 .and_then(Value::as_f64)
219 .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
220 let b = input
221 .get("b")
222 .and_then(Value::as_f64)
223 .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
224 Ok(json!({"sum": a + b}))
225 }
226 }
227
228 #[tokio::test]
229 async fn manual_impl_round_trips_a_value() {
230 let tool = AddTool;
231 let result = tool.invoke(json!({"a": 2, "b": 3})).await.unwrap();
232 assert_eq!(result, json!({"sum": 5.0}));
233 }
234
235 #[tokio::test]
236 async fn trait_object_dispatch_works() {
237 // Critical: dyn Tool must work for ToolRegistry to hold heterogeneous tools.
238 let tool: Arc<dyn Tool> = Arc::new(AddTool);
239 assert_eq!(tool.name(), "add");
240 assert_eq!(
241 tool.description(),
242 Some("Add two numbers and return the sum.")
243 );
244 assert!(tool.schema().is_object());
245 let result = tool.invoke(json!({"a": 4, "b": 1})).await.unwrap();
246 assert_eq!(result["sum"], 5.0);
247 }
248
249 #[tokio::test]
250 async fn invalid_input_propagates_message() {
251 let tool = AddTool;
252 let err = tool.invoke(json!({"a": 1})).await.unwrap_err();
253 let ToolError::InvalidInput(msg) = err else {
254 panic!("expected InvalidInput");
255 };
256 assert!(msg.contains("'b'"), "{msg}");
257 }
258
259 #[test]
260 fn invalid_input_constructor_takes_string_or_str() {
261 let _ = ToolError::invalid_input("plain str");
262 let _ = ToolError::invalid_input(String::from("owned"));
263 }
264
265 #[test]
266 fn execution_wraps_any_std_error() {
267 let inner = std::io::Error::other("disk on fire");
268 let err = ToolError::execution(inner);
269 let display = format!("{err}");
270 assert!(display.contains("disk on fire"), "{display}");
271 let ToolError::Execution(_) = err else {
272 panic!("expected Execution");
273 };
274 }
275
276 #[test]
277 fn tool_is_send_and_sync() {
278 // Compile-time check: dyn Tool must be Send + Sync to live in async tasks.
279 fn assert_send_sync<T: Send + Sync + ?Sized>() {}
280 assert_send_sync::<dyn Tool>();
281 }
282}