Skip to main content

claude_api/tool_dispatch/
tool.rs

1//! The [`Tool`] async trait and its companion [`ToolError`].
2//!
3//! [`Tool`] is the foundation for every tool-dispatch shape: direct
4//! trait implementations, closure adapters via
5//! [`FnTool`](crate::tool_dispatch::FnTool), and schemars-driven typed
6//! handlers ([`crate::tool_dispatch::TypedTool`]) all reduce to a
7//! `dyn Tool` stored in the registry.
8//!
9//! [`ToolApprover`] is the mid-stream approval gate: each pending
10//! tool call is passed to an approver before it executes; return
11//! [`ApprovalDecision::Deny`] to short-circuit the tool and inject an
12//! error result instead.
13
14use async_trait::async_trait;
15
16/// A tool the model can invoke during generation.
17///
18/// # Contract
19///
20/// - **`name()`** returns a stable identifier. Names must be unique within
21///   a single registry; the model uses them to refer to specific tools.
22/// - **`schema()`** returns a JSON Schema describing the tool's input. The
23///   schema is sent to the model as part of the request and is used to
24///   guide what `invoke` will receive. Must be valid JSON Schema.
25/// - **`invoke(input)`** runs the tool. `input` is the raw `Value` the
26///   model produced and is *not* validated against the schema by the SDK
27///   -- impls should validate themselves and return [`ToolError::InvalidInput`]
28///   on failure. `invoke` may take arbitrary time; the agent-loop runner
29///   in #20 supports per-iteration timeouts.
30///
31/// All methods take `&self` so a single instance can be shared via `Arc`.
32/// The trait is `Send + Sync + 'static` so tools can live in concurrent
33/// contexts.
34///
35/// # Example
36///
37/// ```
38/// use async_trait::async_trait;
39/// use claude_api::tool_dispatch::{Tool, ToolError};
40/// use serde_json::{json, Value};
41///
42/// struct AddTool;
43///
44/// #[async_trait]
45/// impl Tool for AddTool {
46///     fn name(&self) -> &str { "add" }
47///     fn schema(&self) -> Value {
48///         json!({
49///             "type": "object",
50///             "properties": {
51///                 "a": {"type": "number"},
52///                 "b": {"type": "number"}
53///             },
54///             "required": ["a", "b"]
55///         })
56///     }
57///     async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
58///         let a = input.get("a").and_then(Value::as_f64)
59///             .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
60///         let b = input.get("b").and_then(Value::as_f64)
61///             .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
62///         Ok(json!({"sum": a + b}))
63///     }
64/// }
65/// ```
66#[async_trait]
67pub trait Tool: Send + Sync + 'static {
68    /// Stable identifier the model uses to refer to this tool.
69    fn name(&self) -> &str;
70
71    /// JSON Schema describing the tool's expected input.
72    fn schema(&self) -> serde_json::Value;
73
74    /// Optional human-readable description; helps the model decide when to
75    /// invoke. Default returns `None`.
76    fn description(&self) -> Option<&str> {
77        None
78    }
79
80    /// Run the tool with the model-supplied input and return its result.
81    async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
82}
83
84/// Errors a [`Tool`] implementation can return.
85///
86/// Construct via [`Self::invalid_input`] for caller-side validation
87/// failures (string message, surfaced back to the model) or
88/// [`Self::execution`] to wrap any underlying error type that implements
89/// [`std::error::Error`].
90#[derive(Debug, thiserror::Error)]
91#[non_exhaustive]
92pub enum ToolError {
93    /// The model-supplied input did not satisfy the tool's schema or
94    /// other validation rules. The string is surfaced back to the model
95    /// as the `tool_result` content.
96    #[error("invalid tool input: {0}")]
97    InvalidInput(String),
98
99    /// The tool ran but its underlying operation failed.
100    #[error("tool execution failed: {0}")]
101    Execution(Box<dyn std::error::Error + Send + Sync>),
102
103    /// A registry was asked to dispatch a tool name it doesn't know.
104    /// Surfaced by `ToolRegistry::dispatch` (lands in #19).
105    #[error("no tool registered with name '{name}'")]
106    Unknown {
107        /// Tool name the registry was asked to dispatch.
108        name: String,
109    },
110}
111
112impl ToolError {
113    /// Build an [`InvalidInput`](Self::InvalidInput) error from a message.
114    pub fn invalid_input(msg: impl Into<String>) -> Self {
115        Self::InvalidInput(msg.into())
116    }
117
118    /// Build an [`Execution`](Self::Execution) error wrapping any error type.
119    pub fn execution<E>(err: E) -> Self
120    where
121        E: std::error::Error + Send + Sync + 'static,
122    {
123        Self::Execution(Box::new(err))
124    }
125}
126
127/// Verdict from a [`ToolApprover`] for a single `tool_use` invocation.
128///
129/// Approvers are consulted by [`crate::Client::run`] *before* each tool
130/// dispatch, so users can gate side-effecting tools behind an interactive
131/// confirmation, a policy check, an input rewriter, or a static
132/// allowlist.
133#[non_exhaustive]
134#[derive(Debug, Clone)]
135pub enum ApprovalDecision {
136    /// Proceed with the tool dispatch unchanged.
137    Approve,
138    /// Proceed, but substitute a different `input` (the model's original
139    /// payload is discarded). Useful for sanitizing arguments before the
140    /// tool runs (path scrubbing, scope clamping, etc.).
141    ApproveWithInput(serde_json::Value),
142    /// Skip the tool dispatch entirely and return `value` as the
143    /// `tool_result` content (with no `is_error` flag). Useful for
144    /// stubbing tools in tests or short-circuiting expensive calls when
145    /// the answer is already known.
146    Substitute(serde_json::Value),
147    /// Skip the tool dispatch. The supplied `reason` is returned to the
148    /// model as the `tool_result` content (with `is_error = true`) so
149    /// the model can choose how to recover.
150    Deny(String),
151    /// Abort the entire agent loop. Surfaces as
152    /// [`crate::Error::ToolApprovalStopped`] from `Client::run`.
153    Stop(String),
154}
155
156/// Async-callable predicate consulted before each tool dispatch.
157///
158/// Implement this trait for stateful approvers, or use the closure
159/// adapter [`fn_approver`] /
160/// [`RunOptions::with_approver_fn`](crate::tool_dispatch::runner::RunOptions::with_approver_fn).
161#[async_trait]
162pub trait ToolApprover: Send + Sync + 'static {
163    /// Inspect a pending tool dispatch and return a verdict.
164    async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision;
165}
166
167/// Wrap an async closure into a [`ToolApprover`].
168#[must_use]
169pub fn fn_approver<F, Fut>(handler: F) -> std::sync::Arc<dyn ToolApprover>
170where
171    F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
172    Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
173{
174    std::sync::Arc::new(FnApprover { handler })
175}
176
177struct FnApprover<F> {
178    handler: F,
179}
180
181#[async_trait]
182impl<F, Fut> ToolApprover for FnApprover<F>
183where
184    F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
185    Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
186{
187    async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision {
188        (self.handler)(tool_name, input).await
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use serde_json::{Value, json};
196    use std::sync::Arc;
197
198    struct AddTool;
199
200    #[async_trait]
201    impl Tool for AddTool {
202        // Trait dictates the return type; clippy can't tell we're returning
203        // a literal versus a stored String, so allow the lint locally.
204        #[allow(clippy::unnecessary_literal_bound)]
205        fn name(&self) -> &str {
206            "add"
207        }
208        #[allow(clippy::unnecessary_literal_bound)]
209        fn description(&self) -> Option<&str> {
210            Some("Add two numbers and return the sum.")
211        }
212        fn schema(&self) -> Value {
213            json!({
214                "type": "object",
215                "properties": {
216                    "a": {"type": "number"},
217                    "b": {"type": "number"}
218                },
219                "required": ["a", "b"]
220            })
221        }
222        async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
223            let a = input
224                .get("a")
225                .and_then(Value::as_f64)
226                .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
227            let b = input
228                .get("b")
229                .and_then(Value::as_f64)
230                .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
231            Ok(json!({"sum": a + b}))
232        }
233    }
234
235    #[tokio::test]
236    async fn manual_impl_round_trips_a_value() {
237        let tool = AddTool;
238        let result = tool.invoke(json!({"a": 2, "b": 3})).await.unwrap();
239        assert_eq!(result, json!({"sum": 5.0}));
240    }
241
242    #[tokio::test]
243    async fn trait_object_dispatch_works() {
244        // Critical: dyn Tool must work for ToolRegistry to hold heterogeneous tools.
245        let tool: Arc<dyn Tool> = Arc::new(AddTool);
246        assert_eq!(tool.name(), "add");
247        assert_eq!(
248            tool.description(),
249            Some("Add two numbers and return the sum.")
250        );
251        assert!(tool.schema().is_object());
252        let result = tool.invoke(json!({"a": 4, "b": 1})).await.unwrap();
253        assert_eq!(result["sum"], 5.0);
254    }
255
256    #[tokio::test]
257    async fn invalid_input_propagates_message() {
258        let tool = AddTool;
259        let err = tool.invoke(json!({"a": 1})).await.unwrap_err();
260        let ToolError::InvalidInput(msg) = err else {
261            panic!("expected InvalidInput");
262        };
263        assert!(msg.contains("'b'"), "{msg}");
264    }
265
266    #[test]
267    fn invalid_input_constructor_takes_string_or_str() {
268        let _ = ToolError::invalid_input("plain str");
269        let _ = ToolError::invalid_input(String::from("owned"));
270    }
271
272    #[test]
273    fn execution_wraps_any_std_error() {
274        let inner = std::io::Error::other("disk on fire");
275        let err = ToolError::execution(inner);
276        let display = format!("{err}");
277        assert!(display.contains("disk on fire"), "{display}");
278        let ToolError::Execution(_) = err else {
279            panic!("expected Execution");
280        };
281    }
282
283    #[test]
284    fn tool_is_send_and_sync() {
285        // Compile-time check: dyn Tool must be Send + Sync to live in async tasks.
286        fn assert_send_sync<T: Send + Sync + ?Sized>() {}
287        assert_send_sync::<dyn Tool>();
288    }
289}