Skip to main content

rig/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20    completion::{self, ToolDefinition},
21    embeddings::{embed::EmbedError, tool::ToolSchema},
22    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27    #[cfg(not(target_family = "wasm"))]
28    /// Error returned by the tool
29    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31    #[cfg(target_family = "wasm")]
32    /// Error returned by the tool
33    ToolCallError(#[from] Box<dyn std::error::Error>),
34    /// Error caused by a de/serialization fail
35    JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            ToolError::ToolCallError(e) => {
42                let error_str = e.to_string();
43                // This is required due to being able to use agents as tools
44                // which means it is possible to get recursive tool call errors
45                if error_str.starts_with("ToolCallError: ") {
46                    write!(f, "{}", error_str)
47                } else {
48                    write!(f, "ToolCallError: {}", error_str)
49                }
50            }
51            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52        }
53    }
54}
55
56/// Trait that represents a simple LLM tool
57///
58/// # Example
59/// ```
60/// use rig::{
61///     completion::ToolDefinition,
62///     tool::{ToolSet, Tool},
63/// };
64///
65/// #[derive(serde::Deserialize)]
66/// struct AddArgs {
67///     x: i32,
68///     y: i32,
69/// }
70///
71/// #[derive(Debug, thiserror::Error)]
72/// #[error("Math error")]
73/// struct MathError;
74///
75/// #[derive(serde::Deserialize, serde::Serialize)]
76/// struct Adder;
77///
78/// impl Tool for Adder {
79///     const NAME: &'static str = "add";
80///
81///     type Error = MathError;
82///     type Args = AddArgs;
83///     type Output = i32;
84///
85///     async fn definition(&self, _prompt: String) -> ToolDefinition {
86///         ToolDefinition {
87///             name: "add".to_string(),
88///             description: "Add x and y together".to_string(),
89///             parameters: serde_json::json!({
90///                 "type": "object",
91///                 "properties": {
92///                     "x": {
93///                         "type": "number",
94///                         "description": "The first number to add"
95///                     },
96///                     "y": {
97///                         "type": "number",
98///                         "description": "The second number to add"
99///                     }
100///                 }
101///             })
102///         }
103///     }
104///
105///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
106///         let result = args.x + args.y;
107///         Ok(result)
108///     }
109/// }
110/// ```
111pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112    /// The name of the tool. This name should be unique.
113    const NAME: &'static str;
114
115    /// The error type of the tool.
116    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117    /// The arguments type of the tool.
118    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119    /// The output type of the tool.
120    type Output: Serialize;
121
122    /// A method returning the name of the tool.
123    fn name(&self) -> String {
124        Self::NAME.to_string()
125    }
126
127    /// A method returning the tool definition. The user prompt can be used to
128    /// tailor the definition to the specific use case.
129    fn definition(
130        &self,
131        _prompt: String,
132    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134    /// The tool execution method.
135    /// Both the arguments and return value are a String since these values are meant to
136    /// be the output and input of LLM models (respectively)
137    fn call(
138        &self,
139        args: Self::Args,
140    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
144pub trait ToolEmbedding: Tool {
145    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147    /// Type of the tool' context. This context will be saved and loaded from the
148    /// vector store when ragging the tool.
149    /// This context can be used to store the tool's static configuration and local
150    /// context.
151    type Context: for<'a> Deserialize<'a> + Serialize;
152
153    /// Type of the tool's state. This state will be passed to the tool when initializing it.
154    /// This state can be used to pass runtime arguments to the tool such as clients,
155    /// API keys and other configuration.
156    type State: WasmCompatSend;
157
158    /// A method returning the documents that will be used as embeddings for the tool.
159    /// This allows for a tool to be retrieved from multiple embedding "directions".
160    /// If the tool will not be RAGged, this method should return an empty vector.
161    fn embedding_docs(&self) -> Vec<String>;
162
163    /// A method returning the context of the tool.
164    fn context(&self) -> Self::Context;
165
166    /// A method to initialize the tool from the context, and a state.
167    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170/// Wrapper trait to allow for dynamic dispatch of simple tools
171pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172    fn name(&self) -> String;
173
174    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180    fn name(&self) -> String {
181        self.name()
182    }
183
184    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185        Box::pin(<Self as Tool>::definition(self, prompt))
186    }
187
188    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189        Box::pin(async move {
190            match serde_json::from_str(&args) {
191                Ok(args) => <Self as Tool>::call(self, args)
192                    .await
193                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194                    .and_then(|output| {
195                        serde_json::to_string(&output).map_err(ToolError::JsonError)
196                    }),
197                Err(e) => Err(ToolError::JsonError(e)),
198            }
199        })
200    }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp;
206
207/// Wrapper trait to allow for dynamic dispatch of raggable tools
208pub trait ToolEmbeddingDyn: ToolDyn {
209    fn context(&self) -> serde_json::Result<serde_json::Value>;
210
211    fn embedding_docs(&self) -> Vec<String>;
212}
213
214impl<T> ToolEmbeddingDyn for T
215where
216    T: ToolEmbedding + 'static,
217{
218    fn context(&self) -> serde_json::Result<serde_json::Value> {
219        serde_json::to_value(self.context())
220    }
221
222    fn embedding_docs(&self) -> Vec<String> {
223        self.embedding_docs()
224    }
225}
226
227pub(crate) enum ToolType {
228    Simple(Box<dyn ToolDyn>),
229    Embedding(Box<dyn ToolEmbeddingDyn>),
230}
231
232impl ToolType {
233    pub fn name(&self) -> String {
234        match self {
235            ToolType::Simple(tool) => tool.name(),
236            ToolType::Embedding(tool) => tool.name(),
237        }
238    }
239
240    pub async fn definition(&self, prompt: String) -> ToolDefinition {
241        match self {
242            ToolType::Simple(tool) => tool.definition(prompt).await,
243            ToolType::Embedding(tool) => tool.definition(prompt).await,
244        }
245    }
246
247    pub async fn call(&self, args: String) -> Result<String, ToolError> {
248        match self {
249            ToolType::Simple(tool) => tool.call(args).await,
250            ToolType::Embedding(tool) => tool.call(args).await,
251        }
252    }
253}
254
255#[derive(Debug, thiserror::Error)]
256pub enum ToolSetError {
257    /// Error returned by the tool
258    #[error("ToolCallError: {0}")]
259    ToolCallError(#[from] ToolError),
260
261    /// Could not find a tool
262    #[error("ToolNotFoundError: {0}")]
263    ToolNotFoundError(String),
264
265    // TODO: Revisit this
266    #[error("JsonError: {0}")]
267    JsonError(#[from] serde_json::Error),
268
269    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
270    #[error("Tool call interrupted")]
271    Interrupted,
272}
273
274/// A struct that holds a set of tools
275#[derive(Default)]
276pub struct ToolSet {
277    pub(crate) tools: HashMap<String, ToolType>,
278}
279
280impl ToolSet {
281    /// Create a new ToolSet from a list of tools
282    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
283        let mut toolset = Self::default();
284        tools.into_iter().for_each(|tool| {
285            toolset.add_tool(tool);
286        });
287        toolset
288    }
289
290    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
291        let mut toolset = Self::default();
292        tools.into_iter().for_each(|tool| {
293            toolset.add_tool_boxed(tool);
294        });
295        toolset
296    }
297
298    /// Create a toolset builder
299    pub fn builder() -> ToolSetBuilder {
300        ToolSetBuilder::default()
301    }
302
303    /// Check if the toolset contains a tool with the given name
304    pub fn contains(&self, toolname: &str) -> bool {
305        self.tools.contains_key(toolname)
306    }
307
308    /// Add a tool to the toolset
309    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
310        self.tools
311            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
312    }
313
314    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
315    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
316        self.tools.insert(tool.name(), ToolType::Simple(tool));
317    }
318
319    pub fn delete_tool(&mut self, tool_name: &str) {
320        let _ = self.tools.remove(tool_name);
321    }
322
323    /// Merge another toolset into this one
324    pub fn add_tools(&mut self, toolset: ToolSet) {
325        self.tools.extend(toolset.tools);
326    }
327
328    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
329        self.tools.get(toolname)
330    }
331
332    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
333        let mut defs = Vec::new();
334        for tool in self.tools.values() {
335            let def = tool.definition(String::new()).await;
336            defs.push(def);
337        }
338        Ok(defs)
339    }
340
341    /// Call a tool with the given name and arguments
342    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
343        if let Some(tool) = self.tools.get(toolname) {
344            tracing::debug!(target: "rig",
345                "Calling tool {toolname} with args:\n{}",
346                serde_json::to_string_pretty(&args).unwrap()
347            );
348            Ok(tool.call(args).await?)
349        } else {
350            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
351        }
352    }
353
354    /// Get the documents of all the tools in the toolset
355    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
356        let mut docs = Vec::new();
357        for tool in self.tools.values() {
358            match tool {
359                ToolType::Simple(tool) => {
360                    docs.push(completion::Document {
361                        id: tool.name(),
362                        text: format!(
363                            "\
364                            Tool: {}\n\
365                            Definition: \n\
366                            {}\
367                        ",
368                            tool.name(),
369                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
370                        ),
371                        additional_props: HashMap::new(),
372                    });
373                }
374                ToolType::Embedding(tool) => {
375                    docs.push(completion::Document {
376                        id: tool.name(),
377                        text: format!(
378                            "\
379                            Tool: {}\n\
380                            Definition: \n\
381                            {}\
382                        ",
383                            tool.name(),
384                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
385                        ),
386                        additional_props: HashMap::new(),
387                    });
388                }
389            }
390        }
391        Ok(docs)
392    }
393
394    /// Convert tools in self to objects of type ToolSchema.
395    /// This is necessary because when adding tools to the EmbeddingBuilder because all
396    /// documents added to the builder must all be of the same type.
397    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
398        self.tools
399            .values()
400            .filter_map(|tool_type| {
401                if let ToolType::Embedding(tool) = tool_type {
402                    Some(ToolSchema::try_from(&**tool))
403                } else {
404                    None
405                }
406            })
407            .collect::<Result<Vec<_>, _>>()
408    }
409}
410
411#[derive(Default)]
412pub struct ToolSetBuilder {
413    tools: Vec<ToolType>,
414}
415
416impl ToolSetBuilder {
417    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
418        self.tools.push(ToolType::Simple(Box::new(tool)));
419        self
420    }
421
422    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
423        self.tools.push(ToolType::Embedding(Box::new(tool)));
424        self
425    }
426
427    pub fn build(self) -> ToolSet {
428        ToolSet {
429            tools: self
430                .tools
431                .into_iter()
432                .map(|tool| (tool.name(), tool))
433                .collect(),
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use serde_json::json;
441
442    use super::*;
443
444    fn get_test_toolset() -> ToolSet {
445        let mut toolset = ToolSet::default();
446
447        #[derive(Deserialize)]
448        struct OperationArgs {
449            x: i32,
450            y: i32,
451        }
452
453        #[derive(Debug, thiserror::Error)]
454        #[error("Math error")]
455        struct MathError;
456
457        #[derive(Deserialize, Serialize)]
458        struct Adder;
459
460        impl Tool for Adder {
461            const NAME: &'static str = "add";
462            type Error = MathError;
463            type Args = OperationArgs;
464            type Output = i32;
465
466            async fn definition(&self, _prompt: String) -> ToolDefinition {
467                ToolDefinition {
468                    name: "add".to_string(),
469                    description: "Add x and y together".to_string(),
470                    parameters: json!({
471                        "type": "object",
472                        "properties": {
473                            "x": {
474                                "type": "number",
475                                "description": "The first number to add"
476                            },
477                            "y": {
478                                "type": "number",
479                                "description": "The second number to add"
480                            }
481                        },
482                        "required": ["x", "y"]
483                    }),
484                }
485            }
486
487            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
488                let result = args.x + args.y;
489                Ok(result)
490            }
491        }
492
493        #[derive(Deserialize, Serialize)]
494        struct Subtract;
495
496        impl Tool for Subtract {
497            const NAME: &'static str = "subtract";
498            type Error = MathError;
499            type Args = OperationArgs;
500            type Output = i32;
501
502            async fn definition(&self, _prompt: String) -> ToolDefinition {
503                serde_json::from_value(json!({
504                    "name": "subtract",
505                    "description": "Subtract y from x (i.e.: x - y)",
506                    "parameters": {
507                        "type": "object",
508                        "properties": {
509                            "x": {
510                                "type": "number",
511                                "description": "The number to subtract from"
512                            },
513                            "y": {
514                                "type": "number",
515                                "description": "The number to subtract"
516                            }
517                        },
518                        "required": ["x", "y"]
519                    }
520                }))
521                .expect("Tool Definition")
522            }
523
524            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
525                let result = args.x - args.y;
526                Ok(result)
527            }
528        }
529
530        toolset.add_tool(Adder);
531        toolset.add_tool(Subtract);
532        toolset
533    }
534
535    #[tokio::test]
536    async fn test_get_tool_definitions() {
537        let toolset = get_test_toolset();
538        let tools = toolset.get_tool_definitions().await.unwrap();
539        assert_eq!(tools.len(), 2);
540    }
541
542    #[test]
543    fn test_tool_deletion() {
544        let mut toolset = get_test_toolset();
545        assert_eq!(toolset.tools.len(), 2);
546        toolset.delete_tool("add");
547        assert!(!toolset.contains("add"));
548        assert_eq!(toolset.tools.len(), 1);
549    }
550}