1use std::sync::Arc;
8
9use dashmap::DashMap;
10use thiserror::Error;
11
12use crate::skill::{Skill, SkillId};
13use crate::tool::{Tool, ToolName};
14
15#[derive(Debug, Error)]
16pub enum KernelError {
17 #[error("tool `{0}` not found in registry")]
18 ToolNotFound(String),
19
20 #[error("tool `{0}` is not authorised for this agent")]
21 ToolNotAuthorised(String),
22
23 #[error("skill `{0}` not found in registry")]
24 SkillNotFound(String),
25
26 #[error("tool invocation failed: {0}")]
27 ToolFailed(String),
28
29 #[error("tool not applicable: {0}")]
34 ToolNotApplicable(String),
35
36 #[error("skill execution failed: {0}")]
37 SkillFailed(String),
38
39 #[error("invalid argument: {0}")]
40 InvalidArgument(String),
41
42 #[error(transparent)]
43 Serde(#[from] serde_json::Error),
44}
45
46#[derive(Clone, Default)]
48pub struct ToolRegistry {
49 inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
50 allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
56}
57
58impl ToolRegistry {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn register(&self, tool: Arc<dyn Tool>) {
64 let name = tool.schema().name;
65 self.inner.insert(name, tool);
66 }
67
68 pub fn scoped<I, S>(&self, names: I) -> Self
71 where
72 I: IntoIterator<Item = S>,
73 S: Into<String>,
74 {
75 let allowed: std::collections::HashSet<String> =
76 names.into_iter().map(Into::into).collect();
77 Self {
78 inner: self.inner.clone(),
79 allowed: Some(Arc::new(allowed)),
80 }
81 }
82
83 fn is_authorised(&self, name: &str) -> bool {
84 match &self.allowed {
85 None => true,
86 Some(set) => set.contains(name),
87 }
88 }
89
90 pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
91 if !self.is_authorised(name) {
92 return Err(KernelError::ToolNotAuthorised(name.to_string()));
93 }
94 self.inner
95 .get(name)
96 .map(|t| t.clone())
97 .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
98 }
99
100 pub async fn invoke(
102 &self,
103 name: &str,
104 args: serde_json::Value,
105 ) -> Result<serde_json::Value, KernelError> {
106 let tool = self.get(name)?;
107 tool.invoke(args).await
108 }
109
110 pub fn len(&self) -> usize {
111 match &self.allowed {
112 None => self.inner.len(),
113 Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
114 }
115 }
116
117 pub fn is_empty(&self) -> bool {
118 self.len() == 0
119 }
120
121 pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
125 self.inner
126 .iter()
127 .filter(|e| self.is_authorised(e.key()))
128 .map(|e| e.value().schema())
129 .collect()
130 }
131}
132
133#[derive(Clone, Default)]
135pub struct SkillRegistry {
136 inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
137}
138
139impl SkillRegistry {
140 pub fn new() -> Self {
141 Self::default()
142 }
143
144 pub fn register(&self, skill: Arc<dyn Skill>) {
145 let id = skill.id().to_string();
146 self.inner.insert(id, skill);
147 }
148
149 pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
150 self.inner
151 .get(id)
152 .map(|s| s.clone())
153 .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
154 }
155
156 pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
160 where
161 I: IntoIterator<Item = S>,
162 S: AsRef<str>,
163 {
164 ids.into_iter().map(|id| self.get(id.as_ref())).collect()
165 }
166
167 pub fn len(&self) -> usize {
168 self.inner.len()
169 }
170
171 pub fn is_empty(&self) -> bool {
172 self.inner.is_empty()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::sync::Arc;
179
180 use crate::tool::{LocalTool, ToolSchema};
181 use crate::{KernelError, Tool, ToolRegistry};
182 use serde_json::json;
183
184 fn echo_tool(name: &str) -> Arc<dyn Tool> {
185 let schema = ToolSchema {
186 name: name.into(),
187 description: "echo".into(),
188 args_schema: json!({}),
189 result_schema: json!({}),
190 };
191 Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
192 }
193
194 #[tokio::test]
195 async fn tool_registry_authorisation() {
196 let reg = ToolRegistry::new();
197 reg.register(echo_tool("a"));
198 reg.register(echo_tool("b"));
199
200 assert!(reg.get("a").is_ok());
202 assert!(reg.get("b").is_ok());
203
204 let scoped = reg.scoped(["a"]);
206 assert!(scoped.get("a").is_ok());
207 match scoped.get("b") {
208 Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
209 _ => panic!("expected ToolNotAuthorised"),
210 }
211
212 let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
214 assert_eq!(out, json!({"x": 1}));
215 }
216
217 #[tokio::test]
218 async fn tool_registry_missing() {
219 let reg = ToolRegistry::new();
220 match reg.get("missing") {
221 Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
222 _ => panic!("expected ToolNotFound"),
223 }
224 }
225}