Skip to main content

rig_core/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21    completion::{self, ToolDefinition},
22    embeddings::{embed::EmbedError, tool::ToolSchema},
23    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28    #[cfg(not(target_family = "wasm"))]
29    /// Error returned by the tool
30    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32    #[cfg(target_family = "wasm")]
33    /// Error returned by the tool
34    ToolCallError(#[from] Box<dyn std::error::Error>),
35    /// Error caused by a de/serialization fail
36    JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            ToolError::ToolCallError(e) => {
43                let error_str = e.to_string();
44                // This is required due to being able to use agents as tools
45                // which means it is possible to get recursive tool call errors
46                if error_str.starts_with("ToolCallError: ") {
47                    write!(f, "{}", error_str)
48                } else {
49                    write!(f, "ToolCallError: {}", error_str)
50                }
51            }
52            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53        }
54    }
55}
56
57/// Trait that represents a simple LLM tool
58///
59/// # Example
60/// ```
61/// use rig_core::{
62///     completion::ToolDefinition,
63///     tool::{ToolSet, Tool},
64/// };
65///
66/// #[derive(serde::Deserialize)]
67/// struct AddArgs {
68///     x: i32,
69///     y: i32,
70/// }
71///
72/// #[derive(Debug, thiserror::Error)]
73/// #[error("Math error")]
74/// struct MathError;
75///
76/// #[derive(serde::Deserialize, serde::Serialize)]
77/// struct Adder;
78///
79/// impl Tool for Adder {
80///     const NAME: &'static str = "add";
81///
82///     type Error = MathError;
83///     type Args = AddArgs;
84///     type Output = i32;
85///
86///     async fn definition(&self, _prompt: String) -> ToolDefinition {
87///         ToolDefinition {
88///             name: "add".to_string(),
89///             description: "Add x and y together".to_string(),
90///             parameters: serde_json::json!({
91///                 "type": "object",
92///                 "properties": {
93///                     "x": {
94///                         "type": "number",
95///                         "description": "The first number to add"
96///                     },
97///                     "y": {
98///                         "type": "number",
99///                         "description": "The second number to add"
100///                     }
101///                 }
102///             })
103///         }
104///     }
105///
106///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
107///         let result = args.x + args.y;
108///         Ok(result)
109///     }
110/// }
111/// ```
112pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113    /// The name of the tool. This name should be unique within a single
114    /// [`ToolSet`] or other registration scope that dispatches tools by name.
115    const NAME: &'static str;
116
117    /// The error type of the tool.
118    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119    /// The arguments type of the tool.
120    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121    /// The output type of the tool.
122    type Output: Serialize;
123
124    /// A method returning the name of the tool.
125    fn name(&self) -> String {
126        Self::NAME.to_string()
127    }
128
129    /// A method returning the tool definition. The user prompt can be used to
130    /// tailor the definition to the specific use case.
131    fn definition(
132        &self,
133        _prompt: String,
134    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136    /// The tool execution method.
137    /// Both the arguments and return value are a String since these values are meant to
138    /// be the output and input of LLM models (respectively)
139    fn call(
140        &self,
141        args: Self::Args,
142    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
146pub trait ToolEmbedding: Tool {
147    /// Error returned when reconstructing a dynamic tool from stored context.
148    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
149
150    /// Type of the tool' context. This context will be saved and loaded from the
151    /// vector store when ragging the tool.
152    /// This context can be used to store the tool's static configuration and local
153    /// context.
154    type Context: for<'a> Deserialize<'a> + Serialize;
155
156    /// Type of the tool's state. This state will be passed to the tool when initializing it.
157    /// This state can be used to pass runtime arguments to the tool such as clients,
158    /// API keys and other configuration.
159    type State: WasmCompatSend;
160
161    /// A method returning the documents that will be used as embeddings for the tool.
162    /// This allows for a tool to be retrieved from multiple embedding "directions".
163    /// If the tool will not be RAGged, this method should return an empty vector.
164    fn embedding_docs(&self) -> Vec<String>;
165
166    /// A method returning the context of the tool.
167    fn context(&self) -> Self::Context;
168
169    /// A method to initialize the tool from the context, and a state.
170    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
171}
172
173/// Wrapper trait to allow for dynamic dispatch of simple tools
174pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
175    /// Returns the tool name used for dispatch.
176    fn name(&self) -> String;
177
178    /// Returns the provider-facing tool schema.
179    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
180
181    /// Calls the tool with JSON-encoded arguments and returns model-facing text.
182    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
183}
184
185fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
186    match serde_json::to_value(output)? {
187        serde_json::Value::String(text) => Ok(text),
188        value => Ok(value.to_string()),
189    }
190}
191
192impl<T: Tool> ToolDyn for T {
193    fn name(&self) -> String {
194        self.name()
195    }
196
197    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
198        Box::pin(<Self as Tool>::definition(self, prompt))
199    }
200
201    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
202        Box::pin(async move {
203            // LLMs frequently send `null` for tools whose arguments are all optional.
204            // `serde_json::from_str::<T>("null")` fails for struct types even when
205            // every field is `Option<_>`, because JSON null does not deserialize to an
206            // empty object. Preserve any args type that already accepts `null` (such as
207            // `()` or `Option<T>`) and fall back to `{}` only after the original parse
208            // fails.
209            let args = match serde_json::from_str(&args) {
210                Ok(args) => Ok(args),
211                Err(err) if args.trim() == "null" => serde_json::from_str("{}").map_err(|_| err),
212                Err(err) => Err(err),
213            };
214            match args {
215                Ok(args) => <Self as Tool>::call(self, args)
216                    .await
217                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
218                    .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
219                Err(e) => Err(ToolError::JsonError(e)),
220            }
221        })
222    }
223}
224
225#[cfg(feature = "rmcp")]
226#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
227pub mod rmcp;
228
229/// Wrapper trait to allow for dynamic dispatch of raggable tools
230pub trait ToolEmbeddingDyn: ToolDyn {
231    /// Serializes context needed to reconstruct this dynamic tool.
232    fn context(&self) -> serde_json::Result<serde_json::Value>;
233
234    /// Returns text fragments used to retrieve this tool from a vector store.
235    fn embedding_docs(&self) -> Vec<String>;
236}
237
238impl<T> ToolEmbeddingDyn for T
239where
240    T: ToolEmbedding + 'static,
241{
242    fn context(&self) -> serde_json::Result<serde_json::Value> {
243        serde_json::to_value(self.context())
244    }
245
246    fn embedding_docs(&self) -> Vec<String> {
247        self.embedding_docs()
248    }
249}
250
251#[derive(Clone)]
252pub(crate) enum ToolType {
253    Simple(Arc<dyn ToolDyn>),
254    Embedding(Arc<dyn ToolEmbeddingDyn>),
255}
256
257impl ToolType {
258    pub fn name(&self) -> String {
259        match self {
260            ToolType::Simple(tool) => tool.name(),
261            ToolType::Embedding(tool) => tool.name(),
262        }
263    }
264
265    pub async fn definition(&self, prompt: String) -> ToolDefinition {
266        match self {
267            ToolType::Simple(tool) => tool.definition(prompt).await,
268            ToolType::Embedding(tool) => tool.definition(prompt).await,
269        }
270    }
271
272    pub async fn call(&self, args: String) -> Result<String, ToolError> {
273        match self {
274            ToolType::Simple(tool) => tool.call(args).await,
275            ToolType::Embedding(tool) => tool.call(args).await,
276        }
277    }
278}
279
280#[derive(Debug, thiserror::Error)]
281pub enum ToolSetError {
282    /// Error returned by the tool
283    #[error("ToolCallError: {0}")]
284    ToolCallError(#[from] ToolError),
285
286    /// Could not find a tool
287    #[error("ToolNotFoundError: {0}")]
288    ToolNotFoundError(String),
289
290    /// JSON serialization or deserialization failed while preparing tool data.
291    #[error("JsonError: {0}")]
292    JsonError(#[from] serde_json::Error),
293
294    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
295    #[error("Tool call interrupted")]
296    Interrupted,
297}
298
299/// A struct that holds a set of tools
300#[derive(Default)]
301pub struct ToolSet {
302    pub(crate) tools: HashMap<String, ToolType>,
303}
304
305impl ToolSet {
306    /// Create a new ToolSet from a list of tools
307    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
308        let mut toolset = Self::default();
309        tools.into_iter().for_each(|tool| {
310            toolset.add_tool(tool);
311        });
312        toolset
313    }
314
315    /// Create a new `ToolSet` from boxed dynamically-dispatched tools.
316    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
317        let mut toolset = Self::default();
318        tools.into_iter().for_each(|tool| {
319            toolset.add_tool_boxed(tool);
320        });
321        toolset
322    }
323
324    /// Create a toolset builder
325    pub fn builder() -> ToolSetBuilder {
326        ToolSetBuilder::default()
327    }
328
329    /// Check if the toolset contains a tool with the given name
330    pub fn contains(&self, toolname: &str) -> bool {
331        self.tools.contains_key(toolname)
332    }
333
334    /// Add a tool to the toolset
335    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
336        self.tools
337            .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
338    }
339
340    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
341    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
342        self.tools
343            .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
344    }
345
346    /// Remove a tool by name. Missing tools are ignored.
347    pub fn delete_tool(&mut self, tool_name: &str) {
348        let _ = self.tools.remove(tool_name);
349    }
350
351    /// Merge another toolset into this one
352    pub fn add_tools(&mut self, toolset: ToolSet) {
353        self.tools.extend(toolset.tools);
354    }
355
356    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
357        self.tools.get(toolname)
358    }
359
360    /// Return definitions for all tools currently registered in the set.
361    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
362        let mut defs = Vec::new();
363        for tool in self.tools.values() {
364            let def = tool.definition(String::new()).await;
365            defs.push(def);
366        }
367        Ok(defs)
368    }
369
370    /// Call a tool with the given name and arguments
371    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
372        if let Some(tool) = self.tools.get(toolname) {
373            tracing::debug!(target: "rig",
374                "Calling tool {toolname} with args:\n{}",
375                args
376            );
377            Ok(tool.call(args).await?)
378        } else {
379            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
380        }
381    }
382
383    /// Get the documents of all the tools in the toolset
384    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
385        let mut docs = Vec::new();
386        for tool in self.tools.values() {
387            match tool {
388                ToolType::Simple(tool) => {
389                    docs.push(completion::Document {
390                        id: tool.name(),
391                        text: format!(
392                            "\
393                            Tool: {}\n\
394                            Definition: \n\
395                            {}\
396                        ",
397                            tool.name(),
398                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
399                        ),
400                        additional_props: HashMap::new(),
401                    });
402                }
403                ToolType::Embedding(tool) => {
404                    docs.push(completion::Document {
405                        id: tool.name(),
406                        text: format!(
407                            "\
408                            Tool: {}\n\
409                            Definition: \n\
410                            {}\
411                        ",
412                            tool.name(),
413                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
414                        ),
415                        additional_props: HashMap::new(),
416                    });
417                }
418            }
419        }
420        Ok(docs)
421    }
422
423    /// Convert tools in self to objects of type ToolSchema.
424    /// This is necessary because when adding tools to the EmbeddingBuilder because all
425    /// documents added to the builder must all be of the same type.
426    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
427        self.tools
428            .values()
429            .filter_map(|tool_type| {
430                if let ToolType::Embedding(tool) = tool_type {
431                    Some(ToolSchema::try_from(&**tool))
432                } else {
433                    None
434                }
435            })
436            .collect::<Result<Vec<_>, _>>()
437    }
438}
439
440#[derive(Default)]
441/// Builder for constructing a [`ToolSet`] with static and dynamic tools.
442pub struct ToolSetBuilder {
443    tools: Vec<ToolType>,
444}
445
446impl ToolSetBuilder {
447    /// Add a regular tool that is always available when the set is used.
448    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
449        self.tools.push(ToolType::Simple(Arc::new(tool)));
450        self
451    }
452
453    /// Add a tool that can be represented as embeddings for dynamic retrieval.
454    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
455        self.tools.push(ToolType::Embedding(Arc::new(tool)));
456        self
457    }
458
459    /// Build the tool set, keyed by each tool's name.
460    pub fn build(self) -> ToolSet {
461        ToolSet {
462            tools: self
463                .tools
464                .into_iter()
465                .map(|tool| (tool.name(), tool))
466                .collect(),
467        }
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use crate::message::{DocumentSourceKind, ToolResultContent};
474    use crate::test_utils::{
475        MockExampleTool, MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool,
476        mock_math_toolset,
477    };
478    use serde_json::json;
479
480    use super::*;
481
482    fn get_test_toolset() -> ToolSet {
483        mock_math_toolset()
484    }
485
486    #[tokio::test]
487    async fn test_get_tool_definitions() {
488        let toolset = get_test_toolset();
489        let tools = toolset.get_tool_definitions().await.unwrap();
490        assert_eq!(tools.len(), 2);
491    }
492
493    #[test]
494    fn test_tool_deletion() {
495        let mut toolset = get_test_toolset();
496        assert_eq!(toolset.tools.len(), 2);
497        toolset.delete_tool("add");
498        assert!(!toolset.contains("add"));
499        assert_eq!(toolset.tools.len(), 1);
500    }
501
502    #[tokio::test]
503    async fn string_tool_outputs_are_preserved_verbatim() {
504        let mut toolset = ToolSet::default();
505        toolset.add_tool(MockStringOutputTool);
506
507        let output = toolset
508            .call("string_output", "{}".to_string())
509            .await
510            .expect("tool should succeed");
511
512        assert_eq!(output, "Hello\nWorld");
513    }
514
515    #[tokio::test]
516    async fn structured_string_tool_outputs_remain_parseable() {
517        let mut toolset = ToolSet::default();
518        toolset.add_tool(MockImageOutputTool);
519
520        let output = toolset
521            .call("image_output", "{}".to_string())
522            .await
523            .expect("tool should succeed");
524        let content = ToolResultContent::from_tool_output(output);
525
526        assert_eq!(content.len(), 1);
527        match content.first() {
528            ToolResultContent::Image(image) => {
529                assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
530                assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
531            }
532            other => panic!("expected image tool result content, got {other:?}"),
533        }
534    }
535
536    #[tokio::test]
537    async fn object_tool_outputs_still_serialize_as_json() {
538        let mut toolset = ToolSet::default();
539        toolset.add_tool(MockObjectOutputTool);
540
541        let output = toolset
542            .call("object_output", "{}".to_string())
543            .await
544            .expect("tool should succeed");
545
546        assert!(output.starts_with('{'));
547        assert_eq!(
548            serde_json::from_str::<serde_json::Value>(&output).unwrap(),
549            json!({
550                "status": "ok",
551                "count": 42
552            })
553        );
554    }
555
556    #[tokio::test]
557    async fn null_args_are_preserved_for_unit_args() {
558        let mut toolset = ToolSet::default();
559        toolset.add_tool(MockExampleTool);
560
561        let output = toolset
562            .call("example_tool", "null".to_string())
563            .await
564            .expect("unit args should accept null without object fallback");
565
566        assert_eq!(output, "Example answer");
567    }
568
569    // Struct-typed args with all-optional fields — serde rejects `null` for these
570    // even though the fields are optional. The normalization in `ToolDyn::call`
571    // falls back from `null` to `{}` so callers can omit the
572    // wrapping `Option<Args>` workaround.
573    #[tokio::test]
574    async fn null_args_are_normalized_to_empty_object() {
575        use crate::test_utils::MockToolError;
576
577        #[derive(serde::Deserialize, serde::Serialize)]
578        struct NoRequiredArgs {
579            label: Option<String>,
580        }
581
582        struct NoArgTool;
583
584        impl Tool for NoArgTool {
585            const NAME: &'static str = "no_arg_tool";
586            type Error = MockToolError;
587            type Args = NoRequiredArgs;
588            type Output = String;
589
590            async fn definition(&self, _prompt: String) -> ToolDefinition {
591                ToolDefinition {
592                    name: Self::NAME.to_string(),
593                    description: "Tool with no required arguments".to_string(),
594                    parameters: json!({"type": "object", "properties": {}}),
595                }
596            }
597
598            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
599                Ok(args.label.unwrap_or_else(|| "default".to_string()))
600            }
601        }
602
603        let mut toolset = ToolSet::default();
604        toolset.add_tool(NoArgTool);
605
606        // `null` is what LLMs send when no arguments are provided; without the
607        // normalization this would return `ToolError::JsonError`.
608        let output = toolset
609            .call("no_arg_tool", "null".to_string())
610            .await
611            .expect("null args should succeed after normalisation");
612
613        assert_eq!(output, "default");
614    }
615}