use std::collections::HashMap;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] pub struct ToolExecutionConfig {
pub max_tool_calls_per_turn: u32,
#[serde(with = "humantime_serde_compat")]
pub tool_timeout: Duration,
}
impl Default for ToolExecutionConfig {
fn default() -> Self {
Self {
max_tool_calls_per_turn: 20,
tool_timeout: Duration::from_secs(30),
}
}
}
#[allow(dead_code)] mod humantime_serde_compat {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
s.serialize_u64(d.as_secs())
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
}
#[derive(Debug, Clone)]
pub struct ToolEntry {
pub name: String,
#[allow(dead_code)] pub description: String,
#[allow(dead_code)] pub config: Option<Value>,
}
impl ToolEntry {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
config: Option<Value>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
config,
}
}
}
type ToolConstructor = Box<dyn Fn(Option<&Value>) -> anyhow::Result<ToolEntry> + Send + Sync>;
pub struct BuiltInToolFactory {
registry: HashMap<&'static str, ToolConstructor>,
}
impl BuiltInToolFactory {
pub fn new() -> Self {
let mut factory = Self {
registry: HashMap::new(),
};
factory.register("web_search", "Web search via WebSearchTool");
factory.register("google_search", "Google search via GoogleSearchTool");
factory.register("google_maps", "Google Maps via GoogleMapsTool");
factory.register(
"url_context",
"Fetch and extract URL content via UrlContextTool",
);
factory.register("load_artifacts", "Load artifacts via LoadArtifactsTool");
factory.register("code_execution", "Code execution via CodeTool");
factory.register("python_code", "Python code execution via PythonCodeTool");
factory.register(
"javascript_code",
"JavaScript code execution via JavaScriptCodeTool",
);
factory.register(
"frontend_code",
"Frontend code execution via FrontendCodeTool",
);
factory.register("rust_code", "Rust code execution via RustCodeTool");
factory.register(
"openai_web_search",
"OpenAI web search via OpenAIWebSearchTool",
);
factory.register(
"openai_file_search",
"OpenAI file search via OpenAIFileSearchTool",
);
factory.register(
"openai_image_generation",
"OpenAI image generation via OpenAIImageGenerationTool",
);
factory.register(
"openai_code_interpreter",
"OpenAI code interpreter via OpenAICodeInterpreterTool",
);
factory.register(
"gemini_code_execution",
"Gemini code execution via GeminiCodeExecutionTool",
);
factory.register(
"gemini_file_search",
"Gemini file search via GeminiFileSearchTool",
);
factory
}
fn register(&mut self, name: &'static str, description: &'static str) {
self.registry.insert(
name,
Box::new(move |cfg: Option<&Value>| {
Ok(ToolEntry::new(name, description, cfg.cloned()))
}),
);
}
#[allow(dead_code)] pub fn create(&self, name: &str, config: Option<&Value>) -> Option<anyhow::Result<ToolEntry>> {
self.registry.get(name).map(|ctor| ctor(config))
}
pub fn known_names(&self) -> Vec<&'static str> {
let mut names: Vec<&'static str> = self.registry.keys().copied().collect();
names.sort_unstable();
names
}
}
impl Default for BuiltInToolFactory {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AgentToolEntry {
#[allow(dead_code)] pub agent_id: String,
pub description: String,
}
impl AgentToolEntry {
pub fn new(agent_id: impl Into<String>, description: impl Into<String>) -> Self {
Self {
agent_id: agent_id.into(),
description: description.into(),
}
}
}
pub struct ToolRegistry {
builtin_factory: BuiltInToolFactory,
custom_tools: HashMap<String, ToolEntry>,
agent_tools: HashMap<String, AgentToolEntry>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
builtin_factory: BuiltInToolFactory::new(),
custom_tools: HashMap::new(),
agent_tools: HashMap::new(),
}
}
pub fn register_custom(&mut self, entry: ToolEntry) {
self.custom_tools.insert(entry.name.clone(), entry);
}
pub fn register_agent_tool(&mut self, agent_id: &str, description: &str) {
let entry = AgentToolEntry::new(agent_id, description);
self.agent_tools.insert(agent_id.to_string(), entry);
tracing::info!(
agent_id = %agent_id,
"registered agent tool for inter-agent delegation"
);
}
#[allow(dead_code)] pub fn agent_tools(&self) -> &HashMap<String, AgentToolEntry> {
&self.agent_tools
}
#[allow(dead_code)] pub fn resolve_tools(&self, names: &[String], tool_config: Option<&Value>) -> Vec<ToolEntry> {
let mut resolved = Vec::new();
for name in names {
if let Some(entry) = self.custom_tools.get(name.as_str()) {
resolved.push(entry.clone());
continue;
}
if let Some(agent_entry) = self.agent_tools.get(name.as_str()) {
resolved.push(ToolEntry::new(name.clone(), &agent_entry.description, None));
continue;
}
let cfg = tool_config.and_then(|v| v.get(name.as_str()));
match self.builtin_factory.create(name, cfg) {
Some(Ok(entry)) => {
resolved.push(entry);
}
Some(Err(e)) => {
tracing::warn!(
tool = %name,
error = %e,
"failed to construct built-in tool, skipping"
);
}
None => {
tracing::warn!(
tool = %name,
"unknown tool name, skipping registration"
);
}
}
}
resolved
}
pub fn resolve_all(&self) -> Vec<ToolEntry> {
let mut resolved: Vec<ToolEntry> = self.custom_tools.values().cloned().collect();
for (name, agent_entry) in &self.agent_tools {
resolved.push(ToolEntry::new(name.clone(), &agent_entry.description, None));
}
resolved
}
pub fn known_names(&self) -> Vec<String> {
let mut names: Vec<String> = self
.builtin_factory
.known_names()
.into_iter()
.map(|s| s.to_string())
.collect();
names.extend(self.custom_tools.keys().cloned());
names.extend(self.agent_tools.keys().map(|k| format!("agent:{k}")));
names.sort_unstable();
names.dedup();
names
}
pub fn log_registered_tools(agent_name: &str, tools: &[ToolEntry]) {
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
tracing::info!(
agent = %agent_name,
tools = ?names,
count = names.len(),
"registered tools for agent"
);
}
#[allow(dead_code)] pub fn build_tool_error_result(tool_name: &str, error: &str) -> serde_json::Value {
serde_json::json!({
"name": tool_name,
"error": true,
"content": error,
})
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn factory_creates_all_known_tools() {
let factory = BuiltInToolFactory::new();
let known = factory.known_names();
assert_eq!(known.len(), 16);
for name in &known {
let result = factory.create(name, None);
assert!(result.is_some(), "factory should know about {name}");
let entry = result.unwrap().expect("constructor should succeed");
assert_eq!(entry.name, *name);
assert!(!entry.description.is_empty());
assert!(entry.config.is_none());
}
}
#[test]
fn factory_returns_none_for_unknown() {
let factory = BuiltInToolFactory::new();
assert!(factory.create("nonexistent_tool", None).is_none());
}
#[test]
fn factory_passes_config_to_entry() {
let factory = BuiltInToolFactory::new();
let cfg = serde_json::json!({"api_key": "test123"});
let entry = factory.create("web_search", Some(&cfg)).unwrap().unwrap();
assert_eq!(entry.name, "web_search");
assert_eq!(entry.config.unwrap(), cfg);
}
#[test]
fn registry_resolves_known_tools() {
let registry = ToolRegistry::new();
let names = vec![
"web_search".to_string(),
"code_execution".to_string(),
"gemini_code_execution".to_string(),
];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 3);
assert_eq!(tools[0].name, "web_search");
assert_eq!(tools[1].name, "code_execution");
assert_eq!(tools[2].name, "gemini_code_execution");
}
#[test]
fn registry_skips_unknown_tools() {
let registry = ToolRegistry::new();
let names = vec![
"web_search".to_string(),
"totally_fake_tool".to_string(),
"code_execution".to_string(),
];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "web_search");
assert_eq!(tools[1].name, "code_execution");
}
#[test]
fn registry_passes_per_tool_config() {
let registry = ToolRegistry::new();
let tool_config = serde_json::json!({
"web_search": {"api_key": "ws_key"},
"code_execution": {"sandbox": true}
});
let names = vec!["web_search".to_string(), "code_execution".to_string()];
let tools = registry.resolve_tools(&names, Some(&tool_config));
assert_eq!(tools.len(), 2);
assert_eq!(
tools[0].config.as_ref().unwrap(),
&serde_json::json!({"api_key": "ws_key"})
);
assert_eq!(
tools[1].config.as_ref().unwrap(),
&serde_json::json!({"sandbox": true})
);
}
#[test]
fn registry_custom_tools_take_precedence() {
let mut registry = ToolRegistry::new();
let custom = ToolEntry::new("web_search", "Custom web search override", None);
registry.register_custom(custom);
let names = vec!["web_search".to_string()];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].description, "Custom web search override");
}
#[test]
fn registry_resolves_empty_list() {
let registry = ToolRegistry::new();
let tools = registry.resolve_tools(&[], None);
assert!(tools.is_empty());
}
#[test]
fn registry_handles_all_unknown() {
let registry = ToolRegistry::new();
let names = vec!["fake1".to_string(), "fake2".to_string()];
let tools = registry.resolve_tools(&names, None);
assert!(tools.is_empty());
}
#[test]
fn known_names_are_sorted() {
let factory = BuiltInToolFactory::new();
let names = factory.known_names();
let mut sorted = names.clone();
sorted.sort_unstable();
assert_eq!(names, sorted);
}
#[test]
fn register_agent_tool_stores_entry() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool(
"research-agent",
"Delegates research tasks to the research agent",
);
let agent_tools = registry.agent_tools();
assert_eq!(agent_tools.len(), 1);
let entry = agent_tools.get("research-agent").unwrap();
assert_eq!(entry.agent_id, "research-agent");
assert_eq!(
entry.description,
"Delegates research tasks to the research agent"
);
}
#[test]
fn register_multiple_agent_tools() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("research-agent", "Research delegation");
registry.register_agent_tool("code-agent", "Code generation delegation");
registry.register_agent_tool("review-agent", "Code review delegation");
assert_eq!(registry.agent_tools().len(), 3);
assert!(registry.agent_tools().contains_key("research-agent"));
assert!(registry.agent_tools().contains_key("code-agent"));
assert!(registry.agent_tools().contains_key("review-agent"));
}
#[test]
fn resolve_tools_finds_agent_tools() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("research-agent", "Delegates research tasks");
let names = vec!["research-agent".to_string()];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "research-agent");
assert_eq!(tools[0].description, "Delegates research tasks");
assert!(tools[0].config.is_none());
}
#[test]
fn resolve_tools_mixes_builtin_and_agent_tools() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("research-agent", "Research delegation");
let names = vec![
"web_search".to_string(),
"research-agent".to_string(),
"code_execution".to_string(),
];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 3);
assert_eq!(tools[0].name, "web_search");
assert_eq!(tools[1].name, "research-agent");
assert_eq!(tools[2].name, "code_execution");
}
#[test]
fn custom_tools_take_precedence_over_agent_tools() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("my-tool", "Agent version");
let custom = ToolEntry::new("my-tool", "Custom version", None);
registry.register_custom(custom);
let names = vec!["my-tool".to_string()];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].description, "Custom version");
}
#[test]
fn agent_tools_take_precedence_over_builtin() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("web_search", "Agent-wrapped web search");
let names = vec!["web_search".to_string()];
let tools = registry.resolve_tools(&names, None);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].description, "Agent-wrapped web search");
}
#[test]
fn register_agent_tool_overwrites_same_id() {
let mut registry = ToolRegistry::new();
registry.register_agent_tool("research-agent", "Old description");
registry.register_agent_tool("research-agent", "New description");
assert_eq!(registry.agent_tools().len(), 1);
assert_eq!(
registry
.agent_tools()
.get("research-agent")
.unwrap()
.description,
"New description"
);
}
#[test]
fn build_tool_error_result_contains_name() {
let result = ToolRegistry::build_tool_error_result("web_search", "API key expired");
assert_eq!(result["name"], "web_search");
}
#[test]
fn build_tool_error_result_contains_error_flag() {
let result = ToolRegistry::build_tool_error_result("web_search", "timeout");
assert_eq!(result["error"], true);
}
#[test]
fn build_tool_error_result_contains_error_message() {
let msg = "Connection refused: could not reach search API";
let result = ToolRegistry::build_tool_error_result("google_search", msg);
assert_eq!(result["content"], msg);
}
#[test]
fn build_tool_error_result_with_empty_error() {
let result = ToolRegistry::build_tool_error_result("code_execution", "");
assert_eq!(result["name"], "code_execution");
assert_eq!(result["error"], true);
assert_eq!(result["content"], "");
}
#[test]
fn build_tool_error_result_with_special_characters() {
let msg = r#"Error: unexpected token '<' at line 1, col 1 — "<!DOCTYPE html>""#;
let result = ToolRegistry::build_tool_error_result("url_context", msg);
assert_eq!(result["content"], msg);
}
#[test]
fn tool_execution_config_defaults() {
let config = ToolExecutionConfig::default();
assert_eq!(config.max_tool_calls_per_turn, 20);
assert_eq!(config.tool_timeout, std::time::Duration::from_secs(30));
}
#[test]
fn tool_execution_config_serde_roundtrip() {
let config = ToolExecutionConfig {
max_tool_calls_per_turn: 50,
tool_timeout: std::time::Duration::from_secs(60),
};
let json = serde_json::to_string(&config).unwrap();
let parsed: ToolExecutionConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_tool_calls_per_turn, 50);
assert_eq!(parsed.tool_timeout, std::time::Duration::from_secs(60));
}
}