rig/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14
15use futures::Future;
16use serde::{Deserialize, Serialize};
17
18use crate::{
19    completion::{self, ToolDefinition},
20    embeddings::{embed::EmbedError, tool::ToolSchema},
21    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
22};
23
24#[derive(Debug, thiserror::Error)]
25pub enum ToolError {
26    #[cfg(not(target_family = "wasm"))]
27    /// Error returned by the tool
28    #[error("ToolCallError: {0}")]
29    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31    #[cfg(target_family = "wasm")]
32    /// Error returned by the tool
33    #[error("ToolCallError: {0}")]
34    ToolCallError(#[from] Box<dyn std::error::Error>),
35
36    #[error("JsonError: {0}")]
37    JsonError(#[from] serde_json::Error),
38}
39
40/// Trait that represents a simple LLM tool
41///
42/// # Example
43/// ```
44/// use rig::{
45///     completion::ToolDefinition,
46///     tool::{ToolSet, Tool},
47/// };
48///
49/// #[derive(serde::Deserialize)]
50/// struct AddArgs {
51///     x: i32,
52///     y: i32,
53/// }
54///
55/// #[derive(Debug, thiserror::Error)]
56/// #[error("Math error")]
57/// struct MathError;
58///
59/// #[derive(serde::Deserialize, serde::Serialize)]
60/// struct Adder;
61///
62/// impl Tool for Adder {
63///     const NAME: &'static str = "add";
64///
65///     type Error = MathError;
66///     type Args = AddArgs;
67///     type Output = i32;
68///
69///     async fn definition(&self, _prompt: String) -> ToolDefinition {
70///         ToolDefinition {
71///             name: "add".to_string(),
72///             description: "Add x and y together".to_string(),
73///             parameters: serde_json::json!({
74///                 "type": "object",
75///                 "properties": {
76///                     "x": {
77///                         "type": "number",
78///                         "description": "The first number to add"
79///                     },
80///                     "y": {
81///                         "type": "number",
82///                         "description": "The second number to add"
83///                     }
84///                 }
85///             })
86///         }
87///     }
88///
89///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
90///         let result = args.x + args.y;
91///         Ok(result)
92///     }
93/// }
94/// ```
95pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
96    /// The name of the tool. This name should be unique.
97    const NAME: &'static str;
98
99    /// The error type of the tool.
100    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
101    /// The arguments type of the tool.
102    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
103    /// The output type of the tool.
104    type Output: Serialize;
105
106    /// A method returning the name of the tool.
107    fn name(&self) -> String {
108        Self::NAME.to_string()
109    }
110
111    /// A method returning the tool definition. The user prompt can be used to
112    /// tailor the definition to the specific use case.
113    fn definition(
114        &self,
115        _prompt: String,
116    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
117
118    /// The tool execution method.
119    /// Both the arguments and return value are a String since these values are meant to
120    /// be the output and input of LLM models (respectively)
121    fn call(
122        &self,
123        args: Self::Args,
124    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
125}
126
127/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
128pub trait ToolEmbedding: Tool {
129    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
130
131    /// Type of the tool' context. This context will be saved and loaded from the
132    /// vector store when ragging the tool.
133    /// This context can be used to store the tool's static configuration and local
134    /// context.
135    type Context: for<'a> Deserialize<'a> + Serialize;
136
137    /// Type of the tool's state. This state will be passed to the tool when initializing it.
138    /// This state can be used to pass runtime arguments to the tool such as clients,
139    /// API keys and other configuration.
140    type State: WasmCompatSend;
141
142    /// A method returning the documents that will be used as embeddings for the tool.
143    /// This allows for a tool to be retrieved from multiple embedding "directions".
144    /// If the tool will not be RAGged, this method should return an empty vector.
145    fn embedding_docs(&self) -> Vec<String>;
146
147    /// A method returning the context of the tool.
148    fn context(&self) -> Self::Context;
149
150    /// A method to initialize the tool from the context, and a state.
151    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
152}
153
154/// Wrapper trait to allow for dynamic dispatch of simple tools
155pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
156    fn name(&self) -> String;
157
158    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
159
160    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
161}
162
163impl<T: Tool> ToolDyn for T {
164    fn name(&self) -> String {
165        self.name()
166    }
167
168    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
169        Box::pin(<Self as Tool>::definition(self, prompt))
170    }
171
172    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
173        Box::pin(async move {
174            match serde_json::from_str(&args) {
175                Ok(args) => <Self as Tool>::call(self, args)
176                    .await
177                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
178                    .and_then(|output| {
179                        serde_json::to_string(&output).map_err(ToolError::JsonError)
180                    }),
181                Err(e) => Err(ToolError::JsonError(e)),
182            }
183        })
184    }
185}
186
187#[cfg(feature = "rmcp")]
188#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
189pub mod rmcp {
190    use crate::completion::ToolDefinition;
191    use crate::tool::ToolDyn;
192    use crate::tool::ToolError;
193    use crate::wasm_compat::WasmBoxedFuture;
194    use rmcp::model::RawContent;
195    use std::borrow::Cow;
196
197    #[derive(Clone)]
198    pub struct McpTool {
199        definition: rmcp::model::Tool,
200        client: rmcp::service::ServerSink,
201    }
202
203    impl McpTool {
204        pub fn from_mcp_server(
205            definition: rmcp::model::Tool,
206            client: rmcp::service::ServerSink,
207        ) -> Self {
208            Self { definition, client }
209        }
210    }
211
212    impl From<&rmcp::model::Tool> for ToolDefinition {
213        fn from(val: &rmcp::model::Tool) -> Self {
214            Self {
215                name: val.name.to_string(),
216                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217                parameters: val.schema_as_json_value(),
218            }
219        }
220    }
221
222    impl From<rmcp::model::Tool> for ToolDefinition {
223        fn from(val: rmcp::model::Tool) -> Self {
224            Self {
225                name: val.name.to_string(),
226                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227                parameters: val.schema_as_json_value(),
228            }
229        }
230    }
231
232    #[derive(Debug, thiserror::Error)]
233    #[error("MCP tool error: {0}")]
234    pub struct McpToolError(String);
235
236    impl From<McpToolError> for ToolError {
237        fn from(e: McpToolError) -> Self {
238            ToolError::ToolCallError(Box::new(e))
239        }
240    }
241
242    impl ToolDyn for McpTool {
243        fn name(&self) -> String {
244            self.definition.name.to_string()
245        }
246
247        fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
248            Box::pin(async move {
249                ToolDefinition {
250                    name: self.definition.name.to_string(),
251                    description: self
252                        .definition
253                        .description
254                        .clone()
255                        .unwrap_or(Cow::from(""))
256                        .to_string(),
257                    parameters: serde_json::to_value(&self.definition.input_schema)
258                        .unwrap_or_default(),
259                }
260            })
261        }
262
263        fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
264            let name = self.definition.name.clone();
265            let arguments = serde_json::from_str(&args).unwrap_or_default();
266
267            Box::pin(async move {
268                let result = self
269                    .client
270                    .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
271                    .await
272                    .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
273
274                if let Some(true) = result.is_error {
275                    let error_msg = result
276                        .content
277                        .into_iter()
278                        .map(|x| x.raw.as_text().map(|y| y.to_owned()))
279                        .map(|x| x.map(|x| x.clone().text))
280                        .collect::<Option<Vec<String>>>();
281
282                    let error_message = error_msg.map(|x| x.join("\n"));
283                    if let Some(error_message) = error_message {
284                        return Err(McpToolError(error_message).into());
285                    } else {
286                        return Err(McpToolError("No message returned".to_string()).into());
287                    }
288                };
289
290                Ok(result
291                    .content
292                    .into_iter()
293                    .map(|c| match c.raw {
294                        rmcp::model::RawContent::Text(raw) => raw.text,
295                        rmcp::model::RawContent::Image(raw) => {
296                            format!("data:{};base64,{}", raw.mime_type, raw.data)
297                        }
298                        rmcp::model::RawContent::Resource(raw) => match raw.resource {
299                            rmcp::model::ResourceContents::TextResourceContents {
300                                uri,
301                                mime_type,
302                                text,
303                                ..
304                            } => {
305                                format!(
306                                    "{mime_type}{uri}:{text}",
307                                    mime_type = mime_type
308                                        .map(|m| format!("data:{m};"))
309                                        .unwrap_or_default(),
310                                )
311                            }
312                            rmcp::model::ResourceContents::BlobResourceContents {
313                                uri,
314                                mime_type,
315                                blob,
316                                ..
317                            } => format!(
318                                "{mime_type}{uri}:{blob}",
319                                mime_type = mime_type
320                                    .map(|m| format!("data:{m};"))
321                                    .unwrap_or_default(),
322                            ),
323                        },
324                        RawContent::Audio(_) => {
325                            unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
326                        }
327                        thing => {
328                            unimplemented!("Unsupported type found: {thing:?}")
329                        }
330                    })
331                    .collect::<String>())
332            })
333        }
334    }
335}
336
337/// Wrapper trait to allow for dynamic dispatch of raggable tools
338pub trait ToolEmbeddingDyn: ToolDyn {
339    fn context(&self) -> serde_json::Result<serde_json::Value>;
340
341    fn embedding_docs(&self) -> Vec<String>;
342}
343
344impl<T> ToolEmbeddingDyn for T
345where
346    T: ToolEmbedding + 'static,
347{
348    fn context(&self) -> serde_json::Result<serde_json::Value> {
349        serde_json::to_value(self.context())
350    }
351
352    fn embedding_docs(&self) -> Vec<String> {
353        self.embedding_docs()
354    }
355}
356
357pub(crate) enum ToolType {
358    Simple(Box<dyn ToolDyn>),
359    Embedding(Box<dyn ToolEmbeddingDyn>),
360}
361
362impl ToolType {
363    pub fn name(&self) -> String {
364        match self {
365            ToolType::Simple(tool) => tool.name(),
366            ToolType::Embedding(tool) => tool.name(),
367        }
368    }
369
370    pub async fn definition(&self, prompt: String) -> ToolDefinition {
371        match self {
372            ToolType::Simple(tool) => tool.definition(prompt).await,
373            ToolType::Embedding(tool) => tool.definition(prompt).await,
374        }
375    }
376
377    pub async fn call(&self, args: String) -> Result<String, ToolError> {
378        match self {
379            ToolType::Simple(tool) => tool.call(args).await,
380            ToolType::Embedding(tool) => tool.call(args).await,
381        }
382    }
383}
384
385#[derive(Debug, thiserror::Error)]
386pub enum ToolSetError {
387    /// Error returned by the tool
388    #[error("ToolCallError: {0}")]
389    ToolCallError(#[from] ToolError),
390
391    /// Could not find a tool
392    #[error("ToolNotFoundError: {0}")]
393    ToolNotFoundError(String),
394
395    // TODO: Revisit this
396    #[error("JsonError: {0}")]
397    JsonError(#[from] serde_json::Error),
398
399    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
400    #[error("Tool call interrupted")]
401    Interrupted,
402}
403
404/// A struct that holds a set of tools
405#[derive(Default)]
406pub struct ToolSet {
407    pub(crate) tools: HashMap<String, ToolType>,
408}
409
410impl ToolSet {
411    /// Create a new ToolSet from a list of tools
412    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
413        let mut toolset = Self::default();
414        tools.into_iter().for_each(|tool| {
415            toolset.add_tool(tool);
416        });
417        toolset
418    }
419
420    /// Create a toolset builder
421    pub fn builder() -> ToolSetBuilder {
422        ToolSetBuilder::default()
423    }
424
425    /// Check if the toolset contains a tool with the given name
426    pub fn contains(&self, toolname: &str) -> bool {
427        self.tools.contains_key(toolname)
428    }
429
430    /// Add a tool to the toolset
431    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
432        self.tools
433            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
434    }
435
436    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
437    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
438        self.tools.insert(tool.name(), ToolType::Simple(tool));
439    }
440
441    pub fn delete_tool(&mut self, tool_name: &str) {
442        let _ = self.tools.remove(tool_name);
443    }
444
445    /// Merge another toolset into this one
446    pub fn add_tools(&mut self, toolset: ToolSet) {
447        self.tools.extend(toolset.tools);
448    }
449
450    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
451        self.tools.get(toolname)
452    }
453
454    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
455        let mut defs = Vec::new();
456        for tool in self.tools.values() {
457            let def = tool.definition(String::new()).await;
458            defs.push(def);
459        }
460        Ok(defs)
461    }
462
463    /// Call a tool with the given name and arguments
464    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
465        if let Some(tool) = self.tools.get(toolname) {
466            tracing::debug!(target: "rig",
467                "Calling tool {toolname} with args:\n{}",
468                serde_json::to_string_pretty(&args).unwrap()
469            );
470            Ok(tool.call(args).await?)
471        } else {
472            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
473        }
474    }
475
476    /// Get the documents of all the tools in the toolset
477    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
478        let mut docs = Vec::new();
479        for tool in self.tools.values() {
480            match tool {
481                ToolType::Simple(tool) => {
482                    docs.push(completion::Document {
483                        id: tool.name(),
484                        text: format!(
485                            "\
486                            Tool: {}\n\
487                            Definition: \n\
488                            {}\
489                        ",
490                            tool.name(),
491                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
492                        ),
493                        additional_props: HashMap::new(),
494                    });
495                }
496                ToolType::Embedding(tool) => {
497                    docs.push(completion::Document {
498                        id: tool.name(),
499                        text: format!(
500                            "\
501                            Tool: {}\n\
502                            Definition: \n\
503                            {}\
504                        ",
505                            tool.name(),
506                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
507                        ),
508                        additional_props: HashMap::new(),
509                    });
510                }
511            }
512        }
513        Ok(docs)
514    }
515
516    /// Convert tools in self to objects of type ToolSchema.
517    /// This is necessary because when adding tools to the EmbeddingBuilder because all
518    /// documents added to the builder must all be of the same type.
519    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
520        self.tools
521            .values()
522            .filter_map(|tool_type| {
523                if let ToolType::Embedding(tool) = tool_type {
524                    Some(ToolSchema::try_from(&**tool))
525                } else {
526                    None
527                }
528            })
529            .collect::<Result<Vec<_>, _>>()
530    }
531}
532
533#[derive(Default)]
534pub struct ToolSetBuilder {
535    tools: Vec<ToolType>,
536}
537
538impl ToolSetBuilder {
539    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
540        self.tools.push(ToolType::Simple(Box::new(tool)));
541        self
542    }
543
544    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
545        self.tools.push(ToolType::Embedding(Box::new(tool)));
546        self
547    }
548
549    pub fn build(self) -> ToolSet {
550        ToolSet {
551            tools: self
552                .tools
553                .into_iter()
554                .map(|tool| (tool.name(), tool))
555                .collect(),
556        }
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use serde_json::json;
563
564    use super::*;
565
566    fn get_test_toolset() -> ToolSet {
567        let mut toolset = ToolSet::default();
568
569        #[derive(Deserialize)]
570        struct OperationArgs {
571            x: i32,
572            y: i32,
573        }
574
575        #[derive(Debug, thiserror::Error)]
576        #[error("Math error")]
577        struct MathError;
578
579        #[derive(Deserialize, Serialize)]
580        struct Adder;
581
582        impl Tool for Adder {
583            const NAME: &'static str = "add";
584            type Error = MathError;
585            type Args = OperationArgs;
586            type Output = i32;
587
588            async fn definition(&self, _prompt: String) -> ToolDefinition {
589                ToolDefinition {
590                    name: "add".to_string(),
591                    description: "Add x and y together".to_string(),
592                    parameters: json!({
593                        "type": "object",
594                        "properties": {
595                            "x": {
596                                "type": "number",
597                                "description": "The first number to add"
598                            },
599                            "y": {
600                                "type": "number",
601                                "description": "The second number to add"
602                            }
603                        },
604                        "required": ["x", "y"]
605                    }),
606                }
607            }
608
609            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
610                let result = args.x + args.y;
611                Ok(result)
612            }
613        }
614
615        #[derive(Deserialize, Serialize)]
616        struct Subtract;
617
618        impl Tool for Subtract {
619            const NAME: &'static str = "subtract";
620            type Error = MathError;
621            type Args = OperationArgs;
622            type Output = i32;
623
624            async fn definition(&self, _prompt: String) -> ToolDefinition {
625                serde_json::from_value(json!({
626                    "name": "subtract",
627                    "description": "Subtract y from x (i.e.: x - y)",
628                    "parameters": {
629                        "type": "object",
630                        "properties": {
631                            "x": {
632                                "type": "number",
633                                "description": "The number to subtract from"
634                            },
635                            "y": {
636                                "type": "number",
637                                "description": "The number to subtract"
638                            }
639                        },
640                        "required": ["x", "y"]
641                    }
642                }))
643                .expect("Tool Definition")
644            }
645
646            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
647                let result = args.x - args.y;
648                Ok(result)
649            }
650        }
651
652        toolset.add_tool(Adder);
653        toolset.add_tool(Subtract);
654        toolset
655    }
656
657    #[tokio::test]
658    async fn test_get_tool_definitions() {
659        let toolset = get_test_toolset();
660        let tools = toolset.get_tool_definitions().await.unwrap();
661        assert_eq!(tools.len(), 2);
662    }
663
664    #[test]
665    fn test_tool_deletion() {
666        let mut toolset = get_test_toolset();
667        assert_eq!(toolset.tools.len(), 2);
668        toolset.delete_tool("add");
669        assert!(!toolset.contains("add"));
670        assert_eq!(toolset.tools.len(), 1);
671    }
672}