use std::collections::HashMap;
use super::{
rust_tool::{ErasedTool, RustTool, definition_of},
types::{ToolContext, ToolDefinition, ToolError, ToolOutput},
};
struct RegisteredTool {
definition: ToolDefinition,
erased: Box<dyn ErasedTool>,
}
pub struct ToolRegistry {
tools: HashMap<&'static str, RegisteredTool>,
}
impl std::fmt::Debug for ToolRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names: Vec<&str> = self
.tools
.values()
.map(|r| r.definition.name.as_str())
.collect();
f.debug_struct("ToolRegistry")
.field("tool_count", &self.tools.len())
.field("tool_names", &names)
.finish()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
#[must_use]
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register<T: RustTool + 'static>(&mut self, tool: T) -> &mut Self {
let definition = definition_of(&tool)
.unwrap_or_else(|e| panic!("Failed to build definition for tool '{}': {e}", T::NAME));
self.tools.insert(
T::NAME,
RegisteredTool {
definition,
erased: Box::new(tool),
},
);
self
}
#[must_use]
pub fn with_tool<T: RustTool + 'static>(mut self, tool: T) -> Self {
self.register(tool);
self
}
#[must_use]
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|entry| entry.definition.clone())
.collect()
}
pub async fn dispatch(
&self,
name: &str,
args: serde_json::Value,
ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let entry = self
.tools
.get(name)
.ok_or_else(|| ToolError::new(format!("Unknown tool: {name}")))?;
entry.erased.call_erased(args, ctx).await
}
#[must_use]
pub fn len(&self) -> usize {
self.tools.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&'static str, ToolDefinition)> + '_ {
self.tools
.iter()
.map(|(name, entry)| (*name, entry.definition.clone()))
}
}
impl<'a> IntoIterator for &'a ToolRegistry {
type Item = (&'static str, ToolDefinition);
type IntoIter = Box<dyn Iterator<Item = (&'static str, ToolDefinition)> + 'a>;
fn into_iter(self) -> Self::IntoIter {
Box::new(
self.tools
.iter()
.map(|(name, entry)| (*name, entry.definition.clone())),
)
}
}
#[cfg(test)]
mod tests {
use serde::Deserialize;
use super::{
super::{EmptyParams, definition_of},
*,
};
use crate::llm_tool;
fn test_ctx() -> ToolContext {
ToolContext::new(None)
}
#[derive(Deserialize, schemars::JsonSchema)]
struct PathParams {
path: String,
}
struct SampleTool;
impl RustTool for SampleTool {
type Params = PathParams;
const NAME: &'static str = "sample";
const DESCRIPTION: &'static str = "A sample tool";
async fn call(
&self,
params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Ok(params.path.into())
}
}
#[derive(Deserialize, schemars::JsonSchema)]
struct RunCommandParams {
command: String,
#[serde(default)]
timeout: Option<i64>,
#[serde(default)]
env: Option<std::collections::HashMap<String, String>>,
}
struct RunCommandTool;
impl RustTool for RunCommandTool {
type Params = RunCommandParams;
const NAME: &'static str = "run_command";
const DESCRIPTION: &'static str = "Runs a command.";
async fn call(
&self,
params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
assert!(params.timeout.is_none());
assert!(params.env.is_none());
Ok(format!("Ran: {}", params.command).into())
}
}
#[test]
fn tool_definition_serde_roundtrip() {
let def = definition_of(&SampleTool).expect("schema");
let json = serde_json::to_string(&def).expect("serialize");
let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed.name, def.name);
assert_eq!(parsed.description, def.description);
assert_eq!(parsed.parameter_schema, def.parameter_schema);
}
struct EmptyParamTool;
impl RustTool for EmptyParamTool {
type Params = EmptyParams;
const NAME: &'static str = "empty";
const DESCRIPTION: &'static str = "No params";
async fn call(
&self,
_params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Ok("ok".into())
}
}
#[test]
fn tool_definition_with_empty_schema() {
let tool = definition_of(&EmptyParamTool).expect("schema");
let json = serde_json::to_string(&tool).expect("serialize");
let parsed: ToolDefinition = serde_json::from_str(&json).expect("deserialize");
let orig_json = serde_json::to_value(&tool.parameter_schema).unwrap();
let parsed_json = serde_json::to_value(&parsed.parameter_schema).unwrap();
assert_eq!(orig_json, parsed_json);
}
#[test]
fn tool_definition_with_complex_schema() {
let tool = definition_of(&RunCommandTool).expect("schema");
let schema_json = serde_json::to_value(&tool.parameter_schema).expect("schema to json");
let required = schema_json["required"]
.as_array()
.expect("required should be an array");
assert!(
required.iter().any(|v| v == "command"),
"'command' should be required, got: {required:?}"
);
}
#[tokio::test]
async fn registry_dispatch_valid_tool() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let result = d
.dispatch(
"sample",
serde_json::json!({"path": "/tmp/foo"}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "/tmp/foo");
}
#[tokio::test]
async fn registry_dispatch_unknown_tool() {
let d = ToolRegistry::new();
let result = d
.dispatch("nonexistent", serde_json::json!({}), &test_ctx())
.await;
assert_eq!(
result.unwrap_err(),
ToolError::new("Unknown tool: nonexistent")
);
}
#[tokio::test]
async fn registry_dispatch_invalid_args() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let result = d
.dispatch("sample", serde_json::json!({"path": 42}), &test_ctx())
.await;
let err = result.unwrap_err();
assert!(
err.message.contains("deserialize"),
"Error should mention deserialization, got: {err}"
);
}
#[tokio::test]
async fn registry_dispatch_missing_required_field() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let err = d
.dispatch("sample", serde_json::json!({}), &test_ctx())
.await
.expect_err("Expected error for missing required field");
assert!(
err.message.contains("missing field"),
"Error should mention missing field, got: {err}"
);
}
#[test]
fn registry_definitions_returns_all() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
d.register(RunCommandTool);
let defs = d.definitions();
assert_eq!(defs.len(), 2);
let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
names.sort_unstable();
assert_eq!(names, vec!["run_command", "sample"]);
}
#[test]
fn registry_register_chaining() {
let mut d = ToolRegistry::new();
d.register(SampleTool).register(RunCommandTool);
assert_eq!(d.len(), 2);
assert!(!d.is_empty());
}
#[test]
fn registry_with_tool_owned_chaining() {
let d = ToolRegistry::new()
.with_tool(SampleTool)
.with_tool(RunCommandTool);
assert_eq!(d.len(), 2);
assert!(!d.is_empty());
let defs = d.definitions();
let mut names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
names.sort_unstable();
assert_eq!(names, vec!["run_command", "sample"]);
}
#[test]
fn registry_default_is_empty() {
let d = ToolRegistry::default();
assert!(d.is_empty());
assert_eq!(d.len(), 0);
}
#[tokio::test]
async fn registry_replaces_on_duplicate_name() {
struct AlternateSample;
impl RustTool for AlternateSample {
type Params = PathParams;
const NAME: &'static str = "sample";
const DESCRIPTION: &'static str = "Alternate sample";
async fn call(
&self,
params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Ok(format!("alt: {}", params.path).into())
}
}
let mut d = ToolRegistry::new();
d.register(SampleTool);
d.register(AlternateSample);
assert_eq!(d.len(), 1);
let result = d
.dispatch("sample", serde_json::json!({"path": "x"}), &test_ctx())
.await;
assert_eq!(result.unwrap().content(), "alt: x");
}
#[tokio::test]
async fn registry_tool_returning_error() {
struct FailingTool;
impl RustTool for FailingTool {
type Params = EmptyParams;
const NAME: &'static str = "fail";
const DESCRIPTION: &'static str = "Always fails";
async fn call(
&self,
_params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Err(ToolError::new("intentional failure"))
}
}
let mut d = ToolRegistry::new();
d.register(FailingTool);
let result = d.dispatch("fail", serde_json::json!({}), &test_ctx()).await;
assert_eq!(result.unwrap_err(), ToolError::new("intentional failure"));
}
#[test]
fn registry_debug_shows_tool_names() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let dbg = format!("{d:?}");
assert!(dbg.contains("ToolRegistry"));
assert!(dbg.contains("sample"));
assert!(dbg.contains("tool_count: 1"));
}
struct AsyncSleepTool;
impl RustTool for AsyncSleepTool {
type Params = EmptyParams;
const NAME: &'static str = "async_sleep";
const DESCRIPTION: &'static str = "Sleeps briefly then returns.";
async fn call(
&self,
_params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok("slept".into())
}
}
#[tokio::test]
async fn async_tool_with_tokio_sleep() {
let mut d = ToolRegistry::new();
d.register(AsyncSleepTool);
let result = d
.dispatch("async_sleep", serde_json::json!({}), &test_ctx())
.await;
assert_eq!(result.unwrap().content(), "slept");
}
struct AsyncReadFileTool;
#[derive(Deserialize, schemars::JsonSchema)]
struct ReadFileParams {
path: String,
}
impl RustTool for AsyncReadFileTool {
type Params = ReadFileParams;
const NAME: &'static str = "read_file";
const DESCRIPTION: &'static str = "Reads a file asynchronously.";
async fn call(
&self,
params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
tokio::fs::read_to_string(¶ms.path)
.await
.map(ToolOutput::from)
.map_err(|e| ToolError::new(format!("IO error: {e}")))
}
}
#[tokio::test]
async fn async_tool_with_tokio_fs() {
let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
std::fs::write(tmp.path(), "hello async").expect("write tempfile");
let mut d = ToolRegistry::new();
d.register(AsyncReadFileTool);
let path_str = tmp.path().to_str().expect("path to str").to_owned();
let result = d
.dispatch(
"read_file",
serde_json::json!({"path": path_str}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "hello async");
}
#[tokio::test]
async fn async_tool_tokio_fs_missing_file() {
let mut d = ToolRegistry::new();
d.register(AsyncReadFileTool);
let result = d
.dispatch(
"read_file",
serde_json::json!({"path": "/nonexistent/file.txt"}),
&test_ctx(),
)
.await;
let err = result.unwrap_err();
assert!(
err.message.contains("IO error"),
"Expected IO error, got: {err}"
);
}
struct ChannelTool {
tx: tokio::sync::mpsc::Sender<String>,
rx: std::sync::Mutex<Option<tokio::sync::mpsc::Receiver<String>>>,
}
impl ChannelTool {
fn new() -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(1);
Self {
tx,
rx: std::sync::Mutex::new(Some(rx)),
}
}
}
impl RustTool for ChannelTool {
type Params = EmptyParams;
const NAME: &'static str = "channel_tool";
const DESCRIPTION: &'static str = "Awaits a value from a channel.";
async fn call(
&self,
_params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let mut rx = self
.rx
.lock()
.unwrap()
.take()
.ok_or_else(|| ToolError::new("channel already consumed"))?;
rx.recv()
.await
.map(ToolOutput::from)
.ok_or_else(|| ToolError::new("channel closed"))
}
}
#[tokio::test]
async fn async_tool_awaits_channel() {
let tool = ChannelTool::new();
let tx = tool.tx.clone();
let mut d = ToolRegistry::new();
d.register(tool);
let ctx = test_ctx();
let dispatch_future = d.dispatch("channel_tool", serde_json::json!({}), &ctx);
let send_future = async move {
tx.send("from_channel".to_string()).await.unwrap();
};
let (result, ()) = tokio::join!(dispatch_future, send_future);
assert_eq!(result.unwrap().content(), "from_channel");
}
#[tokio::test]
async fn concurrent_dispatches_to_different_tools() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
d.register(AsyncSleepTool);
d.register(RunCommandTool);
let ctx = test_ctx();
let (r1, r2, r3) = tokio::join!(
d.dispatch("sample", serde_json::json!({"path": "a"}), &ctx),
d.dispatch("async_sleep", serde_json::json!({}), &ctx),
d.dispatch("run_command", serde_json::json!({"command": "ls"}), &ctx),
);
assert_eq!(r1.unwrap().content(), "a");
assert_eq!(r2.unwrap().content(), "slept");
assert_eq!(r3.unwrap().content(), "Ran: ls");
}
#[tokio::test]
async fn concurrent_dispatches_to_same_tool() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let ctx = test_ctx();
let futs: Vec<_> = (0..10)
.map(|i| d.dispatch("sample", serde_json::json!({"path": format!("p{i}")}), &ctx))
.collect();
let results = futures::future::join_all(futs).await;
for (i, r) in results.into_iter().enumerate() {
assert_eq!(r.unwrap().content(), format!("p{i}"));
}
}
#[derive(Deserialize, schemars::JsonSchema)]
struct DocumentedParams {
hostname: String,
port: u16,
#[serde(default)]
timeout: Option<f64>,
}
struct DocumentedTool;
impl RustTool for DocumentedTool {
type Params = DocumentedParams;
const NAME: &'static str = "connect";
const DESCRIPTION: &'static str = "Connects to a remote host.";
async fn call(&self, p: Self::Params, _ctx: &ToolContext) -> Result<ToolOutput, ToolError> {
Ok(format!("{}:{}:{:?}", p.hostname, p.port, p.timeout).into())
}
}
#[test]
fn schema_contains_field_descriptions() {
let def = definition_of(&DocumentedTool).expect("schema");
let schema = &def.parameter_schema;
let props = schema["properties"].as_object().expect("properties object");
assert!(props.contains_key("hostname"), "missing hostname");
assert!(props.contains_key("port"), "missing port");
assert!(props.contains_key("timeout"), "missing timeout");
let hostname_desc = props["hostname"]["description"]
.as_str()
.expect("hostname description");
assert!(
hostname_desc.contains("hostname"),
"hostname description should mention 'hostname', got: {hostname_desc}"
);
let port_desc = props["port"]["description"]
.as_str()
.expect("port description");
assert!(
port_desc.contains("1-65535"),
"port description should mention range, got: {port_desc}"
);
}
#[test]
fn schema_required_vs_optional_fields() {
let def = definition_of(&DocumentedTool).expect("schema");
let schema = &def.parameter_schema;
let required = schema["required"]
.as_array()
.expect("required should be an array");
assert!(
required.iter().any(|v| v == "hostname"),
"hostname required"
);
assert!(required.iter().any(|v| v == "port"), "port required");
assert!(
!required.iter().any(|v| v == "timeout"),
"timeout should NOT be required"
);
}
#[tokio::test]
async fn dispatch_with_optional_field_missing() {
let mut d = ToolRegistry::new();
d.register(DocumentedTool);
let result = d
.dispatch(
"connect",
serde_json::json!({"hostname": "example.com", "port": 443}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "example.com:443:None");
}
#[tokio::test]
async fn dispatch_with_optional_field_present() {
let mut d = ToolRegistry::new();
d.register(DocumentedTool);
let result = d
.dispatch(
"connect",
serde_json::json!({"hostname": "localhost", "port": 8080, "timeout": 30.0}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "localhost:8080:Some(30.0)");
}
#[tokio::test]
async fn dispatch_with_extra_fields_ignored() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let result = d
.dispatch(
"sample",
serde_json::json!({"path": "/tmp/x", "unknown_field": 42}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "/tmp/x");
}
#[tokio::test]
async fn erased_dispatch_preserves_borrow_lifetime() {
let mut d = ToolRegistry::new();
d.register(AsyncSleepTool);
d.register(SampleTool);
let r1 = d
.dispatch("async_sleep", serde_json::json!({}), &test_ctx())
.await;
let r2 = d
.dispatch("sample", serde_json::json!({"path": "test"}), &test_ctx())
.await;
assert_eq!(r1.unwrap().content(), "slept");
assert_eq!(r2.unwrap().content(), "test");
}
#[tokio::test]
async fn dispatch_returns_meaningful_error_for_wrong_type() {
let mut d = ToolRegistry::new();
d.register(RunCommandTool);
let result = d
.dispatch(
"run_command",
serde_json::json!({"command": {"nested": "object"}}),
&test_ctx(),
)
.await;
let err = result.unwrap_err();
assert!(
err.message
.contains("Failed to deserialize tool parameters"),
"Error should mention deserialization failure, got: {err}"
);
}
#[llm_tool]
async fn async_delayed_echo(
message: String,
) -> Result<String, ToolError> {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
Ok(format!("echo: {message}"))
}
#[tokio::test]
async fn tool_macro_async_fn_dispatches_with_await() {
let mut d = ToolRegistry::new();
d.register(AsyncDelayedEcho);
let result = d
.dispatch(
"async_delayed_echo",
serde_json::json!({"message": "hello async"}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "echo: hello async");
}
#[llm_tool]
async fn async_file_reader(
path: String,
) -> Result<String, ToolError> {
tokio::fs::read_to_string(&path)
.await
.map_err(|e| ToolError::new(format!("IO error: {e}")))
}
#[tokio::test]
async fn tool_macro_async_fn_reads_file() {
let tmp = tempfile::NamedTempFile::new().expect("create tempfile");
std::fs::write(tmp.path(), "async macro content").expect("write");
let mut d = ToolRegistry::new();
d.register(AsyncFileReader);
let path_str = tmp.path().to_str().expect("path").to_owned();
let result = d
.dispatch(
"async_file_reader",
serde_json::json!({"path": path_str}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "async macro content");
}
#[llm_tool]
fn greet_optional(
name: String,
greeting: Option<String>,
) -> Result<String, ToolError> {
let g = greeting.unwrap_or_else(|| "Hello".to_string());
Ok(format!("{g}, {name}!"))
}
#[test]
fn tool_macro_option_param_not_in_required() {
let def = definition_of(&GreetOptional).expect("schema");
let schema = &def.parameter_schema;
let required = schema["required"]
.as_array()
.expect("required should be an array");
assert!(
required.iter().any(|v| v == "name"),
"'name' should be required, got: {required:?}"
);
assert!(
!required.iter().any(|v| v == "greeting"),
"'greeting' (Option<String>) should NOT be required, got: {required:?}"
);
}
#[tokio::test]
async fn tool_macro_option_param_missing_from_json() {
let mut d = ToolRegistry::new();
d.register(GreetOptional);
let result = d
.dispatch(
"greet_optional",
serde_json::json!({"name": "World"}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "Hello, World!");
}
#[tokio::test]
async fn tool_macro_option_param_provided_in_json() {
let mut d = ToolRegistry::new();
d.register(GreetOptional);
let result = d
.dispatch(
"greet_optional",
serde_json::json!({"name": "World", "greeting": "Hi"}),
&test_ctx(),
)
.await;
assert_eq!(result.unwrap().content(), "Hi, World!");
}
#[llm_tool]
async fn async_optional_tool(
input: String,
suffix: Option<String>,
) -> Result<String, ToolError> {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
let s = suffix.unwrap_or_default();
Ok(format!("{input}{s}"))
}
#[tokio::test]
async fn tool_macro_async_with_optional_param() {
let mut d = ToolRegistry::new();
d.register(AsyncOptionalTool);
let r1 = d
.dispatch(
"async_optional_tool",
serde_json::json!({"input": "base"}),
&test_ctx(),
)
.await;
assert_eq!(r1.unwrap().content(), "base");
let r2 = d
.dispatch(
"async_optional_tool",
serde_json::json!({"input": "base", "suffix": "_ext"}),
&test_ctx(),
)
.await;
assert_eq!(r2.unwrap().content(), "base_ext");
}
#[test]
fn tool_macro_async_optional_schema_correctness() {
let def = definition_of(&AsyncOptionalTool).expect("schema");
let schema = &def.parameter_schema;
let required = schema["required"].as_array().expect("required array");
assert!(required.iter().any(|v| v == "input"), "'input' required");
assert!(
!required.iter().any(|v| v == "suffix"),
"'suffix' (Option) should NOT be required"
);
}
#[test]
fn into_iter_yields_all_tool_name_definition_pairs() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
d.register(RunCommandTool);
let mut pairs: Vec<(&str, String)> = (&d)
.into_iter()
.map(|(name, def)| (name, def.name))
.collect();
pairs.sort();
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0].0, "run_command");
assert_eq!(pairs[0].1, "run_command");
assert_eq!(pairs[1].0, "sample");
assert_eq!(pairs[1].1, "sample");
}
#[test]
fn into_iter_empty_registry_yields_nothing() {
let d = ToolRegistry::new();
let count = (&d).into_iter().count();
assert_eq!(count, 0);
}
#[test]
fn into_iter_for_loop_syntax() {
let mut d = ToolRegistry::new();
d.register(SampleTool);
let mut found = false;
for (name, def) in &d {
if name == "sample" {
assert_eq!(def.description, "A sample tool");
found = true;
}
}
assert!(found, "Expected to find 'sample' tool via for-in loop");
}
#[test]
fn tool_context_conversation_id_none_by_default() {
let ctx = ToolContext::new(None);
assert!(ctx.conversation_id().is_none());
assert!(!ctx.is_idle());
}
#[test]
fn tool_context_conversation_id_returns_value() {
let ctx = ToolContext::new(Some("conv-123".to_owned()));
assert_eq!(ctx.conversation_id(), Some("conv-123"));
}
#[test]
fn tool_context_get_set_state_roundtrip() {
let ctx = ToolContext::new(None);
let val = ctx.get_state("missing", serde_json::json!("fallback"));
assert_eq!(val, serde_json::json!("fallback"));
ctx.set_state("counter", serde_json::json!(42))
.expect("set_state");
let val = ctx.get_state("counter", serde_json::json!(0));
assert_eq!(val, serde_json::json!(42));
ctx.set_state("counter", serde_json::json!(99))
.expect("set_state");
let val = ctx.get_state("counter", serde_json::json!(0));
assert_eq!(val, serde_json::json!(99));
}
#[test]
fn tool_context_state_persists_across_reads() {
let ctx = ToolContext::new(None);
ctx.set_state("key", serde_json::json!({"nested": true}))
.expect("set_state");
let v1 = ctx.get_state("key", serde_json::json!(null));
let v2 = ctx.get_state("key", serde_json::json!(null));
assert_eq!(v1, v2);
assert_eq!(v1, serde_json::json!({"nested": true}));
}
#[tokio::test]
async fn dispatch_passes_context_to_tool() {
struct ContextAwareTool;
impl RustTool for ContextAwareTool {
type Params = EmptyParams;
const NAME: &'static str = "ctx_tool";
const DESCRIPTION: &'static str = "Reads conversation_id from context.";
async fn call(
&self,
_params: Self::Params,
ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
let conv = ctx.conversation_id().unwrap_or("none");
let count = ctx.get_state("call_count", serde_json::json!(0));
let n = count.as_i64().unwrap_or(0);
ctx.set_state("call_count", serde_json::json!(n + 1))
.map_err(|e| ToolError::new(format!("set_state failed: {e}")))?;
Ok(format!("conv={conv}, call={n}").into())
}
}
let mut d = ToolRegistry::new();
d.register(ContextAwareTool);
let ctx = ToolContext::new(Some("test-conv".to_owned()));
let r1 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
assert_eq!(r1.unwrap().content(), "conv=test-conv, call=0");
let r2 = d.dispatch("ctx_tool", serde_json::json!({}), &ctx).await;
assert_eq!(r2.unwrap().content(), "conv=test-conv, call=1");
}
#[derive(serde::Serialize)]
struct ProcessMeta {
bytes_read: usize,
source: String,
}
struct MetadataTool;
impl RustTool for MetadataTool {
type Params = PathParams;
const NAME: &'static str = "metadata_tool";
const DESCRIPTION: &'static str = "Returns output with metadata.";
async fn call(
&self,
params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
ToolOutput::new(format!("processed: {}", params.path)).with_metadata(&ProcessMeta {
bytes_read: 1024,
source: params.path,
})
}
}
#[tokio::test]
async fn dispatch_preserves_tool_output_metadata() {
let mut d = ToolRegistry::new();
d.register(MetadataTool);
let result = d
.dispatch(
"metadata_tool",
serde_json::json!({"path": "/etc/hosts"}),
&test_ctx(),
)
.await
.unwrap();
assert_eq!(result.content(), "processed: /etc/hosts");
assert_eq!(result.metadata()["bytes_read"], 1024);
assert_eq!(result.metadata()["source"], "/etc/hosts");
assert_eq!(result.metadata().len(), 2);
}
#[tokio::test]
async fn dispatch_tool_output_display_uses_content() {
let output = ToolOutput::new("hello world").with_meta("ignored", serde_json::json!(true));
assert_eq!(output.to_string(), "hello world");
}
#[tokio::test]
async fn dispatch_tool_output_into_content_consumes() {
let output = ToolOutput::new("owned").with_meta("key", serde_json::json!("val"));
let content: String = output.into_content();
assert_eq!(content, "owned");
}
#[test]
fn tool_output_from_str_has_empty_metadata() {
let output: ToolOutput = "plain".into();
assert_eq!(output.content(), "plain");
assert!(output.metadata().is_empty());
}
#[test]
fn tool_output_from_string_has_empty_metadata() {
let output: ToolOutput = "owned".to_string().into();
assert_eq!(output.content(), "owned");
assert!(output.metadata().is_empty());
}
#[test]
fn tool_error_with_metadata() {
let err = ToolError::new("HTTP request failed")
.with_meta("status_code", serde_json::json!(503))
.with_meta("url", serde_json::json!("https://example.com"));
assert_eq!(err.message, "HTTP request failed");
assert_eq!(err.metadata()["status_code"], 503);
assert_eq!(err.metadata()["url"], "https://example.com");
assert_eq!(err.metadata().len(), 2);
}
#[test]
fn tool_error_without_metadata_is_empty() {
let err = ToolError::new("simple error");
assert!(err.metadata().is_empty());
}
#[test]
fn tool_error_display_ignores_metadata() {
let err = ToolError::new("visible").with_meta("hidden", serde_json::json!(true));
assert_eq!(err.to_string(), "visible");
}
#[test]
fn tool_error_equality_includes_metadata() {
let a = ToolError::new("err").with_meta("k", serde_json::json!(1));
let b = ToolError::new("err").with_meta("k", serde_json::json!(1));
let c = ToolError::new("err").with_meta("k", serde_json::json!(2));
assert_eq!(a, b);
assert_ne!(a, c);
}
struct MetadataErrorTool;
impl RustTool for MetadataErrorTool {
type Params = EmptyParams;
const NAME: &'static str = "metadata_error_tool";
const DESCRIPTION: &'static str = "Always fails with metadata.";
async fn call(
&self,
_params: Self::Params,
_ctx: &ToolContext,
) -> Result<ToolOutput, ToolError> {
Err(ToolError::new("service unavailable")
.with_meta("retry_after_secs", serde_json::json!(30)))
}
}
#[tokio::test]
async fn dispatch_preserves_tool_error_metadata() {
let mut d = ToolRegistry::new();
d.register(MetadataErrorTool);
let err = d
.dispatch("metadata_error_tool", serde_json::json!({}), &test_ctx())
.await
.unwrap_err();
assert_eq!(err.message, "service unavailable");
assert_eq!(err.metadata()["retry_after_secs"], 30);
}
#[llm_tool]
fn tool_with_metadata(
input: String,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::new(format!("echoed: {input}"))
.with_meta("input_len", serde_json::json!(input.len())))
}
#[tokio::test]
async fn macro_tool_returning_tool_output_preserves_metadata() {
let mut d = ToolRegistry::new();
d.register(ToolWithMetadata);
let result = d
.dispatch(
"tool_with_metadata",
serde_json::json!({"input": "hello"}),
&test_ctx(),
)
.await
.unwrap();
assert_eq!(result.content(), "echoed: hello");
assert_eq!(result.metadata()["input_len"], 5);
}
#[test]
fn tool_output_with_metadata_struct() {
#[derive(serde::Serialize)]
struct Meta {
status: String,
count: u32,
}
let out = ToolOutput::new("done")
.with_metadata(&Meta {
status: "ok".into(),
count: 42,
})
.unwrap();
assert_eq!(out.metadata()["status"], "ok");
assert_eq!(out.metadata()["count"], 42);
assert_eq!(out.metadata().len(), 2);
}
#[test]
fn tool_output_with_metadata_merges_with_existing() {
#[derive(serde::Serialize)]
struct Extra {
source: String,
}
let out = ToolOutput::new("data")
.with_meta("version", serde_json::json!(1))
.with_metadata(&Extra {
source: "cache".into(),
})
.unwrap();
assert_eq!(out.metadata()["version"], 1);
assert_eq!(out.metadata()["source"], "cache");
assert_eq!(out.metadata().len(), 2);
}
#[test]
fn tool_output_with_metadata_rejects_non_object() {
let err = ToolOutput::new("x").with_metadata(&42_i32).unwrap_err();
assert!(
err.message.contains("JSON object"),
"Expected object error, got: {err}"
);
}
#[test]
fn tool_error_with_metadata_struct() {
#[derive(serde::Serialize)]
struct ErrorMeta {
status_code: u16,
url: String,
}
let err = ToolError::new("HTTP request failed")
.with_metadata(&ErrorMeta {
status_code: 503,
url: "https://example.com".into(),
})
.unwrap();
assert_eq!(err.message, "HTTP request failed");
assert_eq!(err.metadata()["status_code"], 503);
assert_eq!(err.metadata()["url"], "https://example.com");
assert_eq!(err.metadata().len(), 2);
}
#[test]
fn tool_output_from_metadata_populates_both() {
#[derive(serde::Serialize)]
struct Weather {
location: String,
temp_f: i32,
}
let out = ToolOutput::from_metadata(&Weather {
location: "Seattle".into(),
temp_f: 72,
})
.unwrap();
assert!(out.content().contains("Seattle"));
assert!(out.content().contains("72"));
assert_eq!(out.metadata()["location"], "Seattle");
assert_eq!(out.metadata()["temp_f"], 72);
assert_eq!(out.metadata().len(), 2);
}
#[test]
fn tool_output_from_metadata_rejects_non_object() {
let err = ToolOutput::from_metadata(&"just a string").unwrap_err();
assert!(
err.message.contains("JSON object"),
"Expected object error, got: {err}"
);
}
}