use car_ir::ToolSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolPermission {
Allow,
AskUser,
Deny,
}
impl Default for ToolPermission {
fn default() -> Self {
Self::AskUser
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolSource {
Builtin,
UserDefined,
Subprocess,
Mcp { server_name: String },
}
#[derive(Debug, Clone)]
pub struct ToolEntry {
pub schema: ToolSchema,
pub permission: ToolPermission,
pub source: ToolSource,
pub side_effects: bool,
pub category: Option<String>,
}
impl ToolEntry {
pub fn new(schema: ToolSchema) -> Self {
Self {
schema,
permission: ToolPermission::default(),
source: ToolSource::UserDefined,
side_effects: true,
category: None,
}
}
pub fn builtin(schema: ToolSchema) -> Self {
Self {
permission: ToolPermission::Allow,
source: ToolSource::Builtin,
side_effects: false,
category: None,
schema,
}
}
pub fn with_permission(mut self, perm: ToolPermission) -> Self {
self.permission = perm;
self
}
pub fn with_source(mut self, source: ToolSource) -> Self {
self.source = source;
self
}
pub fn with_side_effects(mut self, side_effects: bool) -> Self {
self.side_effects = side_effects;
self
}
pub fn with_category(mut self, category: &str) -> Self {
self.category = Some(category.to_string());
self
}
}
#[derive(Debug, Clone)]
pub struct RegistryValidationError {
pub tool_name: String,
pub message: String,
}
pub struct ToolRegistry {
entries: RwLock<HashMap<String, ToolEntry>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
}
}
pub async fn register(&self, entry: ToolEntry) {
let name = entry.schema.name.clone();
self.entries.write().await.insert(name, entry);
}
pub async fn get(&self, name: &str) -> Option<ToolEntry> {
self.entries.read().await.get(name).cloned()
}
pub async fn contains(&self, name: &str) -> bool {
self.entries.read().await.contains_key(name)
}
pub async fn remove(&self, name: &str) -> Option<ToolEntry> {
self.entries.write().await.remove(name)
}
pub async fn names(&self) -> Vec<String> {
self.entries.read().await.keys().cloned().collect()
}
pub async fn entries(&self) -> Vec<ToolEntry> {
self.entries.read().await.values().cloned().collect()
}
pub async fn schemas(&self) -> Vec<ToolSchema> {
self.entries
.read()
.await
.values()
.map(|e| e.schema.clone())
.collect()
}
pub async fn allowed_schemas(&self) -> Vec<ToolSchema> {
self.entries
.read()
.await
.values()
.filter(|e| e.permission != ToolPermission::Deny)
.map(|e| e.schema.clone())
.collect()
}
pub async fn by_source(&self, source_match: &ToolSource) -> Vec<ToolEntry> {
self.entries
.read()
.await
.values()
.filter(|e| std::mem::discriminant(&e.source) == std::mem::discriminant(source_match))
.cloned()
.collect()
}
pub async fn by_category(&self, category: &str) -> Vec<ToolEntry> {
self.entries
.read()
.await
.values()
.filter(|e| e.category.as_deref() == Some(category))
.cloned()
.collect()
}
pub async fn validate(&self) -> Vec<RegistryValidationError> {
let entries = self.entries.read().await;
let mut errors = Vec::new();
for (name, entry) in entries.iter() {
if entry.schema.name != *name {
errors.push(RegistryValidationError {
tool_name: name.clone(),
message: format!(
"schema name '{}' doesn't match registry key '{}'",
entry.schema.name, name
),
});
}
if entry.schema.description.is_empty() {
errors.push(RegistryValidationError {
tool_name: name.clone(),
message: "missing description".to_string(),
});
}
}
errors
}
pub async fn to_schema_map(&self) -> HashMap<String, ToolSchema> {
self.entries
.read()
.await
.iter()
.map(|(k, v)| (k.clone(), v.schema.clone()))
.collect()
}
pub async fn len(&self) -> usize {
self.entries.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.entries.read().await.is_empty()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_schema(name: &str) -> ToolSchema {
ToolSchema {
name: name.to_string(),
description: format!("{} tool", name),
parameters: serde_json::json!({"type": "object"}),
returns: None,
idempotent: false,
cache_ttl_secs: None,
rate_limit: None,
}
}
#[tokio::test]
async fn test_register_and_get() {
let reg = ToolRegistry::new();
let entry = ToolEntry::new(test_schema("search"))
.with_permission(ToolPermission::Allow)
.with_category("network");
reg.register(entry).await;
let got = reg.get("search").await.unwrap();
assert_eq!(got.schema.name, "search");
assert_eq!(got.permission, ToolPermission::Allow);
assert_eq!(got.category.as_deref(), Some("network"));
}
#[tokio::test]
async fn test_allowed_schemas_excludes_denied() {
let reg = ToolRegistry::new();
reg.register(ToolEntry::new(test_schema("read")).with_permission(ToolPermission::Allow))
.await;
reg.register(ToolEntry::new(test_schema("delete")).with_permission(ToolPermission::Deny))
.await;
reg.register(ToolEntry::new(test_schema("write")).with_permission(ToolPermission::AskUser))
.await;
let allowed = reg.allowed_schemas().await;
assert_eq!(allowed.len(), 2);
assert!(allowed.iter().all(|s| s.name != "delete"));
}
#[tokio::test]
async fn test_validation() {
let reg = ToolRegistry::new();
let mut bad_schema = test_schema("good");
bad_schema.description = String::new();
reg.register(ToolEntry::new(bad_schema)).await;
let errors = reg.validate().await;
assert_eq!(errors.len(), 1);
assert!(errors[0].message.contains("missing description"));
}
#[tokio::test]
async fn test_by_source() {
let reg = ToolRegistry::new();
reg.register(ToolEntry::builtin(test_schema("infer"))).await;
reg.register(ToolEntry::new(test_schema("search"))).await;
let builtins = reg.by_source(&ToolSource::Builtin).await;
assert_eq!(builtins.len(), 1);
assert_eq!(builtins[0].schema.name, "infer");
}
}