1use std::collections::HashMap;
11use std::sync::RwLock;
12
13use crate::errors::{BitrouterError, Result};
14
15use super::admin::{AdminToolRegistry, ParamRestrictions, ToolFilter, ToolUpstreamEntry};
16use super::registry::{ToolEntry, ToolRegistry};
17
18pub struct DynamicToolRegistry<T> {
26 inner: T,
27 filters: RwLock<HashMap<String, ToolFilter>>,
28 restrictions: RwLock<HashMap<String, ParamRestrictions>>,
29 groups: HashMap<String, Vec<String>>,
30}
31
32impl<T> DynamicToolRegistry<T> {
33 pub fn new(
35 inner: T,
36 filters: HashMap<String, ToolFilter>,
37 restrictions: HashMap<String, ParamRestrictions>,
38 groups: HashMap<String, Vec<String>>,
39 ) -> Self {
40 Self {
41 inner,
42 filters: RwLock::new(filters),
43 restrictions: RwLock::new(restrictions),
44 groups,
45 }
46 }
47
48 pub fn inner(&self) -> &T {
50 &self.inner
51 }
52
53 pub fn get_param_restrictions(&self, server: &str) -> Option<ParamRestrictions> {
57 self.restrictions
58 .read()
59 .ok()
60 .and_then(|r| r.get(server).cloned())
61 }
62
63 pub fn get_filter(&self, server: &str) -> Option<ToolFilter> {
65 self.filters
66 .read()
67 .ok()
68 .and_then(|f| f.get(server).cloned())
69 }
70
71 fn known_servers(&self) -> Vec<String>
73 where
74 T: ToolRegistry,
75 {
76 let mut servers: std::collections::HashSet<String> = std::collections::HashSet::new();
79 if let Ok(f) = self.filters.read() {
80 servers.extend(f.keys().cloned());
81 }
82 if let Ok(r) = self.restrictions.read() {
83 servers.extend(r.keys().cloned());
84 }
85 for members in self.groups.values() {
86 servers.extend(members.iter().cloned());
87 }
88 servers.into_iter().collect()
89 }
90}
91
92fn server_of(tool_id: &str) -> &str {
97 tool_id.split_once('/').map(|(s, _)| s).unwrap_or(tool_id)
98}
99
100fn tool_name_of(tool_id: &str) -> &str {
102 tool_id.split_once('/').map(|(_, t)| t).unwrap_or(tool_id)
103}
104
105impl<T: ToolRegistry> ToolRegistry for DynamicToolRegistry<T> {
106 async fn list_tools(&self) -> Vec<ToolEntry> {
107 let all = self.inner.list_tools().await;
108 let filters = match self.filters.read() {
109 Ok(f) => f,
110 Err(_) => return all,
111 };
112
113 all.into_iter()
114 .filter(|entry| {
115 let server = server_of(&entry.id);
116 match filters.get(server) {
117 Some(filter) => filter.accepts(tool_name_of(&entry.id)),
118 None => true,
119 }
120 })
121 .collect()
122 }
123}
124
125impl<T: ToolRegistry> AdminToolRegistry for DynamicToolRegistry<T> {
126 async fn list_upstreams(&self) -> Vec<ToolUpstreamEntry> {
127 let tools = <Self as ToolRegistry>::list_tools(self).await;
129 let mut counts: HashMap<String, usize> = HashMap::new();
130 for tool in &tools {
131 let server = server_of(&tool.id);
132 *counts.entry(server.to_owned()).or_default() += 1;
133 }
134
135 let servers = self.known_servers();
137 for s in &servers {
138 counts.entry(s.clone()).or_default();
139 }
140
141 let filters = self.filters.read().ok();
142 let restrictions = self.restrictions.read().ok();
143
144 let mut entries: Vec<ToolUpstreamEntry> = counts
145 .into_iter()
146 .map(|(name, tool_count)| {
147 let filter = filters.as_ref().and_then(|f| f.get(&name).cloned());
148 let param_restrictions = restrictions
149 .as_ref()
150 .and_then(|r| r.get(&name).cloned())
151 .filter(|r| !r.rules.is_empty());
152 ToolUpstreamEntry {
153 name,
154 tool_count,
155 filter,
156 param_restrictions,
157 }
158 })
159 .collect();
160 entries.sort_by(|a, b| a.name.cmp(&b.name));
161 entries
162 }
163
164 async fn list_groups(&self) -> HashMap<String, Vec<String>> {
165 self.groups.clone()
166 }
167
168 async fn update_filter(&self, server: &str, filter: Option<ToolFilter>) -> Result<()> {
169 let mut filters = self
170 .filters
171 .write()
172 .map_err(|_| BitrouterError::transport(None, "tool registry lock poisoned"))?;
173 match filter {
174 Some(f) => {
175 filters.insert(server.to_owned(), f);
176 }
177 None => {
178 filters.remove(server);
179 }
180 }
181 Ok(())
182 }
183
184 async fn update_param_restrictions(
185 &self,
186 server: &str,
187 restrictions: ParamRestrictions,
188 ) -> Result<()> {
189 let mut r = self
190 .restrictions
191 .write()
192 .map_err(|_| BitrouterError::transport(None, "tool registry lock poisoned"))?;
193 if restrictions.rules.is_empty() {
194 r.remove(server);
195 } else {
196 r.insert(server.to_owned(), restrictions);
197 }
198 Ok(())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 struct StaticToolSource {
207 tools: Vec<ToolEntry>,
208 }
209
210 impl ToolRegistry for StaticToolSource {
211 async fn list_tools(&self) -> Vec<ToolEntry> {
212 self.tools.clone()
213 }
214 }
215
216 fn test_tools() -> Vec<ToolEntry> {
217 vec![
218 ToolEntry {
219 id: "github/search".to_owned(),
220 name: Some("Search".to_owned()),
221 provider: "github".to_owned(),
222 description: Some("Search GitHub".to_owned()),
223 input_schema: None,
224 },
225 ToolEntry {
226 id: "github/create_issue".to_owned(),
227 name: Some("Create Issue".to_owned()),
228 provider: "github".to_owned(),
229 description: Some("Create an issue".to_owned()),
230 input_schema: None,
231 },
232 ToolEntry {
233 id: "jira/search".to_owned(),
234 name: Some("Search".to_owned()),
235 provider: "jira".to_owned(),
236 description: Some("Search Jira".to_owned()),
237 input_schema: None,
238 },
239 ]
240 }
241
242 fn test_registry() -> DynamicToolRegistry<StaticToolSource> {
243 DynamicToolRegistry::new(
244 StaticToolSource {
245 tools: test_tools(),
246 },
247 HashMap::new(),
248 HashMap::new(),
249 HashMap::from([(
250 "dev_tools".to_owned(),
251 vec!["github".to_owned(), "jira".to_owned()],
252 )]),
253 )
254 }
255
256 #[tokio::test]
257 async fn no_filter_returns_all_tools() {
258 let reg = test_registry();
259 let tools = reg.list_tools().await;
260 assert_eq!(tools.len(), 3);
261 }
262
263 #[tokio::test]
264 async fn deny_filter_hides_tools() {
265 let reg = test_registry();
266 reg.update_filter(
267 "github",
268 Some(ToolFilter {
269 allow: None,
270 deny: Some(vec!["search".to_owned()]),
271 }),
272 )
273 .await
274 .ok();
275
276 let tools = reg.list_tools().await;
277 assert_eq!(tools.len(), 2);
278 assert!(tools.iter().all(|t| t.id != "github/search"));
279 }
280
281 #[tokio::test]
282 async fn allow_filter_restricts_tools() {
283 let reg = test_registry();
284 reg.update_filter(
285 "github",
286 Some(ToolFilter {
287 allow: Some(vec!["search".to_owned()]),
288 deny: None,
289 }),
290 )
291 .await
292 .ok();
293
294 let tools = reg.list_tools().await;
295 assert_eq!(tools.len(), 2); assert!(tools.iter().any(|t| t.id == "github/search"));
297 assert!(!tools.iter().any(|t| t.id == "github/create_issue"));
298 }
299
300 #[tokio::test]
301 async fn clear_filter_restores_all() {
302 let reg = test_registry();
303 reg.update_filter(
304 "github",
305 Some(ToolFilter {
306 deny: Some(vec!["search".to_owned()]),
307 ..Default::default()
308 }),
309 )
310 .await
311 .ok();
312 assert_eq!(reg.list_tools().await.len(), 2);
313
314 reg.update_filter("github", None).await.ok();
315 assert_eq!(reg.list_tools().await.len(), 3);
316 }
317
318 #[tokio::test]
319 async fn list_upstreams_reflects_state() {
320 let reg = test_registry();
321 reg.update_filter(
322 "github",
323 Some(ToolFilter {
324 deny: Some(vec!["search".to_owned()]),
325 ..Default::default()
326 }),
327 )
328 .await
329 .ok();
330
331 let upstreams = reg.list_upstreams().await;
332 let github = upstreams.iter().find(|u| u.name == "github");
333 assert!(github.is_some());
334 let github = github.unwrap();
335 assert_eq!(github.tool_count, 1); assert!(github.filter.is_some());
337 }
338
339 #[tokio::test]
340 async fn list_groups_returns_configured() {
341 let reg = test_registry();
342 let groups = reg.list_groups().await;
343 assert_eq!(groups.len(), 1);
344 assert!(groups.contains_key("dev_tools"));
345 }
346
347 #[tokio::test]
348 async fn param_restrictions_roundtrip() {
349 let reg = test_registry();
350 assert!(reg.get_param_restrictions("github").is_none());
351
352 let restrictions = ParamRestrictions {
353 rules: HashMap::from([(
354 "search".to_owned(),
355 super::super::admin::ParamRule {
356 deny: Some(vec!["force".to_owned()]),
357 allow: None,
358 action: super::super::admin::ParamViolationAction::Reject,
359 },
360 )]),
361 };
362 reg.update_param_restrictions("github", restrictions)
363 .await
364 .ok();
365
366 let stored = reg.get_param_restrictions("github");
367 assert!(stored.is_some());
368 assert!(stored.unwrap().rules.contains_key("search"));
369 }
370
371 #[test]
372 fn tool_filter_accepts_logic() {
373 let filter = ToolFilter {
374 allow: Some(vec!["search".to_owned()]),
375 deny: Some(vec!["delete".to_owned()]),
376 };
377 assert!(filter.accepts("search"));
378 assert!(!filter.accepts("delete"));
379 assert!(!filter.accepts("create")); }
381
382 #[test]
383 fn tool_filter_deny_takes_precedence() {
384 let filter = ToolFilter {
385 allow: Some(vec!["search".to_owned()]),
386 deny: Some(vec!["search".to_owned()]),
387 };
388 assert!(!filter.accepts("search")); }
390
391 #[test]
392 fn tool_filter_empty_accepts_all() {
393 let filter = ToolFilter::default();
394 assert!(filter.accepts("anything"));
395 }
396}