1use std::sync::Arc;
8
9use dashmap::DashMap;
10use thiserror::Error;
11
12use crate::skill::{Skill, SkillId};
13use crate::tool::{Tool, ToolName};
14
15#[derive(Debug, Error)]
16pub enum KernelError {
17 #[error("tool `{0}` not found in registry")]
18 ToolNotFound(String),
19
20 #[error("tool `{0}` is not authorised for this agent")]
21 ToolNotAuthorised(String),
22
23 #[error("skill `{0}` not found in registry")]
24 SkillNotFound(String),
25
26 #[error("tool invocation failed: {0}")]
27 ToolFailed(String),
28
29 #[error("skill execution failed: {0}")]
30 SkillFailed(String),
31
32 #[error("invalid argument: {0}")]
33 InvalidArgument(String),
34
35 #[error(transparent)]
36 Serde(#[from] serde_json::Error),
37}
38
39#[derive(Clone, Default)]
41pub struct ToolRegistry {
42 inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
43 allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
49}
50
51impl ToolRegistry {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn register(&self, tool: Arc<dyn Tool>) {
57 let name = tool.schema().name;
58 self.inner.insert(name, tool);
59 }
60
61 pub fn scoped<I, S>(&self, names: I) -> Self
64 where
65 I: IntoIterator<Item = S>,
66 S: Into<String>,
67 {
68 let allowed: std::collections::HashSet<String> =
69 names.into_iter().map(Into::into).collect();
70 Self {
71 inner: self.inner.clone(),
72 allowed: Some(Arc::new(allowed)),
73 }
74 }
75
76 fn is_authorised(&self, name: &str) -> bool {
77 match &self.allowed {
78 None => true,
79 Some(set) => set.contains(name),
80 }
81 }
82
83 pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
84 if !self.is_authorised(name) {
85 return Err(KernelError::ToolNotAuthorised(name.to_string()));
86 }
87 self.inner
88 .get(name)
89 .map(|t| t.clone())
90 .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
91 }
92
93 pub async fn invoke(
95 &self,
96 name: &str,
97 args: serde_json::Value,
98 ) -> Result<serde_json::Value, KernelError> {
99 let tool = self.get(name)?;
100 tool.invoke(args).await
101 }
102
103 pub fn len(&self) -> usize {
104 match &self.allowed {
105 None => self.inner.len(),
106 Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
107 }
108 }
109
110 pub fn is_empty(&self) -> bool {
111 self.len() == 0
112 }
113
114 pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
118 self.inner
119 .iter()
120 .filter(|e| self.is_authorised(e.key()))
121 .map(|e| e.value().schema())
122 .collect()
123 }
124}
125
126#[derive(Clone, Default)]
128pub struct SkillRegistry {
129 inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
130}
131
132impl SkillRegistry {
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn register(&self, skill: Arc<dyn Skill>) {
138 let id = skill.id().to_string();
139 self.inner.insert(id, skill);
140 }
141
142 pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
143 self.inner
144 .get(id)
145 .map(|s| s.clone())
146 .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
147 }
148
149 pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
153 where
154 I: IntoIterator<Item = S>,
155 S: AsRef<str>,
156 {
157 ids.into_iter().map(|id| self.get(id.as_ref())).collect()
158 }
159
160 pub fn len(&self) -> usize {
161 self.inner.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
165 self.inner.is_empty()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::sync::Arc;
172
173 use crate::tool::{LocalTool, ToolSchema};
174 use crate::{KernelError, Tool, ToolRegistry};
175 use serde_json::json;
176
177 fn echo_tool(name: &str) -> Arc<dyn Tool> {
178 let schema = ToolSchema {
179 name: name.into(),
180 description: "echo".into(),
181 args_schema: json!({}),
182 result_schema: json!({}),
183 };
184 Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
185 }
186
187 #[tokio::test]
188 async fn tool_registry_authorisation() {
189 let reg = ToolRegistry::new();
190 reg.register(echo_tool("a"));
191 reg.register(echo_tool("b"));
192
193 assert!(reg.get("a").is_ok());
195 assert!(reg.get("b").is_ok());
196
197 let scoped = reg.scoped(["a"]);
199 assert!(scoped.get("a").is_ok());
200 match scoped.get("b") {
201 Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
202 _ => panic!("expected ToolNotAuthorised"),
203 }
204
205 let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
207 assert_eq!(out, json!({"x": 1}));
208 }
209
210 #[tokio::test]
211 async fn tool_registry_missing() {
212 let reg = ToolRegistry::new();
213 match reg.get("missing") {
214 Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
215 _ => panic!("expected ToolNotFound"),
216 }
217 }
218}