claude_api/tool_dispatch/tool.rs
1//! The [`Tool`] async trait and its companion [`ToolError`].
2//!
3//! [`Tool`] is the foundation for every tool-dispatch shape: direct
4//! trait implementations, closure adapters via
5//! [`FnTool`](crate::tool_dispatch::FnTool), and schemars-driven typed
6//! handlers ([`crate::tool_dispatch::TypedTool`]) all reduce to a
7//! `dyn Tool` stored in the registry.
8//!
9//! [`ToolApprover`] is the mid-stream approval gate: each pending
10//! tool call is passed to an approver before it executes; return
11//! [`ApprovalDecision::Deny`] to short-circuit the tool and inject an
12//! error result instead.
13
14use async_trait::async_trait;
15
16/// A tool the model can invoke during generation.
17///
18/// # Contract
19///
20/// - **`name()`** returns a stable identifier. Names must be unique within
21/// a single registry; the model uses them to refer to specific tools.
22/// - **`schema()`** returns a JSON Schema describing the tool's input. The
23/// schema is sent to the model as part of the request and is used to
24/// guide what `invoke` will receive. Must be valid JSON Schema.
25/// - **`invoke(input)`** runs the tool. `input` is the raw `Value` the
26/// model produced and is *not* validated against the schema by the SDK
27/// -- impls should validate themselves and return [`ToolError::InvalidInput`]
28/// on failure. `invoke` may take arbitrary time; the agent-loop runner
29/// in #20 supports per-iteration timeouts.
30///
31/// All methods take `&self` so a single instance can be shared via `Arc`.
32/// The trait is `Send + Sync + 'static` so tools can live in concurrent
33/// contexts.
34///
35/// # Example
36///
37/// ```
38/// use async_trait::async_trait;
39/// use claude_api::tool_dispatch::{Tool, ToolError};
40/// use serde_json::{json, Value};
41///
42/// struct AddTool;
43///
44/// #[async_trait]
45/// impl Tool for AddTool {
46/// fn name(&self) -> &str { "add" }
47/// fn schema(&self) -> Value {
48/// json!({
49/// "type": "object",
50/// "properties": {
51/// "a": {"type": "number"},
52/// "b": {"type": "number"}
53/// },
54/// "required": ["a", "b"]
55/// })
56/// }
57/// async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
58/// let a = input.get("a").and_then(Value::as_f64)
59/// .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
60/// let b = input.get("b").and_then(Value::as_f64)
61/// .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
62/// Ok(json!({"sum": a + b}))
63/// }
64/// }
65/// ```
66#[async_trait]
67pub trait Tool: Send + Sync + 'static {
68 /// Stable identifier the model uses to refer to this tool.
69 fn name(&self) -> &str;
70
71 /// JSON Schema describing the tool's expected input.
72 fn schema(&self) -> serde_json::Value;
73
74 /// Optional human-readable description; helps the model decide when to
75 /// invoke. Default returns `None`.
76 fn description(&self) -> Option<&str> {
77 None
78 }
79
80 /// Run the tool with the model-supplied input and return its result.
81 async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
82}
83
84/// Errors a [`Tool`] implementation can return.
85///
86/// Construct via [`Self::invalid_input`] for caller-side validation
87/// failures (string message, surfaced back to the model) or
88/// [`Self::execution`] to wrap any underlying error type that implements
89/// [`std::error::Error`].
90#[derive(Debug, thiserror::Error)]
91#[non_exhaustive]
92pub enum ToolError {
93 /// The model-supplied input did not satisfy the tool's schema or
94 /// other validation rules. The string is surfaced back to the model
95 /// as the `tool_result` content.
96 #[error("invalid tool input: {0}")]
97 InvalidInput(String),
98
99 /// The tool ran but its underlying operation failed.
100 #[error("tool execution failed: {0}")]
101 Execution(Box<dyn std::error::Error + Send + Sync>),
102
103 /// A registry was asked to dispatch a tool name it doesn't know.
104 /// Surfaced by `ToolRegistry::dispatch` (lands in #19).
105 #[error("no tool registered with name '{name}'")]
106 Unknown {
107 /// Tool name the registry was asked to dispatch.
108 name: String,
109 },
110}
111
112impl ToolError {
113 /// Build an [`InvalidInput`](Self::InvalidInput) error from a message.
114 pub fn invalid_input(msg: impl Into<String>) -> Self {
115 Self::InvalidInput(msg.into())
116 }
117
118 /// Build an [`Execution`](Self::Execution) error wrapping any error type.
119 pub fn execution<E>(err: E) -> Self
120 where
121 E: std::error::Error + Send + Sync + 'static,
122 {
123 Self::Execution(Box::new(err))
124 }
125}
126
127/// Verdict from a [`ToolApprover`] for a single `tool_use` invocation.
128///
129/// Approvers are consulted by [`crate::Client::run`] *before* each tool
130/// dispatch, so users can gate side-effecting tools behind an interactive
131/// confirmation, a policy check, an input rewriter, or a static
132/// allowlist.
133#[non_exhaustive]
134#[derive(Debug, Clone)]
135pub enum ApprovalDecision {
136 /// Proceed with the tool dispatch unchanged.
137 Approve,
138 /// Proceed, but substitute a different `input` (the model's original
139 /// payload is discarded). Useful for sanitizing arguments before the
140 /// tool runs (path scrubbing, scope clamping, etc.).
141 ApproveWithInput(serde_json::Value),
142 /// Skip the tool dispatch entirely and return `value` as the
143 /// `tool_result` content (with no `is_error` flag). Useful for
144 /// stubbing tools in tests or short-circuiting expensive calls when
145 /// the answer is already known.
146 Substitute(serde_json::Value),
147 /// Skip the tool dispatch. The supplied `reason` is returned to the
148 /// model as the `tool_result` content (with `is_error = true`) so
149 /// the model can choose how to recover.
150 Deny(String),
151 /// Abort the entire agent loop. Surfaces as
152 /// [`crate::Error::ToolApprovalStopped`] from `Client::run`.
153 Stop(String),
154}
155
156/// Async-callable predicate consulted before each tool dispatch.
157///
158/// Implement this trait for stateful approvers, or use the closure
159/// adapter [`fn_approver`] /
160/// [`RunOptions::with_approver_fn`](crate::tool_dispatch::runner::RunOptions::with_approver_fn).
161#[async_trait]
162pub trait ToolApprover: Send + Sync + 'static {
163 /// Inspect a pending tool dispatch and return a verdict.
164 async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision;
165}
166
167/// Wrap an async closure into a [`ToolApprover`].
168#[must_use]
169pub fn fn_approver<F, Fut>(handler: F) -> std::sync::Arc<dyn ToolApprover>
170where
171 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
172 Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
173{
174 std::sync::Arc::new(FnApprover { handler })
175}
176
177struct FnApprover<F> {
178 handler: F,
179}
180
181#[async_trait]
182impl<F, Fut> ToolApprover for FnApprover<F>
183where
184 F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
185 Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
186{
187 async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision {
188 (self.handler)(tool_name, input).await
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use serde_json::{Value, json};
196 use std::sync::Arc;
197
198 struct AddTool;
199
200 #[async_trait]
201 impl Tool for AddTool {
202 // Trait dictates the return type; clippy can't tell we're returning
203 // a literal versus a stored String, so allow the lint locally.
204 #[allow(clippy::unnecessary_literal_bound)]
205 fn name(&self) -> &str {
206 "add"
207 }
208 #[allow(clippy::unnecessary_literal_bound)]
209 fn description(&self) -> Option<&str> {
210 Some("Add two numbers and return the sum.")
211 }
212 fn schema(&self) -> Value {
213 json!({
214 "type": "object",
215 "properties": {
216 "a": {"type": "number"},
217 "b": {"type": "number"}
218 },
219 "required": ["a", "b"]
220 })
221 }
222 async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
223 let a = input
224 .get("a")
225 .and_then(Value::as_f64)
226 .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
227 let b = input
228 .get("b")
229 .and_then(Value::as_f64)
230 .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
231 Ok(json!({"sum": a + b}))
232 }
233 }
234
235 #[tokio::test]
236 async fn manual_impl_round_trips_a_value() {
237 let tool = AddTool;
238 let result = tool.invoke(json!({"a": 2, "b": 3})).await.unwrap();
239 assert_eq!(result, json!({"sum": 5.0}));
240 }
241
242 #[tokio::test]
243 async fn trait_object_dispatch_works() {
244 // Critical: dyn Tool must work for ToolRegistry to hold heterogeneous tools.
245 let tool: Arc<dyn Tool> = Arc::new(AddTool);
246 assert_eq!(tool.name(), "add");
247 assert_eq!(
248 tool.description(),
249 Some("Add two numbers and return the sum.")
250 );
251 assert!(tool.schema().is_object());
252 let result = tool.invoke(json!({"a": 4, "b": 1})).await.unwrap();
253 assert_eq!(result["sum"], 5.0);
254 }
255
256 #[tokio::test]
257 async fn invalid_input_propagates_message() {
258 let tool = AddTool;
259 let err = tool.invoke(json!({"a": 1})).await.unwrap_err();
260 let ToolError::InvalidInput(msg) = err else {
261 panic!("expected InvalidInput");
262 };
263 assert!(msg.contains("'b'"), "{msg}");
264 }
265
266 #[test]
267 fn invalid_input_constructor_takes_string_or_str() {
268 let _ = ToolError::invalid_input("plain str");
269 let _ = ToolError::invalid_input(String::from("owned"));
270 }
271
272 #[test]
273 fn execution_wraps_any_std_error() {
274 let inner = std::io::Error::other("disk on fire");
275 let err = ToolError::execution(inner);
276 let display = format!("{err}");
277 assert!(display.contains("disk on fire"), "{display}");
278 let ToolError::Execution(_) = err else {
279 panic!("expected Execution");
280 };
281 }
282
283 #[test]
284 fn tool_is_send_and_sync() {
285 // Compile-time check: dyn Tool must be Send + Sync to live in async tasks.
286 fn assert_send_sync<T: Send + Sync + ?Sized>() {}
287 assert_send_sync::<dyn Tool>();
288 }
289}