Skip to main content

llmy_agent/
tool.rs

1//! Tool definitions and the [`ToolBox`] registry used by agents.
2//!
3//! This module exposes two traits for describing tools that a language model
4//! can invoke:
5//!
6//! * [`Tool`] — the typed, ergonomic trait that user code implements (or has
7//!   generated via the [`llmy_agent_derive::tool`] attribute macro, re-exported
8//!   from `llmy_agent` as `llmy_agent::tool`). Each `Tool` declares a
9//!   strongly-typed `ARGUMENTS` type, a `NAME`, an optional `DESCRIPTION`,
10//!   and an `invoke` method that receives already-deserialized arguments.
11//! * [`ToolDyn`] — the object-safe counterpart, automatically implemented for
12//!   every `Tool`. Agents store tools as `dyn ToolDyn` so that a heterogeneous
13//!   set of tools can be kept in a single collection.
14//!
15//! Tools are grouped together in a [`ToolBox`], which exposes them to the
16//! model (via [`ToolBox::openai_objects`]) and dispatches incoming tool calls
17//! to the matching implementation.
18
19use std::collections::BTreeMap;
20use std::fmt::Debug;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24
25use dyn_clone::DynClone;
26use llmy_client::req::{
27    ChatCompletionRequestMessageRaw, ChatCompletionRequestToolMessageContent,
28    ChatCompletionRequestToolMessageRaw, ChatCompletionTool, ChatCompletionToolRaw,
29    ChatCompletionTools, ChatCompletionToolsRaw, FunctionObjectRaw,
30};
31use llmy_types::error::{GeneralToolCall, LLMYError};
32use llmy_types::other::WithOtherFields;
33use schemars::schema_for;
34use serde::de::DeserializeOwned;
35use tokio::task::JoinSet;
36use tracing::debug;
37
38/// Object-safe view of a [`Tool`].
39///
40/// `ToolDyn` erases the `ARGUMENTS` associated type so that tools of different
41/// shapes can be stored together (for example inside a [`ToolBox`]). It is
42/// implemented automatically for every `T: Tool + 'static`, so library users
43/// rarely need to implement it directly — implement [`Tool`] instead.
44///
45/// All methods take `&self` and the trait is `Send + Sync + Clone` (via
46/// [`dyn_clone`]), which lets a tool be cheaply cloned into background tasks.
47pub trait ToolDyn: DynClone + Debug + Send + Sync + std::any::Any {
48    /// Returns the tool's name as advertised to the model. Must be unique
49    /// within a [`ToolBox`].
50    fn name(&self) -> String;
51    /// Returns the human-readable description shown to the model, if any.
52    fn description(&self) -> Option<String>;
53    /// Returns the JSON Schema describing this tool's expected arguments.
54    fn schema(&self) -> schemars::Schema;
55    /// Whether the model should honour the JSON schema strictly.
56    fn strict(&self) -> bool {
57        false
58    }
59    /// Renders the tool as an OpenAI [`ChatCompletionTool`] descriptor,
60    /// including its JSON schema, ready to be sent in a chat completion
61    /// request.
62    fn to_openai_obejct(&self) -> ChatCompletionTool {
63        WithOtherFields::new(ChatCompletionToolRaw {
64            function: WithOtherFields::new(FunctionObjectRaw {
65                name: self.name(),
66                description: self.description(),
67                parameters: Some(
68                    serde_json::to_value(self.schema()).expect("Fail to serialize schema"),
69                ),
70                strict: Some(self.strict()),
71            }),
72        })
73    }
74    /// Renders the tool as an MCP [`rmcp::model::Tool`] descriptor.
75    fn to_mcp_tool(&self) -> rmcp::model::Tool {
76        let input_schema = serde_json::to_value(self.schema()).expect("Fail to serialize schema");
77        let input_schema = input_schema.as_object().cloned().unwrap_or_default();
78        rmcp::model::Tool::new_with_raw(
79            self.name(),
80            self.description().map(Into::into),
81            Arc::new(input_schema),
82        )
83    }
84    /// Invokes the tool with raw JSON-encoded `arguments`. The string is
85    /// deserialized into the tool's `ARGUMENTS` type by the blanket impl on
86    /// top of [`Tool`].
87    fn call(
88        &self,
89        arguments: String,
90    ) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>> {
91        Box::pin(async move {
92            match serde_json::from_str::<serde_json::Value>(&arguments) {
93                Ok(value) => self.run(value).await,
94                Err(_) => Err(LLMYError::IncorrectToolCall(
95                    self.name(),
96                    arguments,
97                    self.schema(),
98                )),
99            }
100        })
101    }
102    /// Invokes the tool with a [`serde_json::Value`] as arguments.
103    fn run(
104        &self,
105        arguments: serde_json::Value,
106    ) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>>;
107}
108
109/// Downcasts a `&dyn ToolDyn` to a concrete tool type.
110///
111/// # Panics
112///
113/// Panics if `tool` is not actually an instance of `T`. Use this only when the
114/// concrete type is known by construction — for general dispatch, prefer the
115/// trait methods on [`ToolDyn`].
116pub fn downcast_tool<T: 'static>(tool: &dyn ToolDyn) -> &T {
117    (tool as &dyn std::any::Any)
118        .downcast_ref::<T>()
119        .expect("can not downcast")
120}
121
122dyn_clone::clone_trait_object!(ToolDyn);
123
124/// A typed tool that an agent can call.
125///
126/// Implementors describe the tool with associated constants and an
127/// [`Self::invoke`] method that receives already-deserialized arguments. The
128/// blanket `impl<T: Tool> ToolDyn for T` takes care of JSON deserialization,
129/// schema generation and OpenAI-shaped serialization, so most call sites only
130/// ever interact with [`ToolDyn`].
131///
132/// # Deriving an implementation
133///
134/// The companion [`llmy_agent_derive::tool`] attribute macro (re-exported as
135/// `llmy_agent::tool`, and also reachable through the umbrella crate as
136/// `llmy::agent::tool`) can generate this trait for a struct, wiring the
137/// associated constants and forwarding `invoke` to a method on the struct:
138///
139/// ```ignore
140/// use llmy_agent::tool;
141/// use llmy_types::error::LLMYError;
142/// use schemars::JsonSchema;
143/// use serde::Deserialize;
144///
145/// #[derive(Deserialize, JsonSchema)]
146/// struct EchoArgs { message: String }
147///
148/// #[derive(Clone, Debug)]
149/// #[tool(
150///     description = "Echo a message back",
151///     arguments = EchoArgs,
152///     invoke = run,
153/// )]
154/// struct EchoTool;
155///
156/// impl EchoTool {
157///     async fn run(&self, args: EchoArgs) -> Result<String, LLMYError> {
158///         Ok(args.message)
159///     }
160/// }
161/// ```
162///
163/// The macro accepts `description`, `arguments`, `invoke` (required) and an
164/// optional `name` (defaulting to the struct identifier in `snake_case`).
165pub trait Tool: Send + Sync + DynClone + Debug {
166    /// The strongly-typed argument struct. It must implement
167    /// [`serde::de::DeserializeOwned`] (to be parsed from the model's JSON
168    /// payload) and [`schemars::JsonSchema`] (to generate the schema sent to
169    /// the model).
170    type ARGUMENTS: DeserializeOwned + schemars::JsonSchema + Sized + Send;
171    /// Unique name advertised to the model.
172    const NAME: &str;
173    /// Optional human-readable description shown to the model.
174    const DESCRIPTION: Option<&str>;
175    /// Whether the model should be asked to honour the JSON schema strictly.
176    /// Maps to OpenAI's `strict` field on the function descriptor.
177    const STRICT: bool = false;
178
179    /// Performs the tool's actual work on already-deserialized `arguments`
180    /// and returns the textual result that will be sent back to the model.
181    fn invoke(
182        &self,
183        arguments: Self::ARGUMENTS,
184    ) -> impl Future<Output = Result<String, LLMYError>> + Send;
185}
186
187impl<T: Tool + DynClone + 'static> ToolDyn for T {
188    fn name(&self) -> String {
189        Self::NAME.to_string()
190    }
191    fn description(&self) -> Option<String> {
192        Self::DESCRIPTION.map(|v| v.to_string())
193    }
194    fn schema(&self) -> schemars::Schema {
195        schema_for!(T::ARGUMENTS)
196    }
197    fn strict(&self) -> bool {
198        T::STRICT
199    }
200
201    fn run(
202        &self,
203        arguments: serde_json::Value,
204    ) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>> {
205        Box::pin(async move {
206            match serde_json::from_value::<T::ARGUMENTS>(arguments.clone()) {
207                Ok(args) => self.invoke(args).await,
208                Err(_) => Err(LLMYError::IncorrectToolCall(
209                    T::NAME.to_string(),
210                    arguments.to_string(),
211                    schema_for!(T::ARGUMENTS),
212                )),
213            }
214        })
215    }
216}
217
218/// A name-keyed registry of tools available to an agent.
219///
220/// `ToolBox` owns its tools behind `Arc<Box<dyn ToolDyn>>`, so cloning the
221/// box is cheap and the same set of tools can be shared across concurrent
222/// invocations. Tools are stored in a [`BTreeMap`], so iteration order is
223/// stable and sorted by name.
224#[derive(Default, Clone, Debug)]
225pub struct ToolBox {
226    tools: BTreeMap<String, Arc<Box<dyn ToolDyn>>>,
227}
228
229impl ToolBox {
230    /// Creates an empty `ToolBox`.
231    pub fn new() -> Self {
232        Self::default()
233    }
234
235    /// Returns the number of registered tools.
236    pub fn len(&self) -> usize {
237        self.tools.len()
238    }
239
240    /// Renders the registered tool names, optionally with their descriptions.
241    ///
242    /// When `details` is `true` each entry is formatted as
243    /// `` `name`: "description" ``; otherwise only the bare name is returned.
244    /// Useful when surfacing the tool list inside a system prompt.
245    pub fn render_tools(&self, details: bool) -> Vec<String> {
246        self.tools
247            .iter()
248            .map(|(name, tool)| {
249                if details {
250                    format!(
251                        "`{}`: {:?}", // description may contain new lines
252                        name,
253                        tool.description()
254                            .unwrap_or_else(|| "no description is provided".to_string())
255                    )
256                } else {
257                    name.clone()
258                }
259            })
260            .collect()
261    }
262
263    /// Merges another `ToolBox` into `self`. Tools in `rhs` overwrite any
264    /// existing entries that share a name.
265    pub fn extend(&mut self, rhs: Self) {
266        self.tools.extend(rhs.tools.into_iter());
267    }
268
269    /// Returns whether a tool with the given name is registered.
270    pub fn has_tool(&self, tool: &String) -> bool {
271        self.tools.contains_key(tool)
272    }
273
274    /// Renders every registered tool as an MCP [`rmcp::model::Tool`]
275    /// descriptor.
276    pub fn mcp_tools(&self) -> Vec<rmcp::model::Tool> {
277        self.tools.values().map(|t| t.to_mcp_tool()).collect()
278    }
279
280    /// Renders every registered tool as an OpenAI `ChatCompletionTools`
281    /// entry, ready to be attached to a chat completion request.
282    pub fn openai_objects(&self) -> Vec<ChatCompletionTools> {
283        self.tools
284            .iter()
285            .map(|t| WithOtherFields::new(ChatCompletionToolsRaw::Function(t.1.to_openai_obejct())))
286            .collect()
287    }
288
289    /// Registers a typed [`Tool`]. Equivalent to boxing it and calling
290    /// [`Self::add_dyn_tool`].
291    pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
292        self.add_dyn_tool(Box::new(tool) as _);
293    }
294
295    /// Registers an already-erased [`ToolDyn`]. The tool's
296    /// [`ToolDyn::name`] is used as the registry key, so adding a tool whose
297    /// name collides with an existing one will replace the previous entry.
298    pub fn add_dyn_tool(&mut self, tool: Box<dyn ToolDyn>) {
299        self.tools.insert(tool.name(), Arc::new(tool));
300    }
301
302    /// Invokes a single tool by name with the given JSON-encoded arguments.
303    ///
304    /// Returns `None` if no tool with that name is registered. Otherwise
305    /// returns `Some` with the tool's result (or an [`LLMYError`] from
306    /// argument parsing or the tool itself).
307    pub async fn invoke(
308        &self,
309        tool_name: String,
310        arguments: String,
311    ) -> Option<Result<String, LLMYError>> {
312        if let Some(tool) = self.tools.get(&tool_name) {
313            debug!("Invoking tool {} with arguments {}", &tool_name, &arguments);
314            Some(tool.call(arguments).await)
315        } else {
316            None
317        }
318    }
319
320    pub async fn invoke_value(
321        &self,
322        tool_name: String,
323        arguments: serde_json::Value,
324    ) -> Option<Result<String, LLMYError>> {
325        if let Some(tool) = self.tools.get(&tool_name) {
326            debug!("Invoking tool {} with arguments {}", &tool_name, &arguments);
327            Some(tool.run(arguments).await)
328        } else {
329            None
330        }
331    }
332
333    /// Concurrently invokes every call in `calls`, spawning each one onto a
334    /// [`tokio::task::JoinSet`].
335    ///
336    /// Each result is paired with the original [`GeneralToolCall`] so the
337    /// caller can correlate it back to a specific invocation. Use
338    /// [`Self::invoke_many_sequential`] when ordering matters or when tools
339    /// must not run in parallel.
340    pub async fn invoke_many(
341        &self,
342        calls: Vec<GeneralToolCall>,
343    ) -> Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)> {
344        let mut js = JoinSet::new();
345        for call in calls {
346            let tb = self.clone();
347            js.spawn(async move {
348                let tc: GeneralToolCall = call.clone();
349                tracing::info!("Calling {}", &tc);
350                (tc, tb.invoke(call.tool_name, call.tool_args).await)
351            });
352        }
353
354        js.join_all().await
355    }
356
357    /// Sequentially invokes every call in `calls`, awaiting each one before
358    /// starting the next. Preserves input order and avoids any concurrency
359    /// between tools — pick this over [`Self::invoke_many`] when tools share
360    /// non-`Sync` state or must observe one another's side effects.
361    pub async fn invoke_many_sequential(
362        &self,
363        calls: Vec<GeneralToolCall>,
364    ) -> Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)> {
365        let mut out = Vec::with_capacity(calls.len());
366
367        for call in calls {
368            let tc: GeneralToolCall = call.clone();
369            tracing::info!("Calling {}", &tc);
370            out.push((tc, self.invoke(call.tool_name, call.tool_args).await));
371        }
372
373        out
374    }
375
376    /// Concurrent variant of [`Self::invoke_many`] that wraps each successful
377    /// result in a [`ChatCompletionRequestMessage`] (a tool message tagged
378    /// with the originating `tool_id`), ready to be appended to a
379    /// conversation history.
380    pub async fn agent_invoke_many(
381        &self,
382        calls: Vec<GeneralToolCall>,
383    ) -> Vec<(
384        GeneralToolCall,
385        Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
386    )> {
387        let invokes = self.invoke_many(calls).await;
388        Self::agent_messages_from_invokes(invokes)
389    }
390
391    /// Sequential variant of [`Self::agent_invoke_many`].
392    pub async fn agent_invoke_many_sequential(
393        &self,
394        calls: Vec<GeneralToolCall>,
395    ) -> Vec<(
396        GeneralToolCall,
397        Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
398    )> {
399        let invokes = self.invoke_many_sequential(calls).await;
400        Self::agent_messages_from_invokes(invokes)
401    }
402
403    fn agent_messages_from_invokes(
404        invokes: Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)>,
405    ) -> Vec<(
406        GeneralToolCall,
407        Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
408    )> {
409        let mut out = vec![];
410        for (call, result) in invokes {
411            let id = call.tool_id.clone();
412            let result = result.map(|v| {
413                v.map(|s| {
414                    let tool_msg = ChatCompletionRequestToolMessageRaw {
415                        content: ChatCompletionRequestToolMessageContent::Text(s),
416                        tool_call_id: id,
417                    };
418                    ChatCompletionRequestMessageRaw::Tool(WithOtherFields::new(tool_msg))
419                })
420            });
421
422            out.push((call, result));
423        }
424
425        out
426    }
427}