Skip to main content

rig_core/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    completion::{self, ToolDefinition},
22    embeddings::{embed::EmbedError, tool::ToolSchema},
23    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28    #[cfg(not(target_family = "wasm"))]
29    /// Error returned by the tool
30    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32    #[cfg(target_family = "wasm")]
33    /// Error returned by the tool
34    ToolCallError(#[from] Box<dyn std::error::Error>),
35    /// Error caused by a de/serialization fail
36    JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            ToolError::ToolCallError(e) => {
43                let error_str = e.to_string();
44                // This is required due to being able to use agents as tools
45                // which means it is possible to get recursive tool call errors
46                if error_str.starts_with("ToolCallError: ") {
47                    write!(f, "{}", error_str)
48                } else {
49                    write!(f, "ToolCallError: {}", error_str)
50                }
51            }
52            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53        }
54    }
55}
56
57/// Trait that represents a simple LLM tool
58///
59/// # Example
60/// ```
61/// use rig_core::{
62///     completion::ToolDefinition,
63///     tool::{ToolSet, Tool},
64/// };
65///
66/// #[derive(serde::Deserialize)]
67/// struct AddArgs {
68///     x: i32,
69///     y: i32,
70/// }
71///
72/// #[derive(Debug, thiserror::Error)]
73/// #[error("Math error")]
74/// struct MathError;
75///
76/// #[derive(serde::Deserialize, serde::Serialize)]
77/// struct Adder;
78///
79/// impl Tool for Adder {
80///     const NAME: &'static str = "add";
81///
82///     type Error = MathError;
83///     type Args = AddArgs;
84///     type Output = i32;
85///
86///     async fn definition(&self, _prompt: String) -> ToolDefinition {
87///         ToolDefinition {
88///             name: "add".to_string(),
89///             description: "Add x and y together".to_string(),
90///             parameters: serde_json::json!({
91///                 "type": "object",
92///                 "properties": {
93///                     "x": {
94///                         "type": "number",
95///                         "description": "The first number to add"
96///                     },
97///                     "y": {
98///                         "type": "number",
99///                         "description": "The second number to add"
100///                     }
101///                 }
102///             })
103///         }
104///     }
105///
106///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
107///         let result = args.x + args.y;
108///         Ok(result)
109///     }
110/// }
111/// ```
112pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113    /// The name of the tool. This name should be unique within a single
114    /// [`ToolSet`] or other registration scope that dispatches tools by name.
115    const NAME: &'static str;
116
117    /// The error type of the tool.
118    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119    /// The arguments type of the tool.
120    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121    /// The output type of the tool.
122    type Output: Serialize;
123
124    /// A method returning the name of the tool.
125    fn name(&self) -> String {
126        Self::NAME.to_string()
127    }
128
129    /// A method returning the tool definition. The user prompt can be used to
130    /// tailor the definition to the specific use case.
131    fn definition(
132        &self,
133        _prompt: String,
134    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136    /// The tool execution method.
137    /// Both the arguments and return value are a String since these values are meant to
138    /// be the output and input of LLM models (respectively)
139    fn call(
140        &self,
141        args: Self::Args,
142    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
146pub trait ToolEmbedding: Tool {
147    /// Error returned when reconstructing a dynamic tool from stored context.
148    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
149
150    /// Type of the tool' context. This context will be saved and loaded from the
151    /// vector store when ragging the tool.
152    /// This context can be used to store the tool's static configuration and local
153    /// context.
154    type Context: for<'a> Deserialize<'a> + Serialize;
155
156    /// Type of the tool's state. This state will be passed to the tool when initializing it.
157    /// This state can be used to pass runtime arguments to the tool such as clients,
158    /// API keys and other configuration.
159    type State: WasmCompatSend;
160
161    /// A method returning the documents that will be used as embeddings for the tool.
162    /// This allows for a tool to be retrieved from multiple embedding "directions".
163    /// If the tool will not be RAGged, this method should return an empty vector.
164    fn embedding_docs(&self) -> Vec<String>;
165
166    /// A method returning the context of the tool.
167    fn context(&self) -> Self::Context;
168
169    /// A method to initialize the tool from the context, and a state.
170    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
171}
172
173/// Wrapper trait to allow for dynamic dispatch of simple tools
174pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
175    /// Returns the tool name used for dispatch.
176    fn name(&self) -> String;
177
178    /// Returns the provider-facing tool schema.
179    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
180
181    /// Calls the tool with JSON-encoded arguments and returns model-facing text.
182    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
183}
184
185fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
186    match serde_json::to_value(output)? {
187        serde_json::Value::String(text) => Ok(text),
188        value => Ok(value.to_string()),
189    }
190}
191
192impl<T: Tool> ToolDyn for T {
193    fn name(&self) -> String {
194        self.name()
195    }
196
197    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
198        Box::pin(<Self as Tool>::definition(self, prompt))
199    }
200
201    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
202        Box::pin(async move {
203            match serde_json::from_str(&args) {
204                Ok(args) => <Self as Tool>::call(self, args)
205                    .await
206                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
207                    .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
208                Err(e) => Err(ToolError::JsonError(e)),
209            }
210        })
211    }
212}
213
214#[cfg(feature = "rmcp")]
215#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
216pub mod rmcp;
217
218/// Wrapper trait to allow for dynamic dispatch of raggable tools
219pub trait ToolEmbeddingDyn: ToolDyn {
220    /// Serializes context needed to reconstruct this dynamic tool.
221    fn context(&self) -> serde_json::Result<serde_json::Value>;
222
223    /// Returns text fragments used to retrieve this tool from a vector store.
224    fn embedding_docs(&self) -> Vec<String>;
225}
226
227impl<T> ToolEmbeddingDyn for T
228where
229    T: ToolEmbedding + 'static,
230{
231    fn context(&self) -> serde_json::Result<serde_json::Value> {
232        serde_json::to_value(self.context())
233    }
234
235    fn embedding_docs(&self) -> Vec<String> {
236        self.embedding_docs()
237    }
238}
239
240#[derive(Clone)]
241pub(crate) enum ToolType {
242    Simple(Arc<dyn ToolDyn>),
243    Embedding(Arc<dyn ToolEmbeddingDyn>),
244}
245
246impl ToolType {
247    pub fn name(&self) -> String {
248        match self {
249            ToolType::Simple(tool) => tool.name(),
250            ToolType::Embedding(tool) => tool.name(),
251        }
252    }
253
254    pub async fn definition(&self, prompt: String) -> ToolDefinition {
255        match self {
256            ToolType::Simple(tool) => tool.definition(prompt).await,
257            ToolType::Embedding(tool) => tool.definition(prompt).await,
258        }
259    }
260
261    pub async fn call(&self, args: String) -> Result<String, ToolError> {
262        match self {
263            ToolType::Simple(tool) => tool.call(args).await,
264            ToolType::Embedding(tool) => tool.call(args).await,
265        }
266    }
267}
268
269#[derive(Debug, thiserror::Error)]
270pub enum ToolSetError {
271    /// Error returned by the tool
272    #[error("ToolCallError: {0}")]
273    ToolCallError(#[from] ToolError),
274
275    /// Could not find a tool
276    #[error("ToolNotFoundError: {0}")]
277    ToolNotFoundError(String),
278
279    /// JSON serialization or deserialization failed while preparing tool data.
280    #[error("JsonError: {0}")]
281    JsonError(#[from] serde_json::Error),
282
283    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
284    #[error("Tool call interrupted")]
285    Interrupted,
286}
287
288/// A struct that holds a set of tools
289#[derive(Default)]
290pub struct ToolSet {
291    pub(crate) tools: HashMap<String, ToolType>,
292}
293
294impl ToolSet {
295    /// Create a new ToolSet from a list of tools
296    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
297        let mut toolset = Self::default();
298        tools.into_iter().for_each(|tool| {
299            toolset.add_tool(tool);
300        });
301        toolset
302    }
303
304    /// Create a new `ToolSet` from boxed dynamically-dispatched tools.
305    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
306        let mut toolset = Self::default();
307        tools.into_iter().for_each(|tool| {
308            toolset.add_tool_boxed(tool);
309        });
310        toolset
311    }
312
313    /// Create a toolset builder
314    pub fn builder() -> ToolSetBuilder {
315        ToolSetBuilder::default()
316    }
317
318    /// Check if the toolset contains a tool with the given name
319    pub fn contains(&self, toolname: &str) -> bool {
320        self.tools.contains_key(toolname)
321    }
322
323    /// Add a tool to the toolset
324    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
325        self.tools
326            .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
327    }
328
329    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
330    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
331        self.tools
332            .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
333    }
334
335    /// Remove a tool by name. Missing tools are ignored.
336    pub fn delete_tool(&mut self, tool_name: &str) {
337        let _ = self.tools.remove(tool_name);
338    }
339
340    /// Merge another toolset into this one
341    pub fn add_tools(&mut self, toolset: ToolSet) {
342        self.tools.extend(toolset.tools);
343    }
344
345    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
346        self.tools.get(toolname)
347    }
348
349    /// Return definitions for all tools currently registered in the set.
350    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
351        let mut defs = Vec::new();
352        for tool in self.tools.values() {
353            let def = tool.definition(String::new()).await;
354            defs.push(def);
355        }
356        Ok(defs)
357    }
358
359    /// Call a tool with the given name and arguments
360    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
361        if let Some(tool) = self.tools.get(toolname) {
362            tracing::debug!(target: "rig",
363                "Calling tool {toolname} with args:\n{}",
364                args
365            );
366            Ok(tool.call(args).await?)
367        } else {
368            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
369        }
370    }
371
372    /// Get the documents of all the tools in the toolset
373    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
374        let mut docs = Vec::new();
375        for tool in self.tools.values() {
376            match tool {
377                ToolType::Simple(tool) => {
378                    docs.push(completion::Document {
379                        id: tool.name(),
380                        text: format!(
381                            "\
382                            Tool: {}\n\
383                            Definition: \n\
384                            {}\
385                        ",
386                            tool.name(),
387                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
388                        ),
389                        additional_props: HashMap::new(),
390                    });
391                }
392                ToolType::Embedding(tool) => {
393                    docs.push(completion::Document {
394                        id: tool.name(),
395                        text: format!(
396                            "\
397                            Tool: {}\n\
398                            Definition: \n\
399                            {}\
400                        ",
401                            tool.name(),
402                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
403                        ),
404                        additional_props: HashMap::new(),
405                    });
406                }
407            }
408        }
409        Ok(docs)
410    }
411
412    /// Convert tools in self to objects of type ToolSchema.
413    /// This is necessary because when adding tools to the EmbeddingBuilder because all
414    /// documents added to the builder must all be of the same type.
415    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
416        self.tools
417            .values()
418            .filter_map(|tool_type| {
419                if let ToolType::Embedding(tool) = tool_type {
420                    Some(ToolSchema::try_from(&**tool))
421                } else {
422                    None
423                }
424            })
425            .collect::<Result<Vec<_>, _>>()
426    }
427}
428
429#[derive(Default)]
430/// Builder for constructing a [`ToolSet`] with static and dynamic tools.
431pub struct ToolSetBuilder {
432    tools: Vec<ToolType>,
433}
434
435impl ToolSetBuilder {
436    /// Add a regular tool that is always available when the set is used.
437    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
438        self.tools.push(ToolType::Simple(Arc::new(tool)));
439        self
440    }
441
442    /// Add a tool that can be represented as embeddings for dynamic retrieval.
443    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
444        self.tools.push(ToolType::Embedding(Arc::new(tool)));
445        self
446    }
447
448    /// Build the tool set, keyed by each tool's name.
449    pub fn build(self) -> ToolSet {
450        ToolSet {
451            tools: self
452                .tools
453                .into_iter()
454                .map(|tool| (tool.name(), tool))
455                .collect(),
456        }
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use crate::message::{DocumentSourceKind, ToolResultContent};
463    use crate::test_utils::{
464        MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool, mock_math_toolset,
465    };
466    use serde_json::json;
467
468    use super::*;
469
470    fn get_test_toolset() -> ToolSet {
471        mock_math_toolset()
472    }
473
474    #[tokio::test]
475    async fn test_get_tool_definitions() {
476        let toolset = get_test_toolset();
477        let tools = toolset.get_tool_definitions().await.unwrap();
478        assert_eq!(tools.len(), 2);
479    }
480
481    #[test]
482    fn test_tool_deletion() {
483        let mut toolset = get_test_toolset();
484        assert_eq!(toolset.tools.len(), 2);
485        toolset.delete_tool("add");
486        assert!(!toolset.contains("add"));
487        assert_eq!(toolset.tools.len(), 1);
488    }
489
490    #[tokio::test]
491    async fn string_tool_outputs_are_preserved_verbatim() {
492        let mut toolset = ToolSet::default();
493        toolset.add_tool(MockStringOutputTool);
494
495        let output = toolset
496            .call("string_output", "{}".to_string())
497            .await
498            .expect("tool should succeed");
499
500        assert_eq!(output, "Hello\nWorld");
501    }
502
503    #[tokio::test]
504    async fn structured_string_tool_outputs_remain_parseable() {
505        let mut toolset = ToolSet::default();
506        toolset.add_tool(MockImageOutputTool);
507
508        let output = toolset
509            .call("image_output", "{}".to_string())
510            .await
511            .expect("tool should succeed");
512        let content = ToolResultContent::from_tool_output(output);
513
514        assert_eq!(content.len(), 1);
515        match content.first() {
516            ToolResultContent::Image(image) => {
517                assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
518                assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
519            }
520            other => panic!("expected image tool result content, got {other:?}"),
521        }
522    }
523
524    #[tokio::test]
525    async fn object_tool_outputs_still_serialize_as_json() {
526        let mut toolset = ToolSet::default();
527        toolset.add_tool(MockObjectOutputTool);
528
529        let output = toolset
530            .call("object_output", "{}".to_string())
531            .await
532            .expect("tool should succeed");
533
534        assert!(output.starts_with('{'));
535        assert_eq!(
536            serde_json::from_str::<serde_json::Value>(&output).unwrap(),
537            json!({
538                "status": "ok",
539                "count": 42
540            })
541        );
542    }
543}