rig/
tool.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    completion::{self, ToolDefinition},
19    embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24    /// Error returned by the tool
25    #[error("ToolCallError: {0}")]
26    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28    #[error("JsonError: {0}")]
29    JsonError(#[from] serde_json::Error),
30}
31
32/// Trait that represents a simple LLM tool
33///
34/// # Example
35/// ```
36/// use rig::{
37///     completion::ToolDefinition,
38///     tool::{ToolSet, Tool},
39/// };
40///
41/// #[derive(serde::Deserialize)]
42/// struct AddArgs {
43///     x: i32,
44///     y: i32,
45/// }
46///
47/// #[derive(Debug, thiserror::Error)]
48/// #[error("Math error")]
49/// struct MathError;
50///
51/// #[derive(serde::Deserialize, serde::Serialize)]
52/// struct Adder;
53///
54/// impl Tool for Adder {
55///     const NAME: &'static str = "add";
56///
57///     type Error = MathError;
58///     type Args = AddArgs;
59///     type Output = i32;
60///
61///     async fn definition(&self, _prompt: String) -> ToolDefinition {
62///         ToolDefinition {
63///             name: "add".to_string(),
64///             description: "Add x and y together".to_string(),
65///             parameters: serde_json::json!({
66///                 "type": "object",
67///                 "properties": {
68///                     "x": {
69///                         "type": "number",
70///                         "description": "The first number to add"
71///                     },
72///                     "y": {
73///                         "type": "number",
74///                         "description": "The second number to add"
75///                     }
76///                 }
77///             })
78///         }
79///     }
80///
81///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
82///         let result = args.x + args.y;
83///         Ok(result)
84///     }
85/// }
86/// ```
87pub trait Tool: Sized + Send + Sync {
88    /// The name of the tool. This name should be unique.
89    const NAME: &'static str;
90
91    /// The error type of the tool.
92    type Error: std::error::Error + Send + Sync + 'static;
93    /// The arguments type of the tool.
94    type Args: for<'a> Deserialize<'a> + Send + Sync;
95    /// The output type of the tool.
96    type Output: Serialize;
97
98    /// A method returning the name of the tool.
99    fn name(&self) -> String {
100        Self::NAME.to_string()
101    }
102
103    /// A method returning the tool definition. The user prompt can be used to
104    /// tailor the definition to the specific use case.
105    fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107    /// The tool execution method.
108    /// Both the arguments and return value are a String since these values are meant to
109    /// be the output and input of LLM models (respectively)
110    fn call(
111        &self,
112        args: Self::Args,
113    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + Sync;
114}
115
116/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
117pub trait ToolEmbedding: Tool {
118    type InitError: std::error::Error + Send + Sync + 'static;
119
120    /// Type of the tool' context. This context will be saved and loaded from the
121    /// vector store when ragging the tool.
122    /// This context can be used to store the tool's static configuration and local
123    /// context.
124    type Context: for<'a> Deserialize<'a> + Serialize;
125
126    /// Type of the tool's state. This state will be passed to the tool when initializing it.
127    /// This state can be used to pass runtime arguments to the tool such as clients,
128    /// API keys and other configuration.
129    type State: Send;
130
131    /// A method returning the documents that will be used as embeddings for the tool.
132    /// This allows for a tool to be retrieved from multiple embedding "directions".
133    /// If the tool will not be RAGged, this method should return an empty vector.
134    fn embedding_docs(&self) -> Vec<String>;
135
136    /// A method returning the context of the tool.
137    fn context(&self) -> Self::Context;
138
139    /// A method to initialize the tool from the context, and a state.
140    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143/// Wrapper trait to allow for dynamic dispatch of simple tools
144pub trait ToolDyn: Send + Sync {
145    fn name(&self) -> String;
146
147    fn definition(
148        &self,
149        prompt: String,
150    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152    fn call(
153        &self,
154        args: String,
155    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159    fn name(&self) -> String {
160        self.name()
161    }
162
163    fn definition(
164        &self,
165        prompt: String,
166    ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167        Box::pin(<Self as Tool>::definition(self, prompt))
168    }
169
170    fn call(
171        &self,
172        args: String,
173    ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
174        Box::pin(async move {
175            match serde_json::from_str(&args) {
176                Ok(args) => <Self as Tool>::call(self, args)
177                    .await
178                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179                    .and_then(|output| {
180                        serde_json::to_string(&output).map_err(ToolError::JsonError)
181                    }),
182                Err(e) => Err(ToolError::JsonError(e)),
183            }
184        })
185    }
186}
187
188#[deprecated(
189    since = "0.16.0",
190    note = "Since the official Rust MCP SDK (`rmcp`) has been added, the original Rig MCP integration with `mcp-core` has now been deprecated.
191Please migrate over to the new integration - you can use the `rmcp` feature flag to do so.
192This integration will be fully removed (and replaced with the `rmcp` one) by 0.18.0 at the earliest.
193A guide can be found at `http://docs.rig.rs`, Rig's official docsite."
194)]
195#[cfg(feature = "mcp")]
196pub mod mcp {
197    use crate::completion::ToolDefinition;
198    use crate::tool::ToolDyn;
199    use crate::tool::ToolError;
200    use std::pin::Pin;
201
202    pub struct McpTool<T: mcp_core::transport::Transport> {
203        definition: mcp_core::types::Tool,
204        client: mcp_core::client::Client<T>,
205    }
206
207    impl<T> McpTool<T>
208    where
209        T: mcp_core::transport::Transport,
210    {
211        pub fn from_mcp_server(
212            definition: mcp_core::types::Tool,
213            client: mcp_core::client::Client<T>,
214        ) -> Self {
215            Self { definition, client }
216        }
217    }
218
219    impl From<&mcp_core::types::Tool> for ToolDefinition {
220        fn from(val: &mcp_core::types::Tool) -> Self {
221            Self {
222                name: val.name.to_owned(),
223                description: val.description.to_owned().unwrap_or_default(),
224                parameters: val.input_schema.to_owned(),
225            }
226        }
227    }
228
229    impl From<mcp_core::types::Tool> for ToolDefinition {
230        fn from(val: mcp_core::types::Tool) -> Self {
231            Self {
232                name: val.name,
233                description: val.description.unwrap_or_default(),
234                parameters: val.input_schema,
235            }
236        }
237    }
238
239    #[derive(Debug, thiserror::Error)]
240    #[error("MCP tool error: {0}")]
241    pub struct McpToolError(String);
242
243    impl From<McpToolError> for ToolError {
244        fn from(e: McpToolError) -> Self {
245            ToolError::ToolCallError(Box::new(e))
246        }
247    }
248
249    impl<T> ToolDyn for McpTool<T>
250    where
251        T: mcp_core::transport::Transport,
252    {
253        fn name(&self) -> String {
254            self.definition.name.clone()
255        }
256
257        fn definition(
258            &self,
259            _prompt: String,
260        ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
261            Box::pin(async move {
262                ToolDefinition {
263                    name: self.definition.name.clone(),
264                    description: match &self.definition.description {
265                        Some(desc) => desc.clone(),
266                        None => String::new(),
267                    },
268                    parameters: serde_json::to_value(&self.definition.input_schema)
269                        .unwrap_or_default(),
270                }
271            })
272        }
273
274        fn call(
275            &self,
276            args: String,
277        ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
278            let name = self.definition.name.clone();
279            let args_clone = args.clone();
280            let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
281            Box::pin(async move {
282                let result = self
283                    .client
284                    .call_tool(&name, Some(args))
285                    .await
286                    .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
287
288                if result.is_error.unwrap_or(false) {
289                    if let Some(error) = result.content.first() {
290                        match error {
291                            mcp_core::types::ToolResponseContent::Text(text_content) => {
292                                return Err(McpToolError(text_content.text.clone()).into());
293                            }
294                            _ => {
295                                return Err(
296                                    McpToolError("Unsuppported error type".to_string()).into()
297                                );
298                            }
299                        }
300                    } else {
301                        return Err(McpToolError("No error message returned".to_string()).into());
302                    }
303                }
304
305                Ok(result
306                    .content
307                    .into_iter()
308                    .map(|c| match c {
309                        mcp_core::types::ToolResponseContent::Text(text_content) => {
310                            text_content.text
311                        }
312                        mcp_core::types::ToolResponseContent::Image(image_content) => {
313                            format!(
314                                "data:{};base64,{}",
315                                image_content.mime_type, image_content.data
316                            )
317                        }
318                        mcp_core::types::ToolResponseContent::Audio(audio_content) => {
319                            format!(
320                                "data:{};base64,{}",
321                                audio_content.mime_type, audio_content.data
322                            )
323                        }
324
325                        mcp_core::types::ToolResponseContent::Resource(embedded_resource) => {
326                            format!(
327                                "{}{}",
328                                embedded_resource
329                                    .resource
330                                    .mime_type
331                                    .map(|m| format!("data:{m};"))
332                                    .unwrap_or_default(),
333                                embedded_resource.resource.uri
334                            )
335                        }
336                    })
337                    .collect::<Vec<_>>()
338                    .join(""))
339            })
340        }
341    }
342}
343
344#[cfg(feature = "rmcp")]
345pub mod rmcp {
346    use crate::completion::ToolDefinition;
347    use crate::tool::ToolDyn;
348    use crate::tool::ToolError;
349    use rmcp::model::RawContent;
350    use std::borrow::Cow;
351    use std::pin::Pin;
352
353    pub struct McpTool {
354        definition: rmcp::model::Tool,
355        client: rmcp::service::ServerSink,
356    }
357
358    impl McpTool {
359        pub fn from_mcp_server(
360            definition: rmcp::model::Tool,
361            client: rmcp::service::ServerSink,
362        ) -> Self {
363            Self { definition, client }
364        }
365    }
366
367    impl From<&rmcp::model::Tool> for ToolDefinition {
368        fn from(val: &rmcp::model::Tool) -> Self {
369            Self {
370                name: val.name.to_string(),
371                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
372                parameters: val.schema_as_json_value(),
373            }
374        }
375    }
376
377    impl From<rmcp::model::Tool> for ToolDefinition {
378        fn from(val: rmcp::model::Tool) -> Self {
379            Self {
380                name: val.name.to_string(),
381                description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
382                parameters: val.schema_as_json_value(),
383            }
384        }
385    }
386
387    #[derive(Debug, thiserror::Error)]
388    #[error("MCP tool error: {0}")]
389    pub struct McpToolError(String);
390
391    impl From<McpToolError> for ToolError {
392        fn from(e: McpToolError) -> Self {
393            ToolError::ToolCallError(Box::new(e))
394        }
395    }
396
397    impl ToolDyn for McpTool {
398        fn name(&self) -> String {
399            self.definition.name.to_string()
400        }
401
402        fn definition(
403            &self,
404            _prompt: String,
405        ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
406            Box::pin(async move {
407                ToolDefinition {
408                    name: self.definition.name.to_string(),
409                    description: self
410                        .definition
411                        .description
412                        .clone()
413                        .unwrap_or(Cow::from(""))
414                        .to_string(),
415                    parameters: serde_json::to_value(&self.definition.input_schema)
416                        .unwrap_or_default(),
417                }
418            })
419        }
420
421        fn call(
422            &self,
423            args: String,
424        ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
425            let name = self.definition.name.clone();
426            let arguments = serde_json::from_str(&args).unwrap_or_default();
427
428            Box::pin(async move {
429                let result = self
430                    .client
431                    .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
432                    .await
433                    .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
434
435                if result.is_error.unwrap_or(false) {
436                    if let Some(error) = result.content.first() {
437                        if let Some(raw) = error.as_text() {
438                            return Err(McpToolError(raw.text.clone()).into());
439                        } else {
440                            return Err(McpToolError("Unsuppported error type".to_string()).into());
441                        }
442                    } else {
443                        return Err(McpToolError("No error message returned".to_string()).into());
444                    }
445                };
446
447                Ok(result
448                    .content
449                    .into_iter()
450                    .map(|c| match c.raw {
451                        rmcp::model::RawContent::Text(raw) => raw.text,
452                        rmcp::model::RawContent::Image(raw) => {
453                            format!("data:{};base64,{}", raw.mime_type, raw.data)
454                        }
455                        rmcp::model::RawContent::Resource(raw) => match raw.resource {
456                            rmcp::model::ResourceContents::TextResourceContents {
457                                uri,
458                                mime_type,
459                                text,
460                            } => {
461                                format!(
462                                    "{mime_type}{uri}:{text}",
463                                    mime_type = mime_type
464                                        .map(|m| format!("data:{m};"))
465                                        .unwrap_or_default(),
466                                )
467                            }
468                            rmcp::model::ResourceContents::BlobResourceContents {
469                                uri,
470                                mime_type,
471                                blob,
472                            } => format!(
473                                "{mime_type}{uri}:{blob}",
474                                mime_type = mime_type
475                                    .map(|m| format!("data:{m};"))
476                                    .unwrap_or_default(),
477                            ),
478                        },
479                        RawContent::Audio(_) => {
480                            unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
481                        }
482                    })
483                    .collect::<Vec<_>>()
484                    .join(""))
485            })
486        }
487    }
488}
489
490/// Wrapper trait to allow for dynamic dispatch of raggable tools
491pub trait ToolEmbeddingDyn: ToolDyn {
492    fn context(&self) -> serde_json::Result<serde_json::Value>;
493
494    fn embedding_docs(&self) -> Vec<String>;
495}
496
497impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
498    fn context(&self) -> serde_json::Result<serde_json::Value> {
499        serde_json::to_value(self.context())
500    }
501
502    fn embedding_docs(&self) -> Vec<String> {
503        self.embedding_docs()
504    }
505}
506
507pub(crate) enum ToolType {
508    Simple(Box<dyn ToolDyn>),
509    Embedding(Box<dyn ToolEmbeddingDyn>),
510}
511
512impl ToolType {
513    pub fn name(&self) -> String {
514        match self {
515            ToolType::Simple(tool) => tool.name(),
516            ToolType::Embedding(tool) => tool.name(),
517        }
518    }
519
520    pub async fn definition(&self, prompt: String) -> ToolDefinition {
521        match self {
522            ToolType::Simple(tool) => tool.definition(prompt).await,
523            ToolType::Embedding(tool) => tool.definition(prompt).await,
524        }
525    }
526
527    pub async fn call(&self, args: String) -> Result<String, ToolError> {
528        match self {
529            ToolType::Simple(tool) => tool.call(args).await,
530            ToolType::Embedding(tool) => tool.call(args).await,
531        }
532    }
533}
534
535#[derive(Debug, thiserror::Error)]
536pub enum ToolSetError {
537    /// Error returned by the tool
538    #[error("ToolCallError: {0}")]
539    ToolCallError(#[from] ToolError),
540
541    #[error("ToolNotFoundError: {0}")]
542    ToolNotFoundError(String),
543
544    // TODO: Revisit this
545    #[error("JsonError: {0}")]
546    JsonError(#[from] serde_json::Error),
547}
548
549/// A struct that holds a set of tools
550#[derive(Default)]
551pub struct ToolSet {
552    pub(crate) tools: HashMap<String, ToolType>,
553}
554
555impl ToolSet {
556    /// Create a new ToolSet from a list of tools
557    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
558        let mut toolset = Self::default();
559        tools.into_iter().for_each(|tool| {
560            toolset.add_tool(tool);
561        });
562        toolset
563    }
564
565    /// Create a toolset builder
566    pub fn builder() -> ToolSetBuilder {
567        ToolSetBuilder::default()
568    }
569
570    /// Check if the toolset contains a tool with the given name
571    pub fn contains(&self, toolname: &str) -> bool {
572        self.tools.contains_key(toolname)
573    }
574
575    /// Add a tool to the toolset
576    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
577        self.tools
578            .insert(tool.name(), ToolType::Simple(Box::new(tool)));
579    }
580
581    /// Merge another toolset into this one
582    pub fn add_tools(&mut self, toolset: ToolSet) {
583        self.tools.extend(toolset.tools);
584    }
585
586    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
587        self.tools.get(toolname)
588    }
589
590    /// Call a tool with the given name and arguments
591    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
592        if let Some(tool) = self.tools.get(toolname) {
593            tracing::info!(target: "rig",
594                "Calling tool {toolname} with args:\n{}",
595                serde_json::to_string_pretty(&args).unwrap()
596            );
597            Ok(tool.call(args).await?)
598        } else {
599            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
600        }
601    }
602
603    /// Get the documents of all the tools in the toolset
604    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
605        let mut docs = Vec::new();
606        for tool in self.tools.values() {
607            match tool {
608                ToolType::Simple(tool) => {
609                    docs.push(completion::Document {
610                        id: tool.name(),
611                        text: format!(
612                            "\
613                            Tool: {}\n\
614                            Definition: \n\
615                            {}\
616                        ",
617                            tool.name(),
618                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
619                        ),
620                        additional_props: HashMap::new(),
621                    });
622                }
623                ToolType::Embedding(tool) => {
624                    docs.push(completion::Document {
625                        id: tool.name(),
626                        text: format!(
627                            "\
628                            Tool: {}\n\
629                            Definition: \n\
630                            {}\
631                        ",
632                            tool.name(),
633                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
634                        ),
635                        additional_props: HashMap::new(),
636                    });
637                }
638            }
639        }
640        Ok(docs)
641    }
642
643    /// Convert tools in self to objects of type ToolSchema.
644    /// This is necessary because when adding tools to the EmbeddingBuilder because all
645    /// documents added to the builder must all be of the same type.
646    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
647        self.tools
648            .values()
649            .filter_map(|tool_type| {
650                if let ToolType::Embedding(tool) = tool_type {
651                    Some(ToolSchema::try_from(&**tool))
652                } else {
653                    None
654                }
655            })
656            .collect::<Result<Vec<_>, _>>()
657    }
658}
659
660#[derive(Default)]
661pub struct ToolSetBuilder {
662    tools: Vec<ToolType>,
663}
664
665impl ToolSetBuilder {
666    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
667        self.tools.push(ToolType::Simple(Box::new(tool)));
668        self
669    }
670
671    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
672        self.tools.push(ToolType::Embedding(Box::new(tool)));
673        self
674    }
675
676    pub fn build(self) -> ToolSet {
677        ToolSet {
678            tools: self
679                .tools
680                .into_iter()
681                .map(|tool| (tool.name(), tool))
682                .collect(),
683        }
684    }
685}