Skip to main content

rig_compose/
registry.rs

1//! [`ToolRegistry`] and [`SkillRegistry`] — composition surfaces.
2//!
3//! Both registries are append-only at runtime and indexed by name. Agents
4//! reference entries by name and never own them; the same skill or tool
5//! instance can be shared across any number of agents.
6
7use std::sync::Arc;
8
9use dashmap::DashMap;
10use thiserror::Error;
11
12use crate::skill::{Skill, SkillId};
13use crate::tool::{Tool, ToolName};
14
15#[derive(Debug, Error)]
16pub enum KernelError {
17    #[error("tool `{0}` not found in registry")]
18    ToolNotFound(String),
19
20    #[error("tool `{0}` is not authorised for this agent")]
21    ToolNotAuthorised(String),
22
23    #[error("skill `{0}` not found in registry")]
24    SkillNotFound(String),
25
26    #[error("tool invocation failed: {0}")]
27    ToolFailed(String),
28
29    /// Soft failure: the tool ran without infrastructure error but the
30    /// requested operation was inapplicable to its current state (e.g.
31    /// expanding around an entity the graph has never seen). Callers
32    /// can treat this as a no-op rather than propagating an error.
33    #[error("tool not applicable: {0}")]
34    ToolNotApplicable(String),
35
36    #[error("skill execution failed: {0}")]
37    SkillFailed(String),
38
39    #[error("invalid argument: {0}")]
40    InvalidArgument(String),
41
42    /// Failed to parse tool-call markers in raw model output. Distinct from
43    /// [`Self::ToolFailed`] (which signals an invocation-time error) so
44    /// callers can distinguish a parse/normalizer failure from a runtime one.
45    #[error("tool-call normalizer failed: {0}")]
46    NormalizerFailed(String),
47
48    /// A dispatch hook intentionally stopped a normalized tool dispatch loop.
49    #[error("tool dispatch terminated: {0}")]
50    ToolDispatchTerminated(String),
51
52    /// A budget or accounting hook failed while evaluating dispatch policy.
53    #[error("budget failed: {0}")]
54    BudgetFailed(String),
55
56    #[error(transparent)]
57    Serde(#[from] serde_json::Error),
58}
59
60/// Registry of named [`Tool`]s. Cheap to clone (Arc-backed).
61#[derive(Clone, Default)]
62pub struct ToolRegistry {
63    inner: Arc<DashMap<ToolName, Arc<dyn Tool>>>,
64    /// Optional whitelist applied at lookup time. `None` = unrestricted;
65    /// `Some(set)` = only tools whose name appears in the set are visible
66    /// through [`Self::get`]/[`Self::invoke`]. Used to scope a registry
67    /// down to an agent's authorised tool surface without copying the
68    /// underlying map.
69    allowed: Option<Arc<std::collections::HashSet<ToolName>>>,
70}
71
72impl ToolRegistry {
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    pub fn register(&self, tool: Arc<dyn Tool>) {
78        let name = tool.schema().name;
79        self.inner.insert(name, tool);
80    }
81
82    /// Return a new registry view restricted to `names`. The underlying
83    /// tools are shared; only the whitelist differs.
84    pub fn scoped<I, S>(&self, names: I) -> Self
85    where
86        I: IntoIterator<Item = S>,
87        S: Into<String>,
88    {
89        let allowed: std::collections::HashSet<String> =
90            names.into_iter().map(Into::into).collect();
91        Self {
92            inner: self.inner.clone(),
93            allowed: Some(Arc::new(allowed)),
94        }
95    }
96
97    fn is_authorised(&self, name: &str) -> bool {
98        match &self.allowed {
99            None => true,
100            Some(set) => set.contains(name),
101        }
102    }
103
104    pub fn get(&self, name: &str) -> Result<Arc<dyn Tool>, KernelError> {
105        if !self.is_authorised(name) {
106            return Err(KernelError::ToolNotAuthorised(name.to_string()));
107        }
108        self.inner
109            .get(name)
110            .map(|t| t.clone())
111            .ok_or_else(|| KernelError::ToolNotFound(name.to_string()))
112    }
113
114    /// Convenience: look up `name` and invoke it.
115    pub async fn invoke(
116        &self,
117        name: &str,
118        args: serde_json::Value,
119    ) -> Result<serde_json::Value, KernelError> {
120        let tool = self.get(name)?;
121        tool.invoke(args).await
122    }
123
124    pub fn len(&self) -> usize {
125        match &self.allowed {
126            None => self.inner.len(),
127            Some(set) => self.inner.iter().filter(|e| set.contains(e.key())).count(),
128        }
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.len() == 0
133    }
134
135    /// Snapshot of every visible tool's schema. Honours the `allowed`
136    /// whitelist when present. Used by the MCP loopback transport to
137    /// surface a server-side registry to a client.
138    pub fn schemas(&self) -> Vec<crate::tool::ToolSchema> {
139        self.inner
140            .iter()
141            .filter(|e| self.is_authorised(e.key()))
142            .map(|e| e.value().schema())
143            .collect()
144    }
145}
146
147/// Registry of named [`Skill`]s. Identical structure to [`ToolRegistry`].
148#[derive(Clone, Default)]
149pub struct SkillRegistry {
150    inner: Arc<DashMap<SkillId, Arc<dyn Skill>>>,
151}
152
153impl SkillRegistry {
154    pub fn new() -> Self {
155        Self::default()
156    }
157
158    pub fn register(&self, skill: Arc<dyn Skill>) {
159        let id = skill.id().to_string();
160        self.inner.insert(id, skill);
161    }
162
163    pub fn get(&self, id: &str) -> Result<Arc<dyn Skill>, KernelError> {
164        self.inner
165            .get(id)
166            .map(|s| s.clone())
167            .ok_or_else(|| KernelError::SkillNotFound(id.to_string()))
168    }
169
170    /// Resolve a list of skill ids in declared order. Errors on the first
171    /// missing id. Used by `GenericAgent` to build its skill chain at
172    /// construction.
173    pub fn resolve_chain<I, S>(&self, ids: I) -> Result<Vec<Arc<dyn Skill>>, KernelError>
174    where
175        I: IntoIterator<Item = S>,
176        S: AsRef<str>,
177    {
178        ids.into_iter().map(|id| self.get(id.as_ref())).collect()
179    }
180
181    pub fn len(&self) -> usize {
182        self.inner.len()
183    }
184
185    pub fn is_empty(&self) -> bool {
186        self.inner.is_empty()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::sync::Arc;
193
194    use crate::tool::{LocalTool, ToolSchema};
195    use crate::{KernelError, Tool, ToolRegistry};
196    use serde_json::json;
197
198    fn echo_tool(name: &str) -> Arc<dyn Tool> {
199        let schema = ToolSchema {
200            name: name.into(),
201            description: "echo".into(),
202            args_schema: json!({}),
203            result_schema: json!({}),
204        };
205        Arc::new(LocalTool::new(schema, |v| async move { Ok(v) }))
206    }
207
208    #[tokio::test]
209    async fn tool_registry_authorisation() {
210        let reg = ToolRegistry::new();
211        reg.register(echo_tool("a"));
212        reg.register(echo_tool("b"));
213
214        // Unrestricted view sees both.
215        assert!(reg.get("a").is_ok());
216        assert!(reg.get("b").is_ok());
217
218        // Scoped view only sees `a`.
219        let scoped = reg.scoped(["a"]);
220        assert!(scoped.get("a").is_ok());
221        match scoped.get("b") {
222            Err(KernelError::ToolNotAuthorised(name)) => assert_eq!(name, "b"),
223            _ => panic!("expected ToolNotAuthorised"),
224        }
225
226        // Invocation works through the scoped view for authorised tools.
227        let out = scoped.invoke("a", json!({"x": 1})).await.unwrap();
228        assert_eq!(out, json!({"x": 1}));
229    }
230
231    #[tokio::test]
232    async fn tool_registry_missing() {
233        let reg = ToolRegistry::new();
234        match reg.get("missing") {
235            Err(KernelError::ToolNotFound(name)) => assert_eq!(name, "missing"),
236            _ => panic!("expected ToolNotFound"),
237        }
238    }
239}