Skip to main content

rig_compose/
registry.rs

1//! [`ToolRegistry`] and [`SkillRegistry`] — composition surfaces.
2//!
3//! Both registries are append-only at runtime and indexed by name. Agents
4//! reference entries by name and never own them; the same skill or tool
5//! instance can be shared across any number of agents.
6
7use std::sync::Arc;
8
9use dashmap::DashMap;
10use thiserror::Error;
11
12use crate::skill::{Skill, SkillId};
13use crate::tool::{Tool, ToolName};
14
15#[derive(Debug, Error)]
16pub enum KernelError {
17    #[error("tool `{0}` not found in registry")]
18    ToolNotFound(String),
19
20    #[error("tool `{0}` is not authorised for this agent")]
21    ToolNotAuthorised(String),
22
23    #[error("skill `{0}` not found in registry")]
24    SkillNotFound(String),
25
26    #[error("tool invocation failed: {0}")]
27    ToolFailed(String),
28
29    /// Soft failure: the tool ran without infrastructure error but the
30    /// requested operation was inapplicable to its current state (e.g.
31    /// expanding around an entity the graph has never seen). Callers
32    /// can treat this as a no-op rather than propagating an error.
33    #[error("tool not applicable: {0}")]
34    ToolNotApplicable(String),
35
36    #[error("skill execution failed: {0}")]
37    SkillFailed(String),
38
39    #[error("invalid argument: {0}")]
40    InvalidArgument(String),
41
42    #[error(transparent)]
43    Serde(#[from] serde_json::Error),
44}
45
46/// Registry of named [`Tool`]s. Cheap to clone (Arc-backed).
47#[derive(Clone, Default)]
48pub struct ToolRegistry {
49    inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
50    /// Optional whitelist applied at lookup time. `None` = unrestricted;
51    /// `Some(set)` = only tools whose name appears in the set are visible
52    /// through [`Self::get`]/[`Self::invoke`]. Used to scope a registry
53    /// down to an agent's authorised tool surface without copying the
54    /// underlying map.
55    allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
56}
57
58impl ToolRegistry {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    pub fn register(&self, tool: Arc<dyn Tool>) {
64        let name = tool.schema().name;
65        self.inner.insert(name, tool);
66    }
67
68    /// Return a new registry view restricted to `names`. The underlying
69    /// tools are shared; only the whitelist differs.
70    pub fn scoped<I, S>(&self, names: I) -> Self
71    where
72        I: IntoIterator<Item = S>,
73        S: Into<String>,
74    {
75        let allowed: std::collections::HashSet<String> =
76            names.into_iter().map(Into::into).collect();
77        Self {
78            inner: self.inner.clone(),
79            allowed: Some(Arc::new(allowed)),
80        }
81    }
82
83    fn is_authorised(&self, name: &str) -> bool {
84        match &self.allowed {
85            None => true,
86            Some(set) => set.contains(name),
87        }
88    }
89
90    pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
91        if !self.is_authorised(name) {
92            return Err(KernelError::ToolNotAuthorised(name.to_string()));
93        }
94        self.inner
95            .get(name)
96            .map(|t| t.clone())
97            .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
98    }
99
100    /// Convenience: look up `name` and invoke it.
101    pub async fn invoke(
102        &self,
103        name: &str,
104        args: serde_json::Value,
105    ) -> Result<serde_json::Value, KernelError> {
106        let tool = self.get(name)?;
107        tool.invoke(args).await
108    }
109
110    pub fn len(&self) -> usize {
111        match &self.allowed {
112            None => self.inner.len(),
113            Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
114        }
115    }
116
117    pub fn is_empty(&self) -> bool {
118        self.len() == 0
119    }
120
121    /// Snapshot of every visible tool's schema. Honours the `allowed`
122    /// whitelist when present. Used by the MCP loopback transport to
123    /// surface a server-side registry to a client.
124    pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
125        self.inner
126            .iter()
127            .filter(|e| self.is_authorised(e.key()))
128            .map(|e| e.value().schema())
129            .collect()
130    }
131}
132
133/// Registry of named [`Skill`]s. Identical structure to [`ToolRegistry`].
134#[derive(Clone, Default)]
135pub struct SkillRegistry {
136    inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
137}
138
139impl SkillRegistry {
140    pub fn new() -> Self {
141        Self::default()
142    }
143
144    pub fn register(&self, skill: Arc<dyn Skill>) {
145        let id = skill.id().to_string();
146        self.inner.insert(id, skill);
147    }
148
149    pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
150        self.inner
151            .get(id)
152            .map(|s| s.clone())
153            .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
154    }
155
156    /// Resolve a list of skill ids in declared order. Errors on the first
157    /// missing id. Used by `GenericAgent` to build its skill chain at
158    /// construction.
159    pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
160    where
161        I: IntoIterator<Item = S>,
162        S: AsRef<str>,
163    {
164        ids.into_iter().map(|id| self.get(id.as_ref())).collect()
165    }
166
167    pub fn len(&self) -> usize {
168        self.inner.len()
169    }
170
171    pub fn is_empty(&self) -> bool {
172        self.inner.is_empty()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::sync::Arc;
179
180    use crate::tool::{LocalTool, ToolSchema};
181    use crate::{KernelError, Tool, ToolRegistry};
182    use serde_json::json;
183
184    fn echo_tool(name: &str) -> Arc<dyn Tool> {
185        let schema = ToolSchema {
186            name: name.into(),
187            description: "echo".into(),
188            args_schema: json!({}),
189            result_schema: json!({}),
190        };
191        Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
192    }
193
194    #[tokio::test]
195    async fn tool_registry_authorisation() {
196        let reg = ToolRegistry::new();
197        reg.register(echo_tool("a"));
198        reg.register(echo_tool("b"));
199
200        // Unrestricted view sees both.
201        assert!(reg.get("a").is_ok());
202        assert!(reg.get("b").is_ok());
203
204        // Scoped view only sees `a`.
205        let scoped = reg.scoped(["a"]);
206        assert!(scoped.get("a").is_ok());
207        match scoped.get("b") {
208            Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
209            _ => panic!("expected ToolNotAuthorised"),
210        }
211
212        // Invocation works through the scoped view for authorised tools.
213        let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
214        assert_eq!(out, json!({"x": 1}));
215    }
216
217    #[tokio::test]
218    async fn tool_registry_missing() {
219        let reg = ToolRegistry::new();
220        match reg.get("missing") {
221            Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
222            _ => panic!("expected ToolNotFound"),
223        }
224    }
225}