use super::device::DeviceRegistry;
use super::gpio::gpio_tools;
use super::loader::scan_plugin_dir;
use crate::tools::traits::{Tool, ToolResult};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Debug, Error)]
pub enum ToolError {
#[error("unknown tool: '{0}'")]
UnknownTool(String),
#[error("tool execution failed: {0}")]
ExecutionFailed(String),
}
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
device_registry: Arc<RwLock<DeviceRegistry>>,
}
impl ToolRegistry {
pub async fn load(devices: Arc<RwLock<DeviceRegistry>>) -> anyhow::Result<Self> {
let mut tools: HashMap<String, Box<dyn Tool>> = HashMap::new();
for tool in gpio_tools(devices.clone()) {
let name = tool.name().to_string();
if tools.contains_key(&name) {
anyhow::bail!("duplicate built-in tool name: '{}'", name);
}
println!("[registry] loaded built-in: {}", name);
tools.insert(name, tool);
}
#[cfg(feature = "hardware")]
{
let tool: Box<dyn Tool> =
Box::new(super::pico_flash::PicoFlashTool::new(devices.clone()));
let name = tool.name().to_string();
if tools.contains_key(&name) {
anyhow::bail!("duplicate built-in tool name: '{}'", name);
}
println!("[registry] loaded built-in: {}", name);
tools.insert(name, tool);
}
#[cfg(feature = "hardware")]
{
for tool in super::pico_code::device_code_tools(devices.clone()) {
let name = tool.name().to_string();
if tools.contains_key(&name) {
anyhow::bail!("duplicate built-in tool name: '{}'", name);
}
println!("[registry] loaded built-in: {}", name);
tools.insert(name, tool);
}
}
#[cfg(feature = "hardware")]
{
let has_aardvark = {
let reg = devices.read().await;
reg.has_aardvark()
};
if has_aardvark {
for tool in super::aardvark_tools::aardvark_tools(devices.clone()) {
let name = tool.name().to_string();
if tools.contains_key(&name) {
anyhow::bail!("duplicate built-in tool name: '{}'", name);
}
println!("[registry] loaded built-in: {}", name);
tools.insert(name, tool);
}
{
let tool: Box<dyn Tool> = Box::new(super::datasheet::DatasheetTool::new());
let name = tool.name().to_string();
if tools.contains_key(&name) {
anyhow::bail!("duplicate built-in tool name: '{}'", name);
}
println!("[registry] loaded built-in: {}", name);
tools.insert(name, tool);
}
}
}
let plugins = scan_plugin_dir();
for plugin in plugins {
if tools.contains_key(&plugin.name) {
anyhow::bail!(
"duplicate tool name: plugin '{}' conflicts with an existing tool",
plugin.name
);
}
println!(
"[registry] loaded plugin: {} (v{})",
plugin.name, plugin.version
);
tools.insert(plugin.name, plugin.tool);
}
println!("[registry] {} tools available", tools.len());
{
let reg = devices.read().await;
let mut aliases = reg.aliases();
aliases.sort_unstable(); for alias in aliases {
if let Some(device) = reg.get_device(alias) {
let port = device.port().unwrap_or("(native)");
println!("[registry] {} ready → {}", alias, port);
}
}
}
Ok(Self {
tools,
device_registry: devices,
})
}
pub fn schemas(&self) -> Vec<serde_json::Value> {
let mut schemas: Vec<serde_json::Value> = self
.tools
.values()
.map(|tool| {
serde_json::json!({
"name": tool.name(),
"description": tool.description(),
"parameters": tool.parameters_schema(),
})
})
.collect();
schemas.sort_by(|a, b| {
a["name"]
.as_str()
.unwrap_or("")
.cmp(b["name"].as_str().unwrap_or(""))
});
schemas
}
pub async fn dispatch(
&self,
name: &str,
args: serde_json::Value,
) -> Result<ToolResult, ToolError> {
let tool = self
.tools
.get(name)
.ok_or_else(|| ToolError::UnknownTool(name.to_string()))?;
tool.execute(args)
.await
.map_err(|e| ToolError::ExecutionFailed(e.to_string()))
}
pub fn list(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.tools.keys().map(|s| s.as_str()).collect();
names.sort_unstable();
names
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn device_registry(&self) -> Arc<RwLock<DeviceRegistry>> {
self.device_registry.clone()
}
pub fn into_tools(self) -> Vec<Box<dyn Tool>> {
let mut pairs: Vec<(String, Box<dyn Tool>)> = self.tools.into_iter().collect();
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
pairs.into_iter().map(|(_, tool)| tool).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn empty_device_registry() -> Arc<RwLock<DeviceRegistry>> {
Arc::new(RwLock::new(DeviceRegistry::new()))
}
#[tokio::test]
async fn load_registers_builtin_gpio_tools() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let names = registry.list();
assert!(
names.contains(&"gpio_write"),
"gpio_write missing; got: {:?}",
names
);
assert!(
names.contains(&"gpio_read"),
"gpio_read missing; got: {:?}",
names
);
assert!(registry.len() >= 2);
}
#[cfg(feature = "hardware")]
#[tokio::test]
async fn hardware_feature_registers_all_six_tools() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let names = registry.list();
let expected = [
"device_exec",
"device_read_code",
"device_write_code",
"gpio_read",
"gpio_write",
"pico_flash",
];
for tool_name in &expected {
assert!(
names.contains(tool_name),
"expected tool '{}' missing; got: {:?}",
tool_name,
names
);
}
assert_eq!(
registry.len(),
6,
"expected exactly 6 built-in tools, got {} (names: {:?})",
registry.len(),
names
);
}
#[tokio::test]
async fn schemas_returns_valid_json_schema_array() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let schemas = registry.schemas();
assert!(!schemas.is_empty());
for schema in &schemas {
assert!(schema["name"].is_string(), "name missing in schema");
assert!(schema["description"].is_string(), "description missing");
assert!(
schema["parameters"]["type"] == "object",
"parameters.type should be object"
);
}
}
#[tokio::test]
async fn schemas_are_sorted_by_name() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let schemas = registry.schemas();
let names: Vec<&str> = schemas
.iter()
.map(|s| s["name"].as_str().unwrap_or(""))
.collect();
let mut sorted = names.clone();
sorted.sort_unstable();
assert_eq!(names, sorted, "schemas not sorted by name");
}
#[tokio::test]
async fn dispatch_unknown_tool_returns_error() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let result = registry
.dispatch("nonexistent_tool", serde_json::json!({}))
.await;
match result {
Err(ToolError::UnknownTool(name)) => assert_eq!(name, "nonexistent_tool"),
other => panic!("expected UnknownTool, got: {:?}", other),
}
}
#[tokio::test]
async fn list_returns_sorted_tool_names() {
let devices = empty_device_registry();
let registry = ToolRegistry::load(devices).await.expect("load failed");
let names = registry.list();
let mut sorted = names.clone();
sorted.sort_unstable();
assert_eq!(
names, sorted,
"list() should return sorted names; got: {:?}",
names
);
}
#[test]
fn tool_error_display() {
let e = ToolError::UnknownTool("bad_tool".to_string());
assert_eq!(e.to_string(), "unknown tool: 'bad_tool'");
let e = ToolError::ExecutionFailed("oops".to_string());
assert_eq!(e.to_string(), "tool execution failed: oops");
}
}