Skip to main content

rig/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    completion::{self, ToolDefinition},
22    embeddings::{embed::EmbedError, tool::ToolSchema},
23    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28    #[cfg(not(target_family = "wasm"))]
29    /// Error returned by the tool
30    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32    #[cfg(target_family = "wasm")]
33    /// Error returned by the tool
34    ToolCallError(#[from] Box<dyn std::error::Error>),
35    /// Error caused by a de/serialization fail
36    JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            ToolError::ToolCallError(e) => {
43                let error_str = e.to_string();
44                // This is required due to being able to use agents as tools
45                // which means it is possible to get recursive tool call errors
46                if error_str.starts_with("ToolCallError: ") {
47                    write!(f, "{}", error_str)
48                } else {
49                    write!(f, "ToolCallError: {}", error_str)
50                }
51            }
52            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53        }
54    }
55}
56
57/// Trait that represents a simple LLM tool
58///
59/// # Example
60/// ```
61/// use rig::{
62///     completion::ToolDefinition,
63///     tool::{ToolSet, Tool},
64/// };
65///
66/// #[derive(serde::Deserialize)]
67/// struct AddArgs {
68///     x: i32,
69///     y: i32,
70/// }
71///
72/// #[derive(Debug, thiserror::Error)]
73/// #[error("Math error")]
74/// struct MathError;
75///
76/// #[derive(serde::Deserialize, serde::Serialize)]
77/// struct Adder;
78///
79/// impl Tool for Adder {
80///     const NAME: &'static str = "add";
81///
82///     type Error = MathError;
83///     type Args = AddArgs;
84///     type Output = i32;
85///
86///     async fn definition(&self, _prompt: String) -> ToolDefinition {
87///         ToolDefinition {
88///             name: "add".to_string(),
89///             description: "Add x and y together".to_string(),
90///             parameters: serde_json::json!({
91///                 "type": "object",
92///                 "properties": {
93///                     "x": {
94///                         "type": "number",
95///                         "description": "The first number to add"
96///                     },
97///                     "y": {
98///                         "type": "number",
99///                         "description": "The second number to add"
100///                     }
101///                 }
102///             })
103///         }
104///     }
105///
106///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
107///         let result = args.x + args.y;
108///         Ok(result)
109///     }
110/// }
111/// ```
112pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113    /// The name of the tool. This name should be unique.
114    const NAME: &'static str;
115
116    /// The error type of the tool.
117    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
118    /// The arguments type of the tool.
119    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
120    /// The output type of the tool.
121    type Output: Serialize;
122
123    /// A method returning the name of the tool.
124    fn name(&self) -> String {
125        Self::NAME.to_string()
126    }
127
128    /// A method returning the tool definition. The user prompt can be used to
129    /// tailor the definition to the specific use case.
130    fn definition(
131        &self,
132        _prompt: String,
133    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
134
135    /// The tool execution method.
136    /// Both the arguments and return value are a String since these values are meant to
137    /// be the output and input of LLM models (respectively)
138    fn call(
139        &self,
140        args: Self::Args,
141    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
142}
143
144/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
145pub trait ToolEmbedding: Tool {
146    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
147
148    /// Type of the tool' context. This context will be saved and loaded from the
149    /// vector store when ragging the tool.
150    /// This context can be used to store the tool's static configuration and local
151    /// context.
152    type Context: for<'a> Deserialize<'a> + Serialize;
153
154    /// Type of the tool's state. This state will be passed to the tool when initializing it.
155    /// This state can be used to pass runtime arguments to the tool such as clients,
156    /// API keys and other configuration.
157    type State: WasmCompatSend;
158
159    /// A method returning the documents that will be used as embeddings for the tool.
160    /// This allows for a tool to be retrieved from multiple embedding "directions".
161    /// If the tool will not be RAGged, this method should return an empty vector.
162    fn embedding_docs(&self) -> Vec<String>;
163
164    /// A method returning the context of the tool.
165    fn context(&self) -> Self::Context;
166
167    /// A method to initialize the tool from the context, and a state.
168    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
169}
170
171/// Wrapper trait to allow for dynamic dispatch of simple tools
172pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
173    fn name(&self) -> String;
174
175    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
176
177    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
178}
179
180fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
181    match serde_json::to_value(output)? {
182        serde_json::Value::String(text) => Ok(text),
183        value => Ok(value.to_string()),
184    }
185}
186
187impl<T: Tool> ToolDyn for T {
188    fn name(&self) -> String {
189        self.name()
190    }
191
192    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
193        Box::pin(<Self as Tool>::definition(self, prompt))
194    }
195
196    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
197        Box::pin(async move {
198            match serde_json::from_str(&args) {
199                Ok(args) => <Self as Tool>::call(self, args)
200                    .await
201                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
202                    .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
203                Err(e) => Err(ToolError::JsonError(e)),
204            }
205        })
206    }
207}
208
209#[cfg(feature = "rmcp")]
210#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
211pub mod rmcp;
212
213/// Wrapper trait to allow for dynamic dispatch of raggable tools
214pub trait ToolEmbeddingDyn: ToolDyn {
215    fn context(&self) -> serde_json::Result<serde_json::Value>;
216
217    fn embedding_docs(&self) -> Vec<String>;
218}
219
220impl<T> ToolEmbeddingDyn for T
221where
222    T: ToolEmbedding + 'static,
223{
224    fn context(&self) -> serde_json::Result<serde_json::Value> {
225        serde_json::to_value(self.context())
226    }
227
228    fn embedding_docs(&self) -> Vec<String> {
229        self.embedding_docs()
230    }
231}
232
233#[derive(Clone)]
234pub(crate) enum ToolType {
235    Simple(Arc<dyn ToolDyn>),
236    Embedding(Arc<dyn ToolEmbeddingDyn>),
237}
238
239impl ToolType {
240    pub fn name(&self) -> String {
241        match self {
242            ToolType::Simple(tool) => tool.name(),
243            ToolType::Embedding(tool) => tool.name(),
244        }
245    }
246
247    pub async fn definition(&self, prompt: String) -> ToolDefinition {
248        match self {
249            ToolType::Simple(tool) => tool.definition(prompt).await,
250            ToolType::Embedding(tool) => tool.definition(prompt).await,
251        }
252    }
253
254    pub async fn call(&self, args: String) -> Result<String, ToolError> {
255        match self {
256            ToolType::Simple(tool) => tool.call(args).await,
257            ToolType::Embedding(tool) => tool.call(args).await,
258        }
259    }
260}
261
262#[derive(Debug, thiserror::Error)]
263pub enum ToolSetError {
264    /// Error returned by the tool
265    #[error("ToolCallError: {0}")]
266    ToolCallError(#[from] ToolError),
267
268    /// Could not find a tool
269    #[error("ToolNotFoundError: {0}")]
270    ToolNotFoundError(String),
271
272    // TODO: Revisit this
273    #[error("JsonError: {0}")]
274    JsonError(#[from] serde_json::Error),
275
276    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
277    #[error("Tool call interrupted")]
278    Interrupted,
279}
280
281/// A struct that holds a set of tools
282#[derive(Default)]
283pub struct ToolSet {
284    pub(crate) tools: HashMap<String, ToolType>,
285}
286
287impl ToolSet {
288    /// Create a new ToolSet from a list of tools
289    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
290        let mut toolset = Self::default();
291        tools.into_iter().for_each(|tool| {
292            toolset.add_tool(tool);
293        });
294        toolset
295    }
296
297    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
298        let mut toolset = Self::default();
299        tools.into_iter().for_each(|tool| {
300            toolset.add_tool_boxed(tool);
301        });
302        toolset
303    }
304
305    /// Create a toolset builder
306    pub fn builder() -> ToolSetBuilder {
307        ToolSetBuilder::default()
308    }
309
310    /// Check if the toolset contains a tool with the given name
311    pub fn contains(&self, toolname: &str) -> bool {
312        self.tools.contains_key(toolname)
313    }
314
315    /// Add a tool to the toolset
316    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
317        self.tools
318            .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
319    }
320
321    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
322    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
323        self.tools
324            .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
325    }
326
327    pub fn delete_tool(&mut self, tool_name: &str) {
328        let _ = self.tools.remove(tool_name);
329    }
330
331    /// Merge another toolset into this one
332    pub fn add_tools(&mut self, toolset: ToolSet) {
333        self.tools.extend(toolset.tools);
334    }
335
336    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
337        self.tools.get(toolname)
338    }
339
340    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
341        let mut defs = Vec::new();
342        for tool in self.tools.values() {
343            let def = tool.definition(String::new()).await;
344            defs.push(def);
345        }
346        Ok(defs)
347    }
348
349    /// Call a tool with the given name and arguments
350    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
351        if let Some(tool) = self.tools.get(toolname) {
352            tracing::debug!(target: "rig",
353                "Calling tool {toolname} with args:\n{}",
354                serde_json::to_string_pretty(&args).unwrap()
355            );
356            Ok(tool.call(args).await?)
357        } else {
358            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
359        }
360    }
361
362    /// Get the documents of all the tools in the toolset
363    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
364        let mut docs = Vec::new();
365        for tool in self.tools.values() {
366            match tool {
367                ToolType::Simple(tool) => {
368                    docs.push(completion::Document {
369                        id: tool.name(),
370                        text: format!(
371                            "\
372                            Tool: {}\n\
373                            Definition: \n\
374                            {}\
375                        ",
376                            tool.name(),
377                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
378                        ),
379                        additional_props: HashMap::new(),
380                    });
381                }
382                ToolType::Embedding(tool) => {
383                    docs.push(completion::Document {
384                        id: tool.name(),
385                        text: format!(
386                            "\
387                            Tool: {}\n\
388                            Definition: \n\
389                            {}\
390                        ",
391                            tool.name(),
392                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
393                        ),
394                        additional_props: HashMap::new(),
395                    });
396                }
397            }
398        }
399        Ok(docs)
400    }
401
402    /// Convert tools in self to objects of type ToolSchema.
403    /// This is necessary because when adding tools to the EmbeddingBuilder because all
404    /// documents added to the builder must all be of the same type.
405    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
406        self.tools
407            .values()
408            .filter_map(|tool_type| {
409                if let ToolType::Embedding(tool) = tool_type {
410                    Some(ToolSchema::try_from(&**tool))
411                } else {
412                    None
413                }
414            })
415            .collect::<Result<Vec<_>, _>>()
416    }
417}
418
419#[derive(Default)]
420pub struct ToolSetBuilder {
421    tools: Vec<ToolType>,
422}
423
424impl ToolSetBuilder {
425    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
426        self.tools.push(ToolType::Simple(Arc::new(tool)));
427        self
428    }
429
430    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
431        self.tools.push(ToolType::Embedding(Arc::new(tool)));
432        self
433    }
434
435    pub fn build(self) -> ToolSet {
436        ToolSet {
437            tools: self
438                .tools
439                .into_iter()
440                .map(|tool| (tool.name(), tool))
441                .collect(),
442        }
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use crate::message::{DocumentSourceKind, ToolResultContent};
449    use serde_json::json;
450
451    use super::*;
452
453    fn get_test_toolset() -> ToolSet {
454        let mut toolset = ToolSet::default();
455
456        #[derive(Deserialize)]
457        struct OperationArgs {
458            x: i32,
459            y: i32,
460        }
461
462        #[derive(Debug, thiserror::Error)]
463        #[error("Math error")]
464        struct MathError;
465
466        #[derive(Deserialize, Serialize)]
467        struct Adder;
468
469        impl Tool for Adder {
470            const NAME: &'static str = "add";
471            type Error = MathError;
472            type Args = OperationArgs;
473            type Output = i32;
474
475            async fn definition(&self, _prompt: String) -> ToolDefinition {
476                ToolDefinition {
477                    name: "add".to_string(),
478                    description: "Add x and y together".to_string(),
479                    parameters: json!({
480                        "type": "object",
481                        "properties": {
482                            "x": {
483                                "type": "number",
484                                "description": "The first number to add"
485                            },
486                            "y": {
487                                "type": "number",
488                                "description": "The second number to add"
489                            }
490                        },
491                        "required": ["x", "y"]
492                    }),
493                }
494            }
495
496            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
497                let result = args.x + args.y;
498                Ok(result)
499            }
500        }
501
502        #[derive(Deserialize, Serialize)]
503        struct Subtract;
504
505        impl Tool for Subtract {
506            const NAME: &'static str = "subtract";
507            type Error = MathError;
508            type Args = OperationArgs;
509            type Output = i32;
510
511            async fn definition(&self, _prompt: String) -> ToolDefinition {
512                serde_json::from_value(json!({
513                    "name": "subtract",
514                    "description": "Subtract y from x (i.e.: x - y)",
515                    "parameters": {
516                        "type": "object",
517                        "properties": {
518                            "x": {
519                                "type": "number",
520                                "description": "The number to subtract from"
521                            },
522                            "y": {
523                                "type": "number",
524                                "description": "The number to subtract"
525                            }
526                        },
527                        "required": ["x", "y"]
528                    }
529                }))
530                .expect("Tool Definition")
531            }
532
533            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
534                let result = args.x - args.y;
535                Ok(result)
536            }
537        }
538
539        toolset.add_tool(Adder);
540        toolset.add_tool(Subtract);
541        toolset
542    }
543
544    #[tokio::test]
545    async fn test_get_tool_definitions() {
546        let toolset = get_test_toolset();
547        let tools = toolset.get_tool_definitions().await.unwrap();
548        assert_eq!(tools.len(), 2);
549    }
550
551    #[test]
552    fn test_tool_deletion() {
553        let mut toolset = get_test_toolset();
554        assert_eq!(toolset.tools.len(), 2);
555        toolset.delete_tool("add");
556        assert!(!toolset.contains("add"));
557        assert_eq!(toolset.tools.len(), 1);
558    }
559
560    #[derive(Debug, thiserror::Error)]
561    #[error("Test tool error")]
562    struct TestToolError;
563
564    #[derive(Deserialize, Serialize)]
565    struct StringOutputTool;
566
567    impl Tool for StringOutputTool {
568        const NAME: &'static str = "string_output";
569        type Error = TestToolError;
570        type Args = serde_json::Value;
571        type Output = String;
572
573        async fn definition(&self, _prompt: String) -> ToolDefinition {
574            ToolDefinition {
575                name: Self::NAME.to_string(),
576                description: "Returns a multiline string".to_string(),
577                parameters: json!({
578                    "type": "object",
579                    "properties": {}
580                }),
581            }
582        }
583
584        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
585            Ok("Hello\nWorld".to_string())
586        }
587    }
588
589    #[tokio::test]
590    async fn string_tool_outputs_are_preserved_verbatim() {
591        let mut toolset = ToolSet::default();
592        toolset.add_tool(StringOutputTool);
593
594        let output = toolset
595            .call("string_output", "{}".to_string())
596            .await
597            .expect("tool should succeed");
598
599        assert_eq!(output, "Hello\nWorld");
600    }
601
602    #[derive(Deserialize, Serialize)]
603    struct ImageOutputTool;
604
605    impl Tool for ImageOutputTool {
606        const NAME: &'static str = "image_output";
607        type Error = TestToolError;
608        type Args = serde_json::Value;
609        type Output = String;
610
611        async fn definition(&self, _prompt: String) -> ToolDefinition {
612            ToolDefinition {
613                name: Self::NAME.to_string(),
614                description: "Returns image JSON".to_string(),
615                parameters: json!({
616                    "type": "object",
617                    "properties": {}
618                }),
619            }
620        }
621
622        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
623            Ok(json!({
624                "type": "image",
625                "data": "base64data==",
626                "mimeType": "image/png"
627            })
628            .to_string())
629        }
630    }
631
632    #[tokio::test]
633    async fn structured_string_tool_outputs_remain_parseable() {
634        let mut toolset = ToolSet::default();
635        toolset.add_tool(ImageOutputTool);
636
637        let output = toolset
638            .call("image_output", "{}".to_string())
639            .await
640            .expect("tool should succeed");
641        let content = ToolResultContent::from_tool_output(output);
642
643        assert_eq!(content.len(), 1);
644        match content.first() {
645            ToolResultContent::Image(image) => {
646                assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
647                assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
648            }
649            other => panic!("expected image tool result content, got {other:?}"),
650        }
651    }
652
653    #[derive(Deserialize, Serialize)]
654    struct ObjectOutputTool;
655
656    impl Tool for ObjectOutputTool {
657        const NAME: &'static str = "object_output";
658        type Error = TestToolError;
659        type Args = serde_json::Value;
660        type Output = serde_json::Value;
661
662        async fn definition(&self, _prompt: String) -> ToolDefinition {
663            ToolDefinition {
664                name: Self::NAME.to_string(),
665                description: "Returns an object".to_string(),
666                parameters: json!({
667                    "type": "object",
668                    "properties": {}
669                }),
670            }
671        }
672
673        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
674            Ok(json!({
675                "status": "ok",
676                "count": 42
677            }))
678        }
679    }
680
681    #[tokio::test]
682    async fn object_tool_outputs_still_serialize_as_json() {
683        let mut toolset = ToolSet::default();
684        toolset.add_tool(ObjectOutputTool);
685
686        let output = toolset
687            .call("object_output", "{}".to_string())
688            .await
689            .expect("tool should succeed");
690
691        assert!(output.starts_with('{'));
692        assert_eq!(
693            serde_json::from_str::<serde_json::Value>(&output).unwrap(),
694            json!({
695                "status": "ok",
696                "count": 42
697            })
698        );
699    }
700}