use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use crate::traits::tool::ErasedTool;
use crate::traits::tool_registry::ToolRegistry;
struct ToolEntry {
tool: Arc<dyn ErasedTool>,
enabled: bool,
}
pub struct DynamicRegistry {
tools: RwLock<Vec<ToolEntry>>,
}
impl DynamicRegistry {
#[must_use]
pub fn new() -> Self {
Self {
tools: RwLock::new(Vec::new()),
}
}
#[must_use]
pub fn with_tools(tools: Vec<Arc<dyn ErasedTool>>) -> Self {
let entries = tools
.into_iter()
.map(|tool| ToolEntry {
tool,
enabled: true,
})
.collect();
Self {
tools: RwLock::new(entries),
}
}
}
impl Default for DynamicRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry for DynamicRegistry {
fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
tools
.iter()
.filter(|e| e.enabled)
.map(|e| Arc::clone(&e.tool))
.collect()
}
fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
tools
.iter()
.find(|e| e.enabled && e.tool.name() == name)
.map(|e| Arc::clone(&e.tool))
}
fn register(&self, tool: Arc<dyn ErasedTool>) -> bool {
let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
let name = tool.name().to_string();
if tools.iter().any(|e| e.tool.name() == name) {
return false;
}
tools.push(ToolEntry {
tool,
enabled: true,
});
true
}
fn unregister(&self, name: &str) -> bool {
let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
let len_before = tools.len();
tools.retain(|e| e.tool.name() != name);
tools.len() < len_before
}
fn set_enabled(&self, name: &str, enabled: bool) -> bool {
let mut tools = self.tools.write().expect("DynamicRegistry lock poisoned");
if let Some(entry) = tools.iter_mut().find(|e| e.tool.name() == name) {
if entry.enabled != enabled {
entry.enabled = enabled;
return true;
}
}
false
}
fn is_enabled(&self, name: &str) -> bool {
let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
tools.iter().any(|e| e.tool.name() == name && e.enabled)
}
fn len(&self) -> usize {
let tools = self.tools.read().expect("DynamicRegistry lock poisoned");
tools.iter().filter(|e| e.enabled).count()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct GroupedRegistry {
groups: RwLock<HashMap<String, Vec<Arc<dyn ErasedTool>>>>,
active_groups: RwLock<HashSet<String>>,
}
impl GroupedRegistry {
#[must_use]
pub fn new() -> Self {
Self {
groups: RwLock::new(HashMap::new()),
active_groups: RwLock::new(HashSet::new()),
}
}
#[must_use]
pub fn group(self, name: impl Into<String>, tools: Vec<Arc<dyn ErasedTool>>) -> Self {
{
let mut groups = self.groups.write().expect("GroupedRegistry lock poisoned");
groups.insert(name.into(), tools);
}
self
}
#[must_use]
pub fn activate(self, name: impl Into<String>) -> Self {
{
let mut active = self
.active_groups
.write()
.expect("GroupedRegistry lock poisoned");
active.insert(name.into());
}
self
}
pub fn activate_group(&self, name: &str) -> bool {
let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
if groups.contains_key(name) {
let mut active = self
.active_groups
.write()
.expect("GroupedRegistry lock poisoned");
active.insert(name.to_string());
true
} else {
false
}
}
pub fn deactivate_group(&self, name: &str) -> bool {
let mut active = self
.active_groups
.write()
.expect("GroupedRegistry lock poisoned");
active.remove(name)
}
#[must_use]
pub fn group_names(&self) -> Vec<String> {
let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
groups.keys().cloned().collect()
}
#[must_use]
pub fn active_group_names(&self) -> Vec<String> {
let active = self
.active_groups
.read()
.expect("GroupedRegistry lock poisoned");
active.iter().cloned().collect()
}
#[must_use]
pub fn is_group_active(&self, name: &str) -> bool {
let active = self
.active_groups
.read()
.expect("GroupedRegistry lock poisoned");
active.contains(name)
}
}
impl Default for GroupedRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry for GroupedRegistry {
fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
let active = self
.active_groups
.read()
.expect("GroupedRegistry lock poisoned");
let mut tools = Vec::new();
for group_name in active.iter() {
if let Some(group_tools) = groups.get(group_name) {
for tool in group_tools {
tools.push(Arc::clone(tool));
}
}
}
tools
}
fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
for tools in groups.values() {
if let Some(tool) = tools.iter().find(|t| t.name() == name) {
return Some(Arc::clone(tool));
}
}
None
}
fn len(&self) -> usize {
let groups = self.groups.read().expect("GroupedRegistry lock poisoned");
let active = self
.active_groups
.read()
.expect("GroupedRegistry lock poisoned");
let mut count = 0;
for group_name in active.iter() {
if let Some(group_tools) = groups.get(group_name) {
count += group_tools.len();
}
}
count
}
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Copy)]
pub struct TierLimits {
pub small: usize,
pub medium: usize,
pub large: usize,
}
impl Default for TierLimits {
fn default() -> Self {
Self {
small: 5,
medium: 15,
large: usize::MAX,
}
}
}
pub struct AdaptiveRegistry {
tools: Vec<Arc<dyn ErasedTool>>,
limits: TierLimits,
tier: crate::types::model_info::ModelTier,
}
impl AdaptiveRegistry {
#[must_use]
pub fn new(tools: Vec<Arc<dyn ErasedTool>>, tier: crate::types::model_info::ModelTier) -> Self {
Self {
tools,
limits: TierLimits::default(),
tier,
}
}
#[must_use]
pub fn with_limits(mut self, small: usize, medium: usize, large: usize) -> Self {
self.limits = TierLimits {
small,
medium,
large,
};
self
}
#[must_use]
pub fn limits(&self) -> TierLimits {
self.limits
}
#[must_use]
pub fn tier(&self) -> crate::types::model_info::ModelTier {
self.tier
}
fn effective_limit(&self) -> usize {
use crate::types::model_info::ModelTier;
match self.tier {
ModelTier::Small => self.limits.small,
ModelTier::Medium => self.limits.medium,
ModelTier::Large => self.limits.large,
}
}
}
impl ToolRegistry for AdaptiveRegistry {
fn get_tools(&self) -> Vec<Arc<dyn ErasedTool>> {
let limit = self.effective_limit();
self.tools
.iter()
.take(limit)
.map(|t| Arc::clone(t))
.collect()
}
fn find_tool(&self, name: &str) -> Option<Arc<dyn ErasedTool>> {
self.tools
.iter()
.find(|t| t.name() == name)
.map(|t| Arc::clone(t))
}
fn len(&self) -> usize {
let limit = self.effective_limit();
self.tools.len().min(limit)
}
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
struct FakeTool {
tool_name: String,
}
impl FakeTool {
fn new(name: &str) -> Self {
Self {
tool_name: name.to_string(),
}
}
}
#[async_trait]
impl ErasedTool for FakeTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
"fake"
}
fn schema(&self) -> crate::traits::tool::ToolSchema {
crate::traits::tool::ToolSchema {
name: self.tool_name.clone(),
description: "fake".to_string(),
parameters: serde_json::json!({}),
}
}
async fn execute_json(
&self,
_input: serde_json::Value,
) -> crate::Result<serde_json::Value> {
Ok(serde_json::json!("ok"))
}
}
#[test]
fn test_dynamic_registry_empty() {
let reg = DynamicRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
}
#[test]
fn test_dynamic_registry_register_and_find() {
let reg = DynamicRegistry::new();
assert!(reg.register(Arc::new(FakeTool::new("search"))));
assert_eq!(reg.len(), 1);
assert!(reg.find_tool("search").is_some());
assert!(reg.find_tool("calc").is_none());
}
#[test]
fn test_dynamic_registry_no_duplicates() {
let reg = DynamicRegistry::new();
assert!(reg.register(Arc::new(FakeTool::new("search"))));
assert!(!reg.register(Arc::new(FakeTool::new("search"))));
assert_eq!(reg.len(), 1);
}
#[test]
fn test_dynamic_registry_unregister() {
let reg = DynamicRegistry::new();
reg.register(Arc::new(FakeTool::new("search")));
assert!(reg.unregister("search"));
assert!(reg.is_empty());
assert!(!reg.unregister("search")); }
#[test]
fn test_dynamic_registry_set_enabled() {
let reg = DynamicRegistry::new();
reg.register(Arc::new(FakeTool::new("search")));
assert!(reg.is_enabled("search"));
assert!(reg.set_enabled("search", false));
assert!(!reg.is_enabled("search"));
assert_eq!(reg.len(), 0); assert!(reg.find_tool("search").is_none());
assert!(reg.set_enabled("search", true));
assert!(reg.is_enabled("search"));
assert_eq!(reg.len(), 1);
}
#[test]
fn test_dynamic_registry_set_enabled_no_change() {
let reg = DynamicRegistry::new();
reg.register(Arc::new(FakeTool::new("search")));
assert!(!reg.set_enabled("search", true));
}
#[test]
fn test_dynamic_registry_with_tools() {
let tools: Vec<Arc<dyn ErasedTool>> =
vec![Arc::new(FakeTool::new("a")), Arc::new(FakeTool::new("b"))];
let reg = DynamicRegistry::with_tools(tools);
assert_eq!(reg.len(), 2);
assert!(reg.is_enabled("a"));
assert!(reg.is_enabled("b"));
}
#[test]
fn test_dynamic_registry_get_tools_only_enabled() {
let reg = DynamicRegistry::new();
reg.register(Arc::new(FakeTool::new("a")));
reg.register(Arc::new(FakeTool::new("b")));
reg.set_enabled("a", false);
let tools = reg.get_tools();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name(), "b");
}
#[test]
fn test_grouped_registry_empty() {
let reg = GroupedRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
assert!(reg.get_tools().is_empty());
assert!(reg.find_tool("anything").is_none());
}
#[test]
fn test_grouped_registry_single_group() {
let tools: Vec<Arc<dyn ErasedTool>> = vec![
Arc::new(FakeTool::new("web_search")),
Arc::new(FakeTool::new("deep_search")),
];
let reg = GroupedRegistry::new()
.group("search", tools)
.activate("search");
assert_eq!(reg.len(), 2);
assert!(!reg.is_empty());
let active = reg.get_tools();
assert_eq!(active.len(), 2);
}
#[test]
fn test_grouped_registry_multiple_groups_activate_switch() {
let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
let code_tools: Vec<Arc<dyn ErasedTool>> = vec![
Arc::new(FakeTool::new("read_file")),
Arc::new(FakeTool::new("write_file")),
];
let reg = GroupedRegistry::new()
.group("search", search_tools)
.group("code", code_tools)
.activate("search");
assert_eq!(reg.len(), 1);
assert_eq!(reg.get_tools()[0].name(), "web_search");
assert!(reg.deactivate_group("search"));
assert!(reg.is_empty());
assert!(reg.activate_group("code"));
assert_eq!(reg.len(), 2);
let names: Vec<String> = reg
.get_tools()
.iter()
.map(|t| t.name().to_string())
.collect();
assert!(names.contains(&"read_file".to_string()));
assert!(names.contains(&"write_file".to_string()));
}
#[test]
fn test_grouped_registry_multiple_active_groups() {
let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
let code_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("read_file"))];
let reg = GroupedRegistry::new()
.group("search", search_tools)
.group("code", code_tools)
.activate("search")
.activate("code");
assert_eq!(reg.len(), 2);
let names: Vec<String> = reg
.get_tools()
.iter()
.map(|t| t.name().to_string())
.collect();
assert!(names.contains(&"web_search".to_string()));
assert!(names.contains(&"read_file".to_string()));
}
#[test]
fn test_grouped_registry_find_tool_searches_all_groups() {
let search_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("web_search"))];
let code_tools: Vec<Arc<dyn ErasedTool>> = vec![Arc::new(FakeTool::new("read_file"))];
let reg = GroupedRegistry::new()
.group("search", search_tools)
.group("code", code_tools)
.activate("search");
assert_eq!(reg.get_tools().len(), 1);
assert!(reg.find_tool("web_search").is_some()); assert!(reg.find_tool("read_file").is_some()); assert!(reg.find_tool("nonexistent").is_none());
}
#[test]
fn test_grouped_registry_activate_nonexistent_group() {
let reg = GroupedRegistry::new().group("search", vec![Arc::new(FakeTool::new("a"))]);
assert!(!reg.activate_group("nonexistent"));
assert!(reg.activate_group("search"));
}
#[test]
fn test_grouped_registry_deactivate_nonexistent() {
let reg = GroupedRegistry::new();
assert!(!reg.deactivate_group("nonexistent"));
}
#[test]
fn test_grouped_registry_group_names() {
let reg = GroupedRegistry::new()
.group("search", vec![])
.group("code", vec![])
.activate("search");
let mut names = reg.group_names();
names.sort();
assert_eq!(names, vec!["code", "search"]);
let active = reg.active_group_names();
assert_eq!(active.len(), 1);
assert!(active.contains(&"search".to_string()));
}
#[test]
fn test_grouped_registry_is_group_active() {
let reg = GroupedRegistry::new()
.group("search", vec![])
.group("code", vec![])
.activate("search");
assert!(reg.is_group_active("search"));
assert!(!reg.is_group_active("code"));
}
#[test]
fn test_grouped_registry_concurrent_read() {
use std::thread;
let reg = Arc::new(
GroupedRegistry::new()
.group("a", vec![Arc::new(FakeTool::new("tool_a"))])
.group("b", vec![Arc::new(FakeTool::new("tool_b"))])
.activate("a"),
);
let mut handles = vec![];
for _ in 0..10 {
let reg_clone = Arc::clone(®);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let tools = reg_clone.get_tools();
assert_eq!(tools.len(), 1);
assert!(reg_clone.find_tool("tool_a").is_some());
assert!(reg_clone.find_tool("tool_b").is_some());
}
}));
}
for h in handles {
h.join().expect("thread panicked");
}
}
#[test]
fn test_grouped_registry_object_safe() {
let reg = GroupedRegistry::new();
let _: Arc<dyn ToolRegistry> = Arc::new(reg);
}
#[test]
fn test_grouped_registry_replace_group() {
let reg = GroupedRegistry::new()
.group("search", vec![Arc::new(FakeTool::new("old_tool"))])
.group("search", vec![Arc::new(FakeTool::new("new_tool"))])
.activate("search");
assert_eq!(reg.len(), 1);
assert!(reg.find_tool("new_tool").is_some());
assert!(reg.find_tool("old_tool").is_none());
}
fn make_tools(n: usize) -> Vec<Arc<dyn ErasedTool>> {
(0..n)
.map(|i| Arc::new(FakeTool::new(&format!("tool_{i}"))) as Arc<dyn ErasedTool>)
.collect()
}
#[test]
fn test_adaptive_registry_small_tier_limits() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small);
assert_eq!(reg.len(), 5);
assert_eq!(reg.get_tools().len(), 5);
assert_eq!(reg.get_tools()[0].name(), "tool_0");
assert_eq!(reg.get_tools()[4].name(), "tool_4");
}
#[test]
fn test_adaptive_registry_medium_tier_limits() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Medium);
assert_eq!(reg.len(), 15);
assert_eq!(reg.get_tools().len(), 15);
}
#[test]
fn test_adaptive_registry_large_tier_all() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Large);
assert_eq!(reg.len(), 30);
assert_eq!(reg.get_tools().len(), 30);
}
#[test]
fn test_adaptive_registry_custom_limits() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small).with_limits(3, 10, 50);
assert_eq!(reg.len(), 3);
assert_eq!(reg.get_tools().len(), 3);
}
#[test]
fn test_adaptive_registry_find_tool_beyond_limit() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(make_tools(30), ModelTier::Small);
assert!(reg.find_tool("tool_29").is_some());
assert!(reg.find_tool("tool_0").is_some());
assert!(reg.find_tool("nonexistent").is_none());
}
#[test]
fn test_adaptive_registry_empty() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(vec![], ModelTier::Large);
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
}
#[test]
fn test_adaptive_registry_object_safe() {
use crate::types::model_info::ModelTier;
let reg = AdaptiveRegistry::new(vec![], ModelTier::Medium);
let _: Arc<dyn ToolRegistry> = Arc::new(reg);
}
}