Skip to main content

bitrouter_core/routers/
dynamic_tool.rs

1//! A tool registry wrapper that adds runtime filter and restriction management.
2//!
3//! [`DynamicToolRegistry`] wraps any [`ToolRegistry`] and layers per-server
4//! filters and parameter restrictions on top. Filters affect which tools are
5//! visible; restrictions are stored for protocol-level call-time enforcement.
6//!
7//! Parallel to [`DynamicRoutingTable`](super::dynamic::DynamicRoutingTable)
8//! for models.
9
10use std::collections::HashMap;
11use std::sync::RwLock;
12
13use crate::errors::{BitrouterError, Result};
14
15use super::admin::{AdminToolRegistry, ParamRestrictions, ToolFilter, ToolUpstreamEntry};
16use super::registry::{ToolEntry, ToolRegistry};
17
18/// A tool registry wrapper that adds runtime filter and restriction management.
19///
20/// Wraps any `T: ToolRegistry` and layers per-server state on top:
21/// - **Filters** control which tools are visible in `list_tools()`.
22/// - **Restrictions** are stored here and exposed for protocol crates to
23///   read at call time via [`get_param_restrictions`](Self::get_param_restrictions).
24/// - **Groups** map group names to sets of server names for access control.
25pub struct DynamicToolRegistry<T> {
26    inner: T,
27    filters: RwLock<HashMap<String, ToolFilter>>,
28    restrictions: RwLock<HashMap<String, ParamRestrictions>>,
29    groups: HashMap<String, Vec<String>>,
30}
31
32impl<T> DynamicToolRegistry<T> {
33    /// Create a new dynamic tool registry wrapping the given inner registry.
34    pub fn new(
35        inner: T,
36        filters: HashMap<String, ToolFilter>,
37        restrictions: HashMap<String, ParamRestrictions>,
38        groups: HashMap<String, Vec<String>>,
39    ) -> Self {
40        Self {
41            inner,
42            filters: RwLock::new(filters),
43            restrictions: RwLock::new(restrictions),
44            groups,
45        }
46    }
47
48    /// Access the inner registry for protocol-specific operations.
49    pub fn inner(&self) -> &T {
50        &self.inner
51    }
52
53    /// Read the current parameter restrictions for a server.
54    ///
55    /// Protocol crates call this at tool-call time to enforce restrictions.
56    pub fn get_param_restrictions(&self, server: &str) -> Option<ParamRestrictions> {
57        self.restrictions
58            .read()
59            .ok()
60            .and_then(|r| r.get(server).cloned())
61    }
62
63    /// Read the current filter for a server.
64    pub fn get_filter(&self, server: &str) -> Option<ToolFilter> {
65        self.filters
66            .read()
67            .ok()
68            .and_then(|f| f.get(server).cloned())
69    }
70
71    /// Check whether a server name exists in the known set.
72    fn known_servers(&self) -> Vec<String>
73    where
74        T: ToolRegistry,
75    {
76        // We cannot call async list_tools here, so we derive from filters/restrictions/groups.
77        // The authoritative set is populated at construction time from config.
78        let mut servers: std::collections::HashSet<String> = std::collections::HashSet::new();
79        if let Ok(f) = self.filters.read() {
80            servers.extend(f.keys().cloned());
81        }
82        if let Ok(r) = self.restrictions.read() {
83            servers.extend(r.keys().cloned());
84        }
85        for members in self.groups.values() {
86            servers.extend(members.iter().cloned());
87        }
88        servers.into_iter().collect()
89    }
90}
91
92/// Extract the server (provider) name from a tool ID.
93///
94/// Tool IDs are namespaced as `"server/tool_name"`. Returns the portion
95/// before the first `/`, or the entire string if no `/` is present.
96fn server_of(tool_id: &str) -> &str {
97    tool_id.split_once('/').map(|(s, _)| s).unwrap_or(tool_id)
98}
99
100/// Extract the un-namespaced tool name from a tool ID.
101fn tool_name_of(tool_id: &str) -> &str {
102    tool_id.split_once('/').map(|(_, t)| t).unwrap_or(tool_id)
103}
104
105impl<T: ToolRegistry> ToolRegistry for DynamicToolRegistry<T> {
106    async fn list_tools(&self) -> Vec<ToolEntry> {
107        let all = self.inner.list_tools().await;
108        let filters = match self.filters.read() {
109            Ok(f) => f,
110            Err(_) => return all,
111        };
112
113        all.into_iter()
114            .filter(|entry| {
115                let server = server_of(&entry.id);
116                match filters.get(server) {
117                    Some(filter) => filter.accepts(tool_name_of(&entry.id)),
118                    None => true,
119                }
120            })
121            .collect()
122    }
123}
124
125impl<T: ToolRegistry> AdminToolRegistry for DynamicToolRegistry<T> {
126    async fn list_upstreams(&self) -> Vec<ToolUpstreamEntry> {
127        // Get the filtered tool list to compute per-server tool counts.
128        let tools = <Self as ToolRegistry>::list_tools(self).await;
129        let mut counts: HashMap<String, usize> = HashMap::new();
130        for tool in &tools {
131            let server = server_of(&tool.id);
132            *counts.entry(server.to_owned()).or_default() += 1;
133        }
134
135        // Also include servers that have filters/restrictions but no visible tools.
136        let servers = self.known_servers();
137        for s in &servers {
138            counts.entry(s.clone()).or_default();
139        }
140
141        let filters = self.filters.read().ok();
142        let restrictions = self.restrictions.read().ok();
143
144        let mut entries: Vec<ToolUpstreamEntry> = counts
145            .into_iter()
146            .map(|(name, tool_count)| {
147                let filter = filters.as_ref().and_then(|f| f.get(&name).cloned());
148                let param_restrictions = restrictions
149                    .as_ref()
150                    .and_then(|r| r.get(&name).cloned())
151                    .filter(|r| !r.rules.is_empty());
152                ToolUpstreamEntry {
153                    name,
154                    tool_count,
155                    filter,
156                    param_restrictions,
157                }
158            })
159            .collect();
160        entries.sort_by(|a, b| a.name.cmp(&b.name));
161        entries
162    }
163
164    async fn list_groups(&self) -> HashMap<String, Vec<String>> {
165        self.groups.clone()
166    }
167
168    async fn update_filter(&self, server: &str, filter: Option<ToolFilter>) -> Result<()> {
169        let mut filters = self
170            .filters
171            .write()
172            .map_err(|_| BitrouterError::transport(None, "tool registry lock poisoned"))?;
173        match filter {
174            Some(f) => {
175                filters.insert(server.to_owned(), f);
176            }
177            None => {
178                filters.remove(server);
179            }
180        }
181        Ok(())
182    }
183
184    async fn update_param_restrictions(
185        &self,
186        server: &str,
187        restrictions: ParamRestrictions,
188    ) -> Result<()> {
189        let mut r = self
190            .restrictions
191            .write()
192            .map_err(|_| BitrouterError::transport(None, "tool registry lock poisoned"))?;
193        if restrictions.rules.is_empty() {
194            r.remove(server);
195        } else {
196            r.insert(server.to_owned(), restrictions);
197        }
198        Ok(())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    struct StaticToolSource {
207        tools: Vec<ToolEntry>,
208    }
209
210    impl ToolRegistry for StaticToolSource {
211        async fn list_tools(&self) -> Vec<ToolEntry> {
212            self.tools.clone()
213        }
214    }
215
216    fn test_tools() -> Vec<ToolEntry> {
217        vec![
218            ToolEntry {
219                id: "github/search".to_owned(),
220                name: Some("Search".to_owned()),
221                provider: "github".to_owned(),
222                description: Some("Search GitHub".to_owned()),
223                input_schema: None,
224            },
225            ToolEntry {
226                id: "github/create_issue".to_owned(),
227                name: Some("Create Issue".to_owned()),
228                provider: "github".to_owned(),
229                description: Some("Create an issue".to_owned()),
230                input_schema: None,
231            },
232            ToolEntry {
233                id: "jira/search".to_owned(),
234                name: Some("Search".to_owned()),
235                provider: "jira".to_owned(),
236                description: Some("Search Jira".to_owned()),
237                input_schema: None,
238            },
239        ]
240    }
241
242    fn test_registry() -> DynamicToolRegistry<StaticToolSource> {
243        DynamicToolRegistry::new(
244            StaticToolSource {
245                tools: test_tools(),
246            },
247            HashMap::new(),
248            HashMap::new(),
249            HashMap::from([(
250                "dev_tools".to_owned(),
251                vec!["github".to_owned(), "jira".to_owned()],
252            )]),
253        )
254    }
255
256    #[tokio::test]
257    async fn no_filter_returns_all_tools() {
258        let reg = test_registry();
259        let tools = reg.list_tools().await;
260        assert_eq!(tools.len(), 3);
261    }
262
263    #[tokio::test]
264    async fn deny_filter_hides_tools() {
265        let reg = test_registry();
266        reg.update_filter(
267            "github",
268            Some(ToolFilter {
269                allow: None,
270                deny: Some(vec!["search".to_owned()]),
271            }),
272        )
273        .await
274        .ok();
275
276        let tools = reg.list_tools().await;
277        assert_eq!(tools.len(), 2);
278        assert!(tools.iter().all(|t| t.id != "github/search"));
279    }
280
281    #[tokio::test]
282    async fn allow_filter_restricts_tools() {
283        let reg = test_registry();
284        reg.update_filter(
285            "github",
286            Some(ToolFilter {
287                allow: Some(vec!["search".to_owned()]),
288                deny: None,
289            }),
290        )
291        .await
292        .ok();
293
294        let tools = reg.list_tools().await;
295        assert_eq!(tools.len(), 2); // github/search + jira/search
296        assert!(tools.iter().any(|t| t.id == "github/search"));
297        assert!(!tools.iter().any(|t| t.id == "github/create_issue"));
298    }
299
300    #[tokio::test]
301    async fn clear_filter_restores_all() {
302        let reg = test_registry();
303        reg.update_filter(
304            "github",
305            Some(ToolFilter {
306                deny: Some(vec!["search".to_owned()]),
307                ..Default::default()
308            }),
309        )
310        .await
311        .ok();
312        assert_eq!(reg.list_tools().await.len(), 2);
313
314        reg.update_filter("github", None).await.ok();
315        assert_eq!(reg.list_tools().await.len(), 3);
316    }
317
318    #[tokio::test]
319    async fn list_upstreams_reflects_state() {
320        let reg = test_registry();
321        reg.update_filter(
322            "github",
323            Some(ToolFilter {
324                deny: Some(vec!["search".to_owned()]),
325                ..Default::default()
326            }),
327        )
328        .await
329        .ok();
330
331        let upstreams = reg.list_upstreams().await;
332        let github = upstreams.iter().find(|u| u.name == "github");
333        assert!(github.is_some());
334        let github = github.unwrap();
335        assert_eq!(github.tool_count, 1); // only create_issue visible
336        assert!(github.filter.is_some());
337    }
338
339    #[tokio::test]
340    async fn list_groups_returns_configured() {
341        let reg = test_registry();
342        let groups = reg.list_groups().await;
343        assert_eq!(groups.len(), 1);
344        assert!(groups.contains_key("dev_tools"));
345    }
346
347    #[tokio::test]
348    async fn param_restrictions_roundtrip() {
349        let reg = test_registry();
350        assert!(reg.get_param_restrictions("github").is_none());
351
352        let restrictions = ParamRestrictions {
353            rules: HashMap::from([(
354                "search".to_owned(),
355                super::super::admin::ParamRule {
356                    deny: Some(vec!["force".to_owned()]),
357                    allow: None,
358                    action: super::super::admin::ParamViolationAction::Reject,
359                },
360            )]),
361        };
362        reg.update_param_restrictions("github", restrictions)
363            .await
364            .ok();
365
366        let stored = reg.get_param_restrictions("github");
367        assert!(stored.is_some());
368        assert!(stored.unwrap().rules.contains_key("search"));
369    }
370
371    #[test]
372    fn tool_filter_accepts_logic() {
373        let filter = ToolFilter {
374            allow: Some(vec!["search".to_owned()]),
375            deny: Some(vec!["delete".to_owned()]),
376        };
377        assert!(filter.accepts("search"));
378        assert!(!filter.accepts("delete"));
379        assert!(!filter.accepts("create")); // not in allow list
380    }
381
382    #[test]
383    fn tool_filter_deny_takes_precedence() {
384        let filter = ToolFilter {
385            allow: Some(vec!["search".to_owned()]),
386            deny: Some(vec!["search".to_owned()]),
387        };
388        assert!(!filter.accepts("search")); // deny wins
389    }
390
391    #[test]
392    fn tool_filter_empty_accepts_all() {
393        let filter = ToolFilter::default();
394        assert!(filter.accepts("anything"));
395    }
396}