Skip to main content

mur_common/skill/
capability.rs

1use super::types::TrustLevel;
2use serde::{Deserialize, Serialize};
3use std::str::FromStr;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum Capability {
8    FsReadAgentHome,
9    FsWriteAgentHome,
10    FsReadHost,
11    FsWriteHost,
12    NetworkOutbound,
13    NetworkOutboundAllowlisted,
14    SpawnAllowlisted,
15    Spawn,
16    Mcp,
17    SkillReadOthers,
18}
19
20impl FromStr for Capability {
21    type Err = ();
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        let v: serde_yaml_ng::Value = serde_yaml_ng::Value::String(s.to_string());
24        serde_yaml_ng::from_value(v).map_err(|_| ())
25    }
26}
27
28pub fn allowed_for(level: TrustLevel) -> &'static [Capability] {
29    use Capability::*;
30    match level {
31        TrustLevel::Sandboxed => &[FsReadAgentHome, Mcp],
32        TrustLevel::Verified => &[
33            FsReadAgentHome,
34            FsWriteAgentHome,
35            NetworkOutboundAllowlisted,
36            SpawnAllowlisted,
37            Mcp,
38        ],
39        TrustLevel::Trusted => &[
40            FsReadAgentHome,
41            FsWriteAgentHome,
42            FsReadHost,
43            FsWriteHost,
44            NetworkOutbound,
45            NetworkOutboundAllowlisted,
46            Spawn,
47            SpawnAllowlisted,
48            Mcp,
49            SkillReadOthers,
50        ],
51    }
52}
53
54#[derive(Debug, PartialEq, Eq)]
55pub struct CapabilityViolation {
56    pub capability: Capability,
57    pub trust_level: TrustLevel,
58}
59
60pub fn check_capabilities(
61    declared: &[String],
62    level: TrustLevel,
63) -> Result<(), CapabilityViolation> {
64    let allowed = allowed_for(level);
65    for s in declared {
66        let Ok(cap) = Capability::from_str(s) else {
67            return Err(CapabilityViolation {
68                capability: Capability::Mcp,
69                trust_level: level,
70            });
71        };
72        if !allowed.contains(&cap) {
73            return Err(CapabilityViolation {
74                capability: cap,
75                trust_level: level,
76            });
77        }
78    }
79    Ok(())
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn sandboxed_blocks_network() {
88        let r = check_capabilities(&["network_outbound".into()], TrustLevel::Sandboxed);
89        assert!(matches!(
90            r,
91            Err(CapabilityViolation {
92                capability: Capability::NetworkOutbound,
93                ..
94            })
95        ));
96    }
97
98    #[test]
99    fn verified_allows_allowlisted_net() {
100        let r = check_capabilities(
101            &[
102                "network_outbound_allowlisted".into(),
103                "fs_write_agent_home".into(),
104            ],
105            TrustLevel::Verified,
106        );
107        assert!(r.is_ok());
108    }
109
110    #[test]
111    fn trusted_allows_everything_declared() {
112        let r = check_capabilities(
113            &["spawn".into(), "fs_write_host".into()],
114            TrustLevel::Trusted,
115        );
116        assert!(r.is_ok());
117    }
118
119    #[test]
120    fn unknown_capability_rejected() {
121        let r = check_capabilities(&["nuke_from_orbit".into()], TrustLevel::Trusted);
122        assert!(r.is_err());
123    }
124
125    #[test]
126    fn empty_declarations_always_ok() {
127        assert!(check_capabilities(&[], TrustLevel::Sandboxed).is_ok());
128    }
129}