Skip to main content

wasm_agent/
tools.rs

1// SPDX-License-Identifier: MIT
2//! Tool registry and dispatch for WASM-safe function signatures.
3
4use std::collections::HashMap;
5use crate::error::AgentError;
6use crate::types::ToolResult;
7
8/// A tool handler: takes a JSON string input, returns a [`ToolResult`].
9///
10/// Uses `Box<dyn Fn>` for WASM-compatible dispatch — no async, no `Send`
11/// requirement in single-threaded WASM environments.
12pub type ToolFn = Box<dyn Fn(&str) -> ToolResult>;
13
14/// Metadata describing a tool for system-prompt generation and introspection.
15#[derive(Debug, Clone)]
16pub struct ToolSpec {
17    /// Unique name used to invoke the tool from the ReAct loop.
18    pub name: String,
19    /// Human-readable description included in the agent's system prompt.
20    pub description: String,
21    /// JSON Schema string describing the tool's input format.
22    pub input_schema: String,
23}
24
25impl ToolSpec {
26    /// Creates a new [`ToolSpec`].
27    ///
28    /// # Arguments
29    /// * `name` — Unique tool name (must be non-empty).
30    /// * `description` — Human-readable description.
31    /// * `schema` — JSON Schema for the tool's input.
32    pub fn new(name: impl Into<String>, description: impl Into<String>, schema: impl Into<String>) -> Self {
33        Self { name: name.into(), description: description.into(), input_schema: schema.into() }
34    }
35}
36
37/// Registry of available tools, keyed by name.
38pub struct ToolRegistry {
39    specs: HashMap<String, ToolSpec>,
40    handlers: HashMap<String, ToolFn>,
41}
42
43impl ToolRegistry {
44    /// Creates an empty registry.
45    pub fn new() -> Self {
46        Self { specs: HashMap::new(), handlers: HashMap::new() }
47    }
48
49    /// Registers a tool with its specification and handler.
50    ///
51    /// # Errors
52    /// Returns [`AgentError::InvalidToolSignature`] if the tool name is empty.
53    pub fn register(&mut self, spec: ToolSpec, handler: ToolFn) -> Result<(), AgentError> {
54        if spec.name.is_empty() {
55            return Err(AgentError::InvalidToolSignature("tool name cannot be empty".into()));
56        }
57        self.specs.insert(spec.name.clone(), spec.clone());
58        self.handlers.insert(spec.name, handler);
59        Ok(())
60    }
61
62    /// Dispatches a tool call by name with the given JSON input string.
63    ///
64    /// # Errors
65    /// Returns [`AgentError::ToolNotFound`] if no tool with that name is registered.
66    pub fn dispatch(&self, tool_name: &str, input: &str) -> Result<ToolResult, AgentError> {
67        let handler = self.handlers.get(tool_name)
68            .ok_or_else(|| AgentError::ToolNotFound { name: tool_name.to_string() })?;
69        Ok(handler(input))
70    }
71
72    /// Returns the spec for a registered tool, or `None` if not found.
73    pub fn spec(&self, name: &str) -> Option<&ToolSpec> { self.specs.get(name) }
74
75    /// Returns the number of registered tools.
76    pub fn tool_count(&self) -> usize { self.specs.len() }
77
78    /// Returns the names of all registered tools.
79    pub fn tool_names(&self) -> Vec<&str> { self.specs.keys().map(|s| s.as_str()).collect() }
80
81    /// Generates a system-prompt snippet listing all registered tools.
82    pub fn tools_prompt(&self) -> String {
83        let mut lines = vec!["Available tools:".to_string()];
84        for spec in self.specs.values() {
85            lines.push(format!("- {}: {}", spec.name, spec.description));
86        }
87        lines.join("\n")
88    }
89}
90
91impl Default for ToolRegistry {
92    fn default() -> Self { Self::new() }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    fn echo_tool() -> (ToolSpec, ToolFn) {
100        let spec = ToolSpec::new("echo", "Echoes input back", r#"{"type":"string"}"#);
101        let handler: ToolFn = Box::new(|input: &str| ToolResult {
102            tool_name: "echo".into(),
103            output: format!("echo: {input}"),
104            success: true,
105        });
106        (spec, handler)
107    }
108
109    #[test]
110    fn test_registry_register_and_dispatch_ok() {
111        let mut reg = ToolRegistry::new();
112        let (spec, handler) = echo_tool();
113        reg.register(spec, handler).unwrap();
114        let result = reg.dispatch("echo", "hello").unwrap();
115        assert!(result.success);
116        assert_eq!(result.output, "echo: hello");
117    }
118
119    #[test]
120    fn test_registry_dispatch_unknown_tool_returns_error() {
121        let reg = ToolRegistry::new();
122        let err = reg.dispatch("nonexistent", "").unwrap_err();
123        assert!(matches!(err, AgentError::ToolNotFound { .. }));
124    }
125
126    #[test]
127    fn test_registry_register_empty_name_returns_error() {
128        let mut reg = ToolRegistry::new();
129        let spec = ToolSpec::new("", "bad", "{}");
130        let err = reg.register(spec, Box::new(|_| ToolResult {
131            tool_name: "".into(), output: "".into(), success: false,
132        })).unwrap_err();
133        assert!(matches!(err, AgentError::InvalidToolSignature(_)));
134    }
135
136    #[test]
137    fn test_registry_tool_count_increments() {
138        let mut reg = ToolRegistry::new();
139        assert_eq!(reg.tool_count(), 0);
140        let (spec, handler) = echo_tool();
141        reg.register(spec, handler).unwrap();
142        assert_eq!(reg.tool_count(), 1);
143    }
144
145    #[test]
146    fn test_registry_tools_prompt_contains_tool_name() {
147        let mut reg = ToolRegistry::new();
148        let (spec, handler) = echo_tool();
149        reg.register(spec, handler).unwrap();
150        assert!(reg.tools_prompt().contains("echo"));
151    }
152
153    #[test]
154    fn test_registry_spec_retrieval_present_and_absent() {
155        let mut reg = ToolRegistry::new();
156        let (spec, handler) = echo_tool();
157        reg.register(spec, handler).unwrap();
158        assert!(reg.spec("echo").is_some());
159        assert!(reg.spec("missing").is_none());
160    }
161
162    #[test]
163    fn test_registry_tool_names_lists_all() {
164        let mut reg = ToolRegistry::new();
165        let (spec, handler) = echo_tool();
166        reg.register(spec, handler).unwrap();
167        let names = reg.tool_names();
168        assert!(names.contains(&"echo"));
169    }
170
171    #[test]
172    fn test_registry_default_is_empty() {
173        let reg = ToolRegistry::default();
174        assert_eq!(reg.tool_count(), 0);
175    }
176}