1use car_ir::ToolSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ToolPermission {
16 Allow,
18 AskUser,
20 Deny,
22}
23
24impl Default for ToolPermission {
25 fn default() -> Self {
26 Self::AskUser
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ToolSource {
34 Builtin,
36 UserDefined,
38 Subprocess,
40 Mcp { server_name: String },
42}
43
44#[derive(Debug, Clone)]
46pub struct ToolEntry {
47 pub schema: ToolSchema,
49 pub permission: ToolPermission,
51 pub source: ToolSource,
53 pub side_effects: bool,
55 pub category: Option<String>,
57}
58
59impl ToolEntry {
60 pub fn new(schema: ToolSchema) -> Self {
61 Self {
62 schema,
63 permission: ToolPermission::default(),
64 source: ToolSource::UserDefined,
65 side_effects: true,
66 category: None,
67 }
68 }
69
70 pub fn builtin(schema: ToolSchema) -> Self {
71 Self {
72 permission: ToolPermission::Allow,
73 source: ToolSource::Builtin,
74 side_effects: false,
75 category: None,
76 schema,
77 }
78 }
79
80 pub fn with_permission(mut self, perm: ToolPermission) -> Self {
81 self.permission = perm;
82 self
83 }
84
85 pub fn with_source(mut self, source: ToolSource) -> Self {
86 self.source = source;
87 self
88 }
89
90 pub fn with_side_effects(mut self, side_effects: bool) -> Self {
91 self.side_effects = side_effects;
92 self
93 }
94
95 pub fn with_category(mut self, category: &str) -> Self {
96 self.category = Some(category.to_string());
97 self
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct RegistryValidationError {
104 pub tool_name: String,
105 pub message: String,
106}
107
108pub struct ToolRegistry {
110 entries: RwLock<HashMap<String, ToolEntry>>,
111}
112
113impl ToolRegistry {
114 pub fn new() -> Self {
115 Self {
116 entries: RwLock::new(HashMap::new()),
117 }
118 }
119
120 pub async fn register(&self, entry: ToolEntry) {
122 let name = entry.schema.name.clone();
123 self.entries.write().await.insert(name, entry);
124 }
125
126 pub async fn get(&self, name: &str) -> Option<ToolEntry> {
128 self.entries.read().await.get(name).cloned()
129 }
130
131 pub async fn contains(&self, name: &str) -> bool {
133 self.entries.read().await.contains_key(name)
134 }
135
136 pub async fn remove(&self, name: &str) -> Option<ToolEntry> {
138 self.entries.write().await.remove(name)
139 }
140
141 pub async fn names(&self) -> Vec<String> {
143 self.entries.read().await.keys().cloned().collect()
144 }
145
146 pub async fn entries(&self) -> Vec<ToolEntry> {
148 self.entries.read().await.values().cloned().collect()
149 }
150
151 pub async fn schemas(&self) -> Vec<ToolSchema> {
153 self.entries
154 .read()
155 .await
156 .values()
157 .map(|e| e.schema.clone())
158 .collect()
159 }
160
161 pub async fn allowed_schemas(&self) -> Vec<ToolSchema> {
163 self.entries
164 .read()
165 .await
166 .values()
167 .filter(|e| e.permission != ToolPermission::Deny)
168 .map(|e| e.schema.clone())
169 .collect()
170 }
171
172 pub async fn by_source(&self, source_match: &ToolSource) -> Vec<ToolEntry> {
174 self.entries
175 .read()
176 .await
177 .values()
178 .filter(|e| std::mem::discriminant(&e.source) == std::mem::discriminant(source_match))
179 .cloned()
180 .collect()
181 }
182
183 pub async fn by_category(&self, category: &str) -> Vec<ToolEntry> {
185 self.entries
186 .read()
187 .await
188 .values()
189 .filter(|e| e.category.as_deref() == Some(category))
190 .cloned()
191 .collect()
192 }
193
194 pub async fn validate(&self) -> Vec<RegistryValidationError> {
196 let entries = self.entries.read().await;
197 let mut errors = Vec::new();
198 for (name, entry) in entries.iter() {
199 if entry.schema.name != *name {
200 errors.push(RegistryValidationError {
201 tool_name: name.clone(),
202 message: format!(
203 "schema name '{}' doesn't match registry key '{}'",
204 entry.schema.name, name
205 ),
206 });
207 }
208 if entry.schema.description.is_empty() {
209 errors.push(RegistryValidationError {
210 tool_name: name.clone(),
211 message: "missing description".to_string(),
212 });
213 }
214 }
215 errors
216 }
217
218 pub async fn to_schema_map(&self) -> HashMap<String, ToolSchema> {
220 self.entries
221 .read()
222 .await
223 .iter()
224 .map(|(k, v)| (k.clone(), v.schema.clone()))
225 .collect()
226 }
227
228 pub async fn len(&self) -> usize {
230 self.entries.read().await.len()
231 }
232
233 pub async fn is_empty(&self) -> bool {
234 self.entries.read().await.is_empty()
235 }
236}
237
238impl Default for ToolRegistry {
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 fn test_schema(name: &str) -> ToolSchema {
249 ToolSchema {
250 name: name.to_string(),
251 description: format!("{} tool", name),
252 parameters: serde_json::json!({"type": "object"}),
253 returns: None,
254 idempotent: false,
255 cache_ttl_secs: None,
256 rate_limit: None,
257 }
258 }
259
260 #[tokio::test]
261 async fn test_register_and_get() {
262 let reg = ToolRegistry::new();
263 let entry = ToolEntry::new(test_schema("search"))
264 .with_permission(ToolPermission::Allow)
265 .with_category("network");
266 reg.register(entry).await;
267
268 let got = reg.get("search").await.unwrap();
269 assert_eq!(got.schema.name, "search");
270 assert_eq!(got.permission, ToolPermission::Allow);
271 assert_eq!(got.category.as_deref(), Some("network"));
272 }
273
274 #[tokio::test]
275 async fn test_allowed_schemas_excludes_denied() {
276 let reg = ToolRegistry::new();
277 reg.register(ToolEntry::new(test_schema("read")).with_permission(ToolPermission::Allow))
278 .await;
279 reg.register(ToolEntry::new(test_schema("delete")).with_permission(ToolPermission::Deny))
280 .await;
281 reg.register(ToolEntry::new(test_schema("write")).with_permission(ToolPermission::AskUser))
282 .await;
283
284 let allowed = reg.allowed_schemas().await;
285 assert_eq!(allowed.len(), 2);
286 assert!(allowed.iter().all(|s| s.name != "delete"));
287 }
288
289 #[tokio::test]
290 async fn test_validation() {
291 let reg = ToolRegistry::new();
292 let mut bad_schema = test_schema("good");
293 bad_schema.description = String::new();
294 reg.register(ToolEntry::new(bad_schema)).await;
295
296 let errors = reg.validate().await;
297 assert_eq!(errors.len(), 1);
298 assert!(errors[0].message.contains("missing description"));
299 }
300
301 #[tokio::test]
302 async fn test_by_source() {
303 let reg = ToolRegistry::new();
304 reg.register(ToolEntry::builtin(test_schema("infer"))).await;
305 reg.register(ToolEntry::new(test_schema("search"))).await;
306
307 let builtins = reg.by_source(&ToolSource::Builtin).await;
308 assert_eq!(builtins.len(), 1);
309 assert_eq!(builtins[0].schema.name, "infer");
310 }
311}