1use std::sync::Arc;
8
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12
13use crate::skill::{Skill, SkillId};
14use crate::tool::{Tool, ToolName};
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct SkillDescriptor {
22 pub id: SkillId,
24 pub description: String,
27}
28
29#[derive(Debug, Error)]
30pub enum KernelError {
31 #[error("tool `{0}` not found in registry")]
32 ToolNotFound(String),
33
34 #[error("tool `{0}` is not authorised for this agent")]
35 ToolNotAuthorised(String),
36
37 #[error("skill `{0}` not found in registry")]
38 SkillNotFound(String),
39
40 #[error("tool invocation failed: {0}")]
41 ToolFailed(String),
42
43 #[error("tool not applicable: {0}")]
48 ToolNotApplicable(String),
49
50 #[error("skill execution failed: {0}")]
51 SkillFailed(String),
52
53 #[error("invalid argument: {0}")]
54 InvalidArgument(String),
55
56 #[error("tool-call normalizer failed: {0}")]
60 NormalizerFailed(String),
61
62 #[error("tool dispatch terminated: {0}")]
64 ToolDispatchTerminated(String),
65
66 #[error("budget failed: {0}")]
68 BudgetFailed(String),
69
70 #[error(transparent)]
71 Serde(#[from] serde_json::Error),
72}
73
74#[derive(Clone, Default)]
76pub struct ToolRegistry {
77 inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
78 allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
84}
85
86impl ToolRegistry {
87 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn register(&self, tool: Arc<dyn Tool>) {
92 let name = tool.schema().name;
93 self.inner.insert(name, tool);
94 }
95
96 pub fn scoped<I, S>(&self, names: I) -> Self
99 where
100 I: IntoIterator<Item = S>,
101 S: Into<String>,
102 {
103 let allowed: std::collections::HashSet<String> =
104 names.into_iter().map(Into::into).collect();
105 Self {
106 inner: self.inner.clone(),
107 allowed: Some(Arc::new(allowed)),
108 }
109 }
110
111 fn is_authorised(&self, name: &str) -> bool {
112 match &self.allowed {
113 None => true,
114 Some(set) => set.contains(name),
115 }
116 }
117
118 pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
119 if !self.is_authorised(name) {
120 return Err(KernelError::ToolNotAuthorised(name.to_string()));
121 }
122 self.inner
123 .get(name)
124 .map(|t| t.clone())
125 .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
126 }
127
128 pub async fn invoke(
130 &self,
131 name: &str,
132 args: serde_json::Value,
133 ) -> Result<serde_json::Value, KernelError> {
134 let tool = self.get(name)?;
135 tool.invoke(args).await
136 }
137
138 pub fn len(&self) -> usize {
139 match &self.allowed {
140 None => self.inner.len(),
141 Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
142 }
143 }
144
145 pub fn is_empty(&self) -> bool {
146 self.len() == 0
147 }
148
149 pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
153 let mut schemas: Vec<_> = self
154 .inner
155 .iter()
156 .filter(|e| self.is_authorised(e.key()))
157 .map(|e| e.value().schema())
158 .collect();
159 schemas.sort_by(|left, right| left.name.cmp(&right.name));
160 schemas
161 }
162
163 pub fn descriptors(&self) -> Vec<crate::tool::ToolSchema> {
170 self.schemas()
171 }
172}
173
174#[derive(Clone, Default)]
176pub struct SkillRegistry {
177 inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
178}
179
180impl SkillRegistry {
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 pub fn register(&self, skill: Arc<dyn Skill>) {
186 let id = skill.id().to_string();
187 self.inner.insert(id, skill);
188 }
189
190 pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
191 self.inner
192 .get(id)
193 .map(|s| s.clone())
194 .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
195 }
196
197 pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
201 where
202 I: IntoIterator<Item = S>,
203 S: AsRef<str>,
204 {
205 ids.into_iter().map(|id| self.get(id.as_ref())).collect()
206 }
207
208 pub fn len(&self) -> usize {
209 self.inner.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
213 self.inner.is_empty()
214 }
215
216 pub fn descriptors(&self) -> Vec<SkillDescriptor> {
218 let mut descriptors: Vec<_> = self
219 .inner
220 .iter()
221 .map(|entry| SkillDescriptor {
222 id: entry.key().clone(),
223 description: entry.value().description().to_string(),
224 })
225 .collect();
226 descriptors.sort_by(|left, right| left.id.cmp(&right.id));
227 descriptors
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use crate::tool::{LocalTool, ToolSchema};
236 use crate::{
237 InvestigationContext, KernelError, Skill, SkillOutcome, SkillRegistry, Tool, ToolRegistry,
238 };
239 use async_trait::async_trait;
240 use serde_json::json;
241
242 fn echo_tool(name: &str) -> Arc<dyn Tool> {
243 let schema = ToolSchema {
244 name: name.into(),
245 description: "echo".into(),
246 args_schema: json!({}),
247 result_schema: json!({}),
248 };
249 Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
250 }
251
252 #[tokio::test]
253 async fn tool_registry_authorisation() {
254 let reg = ToolRegistry::new();
255 reg.register(echo_tool("a"));
256 reg.register(echo_tool("b"));
257
258 assert!(reg.get("a").is_ok());
260 assert!(reg.get("b").is_ok());
261
262 let scoped = reg.scoped(["a"]);
264 assert!(scoped.get("a").is_ok());
265 match scoped.get("b") {
266 Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
267 _ => panic!("expected ToolNotAuthorised"),
268 }
269
270 let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
272 assert_eq!(out, json!({"x": 1}));
273 }
274
275 #[tokio::test]
276 async fn tool_registry_missing() {
277 let reg = ToolRegistry::new();
278 match reg.get("missing") {
279 Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
280 _ => panic!("expected ToolNotFound"),
281 }
282 }
283
284 #[test]
285 fn tool_registry_descriptors_are_sorted_and_scoped() {
286 let reg = ToolRegistry::new();
287 reg.register(echo_tool("zeta.tool"));
288 reg.register(echo_tool("alpha.tool"));
289 reg.register(echo_tool("middle.tool"));
290
291 let names: Vec<_> = reg
292 .descriptors()
293 .into_iter()
294 .map(|schema| schema.name)
295 .collect();
296 assert_eq!(names, vec!["alpha.tool", "middle.tool", "zeta.tool"]);
297
298 let scoped = reg.scoped(["zeta.tool", "alpha.tool"]);
299 let scoped_names: Vec<_> = scoped
300 .descriptors()
301 .into_iter()
302 .map(|schema| schema.name)
303 .collect();
304 assert_eq!(scoped_names, vec!["alpha.tool", "zeta.tool"]);
305 }
306
307 struct DescribedSkill {
308 id: &'static str,
309 description: &'static str,
310 }
311
312 #[async_trait]
313 impl Skill for DescribedSkill {
314 fn id(&self) -> &str {
315 self.id
316 }
317
318 fn description(&self) -> &str {
319 self.description
320 }
321
322 async fn execute(
323 &self,
324 _ctx: &mut InvestigationContext,
325 _tools: &ToolRegistry,
326 ) -> Result<SkillOutcome, KernelError> {
327 Ok(SkillOutcome::noop())
328 }
329 }
330
331 #[test]
332 fn skill_registry_descriptors_are_sorted() {
333 let reg = SkillRegistry::new();
334 reg.register(Arc::new(DescribedSkill {
335 id: "zeta.skill",
336 description: "last",
337 }));
338 reg.register(Arc::new(DescribedSkill {
339 id: "alpha.skill",
340 description: "first",
341 }));
342
343 let descriptors = reg.descriptors();
344 assert_eq!(descriptors.len(), 2);
345 assert_eq!(descriptors[0].id, "alpha.skill");
346 assert_eq!(descriptors[0].description, "first");
347 assert_eq!(descriptors[1].id, "zeta.skill");
348 assert_eq!(descriptors[1].description, "last");
349 }
350}