Skip to main content

agnt_core/
tool.rs

1//! The [`Tool`] trait — agent-callable capabilities.
2//!
3//! v0.1 shipped a single erased [`Tool`] trait using `serde_json::Value` for
4//! args and `String` for output. v0.2 adds a typed variant [`TypedTool`] with
5//! associated `Args`/`Output`/`Error` types, plus an [`ErasedAdapter`] that
6//! implements the erased [`Tool`] trait on top of any `TypedTool`. Both
7//! paths coexist — existing `Tool` impls keep working unchanged.
8//!
9//! See v0.2 plan doc Work Item A1.
10
11use serde_json::Value;
12use std::marker::PhantomData;
13
14/// A tool the agent can invoke (erased form).
15pub trait Tool: Send + Sync {
16    /// The tool's name — used by the model to invoke it and for dispatch.
17    fn name(&self) -> &str;
18
19    /// Human-readable description sent to the model as part of the tool list.
20    /// This is the primary steering mechanism for tool selection.
21    fn description(&self) -> &str;
22
23    /// JSON Schema describing the tool's arguments.
24    fn schema(&self) -> Value;
25
26    /// Execute the tool synchronously. Return a string result or an error
27    /// message. Callers must enforce result-byte caps and envelope framing
28    /// before persisting or feeding back to the model.
29    fn call(&self, args: Value) -> Result<String, String>;
30}
31
32/// A typed tool — associated input/output/error types, schema as const.
33///
34/// Prefer this trait when writing new tools; wrap with [`ErasedAdapter`] to
35/// register into a [`Registry`].
36///
37/// ```ignore
38/// use agnt_core::tool::{TypedTool, ErasedAdapter, Registry};
39/// use serde::{Serialize, Deserialize};
40///
41/// #[derive(Deserialize)] struct Args { a: i64, b: i64 }
42/// #[derive(Serialize)] struct Out { sum: i64 }
43///
44/// struct Add;
45/// impl TypedTool for Add {
46///     type Args = Args;
47///     type Output = Out;
48///     type Error = String;
49///     const NAME: &'static str = "add";
50///     const DESCRIPTION: &'static str = "Add two integers.";
51///     fn schema() -> serde_json::Value {
52///         serde_json::json!({
53///             "type": "object",
54///             "properties": {
55///                 "a": { "type": "integer" },
56///                 "b": { "type": "integer" }
57///             },
58///             "required": ["a", "b"]
59///         })
60///     }
61///     fn call(&self, args: Args) -> Result<Out, String> {
62///         Ok(Out { sum: args.a + args.b })
63///     }
64///  }
65///
66/// let mut reg = Registry::new();
67/// reg.register(Box::new(ErasedAdapter::new(Add)));
68/// ```
69pub trait TypedTool: Send + Sync {
70    /// Argument type, deserialized from JSON.
71    type Args: serde::de::DeserializeOwned + Send;
72    /// Return type, serialized to JSON.
73    type Output: serde::Serialize + Send;
74    /// Error type (displayed as a string when bridged to the erased trait).
75    type Error: std::fmt::Display + Send + Sync;
76
77    /// The tool name exposed to the model.
78    const NAME: &'static str;
79    /// Human-readable description for model steering.
80    const DESCRIPTION: &'static str;
81
82    /// JSON Schema for the arguments object.
83    fn schema() -> Value;
84
85    /// Execute the tool.
86    fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error>;
87}
88
89/// Adapter that turns any [`TypedTool`] into an erased [`Tool`].
90///
91/// Deserializes the incoming `serde_json::Value` into `T::Args`, calls the
92/// typed impl, and serializes the output back to a JSON string. Errors at any
93/// stage are flattened to `Err(String)`.
94pub struct ErasedAdapter<T: TypedTool> {
95    inner: T,
96    _marker: PhantomData<fn() -> T>,
97}
98
99impl<T: TypedTool> ErasedAdapter<T> {
100    pub fn new(inner: T) -> Self {
101        Self {
102            inner,
103            _marker: PhantomData,
104        }
105    }
106
107    /// Access the underlying typed tool.
108    pub fn inner(&self) -> &T {
109        &self.inner
110    }
111}
112
113impl<T: TypedTool> Tool for ErasedAdapter<T> {
114    fn name(&self) -> &str {
115        T::NAME
116    }
117
118    fn description(&self) -> &str {
119        T::DESCRIPTION
120    }
121
122    fn schema(&self) -> Value {
123        T::schema()
124    }
125
126    fn call(&self, args: Value) -> Result<String, String> {
127        let typed: T::Args = serde_json::from_value(args)
128            .map_err(|e| format!("args deserialize: {}", e))?;
129        let out = self.inner.call(typed).map_err(|e| e.to_string())?;
130        serde_json::to_string(&out).map_err(|e| format!("output serialize: {}", e))
131    }
132}
133
134/// A collection of tools with name-based dispatch.
135///
136/// The [`Agent`](crate::Agent) holds a `Registry` and uses it to dispatch
137/// tool calls from the model. Tools can be registered at any time before
138/// or between calls to [`Agent::step`](crate::Agent::step).
139pub struct Registry {
140    tools: Vec<Box<dyn Tool>>,
141}
142
143impl Registry {
144    pub fn new() -> Self {
145        Self { tools: Vec::new() }
146    }
147
148    pub fn register(&mut self, tool: Box<dyn Tool>) {
149        self.tools.push(tool);
150    }
151
152    /// Register a [`TypedTool`] directly, wrapping it in an [`ErasedAdapter`].
153    pub fn register_typed<T: TypedTool + 'static>(&mut self, tool: T) {
154        self.tools.push(Box::new(ErasedAdapter::new(tool)));
155    }
156
157    pub fn dispatch(&self, name: &str, args: Value) -> Result<String, String> {
158        self.tools
159            .iter()
160            .find(|t| t.name() == name)
161            .ok_or_else(|| format!("unknown tool: {}", name))?
162            .call(args)
163    }
164
165    pub fn names(&self) -> Vec<&str> {
166        self.tools.iter().map(|t| t.name()).collect()
167    }
168
169    pub fn as_openai_tools(&self) -> Value {
170        Value::Array(
171            self.tools
172                .iter()
173                .map(|t| {
174                    serde_json::json!({
175                        "type": "function",
176                        "function": {
177                            "name": t.name(),
178                            "description": t.description(),
179                            "parameters": t.schema(),
180                        }
181                    })
182                })
183                .collect(),
184        )
185    }
186}
187
188impl Default for Registry {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use serde::{Deserialize, Serialize};
198
199    #[derive(Deserialize)]
200    struct AddArgs {
201        a: i64,
202        b: i64,
203    }
204
205    #[derive(Serialize)]
206    struct AddOut {
207        sum: i64,
208    }
209
210    struct Add;
211    impl TypedTool for Add {
212        type Args = AddArgs;
213        type Output = AddOut;
214        type Error = String;
215        const NAME: &'static str = "add";
216        const DESCRIPTION: &'static str = "Add two integers.";
217        fn schema() -> Value {
218            serde_json::json!({
219                "type": "object",
220                "properties": {
221                    "a": {"type": "integer"},
222                    "b": {"type": "integer"}
223                },
224                "required": ["a", "b"]
225            })
226        }
227        fn call(&self, args: AddArgs) -> Result<AddOut, String> {
228            Ok(AddOut { sum: args.a + args.b })
229        }
230    }
231
232    #[test]
233    fn typed_tool_roundtrips_through_erased_adapter() {
234        let mut reg = Registry::new();
235        reg.register_typed(Add);
236        let out = reg
237            .dispatch("add", serde_json::json!({"a": 2, "b": 3}))
238            .expect("dispatch");
239        assert_eq!(out, r#"{"sum":5}"#);
240    }
241
242    #[test]
243    fn typed_tool_args_deserialize_error_is_string() {
244        let mut reg = Registry::new();
245        reg.register_typed(Add);
246        let err = reg
247            .dispatch("add", serde_json::json!({"a": "not-a-number"}))
248            .unwrap_err();
249        assert!(err.contains("args deserialize"), "got: {}", err);
250    }
251
252    #[test]
253    fn erased_adapter_name_and_description_are_const() {
254        let adapter = ErasedAdapter::new(Add);
255        assert_eq!(adapter.name(), "add");
256        assert_eq!(adapter.description(), "Add two integers.");
257    }
258}