use std::collections::HashMap;
use std::sync::Arc;
use rucora_core::tool::Tool;
use rucora_core::tool::ToolCategory;
use rucora_core::tool::types::ToolDefinition;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ToolSource {
BuiltIn,
Skill,
Mcp,
A2A,
Custom,
}
impl ToolSource {
pub fn as_str(&self) -> &'static str {
match self {
ToolSource::BuiltIn => "builtin",
ToolSource::Skill => "skill",
ToolSource::Mcp => "mcp",
ToolSource::A2A => "a2a",
ToolSource::Custom => "custom",
}
}
}
#[derive(Debug, Clone)]
pub struct ToolMetadata {
pub source: ToolSource,
pub categories: Vec<ToolCategory>,
pub enabled: bool,
pub tags: Vec<String>,
}
impl Default for ToolMetadata {
fn default() -> Self {
Self {
source: ToolSource::Custom,
categories: vec![ToolCategory::Basic],
enabled: true,
tags: vec![],
}
}
}
#[derive(Clone)]
pub struct ToolWrapper {
pub tool: Arc<dyn Tool>,
pub metadata: ToolMetadata,
}
impl ToolWrapper {
pub fn new<T: Tool + 'static>(tool: T) -> Self {
let categories = tool.categories().to_vec();
Self {
tool: Arc::new(tool),
metadata: ToolMetadata {
source: ToolSource::Custom,
categories,
enabled: true,
tags: vec![],
},
}
}
pub fn new_arc(tool: Arc<dyn Tool>) -> Self {
let categories = tool.categories().to_vec();
Self {
tool,
metadata: ToolMetadata {
source: ToolSource::Custom,
categories,
enabled: true,
tags: vec![],
},
}
}
pub fn with_source(mut self, source: ToolSource) -> Self {
self.metadata.source = source;
self
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.metadata.tags = tags;
self
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.metadata.enabled = enabled;
self
}
}
#[derive(Default, Clone)]
pub struct ToolRegistry {
tools: HashMap<String, ToolWrapper>,
namespace_prefix: Option<String>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
namespace_prefix: None,
}
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace_prefix = Some(namespace.into());
self
}
fn namespaced_name(&self, name: &str) -> String {
if let Some(prefix) = &self.namespace_prefix {
format!("{prefix}::{name}")
} else {
name.to_string()
}
}
pub fn register<T: Tool + 'static>(mut self, tool: T) -> Self {
let wrapper = ToolWrapper::new(tool);
let name = self.namespaced_name(wrapper.tool.name());
self.tools.insert(name, wrapper);
self
}
pub fn register_wrapper(mut self, wrapper: ToolWrapper) -> Self {
let name = self.namespaced_name(wrapper.tool.name());
self.tools.insert(name, wrapper);
self
}
pub fn register_arc(mut self, tool: Arc<dyn Tool>) -> Self {
let name = self.namespaced_name(tool.name());
self.tools.insert(
name,
ToolWrapper {
tool,
metadata: ToolMetadata::default(),
},
);
self
}
pub fn register_with_source<T: Tool + 'static>(self, tool: T, source: ToolSource) -> Self {
self.register_wrapper(ToolWrapper::new(tool).with_source(source))
}
pub fn register_all<I>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = ToolWrapper>,
{
for wrapper in tools {
let name = self.namespaced_name(wrapper.tool.name());
self.tools.insert(name, wrapper);
}
self
}
pub fn merge(mut self, other: ToolRegistry) -> Self {
for (name, wrapper) in other.tools {
if self.tools.contains_key(&name) {
let new_name = if let Some(prefix) = &other.namespace_prefix {
format!("{prefix}::{name}")
} else {
format!("merged::{name}")
};
self.tools.insert(new_name, wrapper);
} else {
self.tools.insert(name, wrapper);
}
}
self
}
pub fn filter_by_category(&self, category: ToolCategory) -> Vec<&ToolWrapper> {
self.tools
.values()
.filter(|w| w.metadata.categories.contains(&category))
.filter(|w| w.metadata.enabled)
.collect()
}
pub fn filter_by_source(&self, source: ToolSource) -> Vec<&ToolWrapper> {
self.tools
.values()
.filter(|w| w.metadata.source == source)
.filter(|w| w.metadata.enabled)
.collect()
}
pub fn filter_by_tags(&self, tags: &[String]) -> Vec<&ToolWrapper> {
self.tools
.values()
.filter(|w| tags.iter().any(|tag| w.metadata.tags.contains(tag)) && w.metadata.enabled)
.collect()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.filter(|w| w.metadata.enabled)
.map(|wrapper| ToolDefinition {
name: wrapper.tool.name().to_string(),
description: wrapper.tool.description().map(|s| s.to_string()),
input_schema: wrapper.tool.input_schema(),
})
.collect()
}
pub fn enabled_tools(&self) -> Vec<Arc<dyn Tool>> {
self.tools
.values()
.filter(|w| w.metadata.enabled)
.map(|w| w.tool.clone())
.collect()
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
if let Some(wrapper) = self.tools.get(name) {
if wrapper.metadata.enabled {
return Some(wrapper.tool.clone());
}
return None;
}
if let Some(prefix) = &self.namespace_prefix {
let namespaced = format!("{prefix}::{name}");
if let Some(wrapper) = self.tools.get(&namespaced)
&& wrapper.metadata.enabled
{
return Some(wrapper.tool.clone());
}
}
None
}
pub fn set_tool_enabled(&mut self, name: &str, enabled: bool) -> bool {
if let Some(wrapper) = self.tools.get_mut(name) {
wrapper.metadata.enabled = enabled;
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn enabled_len(&self) -> usize {
self.tools.values().filter(|w| w.metadata.enabled).count()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn tool_names(&self) -> Vec<&String> {
self.tools.keys().collect()
}
pub fn clear(&mut self) {
self.tools.clear();
}
pub async fn call_tool(
&self,
name: &str,
input: serde_json::Value,
) -> Result<serde_json::Value, rucora_core::error::ToolError> {
let tool = self
.get(name)
.ok_or_else(|| rucora_core::error::ToolError::NotFound {
name: name.to_string(),
})?;
tool.call(input).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use rucora_core::error::ToolError;
use serde_json::Value;
use serde_json::json;
struct TestTool {
name: String,
}
#[async_trait::async_trait]
impl Tool for TestTool {
fn name(&self) -> &str {
&self.name
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::Basic]
}
fn input_schema(&self) -> Value {
json!({"type": "object"})
}
async fn call(&self, _input: Value) -> Result<Value, ToolError> {
Ok(json!({"ok": true}))
}
}
#[test]
fn test_tool_registry_namespace() {
let registry = ToolRegistry::new()
.with_namespace("test")
.register(TestTool {
name: "my_tool".to_string(),
});
assert!(registry.get("test::my_tool").is_some());
assert!(registry.get("my_tool").is_some());
}
#[test]
fn test_tool_registry_filter_by_source() {
let registry = ToolRegistry::new()
.register_with_source(
TestTool {
name: "builtin_tool".to_string(),
},
ToolSource::BuiltIn,
)
.register_with_source(
TestTool {
name: "skill_tool".to_string(),
},
ToolSource::Skill,
);
let builtin_tools = registry.filter_by_source(ToolSource::BuiltIn);
assert_eq!(builtin_tools.len(), 1);
assert_eq!(builtin_tools[0].tool.name(), "builtin_tool");
let skill_tools = registry.filter_by_source(ToolSource::Skill);
assert_eq!(skill_tools.len(), 1);
assert_eq!(skill_tools[0].tool.name(), "skill_tool");
}
#[test]
fn test_tool_registry_merge() {
let registry1 = ToolRegistry::new()
.with_namespace("ns1")
.register(TestTool {
name: "tool".to_string(),
});
let registry2 = ToolRegistry::new()
.with_namespace("ns2")
.register(TestTool {
name: "tool".to_string(),
});
let merged = registry1.merge(registry2);
assert!(merged.get("ns1::tool").is_some());
assert!(merged.get("ns2::tool").is_some());
}
}