Skip to main content

claude_api/tool_dispatch/
tool.rs

1//! The [`Tool`] async trait and its companion [`ToolError`].
2//!
3//! [`Tool`] is the foundation for every tool-dispatch shape v0.2 ships:
4//! direct trait implementations, closure adapters via `FnTool` (#19), and
5//! schemars-driven typed handlers (#21) all reduce to a `Tool` impl.
6
7use async_trait::async_trait;
8
9/// A tool the model can invoke during generation.
10///
11/// # Contract
12///
13/// - **`name()`** returns a stable identifier. Names must be unique within
14///   a single registry; the model uses them to refer to specific tools.
15/// - **`schema()`** returns a JSON Schema describing the tool's input. The
16///   schema is sent to the model as part of the request and is used to
17///   guide what `invoke` will receive. Must be valid JSON Schema.
18/// - **`invoke(input)`** runs the tool. `input` is the raw `Value` the
19///   model produced and is *not* validated against the schema by the SDK
20///   -- impls should validate themselves and return [`ToolError::InvalidInput`]
21///   on failure. `invoke` may take arbitrary time; the agent-loop runner
22///   in #20 supports per-iteration timeouts.
23///
24/// All methods take `&self` so a single instance can be shared via `Arc`.
25/// The trait is `Send + Sync + 'static` so tools can live in concurrent
26/// contexts.
27///
28/// # Example
29///
30/// ```
31/// use async_trait::async_trait;
32/// use claude_api::tool_dispatch::{Tool, ToolError};
33/// use serde_json::{json, Value};
34///
35/// struct AddTool;
36///
37/// #[async_trait]
38/// impl Tool for AddTool {
39///     fn name(&self) -> &str { "add" }
40///     fn schema(&self) -> Value {
41///         json!({
42///             "type": "object",
43///             "properties": {
44///                 "a": {"type": "number"},
45///                 "b": {"type": "number"}
46///             },
47///             "required": ["a", "b"]
48///         })
49///     }
50///     async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
51///         let a = input.get("a").and_then(Value::as_f64)
52///             .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
53///         let b = input.get("b").and_then(Value::as_f64)
54///             .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
55///         Ok(json!({"sum": a + b}))
56///     }
57/// }
58/// ```
59#[async_trait]
60pub trait Tool: Send + Sync + 'static {
61    /// Stable identifier the model uses to refer to this tool.
62    fn name(&self) -> &str;
63
64    /// JSON Schema describing the tool's expected input.
65    fn schema(&self) -> serde_json::Value;
66
67    /// Optional human-readable description; helps the model decide when to
68    /// invoke. Default returns `None`.
69    fn description(&self) -> Option<&str> {
70        None
71    }
72
73    /// Run the tool with the model-supplied input and return its result.
74    async fn invoke(&self, input: serde_json::Value) -> Result<serde_json::Value, ToolError>;
75}
76
77/// Errors a [`Tool`] implementation can return.
78///
79/// Construct via [`Self::invalid_input`] for caller-side validation
80/// failures (string message, surfaced back to the model) or
81/// [`Self::execution`] to wrap any underlying error type that implements
82/// [`std::error::Error`].
83#[derive(Debug, thiserror::Error)]
84#[non_exhaustive]
85pub enum ToolError {
86    /// The model-supplied input did not satisfy the tool's schema or
87    /// other validation rules. The string is surfaced back to the model
88    /// as the `tool_result` content.
89    #[error("invalid tool input: {0}")]
90    InvalidInput(String),
91
92    /// The tool ran but its underlying operation failed.
93    #[error("tool execution failed: {0}")]
94    Execution(Box<dyn std::error::Error + Send + Sync>),
95
96    /// A registry was asked to dispatch a tool name it doesn't know.
97    /// Surfaced by `ToolRegistry::dispatch` (lands in #19).
98    #[error("no tool registered with name '{name}'")]
99    Unknown {
100        /// Tool name the registry was asked to dispatch.
101        name: String,
102    },
103}
104
105impl ToolError {
106    /// Build an [`InvalidInput`](Self::InvalidInput) error from a message.
107    pub fn invalid_input(msg: impl Into<String>) -> Self {
108        Self::InvalidInput(msg.into())
109    }
110
111    /// Build an [`Execution`](Self::Execution) error wrapping any error type.
112    pub fn execution<E>(err: E) -> Self
113    where
114        E: std::error::Error + Send + Sync + 'static,
115    {
116        Self::Execution(Box::new(err))
117    }
118}
119
120/// Verdict from a [`ToolApprover`] for a single `tool_use` invocation.
121///
122/// Approvers are consulted by [`crate::Client::run`] *before* each tool
123/// dispatch, so users can gate side-effecting tools behind an interactive
124/// confirmation, a policy check, an input rewriter, or a static
125/// allowlist.
126#[non_exhaustive]
127#[derive(Debug, Clone)]
128pub enum ApprovalDecision {
129    /// Proceed with the tool dispatch unchanged.
130    Approve,
131    /// Proceed, but substitute a different `input` (the model's original
132    /// payload is discarded). Useful for sanitizing arguments before the
133    /// tool runs (path scrubbing, scope clamping, etc.).
134    ApproveWithInput(serde_json::Value),
135    /// Skip the tool dispatch entirely and return `value` as the
136    /// `tool_result` content (with no `is_error` flag). Useful for
137    /// stubbing tools in tests or short-circuiting expensive calls when
138    /// the answer is already known.
139    Substitute(serde_json::Value),
140    /// Skip the tool dispatch. The supplied `reason` is returned to the
141    /// model as the `tool_result` content (with `is_error = true`) so
142    /// the model can choose how to recover.
143    Deny(String),
144    /// Abort the entire agent loop. Surfaces as
145    /// [`crate::Error::ToolApprovalStopped`] from `Client::run`.
146    Stop(String),
147}
148
149/// Async-callable predicate consulted before each tool dispatch.
150///
151/// Implement this trait for stateful approvers, or use the closure
152/// adapter [`fn_approver`] /
153/// [`RunOptions::with_approver_fn`](crate::tool_dispatch::runner::RunOptions::with_approver_fn).
154#[async_trait]
155pub trait ToolApprover: Send + Sync + 'static {
156    /// Inspect a pending tool dispatch and return a verdict.
157    async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision;
158}
159
160/// Wrap an async closure into a [`ToolApprover`].
161#[must_use]
162pub fn fn_approver<F, Fut>(handler: F) -> std::sync::Arc<dyn ToolApprover>
163where
164    F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
165    Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
166{
167    std::sync::Arc::new(FnApprover { handler })
168}
169
170struct FnApprover<F> {
171    handler: F,
172}
173
174#[async_trait]
175impl<F, Fut> ToolApprover for FnApprover<F>
176where
177    F: Fn(&str, &serde_json::Value) -> Fut + Send + Sync + 'static,
178    Fut: std::future::Future<Output = ApprovalDecision> + Send + 'static,
179{
180    async fn approve(&self, tool_name: &str, input: &serde_json::Value) -> ApprovalDecision {
181        (self.handler)(tool_name, input).await
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use serde_json::{Value, json};
189    use std::sync::Arc;
190
191    struct AddTool;
192
193    #[async_trait]
194    impl Tool for AddTool {
195        // Trait dictates the return type; clippy can't tell we're returning
196        // a literal versus a stored String, so allow the lint locally.
197        #[allow(clippy::unnecessary_literal_bound)]
198        fn name(&self) -> &str {
199            "add"
200        }
201        #[allow(clippy::unnecessary_literal_bound)]
202        fn description(&self) -> Option<&str> {
203            Some("Add two numbers and return the sum.")
204        }
205        fn schema(&self) -> Value {
206            json!({
207                "type": "object",
208                "properties": {
209                    "a": {"type": "number"},
210                    "b": {"type": "number"}
211                },
212                "required": ["a", "b"]
213            })
214        }
215        async fn invoke(&self, input: Value) -> Result<Value, ToolError> {
216            let a = input
217                .get("a")
218                .and_then(Value::as_f64)
219                .ok_or_else(|| ToolError::invalid_input("missing 'a'"))?;
220            let b = input
221                .get("b")
222                .and_then(Value::as_f64)
223                .ok_or_else(|| ToolError::invalid_input("missing 'b'"))?;
224            Ok(json!({"sum": a + b}))
225        }
226    }
227
228    #[tokio::test]
229    async fn manual_impl_round_trips_a_value() {
230        let tool = AddTool;
231        let result = tool.invoke(json!({"a": 2, "b": 3})).await.unwrap();
232        assert_eq!(result, json!({"sum": 5.0}));
233    }
234
235    #[tokio::test]
236    async fn trait_object_dispatch_works() {
237        // Critical: dyn Tool must work for ToolRegistry to hold heterogeneous tools.
238        let tool: Arc<dyn Tool> = Arc::new(AddTool);
239        assert_eq!(tool.name(), "add");
240        assert_eq!(
241            tool.description(),
242            Some("Add two numbers and return the sum.")
243        );
244        assert!(tool.schema().is_object());
245        let result = tool.invoke(json!({"a": 4, "b": 1})).await.unwrap();
246        assert_eq!(result["sum"], 5.0);
247    }
248
249    #[tokio::test]
250    async fn invalid_input_propagates_message() {
251        let tool = AddTool;
252        let err = tool.invoke(json!({"a": 1})).await.unwrap_err();
253        let ToolError::InvalidInput(msg) = err else {
254            panic!("expected InvalidInput");
255        };
256        assert!(msg.contains("'b'"), "{msg}");
257    }
258
259    #[test]
260    fn invalid_input_constructor_takes_string_or_str() {
261        let _ = ToolError::invalid_input("plain str");
262        let _ = ToolError::invalid_input(String::from("owned"));
263    }
264
265    #[test]
266    fn execution_wraps_any_std_error() {
267        let inner = std::io::Error::other("disk on fire");
268        let err = ToolError::execution(inner);
269        let display = format!("{err}");
270        assert!(display.contains("disk on fire"), "{display}");
271        let ToolError::Execution(_) = err else {
272            panic!("expected Execution");
273        };
274    }
275
276    #[test]
277    fn tool_is_send_and_sync() {
278        // Compile-time check: dyn Tool must be Send + Sync to live in async tasks.
279        fn assert_send_sync<T: Send + Sync + ?Sized>() {}
280        assert_send_sync::<dyn Tool>();
281    }
282}