use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use aonyx_core::ToolHandler;
use crate::bash::Bash;
use crate::fs::{FsEdit, FsGlob, FsGrep, FsRead, FsWrite};
use crate::git::{GitDiff, GitLog, GitShow, GitStatus};
use crate::web::{WebFetch, WebSearch};
#[derive(Default, Clone)]
pub struct ToolRegistry {
handlers: HashMap<String, Arc<dyn ToolHandler>>,
disabled: Arc<Mutex<HashSet<String>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, handler: Arc<dyn ToolHandler>) {
self.handlers.insert(handler.name().to_string(), handler);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
if self.is_disabled(name) {
return None;
}
self.handlers.get(name).cloned()
}
pub fn get_raw(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
self.handlers.get(name).cloned()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.handlers.keys().map(String::as_str)
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
pub fn is_disabled(&self, name: &str) -> bool {
self.disabled
.lock()
.map(|d| d.contains(name))
.unwrap_or(false)
}
pub fn disable(&self, name: &str) {
if let Ok(mut d) = self.disabled.lock() {
d.insert(name.to_string());
}
}
pub fn enable(&self, name: &str) {
if let Ok(mut d) = self.disabled.lock() {
d.remove(name);
}
}
pub fn toggle(&self, name: &str) -> bool {
if let Ok(mut d) = self.disabled.lock() {
if d.contains(name) {
d.remove(name);
false
} else {
d.insert(name.to_string());
true
}
} else {
false
}
}
pub fn default_set() -> Self {
let mut r = Self::new();
r.register(Arc::new(FsRead));
r.register(Arc::new(FsWrite));
r.register(Arc::new(FsEdit));
r.register(Arc::new(FsGlob));
r.register(Arc::new(FsGrep));
r.register(Arc::new(Bash));
r.register(Arc::new(GitStatus));
r.register(Arc::new(GitDiff));
r.register(Arc::new(GitLog));
r.register(Arc::new(GitShow));
r.register(Arc::new(WebFetch));
r.register(Arc::new(WebSearch));
r
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_set_registers_every_v1_tool() {
let r = ToolRegistry::default_set();
let mut names: Vec<&str> = r.names().collect();
names.sort();
assert_eq!(
names,
vec![
"bash",
"fs_edit",
"fs_glob",
"fs_grep",
"fs_read",
"fs_write",
"git_diff",
"git_log",
"git_show",
"git_status",
"web_fetch",
"web_search",
]
);
assert_eq!(r.len(), 12);
}
#[test]
fn get_returns_none_for_unknown_tool() {
let r = ToolRegistry::default_set();
assert!(r.get("does_not_exist").is_none());
assert!(r.get("bash").is_some());
}
#[test]
fn disable_hides_tool_from_get_but_not_from_names() {
let r = ToolRegistry::default_set();
r.disable("bash");
assert!(r.is_disabled("bash"));
assert!(r.get("bash").is_none());
assert!(r.get_raw("bash").is_some());
let names: Vec<&str> = r.names().collect();
assert!(names.contains(&"bash"));
}
#[test]
fn enable_after_disable_restores_visibility() {
let r = ToolRegistry::default_set();
r.disable("bash");
r.enable("bash");
assert!(!r.is_disabled("bash"));
assert!(r.get("bash").is_some());
}
#[test]
fn toggle_flips_and_returns_new_state() {
let r = ToolRegistry::default_set();
assert!(r.toggle("bash")); assert!(r.is_disabled("bash"));
assert!(!r.toggle("bash")); assert!(!r.is_disabled("bash"));
}
#[test]
fn disabled_state_is_shared_across_clones() {
let a = ToolRegistry::default_set();
let b = a.clone();
a.disable("bash");
assert!(b.is_disabled("bash"));
assert!(b.get("bash").is_none());
}
}