Skip to main content

rig_core/tool/
mod.rs

1//! Module defining tool related structs and traits.
2//!
3//! The [Tool] trait defines a simple interface for creating tools that can be used
4//! by [Agents](crate::agent::Agent).
5//!
6//! The [ToolEmbedding] trait extends the [Tool] trait to allow for tools that can be
7//! stored in a vector store and RAGged.
8//!
9//! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent)
10//! and optionally RAGged.
11
12pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use indexmap::IndexMap;
19use serde::{Deserialize, Serialize};
20
21use crate::{
22    completion::{self, ToolDefinition},
23    embeddings::{embed::EmbedError, tool::ToolSchema},
24    wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27#[derive(Debug, thiserror::Error)]
28pub enum ToolError {
29    #[cfg(not(target_family = "wasm"))]
30    /// Error returned by the tool
31    ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
32
33    #[cfg(target_family = "wasm")]
34    /// Error returned by the tool
35    ToolCallError(#[from] Box<dyn std::error::Error>),
36    /// Error caused by a de/serialization fail
37    JsonError(#[from] serde_json::Error),
38}
39
40impl fmt::Display for ToolError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            ToolError::ToolCallError(e) => {
44                let error_str = e.to_string();
45                // This is required due to being able to use agents as tools
46                // which means it is possible to get recursive tool call errors
47                if error_str.starts_with("ToolCallError: ") {
48                    write!(f, "{}", error_str)
49                } else {
50                    write!(f, "ToolCallError: {}", error_str)
51                }
52            }
53            ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
54        }
55    }
56}
57
58/// Trait that represents a simple LLM tool
59///
60/// # Example
61/// ```
62/// use rig_core::{
63///     completion::ToolDefinition,
64///     tool::{ToolSet, Tool},
65/// };
66///
67/// #[derive(serde::Deserialize)]
68/// struct AddArgs {
69///     x: i32,
70///     y: i32,
71/// }
72///
73/// #[derive(Debug, thiserror::Error)]
74/// #[error("Math error")]
75/// struct MathError;
76///
77/// #[derive(serde::Deserialize, serde::Serialize)]
78/// struct Adder;
79///
80/// impl Tool for Adder {
81///     const NAME: &'static str = "add";
82///
83///     type Error = MathError;
84///     type Args = AddArgs;
85///     type Output = i32;
86///
87///     async fn definition(&self, _prompt: String) -> ToolDefinition {
88///         ToolDefinition {
89///             name: "add".to_string(),
90///             description: "Add x and y together".to_string(),
91///             parameters: serde_json::json!({
92///                 "type": "object",
93///                 "properties": {
94///                     "x": {
95///                         "type": "number",
96///                         "description": "The first number to add"
97///                     },
98///                     "y": {
99///                         "type": "number",
100///                         "description": "The second number to add"
101///                     }
102///                 }
103///             })
104///         }
105///     }
106///
107///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
108///         let result = args.x + args.y;
109///         Ok(result)
110///     }
111/// }
112/// ```
113pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
114    /// The name of the tool. This name should be unique within a single
115    /// [`ToolSet`] or other registration scope that dispatches tools by name.
116    const NAME: &'static str;
117
118    /// The error type of the tool.
119    type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
120    /// The arguments type of the tool.
121    type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
122    /// The output type of the tool.
123    type Output: Serialize;
124
125    /// A method returning the name of the tool.
126    fn name(&self) -> String {
127        Self::NAME.to_string()
128    }
129
130    /// A method returning the tool definition. The user prompt can be used to
131    /// tailor the definition to the specific use case.
132    fn definition(
133        &self,
134        _prompt: String,
135    ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
136
137    /// The tool execution method.
138    /// Both the arguments and return value are a String since these values are meant to
139    /// be the output and input of LLM models (respectively)
140    fn call(
141        &self,
142        args: Self::Args,
143    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
144}
145
146/// Trait that represents an LLM tool that can be stored in a vector store and RAGged
147pub trait ToolEmbedding: Tool {
148    /// Error returned when reconstructing a dynamic tool from stored context.
149    type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
150
151    /// Type of the tool' context. This context will be saved and loaded from the
152    /// vector store when ragging the tool.
153    /// This context can be used to store the tool's static configuration and local
154    /// context.
155    type Context: for<'a> Deserialize<'a> + Serialize;
156
157    /// Type of the tool's state. This state will be passed to the tool when initializing it.
158    /// This state can be used to pass runtime arguments to the tool such as clients,
159    /// API keys and other configuration.
160    type State: WasmCompatSend;
161
162    /// A method returning the documents that will be used as embeddings for the tool.
163    /// This allows for a tool to be retrieved from multiple embedding "directions".
164    /// If the tool will not be RAGged, this method should return an empty vector.
165    fn embedding_docs(&self) -> Vec<String>;
166
167    /// A method returning the context of the tool.
168    fn context(&self) -> Self::Context;
169
170    /// A method to initialize the tool from the context, and a state.
171    fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
172}
173
174/// Wrapper trait to allow for dynamic dispatch of simple tools
175pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
176    /// Returns the tool name used for dispatch.
177    fn name(&self) -> String;
178
179    /// Returns the provider-facing tool schema.
180    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
181
182    /// Calls the tool with JSON-encoded arguments and returns model-facing text.
183    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
184}
185
186fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
187    match serde_json::to_value(output)? {
188        serde_json::Value::String(text) => Ok(text),
189        value => Ok(value.to_string()),
190    }
191}
192
193impl<T: Tool> ToolDyn for T {
194    fn name(&self) -> String {
195        self.name()
196    }
197
198    fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
199        Box::pin(<Self as Tool>::definition(self, prompt))
200    }
201
202    fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
203        Box::pin(async move {
204            // LLMs frequently send `null` for tools whose arguments are all optional.
205            // `serde_json::from_str::<T>("null")` fails for struct types even when
206            // every field is `Option<_>`, because JSON null does not deserialize to an
207            // empty object. Preserve any args type that already accepts `null` (such as
208            // `()` or `Option<T>`) and fall back to `{}` only after the original parse
209            // fails.
210            let args = match serde_json::from_str(&args) {
211                Ok(args) => Ok(args),
212                Err(err) if args.trim() == "null" => serde_json::from_str("{}").map_err(|_| err),
213                Err(err) => Err(err),
214            };
215            match args {
216                Ok(args) => <Self as Tool>::call(self, args)
217                    .await
218                    .map_err(|e| ToolError::ToolCallError(Box::new(e)))
219                    .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
220                Err(e) => Err(ToolError::JsonError(e)),
221            }
222        })
223    }
224}
225
226#[cfg(feature = "rmcp")]
227#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
228pub mod rmcp;
229
230/// Wrapper trait to allow for dynamic dispatch of raggable tools
231pub trait ToolEmbeddingDyn: ToolDyn {
232    /// Serializes context needed to reconstruct this dynamic tool.
233    fn context(&self) -> serde_json::Result<serde_json::Value>;
234
235    /// Returns text fragments used to retrieve this tool from a vector store.
236    fn embedding_docs(&self) -> Vec<String>;
237}
238
239impl<T> ToolEmbeddingDyn for T
240where
241    T: ToolEmbedding + 'static,
242{
243    fn context(&self) -> serde_json::Result<serde_json::Value> {
244        serde_json::to_value(self.context())
245    }
246
247    fn embedding_docs(&self) -> Vec<String> {
248        self.embedding_docs()
249    }
250}
251
252#[derive(Clone)]
253pub(crate) enum ToolType {
254    Simple(Arc<dyn ToolDyn>),
255    Embedding(Arc<dyn ToolEmbeddingDyn>),
256}
257
258impl ToolType {
259    pub fn name(&self) -> String {
260        match self {
261            ToolType::Simple(tool) => tool.name(),
262            ToolType::Embedding(tool) => tool.name(),
263        }
264    }
265
266    pub async fn definition(&self, prompt: String) -> ToolDefinition {
267        match self {
268            ToolType::Simple(tool) => tool.definition(prompt).await,
269            ToolType::Embedding(tool) => tool.definition(prompt).await,
270        }
271    }
272
273    pub async fn call(&self, args: String) -> Result<String, ToolError> {
274        match self {
275            ToolType::Simple(tool) => tool.call(args).await,
276            ToolType::Embedding(tool) => tool.call(args).await,
277        }
278    }
279}
280
281#[derive(Debug, thiserror::Error)]
282pub enum ToolSetError {
283    /// Error returned by the tool
284    #[error("ToolCallError: {0}")]
285    ToolCallError(#[from] ToolError),
286
287    /// Could not find a tool
288    #[error("ToolNotFoundError: {0}")]
289    ToolNotFoundError(String),
290
291    /// JSON serialization or deserialization failed while preparing tool data.
292    #[error("JsonError: {0}")]
293    JsonError(#[from] serde_json::Error),
294
295    /// Tool call was interrupted. Primarily useful for agent multi-step/turn prompting.
296    #[error("Tool call interrupted")]
297    Interrupted,
298}
299
300/// A struct that holds a set of tools.
301///
302/// Tools are stored in an [`IndexMap`] keyed by name, so iteration
303/// (definitions, documents, schemas) follows registration order and the tool
304/// list sent to providers is deterministic across processes. Re-registering an
305/// existing name replaces the implementation but keeps its original position.
306#[derive(Default)]
307pub struct ToolSet {
308    pub(crate) tools: IndexMap<String, ToolType>,
309}
310
311impl ToolSet {
312    /// Create a new ToolSet from a list of tools
313    pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
314        let mut toolset = Self::default();
315        tools.into_iter().for_each(|tool| {
316            toolset.add_tool(tool);
317        });
318        toolset
319    }
320
321    /// Create a new `ToolSet` from boxed dynamically-dispatched tools.
322    pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
323        let mut toolset = Self::default();
324        tools.into_iter().for_each(|tool| {
325            toolset.add_tool_boxed(tool);
326        });
327        toolset
328    }
329
330    /// Create a toolset builder
331    pub fn builder() -> ToolSetBuilder {
332        ToolSetBuilder::default()
333    }
334
335    /// Check if the toolset contains a tool with the given name
336    pub fn contains(&self, toolname: &str) -> bool {
337        self.tools.contains_key(toolname)
338    }
339
340    /// Add a tool to the toolset
341    pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
342        self.insert(ToolType::Simple(Arc::new(tool)));
343    }
344
345    /// Adds a boxed tool to the toolset. Useful for situations when dynamic dispatch is required.
346    pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
347        self.insert(ToolType::Simple(Arc::from(tool)));
348    }
349
350    pub(crate) fn insert(&mut self, tool: ToolType) {
351        let name = tool.name();
352        // `IndexMap::insert` replaces the value while keeping the existing
353        // slot position, and returns the previous value when the name was
354        // already registered.
355        if self.tools.insert(name.clone(), tool).is_some() {
356            tracing::warn!(
357                tool_name = %name,
358                "a tool named {name:?} was already registered; replacing it with the new registration"
359            );
360        }
361    }
362
363    /// Remove a tool by name. Missing tools are ignored.
364    pub fn delete_tool(&mut self, tool_name: &str) {
365        // `shift_remove` preserves the order of the remaining tools;
366        // `swap_remove` would not.
367        self.tools.shift_remove(tool_name);
368    }
369
370    /// Merge another toolset into this one. Tools keep `toolset`'s
371    /// registration order; names that already exist are replaced in place.
372    pub fn add_tools(&mut self, toolset: ToolSet) {
373        for (_, tool) in toolset.tools {
374            self.insert(tool);
375        }
376    }
377
378    pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
379        self.tools.get(toolname)
380    }
381
382    /// Tool names in registration order.
383    pub(crate) fn ordered_names(&self) -> impl Iterator<Item = &String> {
384        self.tools.keys()
385    }
386
387    /// Tools in registration order.
388    fn ordered_tools(&self) -> impl Iterator<Item = &ToolType> {
389        self.tools.values()
390    }
391
392    /// Return definitions for all tools currently registered in the set, in
393    /// registration order.
394    pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
395        let mut defs = Vec::new();
396        for tool in self.ordered_tools() {
397            let def = tool.definition(String::new()).await;
398            defs.push(def);
399        }
400        Ok(defs)
401    }
402
403    /// Call a tool with the given name and arguments
404    pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
405        if let Some(tool) = self.tools.get(toolname) {
406            tracing::debug!(target: "rig",
407                "Calling tool {toolname} with args:\n{}",
408                args
409            );
410            Ok(tool.call(args).await?)
411        } else {
412            Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
413        }
414    }
415
416    /// Get the documents of all the tools in the toolset
417    pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
418        let mut docs = Vec::new();
419        for tool in self.ordered_tools() {
420            match tool {
421                ToolType::Simple(tool) => {
422                    docs.push(completion::Document {
423                        id: tool.name(),
424                        text: format!(
425                            "\
426                            Tool: {}\n\
427                            Definition: \n\
428                            {}\
429                        ",
430                            tool.name(),
431                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
432                        ),
433                        additional_props: HashMap::new(),
434                    });
435                }
436                ToolType::Embedding(tool) => {
437                    docs.push(completion::Document {
438                        id: tool.name(),
439                        text: format!(
440                            "\
441                            Tool: {}\n\
442                            Definition: \n\
443                            {}\
444                        ",
445                            tool.name(),
446                            serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
447                        ),
448                        additional_props: HashMap::new(),
449                    });
450                }
451            }
452        }
453        Ok(docs)
454    }
455
456    /// Convert tools in self to objects of type ToolSchema.
457    /// This is necessary because when adding tools to the EmbeddingBuilder because all
458    /// documents added to the builder must all be of the same type.
459    pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
460        self.ordered_tools()
461            .filter_map(|tool_type| {
462                if let ToolType::Embedding(tool) = tool_type {
463                    Some(ToolSchema::try_from(&**tool))
464                } else {
465                    None
466                }
467            })
468            .collect::<Result<Vec<_>, _>>()
469    }
470}
471
472#[derive(Default)]
473/// Builder for constructing a [`ToolSet`] with static and dynamic tools.
474pub struct ToolSetBuilder {
475    tools: Vec<ToolType>,
476}
477
478impl ToolSetBuilder {
479    /// Add a regular tool that is always available when the set is used.
480    pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
481        self.tools.push(ToolType::Simple(Arc::new(tool)));
482        self
483    }
484
485    /// Add a tool that can be represented as embeddings for dynamic retrieval.
486    pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
487        self.tools.push(ToolType::Embedding(Arc::new(tool)));
488        self
489    }
490
491    /// Build the tool set, keyed by each tool's name.
492    pub fn build(self) -> ToolSet {
493        let mut toolset = ToolSet::default();
494        for tool in self.tools {
495            toolset.insert(tool);
496        }
497        toolset
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use crate::message::{DocumentSourceKind, ToolResultContent};
504    use crate::test_utils::{
505        MockExampleTool, MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool,
506        mock_math_toolset,
507    };
508    use serde_json::json;
509
510    use super::*;
511
512    fn get_test_toolset() -> ToolSet {
513        mock_math_toolset()
514    }
515
516    #[tokio::test]
517    async fn test_get_tool_definitions() {
518        let toolset = get_test_toolset();
519        let tools = toolset.get_tool_definitions().await.unwrap();
520        assert_eq!(tools.len(), 2);
521    }
522
523    #[test]
524    fn test_tool_deletion() {
525        let mut toolset = get_test_toolset();
526        assert_eq!(toolset.tools.len(), 2);
527        toolset.delete_tool("add");
528        assert!(!toolset.contains("add"));
529        assert_eq!(toolset.tools.len(), 1);
530        assert_eq!(
531            toolset.ordered_names().cloned().collect::<Vec<_>>(),
532            vec!["subtract".to_string()]
533        );
534    }
535
536    #[test]
537    fn deleting_a_middle_tool_preserves_order_of_survivors() {
538        // Guards the `shift_remove` (not `swap_remove`) choice in `delete_tool`.
539        // `swap_remove` would move the last tool into the deleted slot, so this
540        // only catches a regression with 3+ tools and a non-last deletion: here
541        // a `swap_remove("beta")` would yield [alpha, delta, gamma].
542        let mut toolset = ToolSet::default();
543        for name in ["alpha", "beta", "gamma", "delta"] {
544            toolset.add_tool(named_tool(name, "test tool"));
545        }
546
547        toolset.delete_tool("beta");
548
549        assert_eq!(
550            toolset.ordered_names().cloned().collect::<Vec<_>>(),
551            vec![
552                "alpha".to_string(),
553                "gamma".to_string(),
554                "delta".to_string()
555            ],
556            "survivors must keep their registration order after a middle deletion"
557        );
558    }
559
560    /// A tool whose name and definition are chosen at runtime, for ordering
561    /// and duplicate-registration tests.
562    struct NamedTool {
563        name: String,
564        description: String,
565    }
566
567    impl ToolDyn for NamedTool {
568        fn name(&self) -> String {
569            self.name.clone()
570        }
571
572        fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
573            Box::pin(async move {
574                ToolDefinition {
575                    name: self.name.clone(),
576                    description: self.description.clone(),
577                    parameters: json!({ "type": "object", "properties": {} }),
578                }
579            })
580        }
581
582        fn call(&self, _args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
583            let output = format!("called {}", self.description);
584            Box::pin(async move { Ok(output) })
585        }
586    }
587
588    fn named_tool(name: &str, description: &str) -> NamedTool {
589        NamedTool {
590            name: name.to_string(),
591            description: description.to_string(),
592        }
593    }
594
595    #[tokio::test]
596    async fn tool_definitions_follow_registration_order() {
597        // Enough names that any non-order-preserving storage would almost
598        // surely surface a regression: its iteration order would differ from
599        // insertion order.
600        let names: Vec<String> = (0..32).map(|i| format!("tool_{i:02}")).collect();
601        let mut toolset = ToolSet::default();
602        for name in &names {
603            toolset.add_tool(named_tool(name, "test tool"));
604        }
605
606        let defs = toolset.get_tool_definitions().await.unwrap();
607        let def_names: Vec<String> = defs.into_iter().map(|def| def.name).collect();
608        assert_eq!(def_names, names);
609
610        let docs = toolset.documents().await.unwrap();
611        let doc_ids: Vec<String> = docs.into_iter().map(|doc| doc.id).collect();
612        assert_eq!(doc_ids, names);
613    }
614
615    #[tokio::test]
616    async fn duplicate_registration_replaces_in_place() {
617        let mut toolset = ToolSet::default();
618        toolset.add_tool(named_tool("alpha", "first alpha"));
619        toolset.add_tool(named_tool("beta", "beta"));
620        toolset.add_tool(named_tool("alpha", "second alpha"));
621
622        let defs = toolset.get_tool_definitions().await.unwrap();
623        assert_eq!(
624            defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>(),
625            vec!["alpha", "beta"],
626            "the duplicate should be deduped and keep its original position"
627        );
628        assert_eq!(
629            defs[0].description, "second alpha",
630            "the last registration should win"
631        );
632
633        let output = toolset.call("alpha", "{}".to_string()).await.unwrap();
634        assert_eq!(output, "called second alpha");
635    }
636
637    #[tokio::test]
638    async fn add_tools_merges_in_order_and_replaces_existing() {
639        let mut base = ToolSet::default();
640        base.add_tool(named_tool("alpha", "base alpha"));
641        base.add_tool(named_tool("beta", "base beta"));
642
643        let mut incoming = ToolSet::default();
644        incoming.add_tool(named_tool("gamma", "incoming gamma"));
645        incoming.add_tool(named_tool("alpha", "incoming alpha"));
646
647        base.add_tools(incoming);
648
649        let defs = base.get_tool_definitions().await.unwrap();
650        assert_eq!(
651            defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>(),
652            vec!["alpha", "beta", "gamma"],
653            "merged tools should follow registration order with replaced names keeping position"
654        );
655        assert_eq!(defs[0].description, "incoming alpha");
656    }
657
658    #[tokio::test]
659    async fn string_tool_outputs_are_preserved_verbatim() {
660        let mut toolset = ToolSet::default();
661        toolset.add_tool(MockStringOutputTool);
662
663        let output = toolset
664            .call("string_output", "{}".to_string())
665            .await
666            .expect("tool should succeed");
667
668        assert_eq!(output, "Hello\nWorld");
669    }
670
671    #[tokio::test]
672    async fn structured_string_tool_outputs_remain_parseable() {
673        let mut toolset = ToolSet::default();
674        toolset.add_tool(MockImageOutputTool);
675
676        let output = toolset
677            .call("image_output", "{}".to_string())
678            .await
679            .expect("tool should succeed");
680        let content = ToolResultContent::from_tool_output(output);
681
682        assert_eq!(content.len(), 1);
683        match content.first() {
684            ToolResultContent::Image(image) => {
685                assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
686                assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
687            }
688            other => panic!("expected image tool result content, got {other:?}"),
689        }
690    }
691
692    #[tokio::test]
693    async fn object_tool_outputs_still_serialize_as_json() {
694        let mut toolset = ToolSet::default();
695        toolset.add_tool(MockObjectOutputTool);
696
697        let output = toolset
698            .call("object_output", "{}".to_string())
699            .await
700            .expect("tool should succeed");
701
702        assert!(output.starts_with('{'));
703        assert_eq!(
704            serde_json::from_str::<serde_json::Value>(&output).unwrap(),
705            json!({
706                "status": "ok",
707                "count": 42
708            })
709        );
710    }
711
712    #[tokio::test]
713    async fn null_args_are_preserved_for_unit_args() {
714        let mut toolset = ToolSet::default();
715        toolset.add_tool(MockExampleTool);
716
717        let output = toolset
718            .call("example_tool", "null".to_string())
719            .await
720            .expect("unit args should accept null without object fallback");
721
722        assert_eq!(output, "Example answer");
723    }
724
725    // Struct-typed args with all-optional fields — serde rejects `null` for these
726    // even though the fields are optional. The normalization in `ToolDyn::call`
727    // falls back from `null` to `{}` so callers can omit the
728    // wrapping `Option<Args>` workaround.
729    #[tokio::test]
730    async fn null_args_are_normalized_to_empty_object() {
731        use crate::test_utils::MockToolError;
732
733        #[derive(serde::Deserialize, serde::Serialize)]
734        struct NoRequiredArgs {
735            label: Option<String>,
736        }
737
738        struct NoArgTool;
739
740        impl Tool for NoArgTool {
741            const NAME: &'static str = "no_arg_tool";
742            type Error = MockToolError;
743            type Args = NoRequiredArgs;
744            type Output = String;
745
746            async fn definition(&self, _prompt: String) -> ToolDefinition {
747                ToolDefinition {
748                    name: Self::NAME.to_string(),
749                    description: "Tool with no required arguments".to_string(),
750                    parameters: json!({"type": "object", "properties": {}}),
751                }
752            }
753
754            async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
755                Ok(args.label.unwrap_or_else(|| "default".to_string()))
756            }
757        }
758
759        let mut toolset = ToolSet::default();
760        toolset.add_tool(NoArgTool);
761
762        // `null` is what LLMs send when no arguments are provided; without the
763        // normalization this would return `ToolError::JsonError`.
764        let output = toolset
765            .call("no_arg_tool", "null".to_string())
766            .await
767            .expect("null args should succeed after normalisation");
768
769        assert_eq!(output, "default");
770    }
771}