use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
use super::types::Executable;
#[derive(Error, Debug)]
pub enum RegistryError {
#[error("Tool with name {0:?} already exists")]
DuplicateTool(String),
}
impl From<RegistryError> for String {
fn from(err: RegistryError) -> Self {
err.to_string()
}
}
pub struct ToolRegistry {
tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: RwLock::new(HashMap::new()),
}
}
pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), RegistryError> {
let name = tool.name().to_string();
let mut tools = self.tools.write().await;
if tools.contains_key(&name) {
return Err(RegistryError::DuplicateTool(name));
}
tools.insert(name, tool);
Ok(())
}
pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
let tools = self.tools.read().await;
tools.get(name).cloned()
}
pub async fn has(&self, name: &str) -> bool {
let tools = self.tools.read().await;
tools.contains_key(name)
}
pub async fn remove(&self, name: &str) {
let mut tools = self.tools.write().await;
tools.remove(name);
}
pub async fn list(&self) -> Vec<String> {
let tools = self.tools.read().await;
tools.keys().cloned().collect()
}
pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
let tools = self.tools.read().await;
tools.values().cloned().collect()
}
pub async fn len(&self) -> usize {
let tools = self.tools.read().await;
tools.len()
}
pub async fn is_empty(&self) -> bool {
let tools = self.tools.read().await;
tools.is_empty()
}
pub async fn cleanup_session(&self, session_id: i64) {
let tools = self.tools.read().await;
for tool in tools.values() {
tool.cleanup_session(session_id).await;
}
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::controller::tools::types::{ToolContext, ToolType};
use std::future::Future;
use std::pin::Pin;
struct MockTool {
name: String,
}
impl Executable for MockTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"A mock tool for testing"
}
fn input_schema(&self) -> &str {
r#"{"type":"object"}"#
}
fn tool_type(&self) -> ToolType {
ToolType::Custom
}
fn execute(
&self,
_context: ToolContext,
_input: HashMap<String, serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
Box::pin(async { Ok("mock result".to_string()) })
}
}
#[tokio::test]
async fn test_register_and_get() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "test_tool".to_string(),
});
registry.register(tool).await.unwrap();
let retrieved = registry.get("test_tool").await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), "test_tool");
}
#[tokio::test]
async fn test_duplicate_registration() {
let registry = ToolRegistry::new();
let tool1 = Arc::new(MockTool {
name: "test_tool".to_string(),
});
let tool2 = Arc::new(MockTool {
name: "test_tool".to_string(),
});
registry.register(tool1).await.unwrap();
let result = registry.register(tool2).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_list_and_remove() {
let registry = ToolRegistry::new();
let tool = Arc::new(MockTool {
name: "test_tool".to_string(),
});
registry.register(tool).await.unwrap();
assert!(registry.has("test_tool").await);
let names = registry.list().await;
assert_eq!(names.len(), 1);
registry.remove("test_tool").await;
assert!(!registry.has("test_tool").await);
}
}