Skip to main content

rig_cat/tool/
mod.rs

1//! Tool trait: functions that agents can invoke.
2
3use comp_cat_rs::effect::io::Io;
4use serde_json::Value;
5
6use crate::error::Error;
7
8/// A tool definition: metadata sent to the LLM so it knows
9/// what tools are available and how to call them.
10#[derive(Debug, Clone)]
11pub struct ToolDefinition {
12    name: String,
13    description: String,
14    parameters_schema: Value,
15}
16
17impl ToolDefinition {
18    #[must_use]
19    pub fn new(name: String, description: String, parameters_schema: Value) -> Self {
20        Self { name, description, parameters_schema }
21    }
22
23    #[must_use]
24    pub fn name(&self) -> &str { &self.name }
25
26    #[must_use]
27    pub fn description(&self) -> &str { &self.description }
28
29    #[must_use]
30    pub fn parameters_schema(&self) -> &Value { &self.parameters_schema }
31
32    /// Format as the JSON structure expected by LLM tool-calling APIs.
33    #[must_use]
34    pub fn to_api_json(&self) -> Value {
35        serde_json::json!({
36            "type": "function",
37            "function": {
38                "name": self.name,
39                "description": self.description,
40                "parameters": self.parameters_schema
41            }
42        })
43    }
44}
45
46/// A tool that an agent can call during reasoning.
47///
48/// Tools receive JSON arguments and return a JSON result,
49/// wrapped in `Io` for effect tracking.
50pub trait Tool {
51    /// The tool's definition (name, description, schema).
52    fn definition(&self) -> ToolDefinition;
53
54    /// Execute the tool with the given JSON arguments.
55    fn call(&self, args: Value) -> Io<Error, Value>;
56}
57
58/// A collection of tools available to an agent.
59///
60/// Generic over the tool type `T` to avoid `dyn Trait`.
61/// All tools in a toolbox must be the same concrete type.
62/// For heterogeneous tools, use an enum that implements `Tool`.
63///
64/// ```rust,ignore
65/// enum MyTools {
66///     Calculator(Calculator),
67///     WebSearch(WebSearch),
68/// }
69///
70/// impl Tool for MyTools {
71///     fn definition(&self) -> ToolDefinition {
72///         match self {
73///             Self::Calculator(t) => t.definition(),
74///             Self::WebSearch(t) => t.definition(),
75///         }
76///     }
77///     fn call(&self, args: Value) -> Io<Error, Value> {
78///         match self {
79///             Self::Calculator(t) => t.call(args),
80///             Self::WebSearch(t) => t.call(args),
81///         }
82///     }
83/// }
84/// ```
85pub struct Toolbox<T: Tool> {
86    tools: Vec<T>,
87}
88
89impl<T: Tool> Toolbox<T> {
90    /// Create an empty toolbox.
91    #[must_use]
92    pub fn new() -> Self { Self { tools: Vec::new() } }
93
94    /// Add a tool to the toolbox.
95    #[must_use]
96    pub fn with_tool(self, tool: T) -> Self {
97        Self {
98            tools: self.tools.into_iter().chain(std::iter::once(tool)).collect(),
99        }
100    }
101
102    /// Look up a tool by name and invoke it.
103    #[must_use]
104    pub fn invoke(&self, name: &str, args: Value) -> Io<Error, Value> {
105        let name_owned = name.to_owned();
106        self.tools.iter()
107            .find(|t| t.definition().name() == name_owned)
108            .map_or_else(
109                move || Io::suspend(move || {
110                    Err(Error::Config { field: format!("unknown tool: {name_owned}") })
111                }),
112                |t| t.call(args),
113            )
114    }
115
116    /// Get tool definitions for sending to the LLM.
117    #[must_use]
118    pub fn definitions(&self) -> Vec<Value> {
119        self.tools.iter().map(|t| t.definition().to_api_json()).collect()
120    }
121}
122
123impl<T: Tool> Default for Toolbox<T> {
124    fn default() -> Self { Self::new() }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    struct FakeTool {
132        name: String,
133    }
134
135    impl Tool for FakeTool {
136        fn definition(&self) -> ToolDefinition {
137            ToolDefinition::new(
138                self.name.clone(),
139                "a fake tool".into(),
140                serde_json::json!({"type": "object"}),
141            )
142        }
143
144        fn call(&self, _args: Value) -> Io<Error, Value> {
145            Io::pure(serde_json::json!({"result": "ok"}))
146        }
147    }
148
149    #[test]
150    fn invoke_known_tool_succeeds() -> Result<(), Error> {
151        let toolbox = Toolbox::new()
152            .with_tool(FakeTool { name: "greet".into() });
153        let result = toolbox.invoke("greet", serde_json::json!({})).run()?;
154        assert_eq!(result, serde_json::json!({"result": "ok"}));
155        Ok(())
156    }
157
158    #[test]
159    fn invoke_unknown_tool_returns_error() {
160        let toolbox: Toolbox<FakeTool> = Toolbox::new();
161        let result = toolbox.invoke("nonexistent", serde_json::json!({})).run();
162        assert!(result.is_err());
163    }
164
165    #[test]
166    fn definitions_lists_all_tools() {
167        let toolbox = Toolbox::new()
168            .with_tool(FakeTool { name: "a".into() })
169            .with_tool(FakeTool { name: "b".into() });
170        let defs = toolbox.definitions();
171        assert_eq!(defs.len(), 2);
172    }
173}