Skip to main content

car_engine/
registry.rs

1//! Canonical tool registry — single source of truth for tool identity.
2//!
3//! A tool is defined once via `ToolEntry` and the registry derives all runtime
4//! behavior: schema for models, executor dispatch, capability classification,
5//! permission defaults, and validation.
6
7use car_ir::ToolSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::RwLock;
11
12/// Permission classification for a tool.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ToolPermission {
16    /// Always allowed without user approval.
17    Allow,
18    /// Requires explicit user approval before execution.
19    AskUser,
20    /// Always denied.
21    Deny,
22}
23
24impl Default for ToolPermission {
25    fn default() -> Self {
26        Self::AskUser
27    }
28}
29
30/// Source/origin of a tool (for debugging and routing).
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ToolSource {
34    /// Built into the runtime (infer, embed, classify, etc.).
35    Builtin,
36    /// Registered by the caller via in-process executor.
37    UserDefined,
38    /// Subprocess tool via stdin/stdout JSON-RPC.
39    Subprocess,
40    /// Discovered from an MCP server.
41    Mcp { server_name: String },
42}
43
44/// A complete tool definition — canonical identity for a tool.
45#[derive(Debug, Clone)]
46pub struct ToolEntry {
47    /// The tool schema (name, description, parameters, etc.).
48    pub schema: ToolSchema,
49    /// Default permission classification.
50    pub permission: ToolPermission,
51    /// Where the tool comes from.
52    pub source: ToolSource,
53    /// Whether the tool modifies external state (for safety classification).
54    pub side_effects: bool,
55    /// Human-readable category for grouping (e.g., "filesystem", "network", "memory").
56    pub category: Option<String>,
57}
58
59impl ToolEntry {
60    pub fn new(schema: ToolSchema) -> Self {
61        Self {
62            schema,
63            permission: ToolPermission::default(),
64            source: ToolSource::UserDefined,
65            side_effects: true,
66            category: None,
67        }
68    }
69
70    pub fn builtin(schema: ToolSchema) -> Self {
71        Self {
72            permission: ToolPermission::Allow,
73            source: ToolSource::Builtin,
74            side_effects: false,
75            category: None,
76            schema,
77        }
78    }
79
80    pub fn with_permission(mut self, perm: ToolPermission) -> Self {
81        self.permission = perm;
82        self
83    }
84
85    pub fn with_source(mut self, source: ToolSource) -> Self {
86        self.source = source;
87        self
88    }
89
90    pub fn with_side_effects(mut self, side_effects: bool) -> Self {
91        self.side_effects = side_effects;
92        self
93    }
94
95    pub fn with_category(mut self, category: &str) -> Self {
96        self.category = Some(category.to_string());
97        self
98    }
99}
100
101/// Validation error when a tool registration is incomplete or inconsistent.
102#[derive(Debug, Clone)]
103pub struct RegistryValidationError {
104    pub tool_name: String,
105    pub message: String,
106}
107
108/// Canonical tool registry — single source of truth.
109pub struct ToolRegistry {
110    entries: RwLock<HashMap<String, ToolEntry>>,
111}
112
113impl ToolRegistry {
114    pub fn new() -> Self {
115        Self {
116            entries: RwLock::new(HashMap::new()),
117        }
118    }
119
120    /// Register a tool. Overwrites if already present.
121    pub async fn register(&self, entry: ToolEntry) {
122        let name = entry.schema.name.clone();
123        self.entries.write().await.insert(name, entry);
124    }
125
126    /// Get a tool entry by name.
127    pub async fn get(&self, name: &str) -> Option<ToolEntry> {
128        self.entries.read().await.get(name).cloned()
129    }
130
131    /// Check if a tool exists.
132    pub async fn contains(&self, name: &str) -> bool {
133        self.entries.read().await.contains_key(name)
134    }
135
136    /// Remove a tool.
137    pub async fn remove(&self, name: &str) -> Option<ToolEntry> {
138        self.entries.write().await.remove(name)
139    }
140
141    /// List all tool names.
142    pub async fn names(&self) -> Vec<String> {
143        self.entries.read().await.keys().cloned().collect()
144    }
145
146    /// List all entries.
147    pub async fn entries(&self) -> Vec<ToolEntry> {
148        self.entries.read().await.values().cloned().collect()
149    }
150
151    /// Get all tool schemas (for model prompt generation).
152    pub async fn schemas(&self) -> Vec<ToolSchema> {
153        self.entries
154            .read()
155            .await
156            .values()
157            .map(|e| e.schema.clone())
158            .collect()
159    }
160
161    /// Get schemas filtered by permission (e.g., only non-denied tools for model).
162    pub async fn allowed_schemas(&self) -> Vec<ToolSchema> {
163        self.entries
164            .read()
165            .await
166            .values()
167            .filter(|e| e.permission != ToolPermission::Deny)
168            .map(|e| e.schema.clone())
169            .collect()
170    }
171
172    /// Get tools by source type.
173    pub async fn by_source(&self, source_match: &ToolSource) -> Vec<ToolEntry> {
174        self.entries
175            .read()
176            .await
177            .values()
178            .filter(|e| std::mem::discriminant(&e.source) == std::mem::discriminant(source_match))
179            .cloned()
180            .collect()
181    }
182
183    /// Get tools by category.
184    pub async fn by_category(&self, category: &str) -> Vec<ToolEntry> {
185        self.entries
186            .read()
187            .await
188            .values()
189            .filter(|e| e.category.as_deref() == Some(category))
190            .cloned()
191            .collect()
192    }
193
194    /// Validate all entries — check for common issues.
195    pub async fn validate(&self) -> Vec<RegistryValidationError> {
196        let entries = self.entries.read().await;
197        let mut errors = Vec::new();
198        for (name, entry) in entries.iter() {
199            if entry.schema.name != *name {
200                errors.push(RegistryValidationError {
201                    tool_name: name.clone(),
202                    message: format!(
203                        "schema name '{}' doesn't match registry key '{}'",
204                        entry.schema.name, name
205                    ),
206                });
207            }
208            if entry.schema.description.is_empty() {
209                errors.push(RegistryValidationError {
210                    tool_name: name.clone(),
211                    message: "missing description".to_string(),
212                });
213            }
214        }
215        errors
216    }
217
218    /// Export the full HashMap of schemas (for backward compat with Runtime.tools).
219    pub async fn to_schema_map(&self) -> HashMap<String, ToolSchema> {
220        self.entries
221            .read()
222            .await
223            .iter()
224            .map(|(k, v)| (k.clone(), v.schema.clone()))
225            .collect()
226    }
227
228    /// Count of registered tools.
229    pub async fn len(&self) -> usize {
230        self.entries.read().await.len()
231    }
232
233    pub async fn is_empty(&self) -> bool {
234        self.entries.read().await.is_empty()
235    }
236}
237
238impl Default for ToolRegistry {
239    fn default() -> Self {
240        Self::new()
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    fn test_schema(name: &str) -> ToolSchema {
249        ToolSchema {
250            name: name.to_string(),
251            description: format!("{} tool", name),
252            parameters: serde_json::json!({"type": "object"}),
253            returns: None,
254            idempotent: false,
255            cache_ttl_secs: None,
256            rate_limit: None,
257        }
258    }
259
260    #[tokio::test]
261    async fn test_register_and_get() {
262        let reg = ToolRegistry::new();
263        let entry = ToolEntry::new(test_schema("search"))
264            .with_permission(ToolPermission::Allow)
265            .with_category("network");
266        reg.register(entry).await;
267
268        let got = reg.get("search").await.unwrap();
269        assert_eq!(got.schema.name, "search");
270        assert_eq!(got.permission, ToolPermission::Allow);
271        assert_eq!(got.category.as_deref(), Some("network"));
272    }
273
274    #[tokio::test]
275    async fn test_allowed_schemas_excludes_denied() {
276        let reg = ToolRegistry::new();
277        reg.register(ToolEntry::new(test_schema("read")).with_permission(ToolPermission::Allow))
278            .await;
279        reg.register(ToolEntry::new(test_schema("delete")).with_permission(ToolPermission::Deny))
280            .await;
281        reg.register(ToolEntry::new(test_schema("write")).with_permission(ToolPermission::AskUser))
282            .await;
283
284        let allowed = reg.allowed_schemas().await;
285        assert_eq!(allowed.len(), 2);
286        assert!(allowed.iter().all(|s| s.name != "delete"));
287    }
288
289    #[tokio::test]
290    async fn test_validation() {
291        let reg = ToolRegistry::new();
292        let mut bad_schema = test_schema("good");
293        bad_schema.description = String::new();
294        reg.register(ToolEntry::new(bad_schema)).await;
295
296        let errors = reg.validate().await;
297        assert_eq!(errors.len(), 1);
298        assert!(errors[0].message.contains("missing description"));
299    }
300
301    #[tokio::test]
302    async fn test_by_source() {
303        let reg = ToolRegistry::new();
304        reg.register(ToolEntry::builtin(test_schema("infer"))).await;
305        reg.register(ToolEntry::new(test_schema("search"))).await;
306
307        let builtins = reg.by_source(&ToolSource::Builtin).await;
308        assert_eq!(builtins.len(), 1);
309        assert_eq!(builtins[0].schema.name, "infer");
310    }
311}