rig/
tool.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    completion::{self, ToolDefinition},
19    embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24    /// Error returned by the tool
25    #[error("ToolCallError: {0}")]
26    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28    #[error("JsonError: {0}")]
29    JsonError(#[from] serde_json::Error),
30}
31
32/// Trait that represents a simple LLM tool
33///
34/// # Example
35/// ```
36/// use rig::{
37///     completion::ToolDefinition,
38///     tool::{ToolSet, Tool},
39/// };
40///
41/// #[derive(serde::Deserialize)]
42/// struct AddArgs {
43///     x: i32,
44///     y: i32,
45/// }
46///
47/// #[derive(Debug, thiserror::Error)]
48/// #[error("Math error")]
49/// struct MathError;
50///
51/// #[derive(serde::Deserialize, serde::Serialize)]
52/// struct Adder;
53///
54/// impl Tool for Adder {
55///     const NAME: &'static str = "add";
56///
57///     type Error = MathError;
58///     type Args = AddArgs;
59///     type Output = i32;
60///
61///     async fn definition(&self, _prompt: String) -> ToolDefinition {
62///         ToolDefinition {
63///             name: "add".to_string(),
64///             description: "Add x and y together".to_string(),
65///             parameters: serde_json::json!({
66///                 "type": "object",
67///                 "properties": {
68///                     "x": {
69///                         "type": "number",
70///                         "description": "The first number to add"
71///                     },
72///                     "y": {
73///                         "type": "number",
74///                         "description": "The second number to add"
75///                     }
76///                 }
77///             })
78///         }
79///     }
80///
81///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
82///         let result = args.x + args.y;
83///         Ok(result)
84///     }
85/// }
86/// ```
87pub trait Tool: Sized + Send + Sync {
88    /// The name of the tool. This name should be unique.
89    const NAME: &'static str;
90
91    /// The error type of the tool.
92    type Error: std::error::Error + Send + Sync + 'static;
93    /// The arguments type of the tool.
94    type Args: for<'a> Deserialize<'a> + Send + Sync;
95    /// The output type of the tool.
96    type Output: Serialize;
97
98    /// A method returning the name of the tool.
99    fn name(&self) -> String {
100        Self::NAME.to_string()
101    }
102
103    /// A method returning the tool definition. The user prompt can be used to
104    /// tailor the definition to the specific use case.
105    fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107    /// The tool execution method.
108    /// Both the arguments and return value are a String since these values are meant to
109    /// be the output and input of LLM models (respectively)
110    fn call(
111        &self,
112        args: Self::Args,
113    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
114}
115
116/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
117pub trait ToolEmbedding: Tool {
118    type InitError: std::error::Error + Send + Sync + 'static;
119
120    /// Type of the tool' context. This context will be saved and loaded from the
121    /// vector store when ragging the tool.
122    /// This context can be used to store the tool's static configuration and local
123    /// context.
124    type Context: for<'a> Deserialize<'a> + Serialize;
125
126    /// Type of the tool's state. This state will be passed to the tool when initializing it.
127    /// This state can be used to pass runtime arguments to the tool such as clients,
128    /// API keys and other configuration.
129    type State: Send;
130
131    /// A method returning the documents that will be used as embeddings for the tool.
132    /// This allows for a tool to be retrieved from multiple embedding "directions".
133    /// If the tool will not be RAGged, this method should return an empty vector.
134    fn embedding_docs(&self) -> Vec<String>;
135
136    /// A method returning the context of the tool.
137    fn context(&self) -> Self::Context;
138
139    /// A method to initialize the tool from the context, and a state.
140    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143/// Wrapper trait to allow for dynamic dispatch of simple tools
144pub trait ToolDyn: Send + Sync {
145    fn name(&self) -> String;
146
147    fn definition(
148        &self,
149        prompt: String,
150    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152    fn call(
153        &self,
154        args: String,
155    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159    fn name(&self) -> String {
160        self.name()
161    }
162
163    fn definition(
164        &self,
165        prompt: String,
166    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167        Box::pin(<Self as Tool>::definition(self, prompt))
168    }
169
170    fn call(
171        &self,
172        args: String,
173    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
174        Box::pin(async move {
175            match serde_json::from_str(&args) {
176                Ok(args) => <Self as Tool>::call(self, args)
177                    .await
178                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179                    .and_then(|output| {
180                        serde_json::to_string(&output).map_err(ToolError::JsonError)
181                    }),
182                Err(e) => Err(ToolError::JsonError(e)),
183            }
184        })
185    }
186}
187
188#[cfg(feature = "rmcp")]
189#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
190pub mod rmcp {
191    use crate::completion::ToolDefinition;
192    use crate::tool::ToolDyn;
193    use crate::tool::ToolError;
194    use rmcp::model::RawContent;
195    use std::borrow::Cow;
196    use std::pin::Pin;
197
198    pub struct McpTool {
199        definition: rmcp::model::Tool,
200        client: rmcp::service::ServerSink,
201    }
202
203    impl McpTool {
204        pub fn from_mcp_server(
205            definition: rmcp::model::Tool,
206            client: rmcp::service::ServerSink,
207        ) -> Self {
208            Self { definition, client }
209        }
210    }
211
212    impl From<&rmcp::model::Tool> for ToolDefinition {
213        fn from(val: &rmcp::model::Tool) -> Self {
214            Self {
215                name: val.name.to_string(),
216                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217                parameters: val.schema_as_json_value(),
218            }
219        }
220    }
221
222    impl From<rmcp::model::Tool> for ToolDefinition {
223        fn from(val: rmcp::model::Tool) -> Self {
224            Self {
225                name: val.name.to_string(),
226                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227                parameters: val.schema_as_json_value(),
228            }
229        }
230    }
231
232    #[derive(Debug, thiserror::Error)]
233    #[error("MCP tool error: {0}")]
234    pub struct McpToolError(String);
235
236    impl From<McpToolError> for ToolError {
237        fn from(e: McpToolError) -> Self {
238            ToolError::ToolCallError(Box::new(e))
239        }
240    }
241
242    impl ToolDyn for McpTool {
243        fn name(&self) -> String {
244            self.definition.name.to_string()
245        }
246
247        fn definition(
248            &self,
249            _prompt: String,
250        ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
251            Box::pin(async move {
252                ToolDefinition {
253                    name: self.definition.name.to_string(),
254                    description: self
255                        .definition
256                        .description
257                        .clone()
258                        .unwrap_or(Cow::from(""))
259                        .to_string(),
260                    parameters: serde_json::to_value(&self.definition.input_schema)
261                        .unwrap_or_default(),
262                }
263            })
264        }
265
266        fn call(
267            &self,
268            args: String,
269        ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
270            let name = self.definition.name.clone();
271            let arguments = serde_json::from_str(&args).unwrap_or_default();
272
273            Box::pin(async move {
274                let result = self
275                    .client
276                    .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
277                    .await
278                    .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280                if let Some(true) = result.is_error {
281                    let error_msg = result
282                        .content
283                        .into_iter()
284                        .map(|x| x.raw.as_text().map(|y| y.to_owned()))
285                        .map(|x| x.map(|x| x.clone().text))
286                        .collect::<Option<Vec<String>>>();
287
288                    let error_message = error_msg.map(|x| x.join("\n"));
289                    if let Some(error_message) = error_message {
290                        return Err(McpToolError(error_message).into());
291                    } else {
292                        return Err(McpToolError("No message returned".to_string()).into());
293                    }
294                };
295
296                Ok(result
297                    .content
298                    .into_iter()
299                    .map(|c| match c.raw {
300                        rmcp::model::RawContent::Text(raw) => raw.text,
301                        rmcp::model::RawContent::Image(raw) => {
302                            format!("data:{};base64,{}", raw.mime_type, raw.data)
303                        }
304                        rmcp::model::RawContent::Resource(raw) => match raw.resource {
305                            rmcp::model::ResourceContents::TextResourceContents {
306                                uri,
307                                mime_type,
308                                text,
309                                ..
310                            } => {
311                                format!(
312                                    "{mime_type}{uri}:{text}",
313                                    mime_type = mime_type
314                                        .map(|m| format!("data:{m};"))
315                                        .unwrap_or_default(),
316                                )
317                            }
318                            rmcp::model::ResourceContents::BlobResourceContents {
319                                uri,
320                                mime_type,
321                                blob,
322                                ..
323                            } => format!(
324                                "{mime_type}{uri}:{blob}",
325                                mime_type = mime_type
326                                    .map(|m| format!("data:{m};"))
327                                    .unwrap_or_default(),
328                            ),
329                        },
330                        RawContent::Audio(_) => {
331                            unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
332                        }
333                        thing => {
334                            unimplemented!("Unsupported type found: {thing:?}")
335                        }
336                    })
337                    .collect::<String>())
338            })
339        }
340    }
341}
342
343/// Wrapper trait to allow for dynamic dispatch of raggable tools
344pub trait ToolEmbeddingDyn: ToolDyn {
345    fn context(&self) -> serde_json::Result<serde_json::Value>;
346
347    fn embedding_docs(&self) -> Vec<String>;
348}
349
350impl<T> ToolEmbeddingDyn for T
351where
352    T: ToolEmbedding,
353{
354    fn context(&self) -> serde_json::Result<serde_json::Value> {
355        serde_json::to_value(self.context())
356    }
357
358    fn embedding_docs(&self) -> Vec<String> {
359        self.embedding_docs()
360    }
361}
362
363pub(crate) enum ToolType {
364    Simple(Box<dyn ToolDyn>),
365    Embedding(Box<dyn ToolEmbeddingDyn>),
366}
367
368impl ToolType {
369    pub fn name(&self) -> String {
370        match self {
371            ToolType::Simple(tool) => tool.name(),
372            ToolType::Embedding(tool) => tool.name(),
373        }
374    }
375
376    pub async fn definition(&self, prompt: String) -> ToolDefinition {
377        match self {
378            ToolType::Simple(tool) => tool.definition(prompt).await,
379            ToolType::Embedding(tool) => tool.definition(prompt).await,
380        }
381    }
382
383    pub async fn call(&self, args: String) -> Result<String, ToolError> {
384        match self {
385            ToolType::Simple(tool) => tool.call(args).await,
386            ToolType::Embedding(tool) => tool.call(args).await,
387        }
388    }
389}
390
391#[derive(Debug, thiserror::Error)]
392pub enum ToolSetError {
393    /// Error returned by the tool
394    #[error("ToolCallError: {0}")]
395    ToolCallError(#[from] ToolError),
396
397    #[error("ToolNotFoundError: {0}")]
398    ToolNotFoundError(String),
399
400    // TODO: Revisit this
401    #[error("JsonError: {0}")]
402    JsonError(#[from] serde_json::Error),
403}
404
405/// A struct that holds a set of tools
406#[derive(Default)]
407pub struct ToolSet {
408    pub(crate) tools: HashMap<String, ToolType>,
409}
410
411impl ToolSet {
412    /// Create a new ToolSet from a list of tools
413    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
414        let mut toolset = Self::default();
415        tools.into_iter().for_each(|tool| {
416            toolset.add_tool(tool);
417        });
418        toolset
419    }
420
421    /// Create a toolset builder
422    pub fn builder() -> ToolSetBuilder {
423        ToolSetBuilder::default()
424    }
425
426    /// Check if the toolset contains a tool with the given name
427    pub fn contains(&self, toolname: &str) -> bool {
428        self.tools.contains_key(toolname)
429    }
430
431    /// Add a tool to the toolset
432    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
433        self.tools
434            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
435    }
436
437    pub fn delete_tool(&mut self, tool_name: &str) {
438        let _ = self.tools.remove(tool_name);
439    }
440
441    /// Merge another toolset into this one
442    pub fn add_tools(&mut self, toolset: ToolSet) {
443        self.tools.extend(toolset.tools);
444    }
445
446    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
447        self.tools.get(toolname)
448    }
449
450    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
451        let mut defs = Vec::new();
452        for tool in self.tools.values() {
453            let def = tool.definition(String::new()).await;
454            defs.push(def);
455        }
456        Ok(defs)
457    }
458
459    /// Call a tool with the given name and arguments
460    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
461        if let Some(tool) = self.tools.get(toolname) {
462            tracing::info!(target: "rig",
463                "Calling tool {toolname} with args:\n{}",
464                serde_json::to_string_pretty(&args).unwrap()
465            );
466            Ok(tool.call(args).await?)
467        } else {
468            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
469        }
470    }
471
472    /// Get the documents of all the tools in the toolset
473    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
474        let mut docs = Vec::new();
475        for tool in self.tools.values() {
476            match tool {
477                ToolType::Simple(tool) => {
478                    docs.push(completion::Document {
479                        id: tool.name(),
480                        text: format!(
481                            "\
482                            Tool: {}\n\
483                            Definition: \n\
484                            {}\
485                        ",
486                            tool.name(),
487                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
488                        ),
489                        additional_props: HashMap::new(),
490                    });
491                }
492                ToolType::Embedding(tool) => {
493                    docs.push(completion::Document {
494                        id: tool.name(),
495                        text: format!(
496                            "\
497                            Tool: {}\n\
498                            Definition: \n\
499                            {}\
500                        ",
501                            tool.name(),
502                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
503                        ),
504                        additional_props: HashMap::new(),
505                    });
506                }
507            }
508        }
509        Ok(docs)
510    }
511
512    /// Convert tools in self to objects of type ToolSchema.
513    /// This is necessary because when adding tools to the EmbeddingBuilder because all
514    /// documents added to the builder must all be of the same type.
515    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
516        self.tools
517            .values()
518            .filter_map(|tool_type| {
519                if let ToolType::Embedding(tool) = tool_type {
520                    Some(ToolSchema::try_from(&**tool))
521                } else {
522                    None
523                }
524            })
525            .collect::<Result<Vec<_>, _>>()
526    }
527}
528
529#[derive(Default)]
530pub struct ToolSetBuilder {
531    tools: Vec<ToolType>,
532}
533
534impl ToolSetBuilder {
535    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
536        self.tools.push(ToolType::Simple(Box::new(tool)));
537        self
538    }
539
540    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
541        self.tools.push(ToolType::Embedding(Box::new(tool)));
542        self
543    }
544
545    pub fn build(self) -> ToolSet {
546        ToolSet {
547            tools: self
548                .tools
549                .into_iter()
550                .map(|tool| (tool.name(), tool))
551                .collect(),
552        }
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use serde_json::json;
559
560    use super::*;
561
562    fn get_test_toolset() -> ToolSet {
563        let mut toolset = ToolSet::default();
564
565        #[derive(Deserialize)]
566        struct OperationArgs {
567            x: i32,
568            y: i32,
569        }
570
571        #[derive(Debug, thiserror::Error)]
572        #[error("Math error")]
573        struct MathError;
574
575        #[derive(Deserialize, Serialize)]
576        struct Adder;
577
578        impl Tool for Adder {
579            const NAME: &'static str = "add";
580            type Error = MathError;
581            type Args = OperationArgs;
582            type Output = i32;
583
584            async fn definition(&self, _prompt: String) -> ToolDefinition {
585                ToolDefinition {
586                    name: "add".to_string(),
587                    description: "Add x and y together".to_string(),
588                    parameters: json!({
589                        "type": "object",
590                        "properties": {
591                            "x": {
592                                "type": "number",
593                                "description": "The first number to add"
594                            },
595                            "y": {
596                                "type": "number",
597                                "description": "The second number to add"
598                            }
599                        },
600                        "required": ["x", "y"]
601                    }),
602                }
603            }
604
605            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
606                let result = args.x + args.y;
607                Ok(result)
608            }
609        }
610
611        #[derive(Deserialize, Serialize)]
612        struct Subtract;
613
614        impl Tool for Subtract {
615            const NAME: &'static str = "subtract";
616            type Error = MathError;
617            type Args = OperationArgs;
618            type Output = i32;
619
620            async fn definition(&self, _prompt: String) -> ToolDefinition {
621                serde_json::from_value(json!({
622                    "name": "subtract",
623                    "description": "Subtract y from x (i.e.: x - y)",
624                    "parameters": {
625                        "type": "object",
626                        "properties": {
627                            "x": {
628                                "type": "number",
629                                "description": "The number to subtract from"
630                            },
631                            "y": {
632                                "type": "number",
633                                "description": "The number to subtract"
634                            }
635                        },
636                        "required": ["x", "y"]
637                    }
638                }))
639                .expect("Tool Definition")
640            }
641
642            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
643                let result = args.x - args.y;
644                Ok(result)
645            }
646        }
647
648        toolset.add_tool(Adder);
649        toolset.add_tool(Subtract);
650        toolset
651    }
652
653    #[tokio::test]
654    async fn test_get_tool_definitions() {
655        let toolset = get_test_toolset();
656        let tools = toolset.get_tool_definitions().await.unwrap();
657        assert_eq!(tools.len(), 2);
658    }
659
660    #[test]
661    fn test_tool_deletion() {
662        let mut toolset = get_test_toolset();
663        assert_eq!(toolset.tools.len(), 2);
664        toolset.delete_tool("add");
665        assert!(!toolset.contains("add"));
666        assert_eq!(toolset.tools.len(), 1);
667    }
668}