Skip to main content

rig/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20    completion::{self, ToolDefinition},
21    embeddings::{embed::EmbedError, tool::ToolSchema},
22    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27    #[cfg(not(target_family = "wasm"))]
28    /// Error returned by the tool
29    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31    #[cfg(target_family = "wasm")]
32    /// Error returned by the tool
33    ToolCallError(#[from] Box<dyn std::error::Error>),
34    /// Error caused by a de/serialization fail
35    JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            ToolError::ToolCallError(e) => {
42                let error_str = e.to_string();
43                // This is required due to being able to use agents as tools
44                // which means it is possible to get recursive tool call errors
45                if error_str.starts_with("ToolCallError: ") {
46                    write!(f, "{}", error_str)
47                } else {
48                    write!(f, "ToolCallError: {}", error_str)
49                }
50            }
51            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52        }
53    }
54}
55
56/// Trait that represents a simple LLM tool
57///
58/// # Example
59/// ```
60/// use rig::{
61///     completion::ToolDefinition,
62///     tool::{ToolSet, Tool},
63/// };
64///
65/// #[derive(serde::Deserialize)]
66/// struct AddArgs {
67///     x: i32,
68///     y: i32,
69/// }
70///
71/// #[derive(Debug, thiserror::Error)]
72/// #[error("Math error")]
73/// struct MathError;
74///
75/// #[derive(serde::Deserialize, serde::Serialize)]
76/// struct Adder;
77///
78/// impl Tool for Adder {
79///     const NAME: &'static str = "add";
80///
81///     type Error = MathError;
82///     type Args = AddArgs;
83///     type Output = i32;
84///
85///     async fn definition(&self, _prompt: String) -> ToolDefinition {
86///         ToolDefinition {
87///             name: "add".to_string(),
88///             description: "Add x and y together".to_string(),
89///             parameters: serde_json::json!({
90///                 "type": "object",
91///                 "properties": {
92///                     "x": {
93///                         "type": "number",
94///                         "description": "The first number to add"
95///                     },
96///                     "y": {
97///                         "type": "number",
98///                         "description": "The second number to add"
99///                     }
100///                 }
101///             })
102///         }
103///     }
104///
105///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
106///         let result = args.x + args.y;
107///         Ok(result)
108///     }
109/// }
110/// ```
111pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112    /// The name of the tool. This name should be unique.
113    const NAME: &'static str;
114
115    /// The error type of the tool.
116    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117    /// The arguments type of the tool.
118    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119    /// The output type of the tool.
120    type Output: Serialize;
121
122    /// A method returning the name of the tool.
123    fn name(&self) -> String {
124        Self::NAME.to_string()
125    }
126
127    /// A method returning the tool definition. The user prompt can be used to
128    /// tailor the definition to the specific use case.
129    fn definition(
130        &self,
131        _prompt: String,
132    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134    /// The tool execution method.
135    /// Both the arguments and return value are a String since these values are meant to
136    /// be the output and input of LLM models (respectively)
137    fn call(
138        &self,
139        args: Self::Args,
140    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
144pub trait ToolEmbedding: Tool {
145    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147    /// Type of the tool' context. This context will be saved and loaded from the
148    /// vector store when ragging the tool.
149    /// This context can be used to store the tool's static configuration and local
150    /// context.
151    type Context: for<'a> Deserialize<'a> + Serialize;
152
153    /// Type of the tool's state. This state will be passed to the tool when initializing it.
154    /// This state can be used to pass runtime arguments to the tool such as clients,
155    /// API keys and other configuration.
156    type State: WasmCompatSend;
157
158    /// A method returning the documents that will be used as embeddings for the tool.
159    /// This allows for a tool to be retrieved from multiple embedding "directions".
160    /// If the tool will not be RAGged, this method should return an empty vector.
161    fn embedding_docs(&self) -> Vec<String>;
162
163    /// A method returning the context of the tool.
164    fn context(&self) -> Self::Context;
165
166    /// A method to initialize the tool from the context, and a state.
167    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170/// Wrapper trait to allow for dynamic dispatch of simple tools
171pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172    fn name(&self) -> String;
173
174    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180    fn name(&self) -> String {
181        self.name()
182    }
183
184    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185        Box::pin(<Self as Tool>::definition(self, prompt))
186    }
187
188    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189        Box::pin(async move {
190            match serde_json::from_str(&args) {
191                Ok(args) => <Self as Tool>::call(self, args)
192                    .await
193                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194                    .and_then(|output| {
195                        serde_json::to_string(&output).map_err(ToolError::JsonError)
196                    }),
197                Err(e) => Err(ToolError::JsonError(e)),
198            }
199        })
200    }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp {
206    use crate::completion::ToolDefinition;
207    use crate::tool::ToolDyn;
208    use crate::tool::ToolError;
209    use crate::wasm_compat::WasmBoxedFuture;
210    use rmcp::model::RawContent;
211    use std::borrow::Cow;
212
213    #[derive(Clone)]
214    pub struct McpTool {
215        definition: rmcp::model::Tool,
216        client: rmcp::service::ServerSink,
217    }
218
219    impl McpTool {
220        pub fn from_mcp_server(
221            definition: rmcp::model::Tool,
222            client: rmcp::service::ServerSink,
223        ) -> Self {
224            Self { definition, client }
225        }
226    }
227
228    impl From<&rmcp::model::Tool> for ToolDefinition {
229        fn from(val: &rmcp::model::Tool) -> Self {
230            Self {
231                name: val.name.to_string(),
232                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
233                parameters: val.schema_as_json_value(),
234            }
235        }
236    }
237
238    impl From<rmcp::model::Tool> for ToolDefinition {
239        fn from(val: rmcp::model::Tool) -> Self {
240            Self {
241                name: val.name.to_string(),
242                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
243                parameters: val.schema_as_json_value(),
244            }
245        }
246    }
247
248    #[derive(Debug, thiserror::Error)]
249    #[error("MCP tool error: {0}")]
250    pub struct McpToolError(String);
251
252    impl From<McpToolError> for ToolError {
253        fn from(e: McpToolError) -> Self {
254            ToolError::ToolCallError(Box::new(e))
255        }
256    }
257
258    impl ToolDyn for McpTool {
259        fn name(&self) -> String {
260            self.definition.name.to_string()
261        }
262
263        fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
264            Box::pin(async move {
265                ToolDefinition {
266                    name: self.definition.name.to_string(),
267                    description: self
268                        .definition
269                        .description
270                        .clone()
271                        .unwrap_or(Cow::from(""))
272                        .to_string(),
273                    parameters: serde_json::to_value(&self.definition.input_schema)
274                        .unwrap_or_default(),
275                }
276            })
277        }
278
279        fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
280            let name = self.definition.name.clone();
281            let arguments = serde_json::from_str(&args).unwrap_or_default();
282
283            Box::pin(async move {
284                let result = self
285                    .client
286                    .call_tool(rmcp::model::CallToolRequestParams {
287                        name,
288                        arguments,
289                        meta: None,
290                        task: None,
291                    })
292                    .await
293                    .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
294
295                if let Some(true) = result.is_error {
296                    let error_msg = result
297                        .content
298                        .into_iter()
299                        .map(|x| x.raw.as_text().map(|y| y.to_owned()))
300                        .map(|x| x.map(|x| x.clone().text))
301                        .collect::<Option<Vec<String>>>();
302
303                    let error_message = error_msg.map(|x| x.join("\n"));
304                    if let Some(error_message) = error_message {
305                        return Err(McpToolError(error_message).into());
306                    } else {
307                        return Err(McpToolError("No message returned".to_string()).into());
308                    }
309                };
310
311                Ok(result
312                    .content
313                    .into_iter()
314                    .map(|c| match c.raw {
315                        rmcp::model::RawContent::Text(raw) => raw.text,
316                        rmcp::model::RawContent::Image(raw) => {
317                            format!("data:{};base64,{}", raw.mime_type, raw.data)
318                        }
319                        rmcp::model::RawContent::Resource(raw) => match raw.resource {
320                            rmcp::model::ResourceContents::TextResourceContents {
321                                uri,
322                                mime_type,
323                                text,
324                                ..
325                            } => {
326                                format!(
327                                    "{mime_type}{uri}:{text}",
328                                    mime_type = mime_type
329                                        .map(|m| format!("data:{m};"))
330                                        .unwrap_or_default(),
331                                )
332                            }
333                            rmcp::model::ResourceContents::BlobResourceContents {
334                                uri,
335                                mime_type,
336                                blob,
337                                ..
338                            } => format!(
339                                "{mime_type}{uri}:{blob}",
340                                mime_type = mime_type
341                                    .map(|m| format!("data:{m};"))
342                                    .unwrap_or_default(),
343                            ),
344                        },
345                        RawContent::Audio(_) => {
346                            panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
347                        }
348                        thing => {
349                            panic!("Unsupported type found: {thing:?}")
350                        }
351                    })
352                    .collect::<String>())
353            })
354        }
355    }
356}
357
358/// Wrapper trait to allow for dynamic dispatch of raggable tools
359pub trait ToolEmbeddingDyn: ToolDyn {
360    fn context(&self) -> serde_json::Result<serde_json::Value>;
361
362    fn embedding_docs(&self) -> Vec<String>;
363}
364
365impl<T> ToolEmbeddingDyn for T
366where
367    T: ToolEmbedding + 'static,
368{
369    fn context(&self) -> serde_json::Result<serde_json::Value> {
370        serde_json::to_value(self.context())
371    }
372
373    fn embedding_docs(&self) -> Vec<String> {
374        self.embedding_docs()
375    }
376}
377
378pub(crate) enum ToolType {
379    Simple(Box<dyn ToolDyn>),
380    Embedding(Box<dyn ToolEmbeddingDyn>),
381}
382
383impl ToolType {
384    pub fn name(&self) -> String {
385        match self {
386            ToolType::Simple(tool) => tool.name(),
387            ToolType::Embedding(tool) => tool.name(),
388        }
389    }
390
391    pub async fn definition(&self, prompt: String) -> ToolDefinition {
392        match self {
393            ToolType::Simple(tool) => tool.definition(prompt).await,
394            ToolType::Embedding(tool) => tool.definition(prompt).await,
395        }
396    }
397
398    pub async fn call(&self, args: String) -> Result<String, ToolError> {
399        match self {
400            ToolType::Simple(tool) => tool.call(args).await,
401            ToolType::Embedding(tool) => tool.call(args).await,
402        }
403    }
404}
405
406#[derive(Debug, thiserror::Error)]
407pub enum ToolSetError {
408    /// Error returned by the tool
409    #[error("ToolCallError: {0}")]
410    ToolCallError(#[from] ToolError),
411
412    /// Could not find a tool
413    #[error("ToolNotFoundError: {0}")]
414    ToolNotFoundError(String),
415
416    // TODO: Revisit this
417    #[error("JsonError: {0}")]
418    JsonError(#[from] serde_json::Error),
419
420    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
421    #[error("Tool call interrupted")]
422    Interrupted,
423}
424
425/// A struct that holds a set of tools
426#[derive(Default)]
427pub struct ToolSet {
428    pub(crate) tools: HashMap<String, ToolType>,
429}
430
431impl ToolSet {
432    /// Create a new ToolSet from a list of tools
433    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
434        let mut toolset = Self::default();
435        tools.into_iter().for_each(|tool| {
436            toolset.add_tool(tool);
437        });
438        toolset
439    }
440
441    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
442        let mut toolset = Self::default();
443        tools.into_iter().for_each(|tool| {
444            toolset.add_tool_boxed(tool);
445        });
446        toolset
447    }
448
449    /// Create a toolset builder
450    pub fn builder() -> ToolSetBuilder {
451        ToolSetBuilder::default()
452    }
453
454    /// Check if the toolset contains a tool with the given name
455    pub fn contains(&self, toolname: &str) -> bool {
456        self.tools.contains_key(toolname)
457    }
458
459    /// Add a tool to the toolset
460    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
461        self.tools
462            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
463    }
464
465    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
466    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
467        self.tools.insert(tool.name(), ToolType::Simple(tool));
468    }
469
470    pub fn delete_tool(&mut self, tool_name: &str) {
471        let _ = self.tools.remove(tool_name);
472    }
473
474    /// Merge another toolset into this one
475    pub fn add_tools(&mut self, toolset: ToolSet) {
476        self.tools.extend(toolset.tools);
477    }
478
479    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
480        self.tools.get(toolname)
481    }
482
483    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
484        let mut defs = Vec::new();
485        for tool in self.tools.values() {
486            let def = tool.definition(String::new()).await;
487            defs.push(def);
488        }
489        Ok(defs)
490    }
491
492    /// Call a tool with the given name and arguments
493    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
494        if let Some(tool) = self.tools.get(toolname) {
495            tracing::debug!(target: "rig",
496                "Calling tool {toolname} with args:\n{}",
497                serde_json::to_string_pretty(&args).unwrap()
498            );
499            Ok(tool.call(args).await?)
500        } else {
501            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
502        }
503    }
504
505    /// Get the documents of all the tools in the toolset
506    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
507        let mut docs = Vec::new();
508        for tool in self.tools.values() {
509            match tool {
510                ToolType::Simple(tool) => {
511                    docs.push(completion::Document {
512                        id: tool.name(),
513                        text: format!(
514                            "\
515                            Tool: {}\n\
516                            Definition: \n\
517                            {}\
518                        ",
519                            tool.name(),
520                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
521                        ),
522                        additional_props: HashMap::new(),
523                    });
524                }
525                ToolType::Embedding(tool) => {
526                    docs.push(completion::Document {
527                        id: tool.name(),
528                        text: format!(
529                            "\
530                            Tool: {}\n\
531                            Definition: \n\
532                            {}\
533                        ",
534                            tool.name(),
535                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
536                        ),
537                        additional_props: HashMap::new(),
538                    });
539                }
540            }
541        }
542        Ok(docs)
543    }
544
545    /// Convert tools in self to objects of type ToolSchema.
546    /// This is necessary because when adding tools to the EmbeddingBuilder because all
547    /// documents added to the builder must all be of the same type.
548    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
549        self.tools
550            .values()
551            .filter_map(|tool_type| {
552                if let ToolType::Embedding(tool) = tool_type {
553                    Some(ToolSchema::try_from(&**tool))
554                } else {
555                    None
556                }
557            })
558            .collect::<Result<Vec<_>, _>>()
559    }
560}
561
562#[derive(Default)]
563pub struct ToolSetBuilder {
564    tools: Vec<ToolType>,
565}
566
567impl ToolSetBuilder {
568    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
569        self.tools.push(ToolType::Simple(Box::new(tool)));
570        self
571    }
572
573    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
574        self.tools.push(ToolType::Embedding(Box::new(tool)));
575        self
576    }
577
578    pub fn build(self) -> ToolSet {
579        ToolSet {
580            tools: self
581                .tools
582                .into_iter()
583                .map(|tool| (tool.name(), tool))
584                .collect(),
585        }
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use serde_json::json;
592
593    use super::*;
594
595    fn get_test_toolset() -> ToolSet {
596        let mut toolset = ToolSet::default();
597
598        #[derive(Deserialize)]
599        struct OperationArgs {
600            x: i32,
601            y: i32,
602        }
603
604        #[derive(Debug, thiserror::Error)]
605        #[error("Math error")]
606        struct MathError;
607
608        #[derive(Deserialize, Serialize)]
609        struct Adder;
610
611        impl Tool for Adder {
612            const NAME: &'static str = "add";
613            type Error = MathError;
614            type Args = OperationArgs;
615            type Output = i32;
616
617            async fn definition(&self, _prompt: String) -> ToolDefinition {
618                ToolDefinition {
619                    name: "add".to_string(),
620                    description: "Add x and y together".to_string(),
621                    parameters: json!({
622                        "type": "object",
623                        "properties": {
624                            "x": {
625                                "type": "number",
626                                "description": "The first number to add"
627                            },
628                            "y": {
629                                "type": "number",
630                                "description": "The second number to add"
631                            }
632                        },
633                        "required": ["x", "y"]
634                    }),
635                }
636            }
637
638            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
639                let result = args.x + args.y;
640                Ok(result)
641            }
642        }
643
644        #[derive(Deserialize, Serialize)]
645        struct Subtract;
646
647        impl Tool for Subtract {
648            const NAME: &'static str = "subtract";
649            type Error = MathError;
650            type Args = OperationArgs;
651            type Output = i32;
652
653            async fn definition(&self, _prompt: String) -> ToolDefinition {
654                serde_json::from_value(json!({
655                    "name": "subtract",
656                    "description": "Subtract y from x (i.e.: x - y)",
657                    "parameters": {
658                        "type": "object",
659                        "properties": {
660                            "x": {
661                                "type": "number",
662                                "description": "The number to subtract from"
663                            },
664                            "y": {
665                                "type": "number",
666                                "description": "The number to subtract"
667                            }
668                        },
669                        "required": ["x", "y"]
670                    }
671                }))
672                .expect("Tool Definition")
673            }
674
675            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
676                let result = args.x - args.y;
677                Ok(result)
678            }
679        }
680
681        toolset.add_tool(Adder);
682        toolset.add_tool(Subtract);
683        toolset
684    }
685
686    #[tokio::test]
687    async fn test_get_tool_definitions() {
688        let toolset = get_test_toolset();
689        let tools = toolset.get_tool_definitions().await.unwrap();
690        assert_eq!(tools.len(), 2);
691    }
692
693    #[test]
694    fn test_tool_deletion() {
695        let mut toolset = get_test_toolset();
696        assert_eq!(toolset.tools.len(), 2);
697        toolset.delete_tool("add");
698        assert!(!toolset.contains("add"));
699        assert_eq!(toolset.tools.len(), 1);
700    }
701}