Skip to main content

rig_compose/
registry.rs

1//! [`ToolRegistry`] and [`SkillRegistry`] — composition surfaces.
2//!
3//! Both registries are append-only at runtime and indexed by name. Agents
4//! reference entries by name and never own them; the same skill or tool
5//! instance can be shared across any number of agents.
6
7use std::sync::Arc;
8
9use dashmap::DashMap;
10use thiserror::Error;
11
12use crate::skill::{Skill, SkillId};
13use crate::tool::{Tool, ToolName};
14
15#[derive(Debug, Error)]
16pub enum KernelError {
17    #[error("tool `{0}` not found in registry")]
18    ToolNotFound(String),
19
20    #[error("tool `{0}` is not authorised for this agent")]
21    ToolNotAuthorised(String),
22
23    #[error("skill `{0}` not found in registry")]
24    SkillNotFound(String),
25
26    #[error("tool invocation failed: {0}")]
27    ToolFailed(String),
28
29    #[error("skill execution failed: {0}")]
30    SkillFailed(String),
31
32    #[error("invalid argument: {0}")]
33    InvalidArgument(String),
34
35    #[error(transparent)]
36    Serde(#[from] serde_json::Error),
37}
38
39/// Registry of named [`Tool`]s. Cheap to clone (Arc-backed).
40#[derive(Clone, Default)]
41pub struct ToolRegistry {
42    inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
43    /// Optional whitelist applied at lookup time. `None` = unrestricted;
44    /// `Some(set)` = only tools whose name appears in the set are visible
45    /// through [`Self::get`]/[`Self::invoke`]. Used to scope a registry
46    /// down to an agent's authorised tool surface without copying the
47    /// underlying map.
48    allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
49}
50
51impl ToolRegistry {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn register(&self, tool: Arc<dyn Tool>) {
57        let name = tool.schema().name;
58        self.inner.insert(name, tool);
59    }
60
61    /// Return a new registry view restricted to `names`. The underlying
62    /// tools are shared; only the whitelist differs.
63    pub fn scoped<I, S>(&self, names: I) -> Self
64    where
65        I: IntoIterator<Item = S>,
66        S: Into<String>,
67    {
68        let allowed: std::collections::HashSet<String> =
69            names.into_iter().map(Into::into).collect();
70        Self {
71            inner: self.inner.clone(),
72            allowed: Some(Arc::new(allowed)),
73        }
74    }
75
76    fn is_authorised(&self, name: &str) -> bool {
77        match &self.allowed {
78            None => true,
79            Some(set) => set.contains(name),
80        }
81    }
82
83    pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
84        if !self.is_authorised(name) {
85            return Err(KernelError::ToolNotAuthorised(name.to_string()));
86        }
87        self.inner
88            .get(name)
89            .map(|t| t.clone())
90            .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
91    }
92
93    /// Convenience: look up `name` and invoke it.
94    pub async fn invoke(
95        &self,
96        name: &str,
97        args: serde_json::Value,
98    ) -> Result<serde_json::Value, KernelError> {
99        let tool = self.get(name)?;
100        tool.invoke(args).await
101    }
102
103    pub fn len(&self) -> usize {
104        match &self.allowed {
105            None => self.inner.len(),
106            Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
107        }
108    }
109
110    pub fn is_empty(&self) -> bool {
111        self.len() == 0
112    }
113
114    /// Snapshot of every visible tool's schema. Honours the `allowed`
115    /// whitelist when present. Used by the MCP loopback transport to
116    /// surface a server-side registry to a client.
117    pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
118        self.inner
119            .iter()
120            .filter(|e| self.is_authorised(e.key()))
121            .map(|e| e.value().schema())
122            .collect()
123    }
124}
125
126/// Registry of named [`Skill`]s. Identical structure to [`ToolRegistry`].
127#[derive(Clone, Default)]
128pub struct SkillRegistry {
129    inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
130}
131
132impl SkillRegistry {
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    pub fn register(&self, skill: Arc<dyn Skill>) {
138        let id = skill.id().to_string();
139        self.inner.insert(id, skill);
140    }
141
142    pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
143        self.inner
144            .get(id)
145            .map(|s| s.clone())
146            .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
147    }
148
149    /// Resolve a list of skill ids in declared order. Errors on the first
150    /// missing id. Used by `GenericAgent` to build its skill chain at
151    /// construction.
152    pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
153    where
154        I: IntoIterator<Item = S>,
155        S: AsRef<str>,
156    {
157        ids.into_iter().map(|id| self.get(id.as_ref())).collect()
158    }
159
160    pub fn len(&self) -> usize {
161        self.inner.len()
162    }
163
164    pub fn is_empty(&self) -> bool {
165        self.inner.is_empty()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::sync::Arc;
172
173    use crate::tool::{LocalTool, ToolSchema};
174    use crate::{KernelError, Tool, ToolRegistry};
175    use serde_json::json;
176
177    fn echo_tool(name: &str) -> Arc<dyn Tool> {
178        let schema = ToolSchema {
179            name: name.into(),
180            description: "echo".into(),
181            args_schema: json!({}),
182            result_schema: json!({}),
183        };
184        Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
185    }
186
187    #[tokio::test]
188    async fn tool_registry_authorisation() {
189        let reg = ToolRegistry::new();
190        reg.register(echo_tool("a"));
191        reg.register(echo_tool("b"));
192
193        // Unrestricted view sees both.
194        assert!(reg.get("a").is_ok());
195        assert!(reg.get("b").is_ok());
196
197        // Scoped view only sees `a`.
198        let scoped = reg.scoped(["a"]);
199        assert!(scoped.get("a").is_ok());
200        match scoped.get("b") {
201            Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
202            _ => panic!("expected ToolNotAuthorised"),
203        }
204
205        // Invocation works through the scoped view for authorised tools.
206        let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
207        assert_eq!(out, json!({"x": 1}));
208    }
209
210    #[tokio::test]
211    async fn tool_registry_missing() {
212        let reg = ToolRegistry::new();
213        match reg.get("missing") {
214            Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
215            _ => panic!("expected ToolNotFound"),
216        }
217    }
218}