1use std::collections::HashMap;
12use std::sync::Arc;
13
14use forge_error::DispatchError;
15use serde_json::Value;
16use tokio::sync::Mutex;
17
18use crate::{ResourceDispatcher, ToolDispatcher};
19
20pub type SharedGroupLock = Arc<Mutex<Option<String>>>;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28#[non_exhaustive]
29pub enum IsolationMode {
30 Strict,
33 Open,
36}
37
38#[derive(Debug, Clone)]
40pub struct GroupPolicy {
41 server_to_group: HashMap<String, String>,
42 group_isolation: HashMap<String, IsolationMode>,
43}
44
45impl GroupPolicy {
46 pub fn from_config(groups: &HashMap<String, (Vec<String>, String)>) -> Self {
50 let mut server_to_group = HashMap::new();
51 let mut group_isolation = HashMap::new();
52
53 for (group_name, (servers, isolation)) in groups {
54 let mode = match isolation.as_str() {
55 "strict" => IsolationMode::Strict,
56 _ => IsolationMode::Open,
57 };
58 group_isolation.insert(group_name.clone(), mode);
59 for server in servers {
60 server_to_group.insert(server.clone(), group_name.clone());
61 }
62 }
63
64 Self {
65 server_to_group,
66 group_isolation,
67 }
68 }
69
70 pub fn is_empty(&self) -> bool {
72 self.group_isolation.is_empty()
73 }
74
75 pub fn server_group(&self, server: &str) -> Option<(&str, IsolationMode)> {
77 self.server_to_group.get(server).map(|group| {
78 let mode = self
79 .group_isolation
80 .get(group)
81 .copied()
82 .unwrap_or(IsolationMode::Open);
83 (group.as_str(), mode)
84 })
85 }
86}
87
88async fn check_group_access(
93 policy: &GroupPolicy,
94 locked_group: &SharedGroupLock,
95 server: &str,
96) -> Result<(), DispatchError> {
97 if let Some((group, mode)) = policy.server_group(server) {
98 if mode == IsolationMode::Strict {
99 let mut locked = locked_group.lock().await;
100 match &*locked {
101 None => {
102 *locked = Some(group.to_string());
103 }
104 Some(existing) if existing == group => {
105 }
107 Some(existing) => {
108 return Err(DispatchError::GroupPolicyDenied {
109 reason: format!(
110 "cross-group call denied: server '{}' is in strict group '{}', \
111 but this execution is locked to strict group '{}'",
112 server, group, existing,
113 ),
114 });
115 }
116 }
117 }
118 }
120 Ok(())
122}
123
124pub struct GroupEnforcingDispatcher {
129 inner: Arc<dyn ToolDispatcher>,
130 policy: Arc<GroupPolicy>,
131 locked_group: SharedGroupLock,
132}
133
134impl GroupEnforcingDispatcher {
135 pub fn new(inner: Arc<dyn ToolDispatcher>, policy: Arc<GroupPolicy>) -> Self {
137 Self {
138 inner,
139 policy,
140 locked_group: Arc::new(Mutex::new(None)),
141 }
142 }
143
144 pub fn with_shared_lock(
149 inner: Arc<dyn ToolDispatcher>,
150 policy: Arc<GroupPolicy>,
151 lock: SharedGroupLock,
152 ) -> Self {
153 Self {
154 inner,
155 policy,
156 locked_group: lock,
157 }
158 }
159
160 pub fn shared_lock(&self) -> SharedGroupLock {
162 self.locked_group.clone()
163 }
164}
165
166pub struct GroupEnforcingResourceDispatcher {
172 inner: Arc<dyn ResourceDispatcher>,
173 policy: Arc<GroupPolicy>,
174 locked_group: SharedGroupLock,
175}
176
177impl GroupEnforcingResourceDispatcher {
178 pub fn new(
180 inner: Arc<dyn ResourceDispatcher>,
181 policy: Arc<GroupPolicy>,
182 lock: SharedGroupLock,
183 ) -> Self {
184 Self {
185 inner,
186 policy,
187 locked_group: lock,
188 }
189 }
190}
191
192#[async_trait::async_trait]
193impl ToolDispatcher for GroupEnforcingDispatcher {
194 async fn call_tool(
195 &self,
196 server: &str,
197 tool: &str,
198 args: Value,
199 ) -> Result<Value, DispatchError> {
200 check_group_access(&self.policy, &self.locked_group, server).await?;
201 self.inner.call_tool(server, tool, args).await
202 }
203}
204
205#[async_trait::async_trait]
206impl ResourceDispatcher for GroupEnforcingResourceDispatcher {
207 async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
208 check_group_access(&self.policy, &self.locked_group, server).await?;
209 self.inner.read_resource(server, uri).await
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 struct MockDispatcher;
218
219 #[async_trait::async_trait]
220 impl ToolDispatcher for MockDispatcher {
221 async fn call_tool(
222 &self,
223 server: &str,
224 tool: &str,
225 _args: Value,
226 ) -> Result<Value, DispatchError> {
227 Ok(serde_json::json!({"server": server, "tool": tool}))
228 }
229 }
230
231 fn make_policy(groups: &[(&str, &[&str], &str)]) -> Arc<GroupPolicy> {
232 let mut map = HashMap::new();
233 for (name, servers, isolation) in groups {
234 map.insert(
235 name.to_string(),
236 (
237 servers.iter().map(|s| s.to_string()).collect(),
238 isolation.to_string(),
239 ),
240 );
241 }
242 Arc::new(GroupPolicy::from_config(&map))
243 }
244
245 #[tokio::test]
246 async fn ungrouped_server_always_allowed() {
247 let policy = make_policy(&[("internal", &["vault"], "strict")]);
248 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
249
250 let result = dispatcher
251 .call_tool("ungrouped", "tool", serde_json::json!({}))
252 .await;
253 assert!(result.is_ok());
254 }
255
256 #[tokio::test]
257 async fn open_group_always_allowed() {
258 let policy = make_policy(&[
259 ("internal", &["vault"], "strict"),
260 ("analysis", &["narsil"], "open"),
261 ]);
262 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
263
264 let _ = dispatcher
266 .call_tool("vault", "secrets.list", serde_json::json!({}))
267 .await
268 .unwrap();
269
270 let result = dispatcher
272 .call_tool("narsil", "scan", serde_json::json!({}))
273 .await;
274 assert!(result.is_ok(), "open group should be allowed after strict");
275 }
276
277 #[tokio::test]
278 async fn strict_group_locks_execution() {
279 let policy = make_policy(&[
280 ("internal", &["vault", "database"], "strict"),
281 ("external", &["slack"], "strict"),
282 ]);
283 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
284
285 let result = dispatcher
287 .call_tool("vault", "secrets.list", serde_json::json!({}))
288 .await;
289 assert!(result.is_ok());
290
291 let result = dispatcher
293 .call_tool("database", "query", serde_json::json!({}))
294 .await;
295 assert!(result.is_ok());
296 }
297
298 #[tokio::test]
299 async fn cross_strict_group_denied() {
300 let policy = make_policy(&[
301 ("internal", &["vault"], "strict"),
302 ("external", &["slack"], "strict"),
303 ]);
304 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
305
306 let _ = dispatcher
308 .call_tool("vault", "secrets.list", serde_json::json!({}))
309 .await
310 .unwrap();
311
312 let result = dispatcher
314 .call_tool("slack", "messages.send", serde_json::json!({}))
315 .await;
316 let err = result.unwrap_err();
317 assert!(
318 matches!(err, DispatchError::GroupPolicyDenied { .. }),
319 "expected GroupPolicyDenied, got: {err}"
320 );
321 let msg = err.to_string();
322 assert!(msg.contains("slack"), "should mention server: {msg}");
323 assert!(
324 msg.contains("external"),
325 "should mention target group: {msg}"
326 );
327 assert!(
328 msg.contains("internal"),
329 "should mention locked group: {msg}"
330 );
331 }
332
333 #[tokio::test]
334 async fn open_group_after_strict_allowed() {
335 let policy = make_policy(&[
336 ("internal", &["vault"], "strict"),
337 ("tools", &["narsil"], "open"),
338 ]);
339 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
340
341 let _ = dispatcher
342 .call_tool("vault", "secrets.list", serde_json::json!({}))
343 .await
344 .unwrap();
345
346 let result = dispatcher
347 .call_tool("narsil", "scan", serde_json::json!({}))
348 .await;
349 assert!(result.is_ok());
350 }
351
352 #[tokio::test]
353 async fn ungrouped_after_strict_allowed() {
354 let policy = make_policy(&[("internal", &["vault"], "strict")]);
355 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
356
357 let _ = dispatcher
358 .call_tool("vault", "secrets.list", serde_json::json!({}))
359 .await
360 .unwrap();
361
362 let result = dispatcher
363 .call_tool("random", "tool", serde_json::json!({}))
364 .await;
365 assert!(result.is_ok(), "ungrouped server should be allowed");
366 }
367
368 #[tokio::test]
369 async fn fresh_dispatcher_is_unlocked() {
370 let policy = make_policy(&[
371 ("internal", &["vault"], "strict"),
372 ("external", &["slack"], "strict"),
373 ]);
374
375 let d1 = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy.clone());
377 let _ = d1
378 .call_tool("vault", "secrets.list", serde_json::json!({}))
379 .await
380 .unwrap();
381
382 let d2 = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
384 let result = d2
385 .call_tool("slack", "messages.send", serde_json::json!({}))
386 .await;
387 assert!(result.is_ok(), "fresh dispatcher should be unlocked");
388 }
389
390 #[tokio::test]
391 async fn empty_policy_allows_everything() {
392 let policy = Arc::new(GroupPolicy::from_config(&HashMap::new()));
393 assert!(policy.is_empty());
394
395 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
396 let result = dispatcher
397 .call_tool("any", "tool", serde_json::json!({}))
398 .await;
399 assert!(result.is_ok());
400 }
401
402 #[test]
403 fn policy_server_group_lookup() {
404 let policy = make_policy(&[
405 ("internal", &["vault", "db"], "strict"),
406 ("external", &["slack"], "open"),
407 ]);
408
409 let (group, mode) = policy.server_group("vault").unwrap();
410 assert_eq!(group, "internal");
411 assert_eq!(mode, IsolationMode::Strict);
412
413 let (group, mode) = policy.server_group("slack").unwrap();
414 assert_eq!(group, "external");
415 assert_eq!(mode, IsolationMode::Open);
416
417 assert!(policy.server_group("unknown").is_none());
418 }
419
420 #[test]
421 fn policy_from_config_handles_empty() {
422 let policy = GroupPolicy::from_config(&HashMap::new());
423 assert!(policy.is_empty());
424 }
425
426 struct MockResourceDispatcher;
427
428 #[async_trait::async_trait]
429 impl ResourceDispatcher for MockResourceDispatcher {
430 async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
431 Ok(serde_json::json!({"server": server, "uri": uri}))
432 }
433 }
434
435 #[tokio::test]
437 async fn rs_s01_resource_read_locks_strict_group() {
438 let policy = make_policy(&[
439 ("internal", &["vault", "database"], "strict"),
440 ("external", &["slack"], "strict"),
441 ]);
442 let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
443
444 let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
445 Arc::new(MockResourceDispatcher),
446 policy.clone(),
447 shared_lock.clone(),
448 );
449 let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
450 Arc::new(MockDispatcher),
451 policy,
452 shared_lock,
453 );
454
455 let result = resource_dispatcher
457 .read_resource("vault", "file:///logs")
458 .await;
459 assert!(result.is_ok());
460
461 let result = tool_dispatcher
463 .call_tool("database", "query", serde_json::json!({}))
464 .await;
465 assert!(result.is_ok(), "same strict group should be allowed");
466
467 let result = tool_dispatcher
469 .call_tool("slack", "send", serde_json::json!({}))
470 .await;
471 assert!(result.is_err(), "cross-group should be denied");
472 }
473
474 #[tokio::test]
476 async fn rs_s02_resource_read_after_tool_to_different_group_denied() {
477 let policy = make_policy(&[
478 ("internal", &["vault"], "strict"),
479 ("external", &["slack"], "strict"),
480 ]);
481 let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
482
483 let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
484 Arc::new(MockDispatcher),
485 policy.clone(),
486 shared_lock.clone(),
487 );
488 let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
489 Arc::new(MockResourceDispatcher),
490 policy,
491 shared_lock,
492 );
493
494 let _ = tool_dispatcher
496 .call_tool("vault", "secrets.list", serde_json::json!({}))
497 .await
498 .unwrap();
499
500 let result = resource_dispatcher
502 .read_resource("slack", "file:///messages")
503 .await;
504 let err = result.unwrap_err();
505 assert!(
506 matches!(err, DispatchError::GroupPolicyDenied { .. }),
507 "expected GroupPolicyDenied, got: {err}"
508 );
509 }
510
511 #[tokio::test]
513 async fn rs_s03_tool_after_resource_read_to_different_group_denied() {
514 let policy = make_policy(&[
515 ("internal", &["vault"], "strict"),
516 ("external", &["slack"], "strict"),
517 ]);
518 let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
519
520 let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
521 Arc::new(MockResourceDispatcher),
522 policy.clone(),
523 shared_lock.clone(),
524 );
525 let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
526 Arc::new(MockDispatcher),
527 policy,
528 shared_lock,
529 );
530
531 let _ = resource_dispatcher
533 .read_resource("slack", "file:///messages")
534 .await
535 .unwrap();
536
537 let result = tool_dispatcher
539 .call_tool("vault", "secrets.list", serde_json::json!({}))
540 .await;
541 let err = result.unwrap_err();
542 assert!(
543 matches!(err, DispatchError::GroupPolicyDenied { .. }),
544 "expected GroupPolicyDenied, got: {err}"
545 );
546 }
547
548 #[tokio::test]
549 async fn error_message_is_actionable() {
550 let policy = make_policy(&[
551 ("secrets", &["vault"], "strict"),
552 ("comms", &["slack"], "strict"),
553 ]);
554 let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
555
556 let _ = dispatcher
557 .call_tool("vault", "read", serde_json::json!({}))
558 .await
559 .unwrap();
560
561 let err = dispatcher
562 .call_tool("slack", "send", serde_json::json!({}))
563 .await
564 .unwrap_err();
565 assert!(
567 matches!(
568 err,
569 DispatchError::GroupPolicyDenied { ref reason }
570 if reason.contains("slack")
571 && reason.contains("comms")
572 && reason.contains("secrets")
573 ),
574 "expected GroupPolicyDenied mentioning server/groups, got: {err}"
575 );
576 }
577}