Skip to main content

rig/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    completion::{self, ToolDefinition},
22    embeddings::{embed::EmbedError, tool::ToolSchema},
23    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28    #[cfg(not(target_family = "wasm"))]
29    /// Error returned by the tool
30    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32    #[cfg(target_family = "wasm")]
33    /// Error returned by the tool
34    ToolCallError(#[from] Box<dyn std::error::Error>),
35    /// Error caused by a de/serialization fail
36    JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            ToolError::ToolCallError(e) => {
43                let error_str = e.to_string();
44                // This is required due to being able to use agents as tools
45                // which means it is possible to get recursive tool call errors
46                if error_str.starts_with("ToolCallError: ") {
47                    write!(f, "{}", error_str)
48                } else {
49                    write!(f, "ToolCallError: {}", error_str)
50                }
51            }
52            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53        }
54    }
55}
56
57/// Trait that represents a simple LLM tool
58///
59/// # Example
60/// ```
61/// use rig::{
62///     completion::ToolDefinition,
63///     tool::{ToolSet, Tool},
64/// };
65///
66/// #[derive(serde::Deserialize)]
67/// struct AddArgs {
68///     x: i32,
69///     y: i32,
70/// }
71///
72/// #[derive(Debug, thiserror::Error)]
73/// #[error("Math error")]
74/// struct MathError;
75///
76/// #[derive(serde::Deserialize, serde::Serialize)]
77/// struct Adder;
78///
79/// impl Tool for Adder {
80///     const NAME: &'static str = "add";
81///
82///     type Error = MathError;
83///     type Args = AddArgs;
84///     type Output = i32;
85///
86///     async fn definition(&self, _prompt: String) -> ToolDefinition {
87///         ToolDefinition {
88///             name: "add".to_string(),
89///             description: "Add x and y together".to_string(),
90///             parameters: serde_json::json!({
91///                 "type": "object",
92///                 "properties": {
93///                     "x": {
94///                         "type": "number",
95///                         "description": "The first number to add"
96///                     },
97///                     "y": {
98///                         "type": "number",
99///                         "description": "The second number to add"
100///                     }
101///                 }
102///             })
103///         }
104///     }
105///
106///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
107///         let result = args.x + args.y;
108///         Ok(result)
109///     }
110/// }
111/// ```
112pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113    /// The name of the tool. This name should be unique within a single
114    /// [`ToolSet`] or other registration scope that dispatches tools by name.
115    const NAME: &'static str;
116
117    /// The error type of the tool.
118    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119    /// The arguments type of the tool.
120    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121    /// The output type of the tool.
122    type Output: Serialize;
123
124    /// A method returning the name of the tool.
125    fn name(&self) -> String {
126        Self::NAME.to_string()
127    }
128
129    /// A method returning the tool definition. The user prompt can be used to
130    /// tailor the definition to the specific use case.
131    fn definition(
132        &self,
133        _prompt: String,
134    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136    /// The tool execution method.
137    /// Both the arguments and return value are a String since these values are meant to
138    /// be the output and input of LLM models (respectively)
139    fn call(
140        &self,
141        args: Self::Args,
142    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
146pub trait ToolEmbedding: Tool {
147    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
148
149    /// Type of the tool' context. This context will be saved and loaded from the
150    /// vector store when ragging the tool.
151    /// This context can be used to store the tool's static configuration and local
152    /// context.
153    type Context: for<'a> Deserialize<'a> + Serialize;
154
155    /// Type of the tool's state. This state will be passed to the tool when initializing it.
156    /// This state can be used to pass runtime arguments to the tool such as clients,
157    /// API keys and other configuration.
158    type State: WasmCompatSend;
159
160    /// A method returning the documents that will be used as embeddings for the tool.
161    /// This allows for a tool to be retrieved from multiple embedding "directions".
162    /// If the tool will not be RAGged, this method should return an empty vector.
163    fn embedding_docs(&self) -> Vec<String>;
164
165    /// A method returning the context of the tool.
166    fn context(&self) -> Self::Context;
167
168    /// A method to initialize the tool from the context, and a state.
169    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
170}
171
172/// Wrapper trait to allow for dynamic dispatch of simple tools
173pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
174    fn name(&self) -> String;
175
176    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
177
178    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
179}
180
181fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
182    match serde_json::to_value(output)? {
183        serde_json::Value::String(text) => Ok(text),
184        value => Ok(value.to_string()),
185    }
186}
187
188impl<T: Tool> ToolDyn for T {
189    fn name(&self) -> String {
190        self.name()
191    }
192
193    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
194        Box::pin(<Self as Tool>::definition(self, prompt))
195    }
196
197    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
198        Box::pin(async move {
199            match serde_json::from_str(&args) {
200                Ok(args) => <Self as Tool>::call(self, args)
201                    .await
202                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
203                    .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
204                Err(e) => Err(ToolError::JsonError(e)),
205            }
206        })
207    }
208}
209
210#[cfg(feature = "rmcp")]
211#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
212pub mod rmcp;
213
214/// Wrapper trait to allow for dynamic dispatch of raggable tools
215pub trait ToolEmbeddingDyn: ToolDyn {
216    fn context(&self) -> serde_json::Result<serde_json::Value>;
217
218    fn embedding_docs(&self) -> Vec<String>;
219}
220
221impl<T> ToolEmbeddingDyn for T
222where
223    T: ToolEmbedding + 'static,
224{
225    fn context(&self) -> serde_json::Result<serde_json::Value> {
226        serde_json::to_value(self.context())
227    }
228
229    fn embedding_docs(&self) -> Vec<String> {
230        self.embedding_docs()
231    }
232}
233
234#[derive(Clone)]
235pub(crate) enum ToolType {
236    Simple(Arc<dyn ToolDyn>),
237    Embedding(Arc<dyn ToolEmbeddingDyn>),
238}
239
240impl ToolType {
241    pub fn name(&self) -> String {
242        match self {
243            ToolType::Simple(tool) => tool.name(),
244            ToolType::Embedding(tool) => tool.name(),
245        }
246    }
247
248    pub async fn definition(&self, prompt: String) -> ToolDefinition {
249        match self {
250            ToolType::Simple(tool) => tool.definition(prompt).await,
251            ToolType::Embedding(tool) => tool.definition(prompt).await,
252        }
253    }
254
255    pub async fn call(&self, args: String) -> Result<String, ToolError> {
256        match self {
257            ToolType::Simple(tool) => tool.call(args).await,
258            ToolType::Embedding(tool) => tool.call(args).await,
259        }
260    }
261}
262
263#[derive(Debug, thiserror::Error)]
264pub enum ToolSetError {
265    /// Error returned by the tool
266    #[error("ToolCallError: {0}")]
267    ToolCallError(#[from] ToolError),
268
269    /// Could not find a tool
270    #[error("ToolNotFoundError: {0}")]
271    ToolNotFoundError(String),
272
273    // TODO: Revisit this
274    #[error("JsonError: {0}")]
275    JsonError(#[from] serde_json::Error),
276
277    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
278    #[error("Tool call interrupted")]
279    Interrupted,
280}
281
282/// A struct that holds a set of tools
283#[derive(Default)]
284pub struct ToolSet {
285    pub(crate) tools: HashMap<String, ToolType>,
286}
287
288impl ToolSet {
289    /// Create a new ToolSet from a list of tools
290    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
291        let mut toolset = Self::default();
292        tools.into_iter().for_each(|tool| {
293            toolset.add_tool(tool);
294        });
295        toolset
296    }
297
298    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
299        let mut toolset = Self::default();
300        tools.into_iter().for_each(|tool| {
301            toolset.add_tool_boxed(tool);
302        });
303        toolset
304    }
305
306    /// Create a toolset builder
307    pub fn builder() -> ToolSetBuilder {
308        ToolSetBuilder::default()
309    }
310
311    /// Check if the toolset contains a tool with the given name
312    pub fn contains(&self, toolname: &str) -> bool {
313        self.tools.contains_key(toolname)
314    }
315
316    /// Add a tool to the toolset
317    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
318        self.tools
319            .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
320    }
321
322    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
323    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
324        self.tools
325            .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
326    }
327
328    pub fn delete_tool(&mut self, tool_name: &str) {
329        let _ = self.tools.remove(tool_name);
330    }
331
332    /// Merge another toolset into this one
333    pub fn add_tools(&mut self, toolset: ToolSet) {
334        self.tools.extend(toolset.tools);
335    }
336
337    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
338        self.tools.get(toolname)
339    }
340
341    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
342        let mut defs = Vec::new();
343        for tool in self.tools.values() {
344            let def = tool.definition(String::new()).await;
345            defs.push(def);
346        }
347        Ok(defs)
348    }
349
350    /// Call a tool with the given name and arguments
351    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
352        if let Some(tool) = self.tools.get(toolname) {
353            tracing::debug!(target: "rig",
354                "Calling tool {toolname} with args:\n{}",
355                args
356            );
357            Ok(tool.call(args).await?)
358        } else {
359            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
360        }
361    }
362
363    /// Get the documents of all the tools in the toolset
364    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
365        let mut docs = Vec::new();
366        for tool in self.tools.values() {
367            match tool {
368                ToolType::Simple(tool) => {
369                    docs.push(completion::Document {
370                        id: tool.name(),
371                        text: format!(
372                            "\
373                            Tool: {}\n\
374                            Definition: \n\
375                            {}\
376                        ",
377                            tool.name(),
378                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
379                        ),
380                        additional_props: HashMap::new(),
381                    });
382                }
383                ToolType::Embedding(tool) => {
384                    docs.push(completion::Document {
385                        id: tool.name(),
386                        text: format!(
387                            "\
388                            Tool: {}\n\
389                            Definition: \n\
390                            {}\
391                        ",
392                            tool.name(),
393                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
394                        ),
395                        additional_props: HashMap::new(),
396                    });
397                }
398            }
399        }
400        Ok(docs)
401    }
402
403    /// Convert tools in self to objects of type ToolSchema.
404    /// This is necessary because when adding tools to the EmbeddingBuilder because all
405    /// documents added to the builder must all be of the same type.
406    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
407        self.tools
408            .values()
409            .filter_map(|tool_type| {
410                if let ToolType::Embedding(tool) = tool_type {
411                    Some(ToolSchema::try_from(&**tool))
412                } else {
413                    None
414                }
415            })
416            .collect::<Result<Vec<_>, _>>()
417    }
418}
419
420#[derive(Default)]
421pub struct ToolSetBuilder {
422    tools: Vec<ToolType>,
423}
424
425impl ToolSetBuilder {
426    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
427        self.tools.push(ToolType::Simple(Arc::new(tool)));
428        self
429    }
430
431    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
432        self.tools.push(ToolType::Embedding(Arc::new(tool)));
433        self
434    }
435
436    pub fn build(self) -> ToolSet {
437        ToolSet {
438            tools: self
439                .tools
440                .into_iter()
441                .map(|tool| (tool.name(), tool))
442                .collect(),
443        }
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use crate::message::{DocumentSourceKind, ToolResultContent};
450    use serde_json::json;
451
452    use super::*;
453
454    fn get_test_toolset() -> ToolSet {
455        let mut toolset = ToolSet::default();
456
457        #[derive(Deserialize)]
458        struct OperationArgs {
459            x: i32,
460            y: i32,
461        }
462
463        #[derive(Debug, thiserror::Error)]
464        #[error("Math error")]
465        struct MathError;
466
467        #[derive(Deserialize, Serialize)]
468        struct Adder;
469
470        impl Tool for Adder {
471            const NAME: &'static str = "add";
472            type Error = MathError;
473            type Args = OperationArgs;
474            type Output = i32;
475
476            async fn definition(&self, _prompt: String) -> ToolDefinition {
477                ToolDefinition {
478                    name: "add".to_string(),
479                    description: "Add x and y together".to_string(),
480                    parameters: json!({
481                        "type": "object",
482                        "properties": {
483                            "x": {
484                                "type": "number",
485                                "description": "The first number to add"
486                            },
487                            "y": {
488                                "type": "number",
489                                "description": "The second number to add"
490                            }
491                        },
492                        "required": ["x", "y"]
493                    }),
494                }
495            }
496
497            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
498                let result = args.x + args.y;
499                Ok(result)
500            }
501        }
502
503        #[derive(Deserialize, Serialize)]
504        struct Subtract;
505
506        impl Tool for Subtract {
507            const NAME: &'static str = "subtract";
508            type Error = MathError;
509            type Args = OperationArgs;
510            type Output = i32;
511
512            async fn definition(&self, _prompt: String) -> ToolDefinition {
513                serde_json::from_value(json!({
514                    "name": "subtract",
515                    "description": "Subtract y from x (i.e.: x - y)",
516                    "parameters": {
517                        "type": "object",
518                        "properties": {
519                            "x": {
520                                "type": "number",
521                                "description": "The number to subtract from"
522                            },
523                            "y": {
524                                "type": "number",
525                                "description": "The number to subtract"
526                            }
527                        },
528                        "required": ["x", "y"]
529                    }
530                }))
531                .expect("Tool Definition")
532            }
533
534            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
535                let result = args.x - args.y;
536                Ok(result)
537            }
538        }
539
540        toolset.add_tool(Adder);
541        toolset.add_tool(Subtract);
542        toolset
543    }
544
545    #[tokio::test]
546    async fn test_get_tool_definitions() {
547        let toolset = get_test_toolset();
548        let tools = toolset.get_tool_definitions().await.unwrap();
549        assert_eq!(tools.len(), 2);
550    }
551
552    #[test]
553    fn test_tool_deletion() {
554        let mut toolset = get_test_toolset();
555        assert_eq!(toolset.tools.len(), 2);
556        toolset.delete_tool("add");
557        assert!(!toolset.contains("add"));
558        assert_eq!(toolset.tools.len(), 1);
559    }
560
561    #[derive(Debug, thiserror::Error)]
562    #[error("Test tool error")]
563    struct TestToolError;
564
565    #[derive(Deserialize, Serialize)]
566    struct StringOutputTool;
567
568    impl Tool for StringOutputTool {
569        const NAME: &'static str = "string_output";
570        type Error = TestToolError;
571        type Args = serde_json::Value;
572        type Output = String;
573
574        async fn definition(&self, _prompt: String) -> ToolDefinition {
575            ToolDefinition {
576                name: Self::NAME.to_string(),
577                description: "Returns a multiline string".to_string(),
578                parameters: json!({
579                    "type": "object",
580                    "properties": {}
581                }),
582            }
583        }
584
585        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
586            Ok("Hello\nWorld".to_string())
587        }
588    }
589
590    #[tokio::test]
591    async fn string_tool_outputs_are_preserved_verbatim() {
592        let mut toolset = ToolSet::default();
593        toolset.add_tool(StringOutputTool);
594
595        let output = toolset
596            .call("string_output", "{}".to_string())
597            .await
598            .expect("tool should succeed");
599
600        assert_eq!(output, "Hello\nWorld");
601    }
602
603    #[derive(Deserialize, Serialize)]
604    struct ImageOutputTool;
605
606    impl Tool for ImageOutputTool {
607        const NAME: &'static str = "image_output";
608        type Error = TestToolError;
609        type Args = serde_json::Value;
610        type Output = String;
611
612        async fn definition(&self, _prompt: String) -> ToolDefinition {
613            ToolDefinition {
614                name: Self::NAME.to_string(),
615                description: "Returns image JSON".to_string(),
616                parameters: json!({
617                    "type": "object",
618                    "properties": {}
619                }),
620            }
621        }
622
623        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
624            Ok(json!({
625                "type": "image",
626                "data": "base64data==",
627                "mimeType": "image/png"
628            })
629            .to_string())
630        }
631    }
632
633    #[tokio::test]
634    async fn structured_string_tool_outputs_remain_parseable() {
635        let mut toolset = ToolSet::default();
636        toolset.add_tool(ImageOutputTool);
637
638        let output = toolset
639            .call("image_output", "{}".to_string())
640            .await
641            .expect("tool should succeed");
642        let content = ToolResultContent::from_tool_output(output);
643
644        assert_eq!(content.len(), 1);
645        match content.first() {
646            ToolResultContent::Image(image) => {
647                assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
648                assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
649            }
650            other => panic!("expected image tool result content, got {other:?}"),
651        }
652    }
653
654    #[derive(Deserialize, Serialize)]
655    struct ObjectOutputTool;
656
657    impl Tool for ObjectOutputTool {
658        const NAME: &'static str = "object_output";
659        type Error = TestToolError;
660        type Args = serde_json::Value;
661        type Output = serde_json::Value;
662
663        async fn definition(&self, _prompt: String) -> ToolDefinition {
664            ToolDefinition {
665                name: Self::NAME.to_string(),
666                description: "Returns an object".to_string(),
667                parameters: json!({
668                    "type": "object",
669                    "properties": {}
670                }),
671            }
672        }
673
674        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
675            Ok(json!({
676                "status": "ok",
677                "count": 42
678            }))
679        }
680    }
681
682    #[tokio::test]
683    async fn object_tool_outputs_still_serialize_as_json() {
684        let mut toolset = ToolSet::default();
685        toolset.add_tool(ObjectOutputTool);
686
687        let output = toolset
688            .call("object_output", "{}".to_string())
689            .await
690            .expect("tool should succeed");
691
692        assert!(output.starts_with('{'));
693        assert_eq!(
694            serde_json::from_str::<serde_json::Value>(&output).unwrap(),
695            json!({
696                "status": "ok",
697                "count": 42
698            })
699        );
700    }
701}