Skip to main content

forge_sandbox/
groups.rs

1//! Server group enforcement for cross-server data flow policies.
2//!
3//! Groups define isolation boundaries between sets of MCP servers. When a group
4//! uses "strict" isolation, the first tool call in an execution locks the
5//! execution to that group — subsequent calls to servers in a different strict
6//! group are denied.
7//!
8//! Both tool calls and resource reads share the same group lock within a single
9//! execution, ensuring consistent isolation enforcement.
10
11use std::collections::HashMap;
12use std::sync::Arc;
13
14use forge_error::DispatchError;
15use serde_json::Value;
16use tokio::sync::Mutex;
17
18use crate::{ResourceDispatcher, ToolDispatcher};
19
20/// Shared lock tracking which strict group (if any) has been accessed
21/// in the current execution. Used by both [`GroupEnforcingDispatcher`]
22/// and [`GroupEnforcingResourceDispatcher`] so that tool calls and
23/// resource reads enforce the same isolation boundary.
24pub type SharedGroupLock = Arc<Mutex<Option<String>>>;
25
26/// Isolation mode for a server group.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28#[non_exhaustive]
29pub enum IsolationMode {
30    /// Strict: once an execution calls a server in this group, it cannot call
31    /// servers in a different strict group.
32    Strict,
33    /// Open: servers in this group can be called from any execution regardless
34    /// of which groups have been accessed.
35    Open,
36}
37
38/// Compiled group policy (immutable, shared across executions).
39#[derive(Debug, Clone)]
40pub struct GroupPolicy {
41    server_to_group: HashMap<String, String>,
42    group_isolation: HashMap<String, IsolationMode>,
43}
44
45impl GroupPolicy {
46    /// Build a group policy from config group definitions.
47    ///
48    /// Each entry in `groups` maps a group name to (server list, isolation mode string).
49    pub fn from_config(groups: &HashMap<String, (Vec<String>, String)>) -> Self {
50        let mut server_to_group = HashMap::new();
51        let mut group_isolation = HashMap::new();
52
53        for (group_name, (servers, isolation)) in groups {
54            let mode = match isolation.as_str() {
55                "strict" => IsolationMode::Strict,
56                _ => IsolationMode::Open,
57            };
58            group_isolation.insert(group_name.clone(), mode);
59            for server in servers {
60                server_to_group.insert(server.clone(), group_name.clone());
61            }
62        }
63
64        Self {
65            server_to_group,
66            group_isolation,
67        }
68    }
69
70    /// Returns true if no groups are configured.
71    pub fn is_empty(&self) -> bool {
72        self.group_isolation.is_empty()
73    }
74
75    /// Look up which group a server belongs to and its isolation mode.
76    pub fn server_group(&self, server: &str) -> Option<(&str, IsolationMode)> {
77        self.server_to_group.get(server).map(|group| {
78            let mode = self
79                .group_isolation
80                .get(group)
81                .copied()
82                .unwrap_or(IsolationMode::Open);
83            (group.as_str(), mode)
84        })
85    }
86}
87
88/// Check a server against the group policy and shared lock.
89///
90/// Returns `Ok(())` if the access is allowed, or an error if a cross-group
91/// violation is detected. Locks the group on first strict access.
92async fn check_group_access(
93    policy: &GroupPolicy,
94    locked_group: &SharedGroupLock,
95    server: &str,
96) -> Result<(), DispatchError> {
97    if let Some((group, mode)) = policy.server_group(server) {
98        if mode == IsolationMode::Strict {
99            let mut locked = locked_group.lock().await;
100            match &*locked {
101                None => {
102                    *locked = Some(group.to_string());
103                }
104                Some(existing) if existing == group => {
105                    // Same strict group: allowed
106                }
107                Some(existing) => {
108                    return Err(DispatchError::GroupPolicyDenied {
109                        reason: format!(
110                            "cross-group call denied: server '{}' is in strict group '{}', \
111                             but this execution is locked to strict group '{}'",
112                            server, group, existing,
113                        ),
114                    });
115                }
116            }
117        }
118        // Open-group servers always pass through
119    }
120    // Ungrouped servers always pass through
121    Ok(())
122}
123
124/// A [`ToolDispatcher`] that enforces group isolation policies.
125///
126/// Created fresh for each execution. The first call to a strict-group server
127/// "locks" this dispatcher to that group for the duration of the execution.
128pub struct GroupEnforcingDispatcher {
129    inner: Arc<dyn ToolDispatcher>,
130    policy: Arc<GroupPolicy>,
131    locked_group: SharedGroupLock,
132}
133
134impl GroupEnforcingDispatcher {
135    /// Create a new group-enforcing dispatcher for a single execution.
136    pub fn new(inner: Arc<dyn ToolDispatcher>, policy: Arc<GroupPolicy>) -> Self {
137        Self {
138            inner,
139            policy,
140            locked_group: Arc::new(Mutex::new(None)),
141        }
142    }
143
144    /// Create a group-enforcing dispatcher that shares a lock with another dispatcher.
145    ///
146    /// Use this to ensure that tool calls and resource reads within the same
147    /// execution enforce the same group isolation boundary.
148    pub fn with_shared_lock(
149        inner: Arc<dyn ToolDispatcher>,
150        policy: Arc<GroupPolicy>,
151        lock: SharedGroupLock,
152    ) -> Self {
153        Self {
154            inner,
155            policy,
156            locked_group: lock,
157        }
158    }
159
160    /// Get the shared lock for use with a paired [`GroupEnforcingResourceDispatcher`].
161    pub fn shared_lock(&self) -> SharedGroupLock {
162        self.locked_group.clone()
163    }
164}
165
166/// A [`ResourceDispatcher`] that enforces group isolation policies.
167///
168/// Shares a [`SharedGroupLock`] with a [`GroupEnforcingDispatcher`] so that
169/// tool calls and resource reads within the same execution enforce the
170/// same group isolation boundary.
171pub struct GroupEnforcingResourceDispatcher {
172    inner: Arc<dyn ResourceDispatcher>,
173    policy: Arc<GroupPolicy>,
174    locked_group: SharedGroupLock,
175}
176
177impl GroupEnforcingResourceDispatcher {
178    /// Create a group-enforcing resource dispatcher with an externally provided lock.
179    pub fn new(
180        inner: Arc<dyn ResourceDispatcher>,
181        policy: Arc<GroupPolicy>,
182        lock: SharedGroupLock,
183    ) -> Self {
184        Self {
185            inner,
186            policy,
187            locked_group: lock,
188        }
189    }
190}
191
192#[async_trait::async_trait]
193impl ToolDispatcher for GroupEnforcingDispatcher {
194    async fn call_tool(
195        &self,
196        server: &str,
197        tool: &str,
198        args: Value,
199    ) -> Result<Value, DispatchError> {
200        check_group_access(&self.policy, &self.locked_group, server).await?;
201        self.inner.call_tool(server, tool, args).await
202    }
203}
204
205#[async_trait::async_trait]
206impl ResourceDispatcher for GroupEnforcingResourceDispatcher {
207    async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
208        check_group_access(&self.policy, &self.locked_group, server).await?;
209        self.inner.read_resource(server, uri).await
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    struct MockDispatcher;
218
219    #[async_trait::async_trait]
220    impl ToolDispatcher for MockDispatcher {
221        async fn call_tool(
222            &self,
223            server: &str,
224            tool: &str,
225            _args: Value,
226        ) -> Result<Value, DispatchError> {
227            Ok(serde_json::json!({"server": server, "tool": tool}))
228        }
229    }
230
231    fn make_policy(groups: &[(&str, &[&str], &str)]) -> Arc<GroupPolicy> {
232        let mut map = HashMap::new();
233        for (name, servers, isolation) in groups {
234            map.insert(
235                name.to_string(),
236                (
237                    servers.iter().map(|s| s.to_string()).collect(),
238                    isolation.to_string(),
239                ),
240            );
241        }
242        Arc::new(GroupPolicy::from_config(&map))
243    }
244
245    #[tokio::test]
246    async fn ungrouped_server_always_allowed() {
247        let policy = make_policy(&[("internal", &["vault"], "strict")]);
248        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
249
250        let result = dispatcher
251            .call_tool("ungrouped", "tool", serde_json::json!({}))
252            .await;
253        assert!(result.is_ok());
254    }
255
256    #[tokio::test]
257    async fn open_group_always_allowed() {
258        let policy = make_policy(&[
259            ("internal", &["vault"], "strict"),
260            ("analysis", &["narsil"], "open"),
261        ]);
262        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
263
264        // Call strict group first
265        let _ = dispatcher
266            .call_tool("vault", "secrets.list", serde_json::json!({}))
267            .await
268            .unwrap();
269
270        // Open group should still be allowed
271        let result = dispatcher
272            .call_tool("narsil", "scan", serde_json::json!({}))
273            .await;
274        assert!(result.is_ok(), "open group should be allowed after strict");
275    }
276
277    #[tokio::test]
278    async fn strict_group_locks_execution() {
279        let policy = make_policy(&[
280            ("internal", &["vault", "database"], "strict"),
281            ("external", &["slack"], "strict"),
282        ]);
283        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
284
285        // First call to strict group: locks to "internal"
286        let result = dispatcher
287            .call_tool("vault", "secrets.list", serde_json::json!({}))
288            .await;
289        assert!(result.is_ok());
290
291        // Same strict group: allowed
292        let result = dispatcher
293            .call_tool("database", "query", serde_json::json!({}))
294            .await;
295        assert!(result.is_ok());
296    }
297
298    #[tokio::test]
299    async fn cross_strict_group_denied() {
300        let policy = make_policy(&[
301            ("internal", &["vault"], "strict"),
302            ("external", &["slack"], "strict"),
303        ]);
304        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
305
306        // Lock to "internal"
307        let _ = dispatcher
308            .call_tool("vault", "secrets.list", serde_json::json!({}))
309            .await
310            .unwrap();
311
312        // Try "external" — should be denied
313        let result = dispatcher
314            .call_tool("slack", "messages.send", serde_json::json!({}))
315            .await;
316        let err = result.unwrap_err();
317        assert!(
318            matches!(err, DispatchError::GroupPolicyDenied { .. }),
319            "expected GroupPolicyDenied, got: {err}"
320        );
321        let msg = err.to_string();
322        assert!(msg.contains("slack"), "should mention server: {msg}");
323        assert!(
324            msg.contains("external"),
325            "should mention target group: {msg}"
326        );
327        assert!(
328            msg.contains("internal"),
329            "should mention locked group: {msg}"
330        );
331    }
332
333    #[tokio::test]
334    async fn open_group_after_strict_allowed() {
335        let policy = make_policy(&[
336            ("internal", &["vault"], "strict"),
337            ("tools", &["narsil"], "open"),
338        ]);
339        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
340
341        let _ = dispatcher
342            .call_tool("vault", "secrets.list", serde_json::json!({}))
343            .await
344            .unwrap();
345
346        let result = dispatcher
347            .call_tool("narsil", "scan", serde_json::json!({}))
348            .await;
349        assert!(result.is_ok());
350    }
351
352    #[tokio::test]
353    async fn ungrouped_after_strict_allowed() {
354        let policy = make_policy(&[("internal", &["vault"], "strict")]);
355        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
356
357        let _ = dispatcher
358            .call_tool("vault", "secrets.list", serde_json::json!({}))
359            .await
360            .unwrap();
361
362        let result = dispatcher
363            .call_tool("random", "tool", serde_json::json!({}))
364            .await;
365        assert!(result.is_ok(), "ungrouped server should be allowed");
366    }
367
368    #[tokio::test]
369    async fn fresh_dispatcher_is_unlocked() {
370        let policy = make_policy(&[
371            ("internal", &["vault"], "strict"),
372            ("external", &["slack"], "strict"),
373        ]);
374
375        // First execution: lock to internal
376        let d1 = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy.clone());
377        let _ = d1
378            .call_tool("vault", "secrets.list", serde_json::json!({}))
379            .await
380            .unwrap();
381
382        // Second execution: fresh, should be able to use external
383        let d2 = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
384        let result = d2
385            .call_tool("slack", "messages.send", serde_json::json!({}))
386            .await;
387        assert!(result.is_ok(), "fresh dispatcher should be unlocked");
388    }
389
390    #[tokio::test]
391    async fn empty_policy_allows_everything() {
392        let policy = Arc::new(GroupPolicy::from_config(&HashMap::new()));
393        assert!(policy.is_empty());
394
395        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
396        let result = dispatcher
397            .call_tool("any", "tool", serde_json::json!({}))
398            .await;
399        assert!(result.is_ok());
400    }
401
402    #[test]
403    fn policy_server_group_lookup() {
404        let policy = make_policy(&[
405            ("internal", &["vault", "db"], "strict"),
406            ("external", &["slack"], "open"),
407        ]);
408
409        let (group, mode) = policy.server_group("vault").unwrap();
410        assert_eq!(group, "internal");
411        assert_eq!(mode, IsolationMode::Strict);
412
413        let (group, mode) = policy.server_group("slack").unwrap();
414        assert_eq!(group, "external");
415        assert_eq!(mode, IsolationMode::Open);
416
417        assert!(policy.server_group("unknown").is_none());
418    }
419
420    #[test]
421    fn policy_from_config_handles_empty() {
422        let policy = GroupPolicy::from_config(&HashMap::new());
423        assert!(policy.is_empty());
424    }
425
426    struct MockResourceDispatcher;
427
428    #[async_trait::async_trait]
429    impl ResourceDispatcher for MockResourceDispatcher {
430        async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
431            Ok(serde_json::json!({"server": server, "uri": uri}))
432        }
433    }
434
435    // --- RS-S01: resource read through strict group locks execution ---
436    #[tokio::test]
437    async fn rs_s01_resource_read_locks_strict_group() {
438        let policy = make_policy(&[
439            ("internal", &["vault", "database"], "strict"),
440            ("external", &["slack"], "strict"),
441        ]);
442        let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
443
444        let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
445            Arc::new(MockResourceDispatcher),
446            policy.clone(),
447            shared_lock.clone(),
448        );
449        let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
450            Arc::new(MockDispatcher),
451            policy,
452            shared_lock,
453        );
454
455        // Resource read to vault (internal) should lock to "internal"
456        let result = resource_dispatcher
457            .read_resource("vault", "file:///logs")
458            .await;
459        assert!(result.is_ok());
460
461        // Tool call to database (same group) should be allowed
462        let result = tool_dispatcher
463            .call_tool("database", "query", serde_json::json!({}))
464            .await;
465        assert!(result.is_ok(), "same strict group should be allowed");
466
467        // Tool call to slack (different strict group) should be denied
468        let result = tool_dispatcher
469            .call_tool("slack", "send", serde_json::json!({}))
470            .await;
471        assert!(result.is_err(), "cross-group should be denied");
472    }
473
474    // --- RS-S02: resource read after tool call to different strict group is denied ---
475    #[tokio::test]
476    async fn rs_s02_resource_read_after_tool_to_different_group_denied() {
477        let policy = make_policy(&[
478            ("internal", &["vault"], "strict"),
479            ("external", &["slack"], "strict"),
480        ]);
481        let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
482
483        let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
484            Arc::new(MockDispatcher),
485            policy.clone(),
486            shared_lock.clone(),
487        );
488        let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
489            Arc::new(MockResourceDispatcher),
490            policy,
491            shared_lock,
492        );
493
494        // Tool call to vault (internal) locks execution
495        let _ = tool_dispatcher
496            .call_tool("vault", "secrets.list", serde_json::json!({}))
497            .await
498            .unwrap();
499
500        // Resource read to slack (external) should be denied
501        let result = resource_dispatcher
502            .read_resource("slack", "file:///messages")
503            .await;
504        let err = result.unwrap_err();
505        assert!(
506            matches!(err, DispatchError::GroupPolicyDenied { .. }),
507            "expected GroupPolicyDenied, got: {err}"
508        );
509    }
510
511    // --- RS-S03: tool call after resource read to different strict group is denied ---
512    #[tokio::test]
513    async fn rs_s03_tool_after_resource_read_to_different_group_denied() {
514        let policy = make_policy(&[
515            ("internal", &["vault"], "strict"),
516            ("external", &["slack"], "strict"),
517        ]);
518        let shared_lock: SharedGroupLock = Arc::new(Mutex::new(None));
519
520        let resource_dispatcher = GroupEnforcingResourceDispatcher::new(
521            Arc::new(MockResourceDispatcher),
522            policy.clone(),
523            shared_lock.clone(),
524        );
525        let tool_dispatcher = GroupEnforcingDispatcher::with_shared_lock(
526            Arc::new(MockDispatcher),
527            policy,
528            shared_lock,
529        );
530
531        // Resource read to slack (external) locks execution
532        let _ = resource_dispatcher
533            .read_resource("slack", "file:///messages")
534            .await
535            .unwrap();
536
537        // Tool call to vault (internal) should be denied
538        let result = tool_dispatcher
539            .call_tool("vault", "secrets.list", serde_json::json!({}))
540            .await;
541        let err = result.unwrap_err();
542        assert!(
543            matches!(err, DispatchError::GroupPolicyDenied { .. }),
544            "expected GroupPolicyDenied, got: {err}"
545        );
546    }
547
548    #[tokio::test]
549    async fn error_message_is_actionable() {
550        let policy = make_policy(&[
551            ("secrets", &["vault"], "strict"),
552            ("comms", &["slack"], "strict"),
553        ]);
554        let dispatcher = GroupEnforcingDispatcher::new(Arc::new(MockDispatcher), policy);
555
556        let _ = dispatcher
557            .call_tool("vault", "read", serde_json::json!({}))
558            .await
559            .unwrap();
560
561        let err = dispatcher
562            .call_tool("slack", "send", serde_json::json!({}))
563            .await
564            .unwrap_err();
565        // Use typed match instead of string matching on error type
566        assert!(
567            matches!(
568                err,
569                DispatchError::GroupPolicyDenied { ref reason }
570                    if reason.contains("slack")
571                    && reason.contains("comms")
572                    && reason.contains("secrets")
573            ),
574            "expected GroupPolicyDenied mentioning server/groups, got: {err}"
575        );
576    }
577}