use crate::policy::match_tree::{Decision, Node, Observable, Pattern, PolicyManifest};
#[derive(Debug, PartialEq, Eq)]
pub enum UpsertResult {
Inserted,
Replaced,
}
pub fn upsert_rule(manifest: &mut PolicyManifest, new_node: Node) -> UpsertResult {
let result = if let Some(idx) = find_matching_chain(&manifest.policy.tree, &new_node) {
replace_leaf_decision(&mut manifest.policy.tree[idx], leaf_decision(&new_node));
UpsertResult::Replaced
} else {
manifest.policy.tree.insert(0, new_node);
UpsertResult::Inserted
};
manifest.policy.tree = Node::compact(std::mem::take(&mut manifest.policy.tree));
result
}
pub fn remove_rule(manifest: &mut PolicyManifest, target: &Node) -> bool {
if let Some(idx) = find_matching_chain(&manifest.policy.tree, target) {
manifest.policy.tree.remove(idx);
manifest.policy.tree = Node::compact(std::mem::take(&mut manifest.policy.tree));
true
} else {
false
}
}
fn find_matching_chain(tree: &[Node], target: &Node) -> Option<usize> {
tree.iter()
.position(|existing| same_match_chain(existing, target))
}
fn same_match_chain(a: &Node, b: &Node) -> bool {
match (a, b) {
(
Node::Condition {
observe: obs_a,
pattern: pat_a,
children: ch_a,
..
},
Node::Condition {
observe: obs_b,
pattern: pat_b,
children: ch_b,
..
},
) => {
if obs_a != obs_b || !patterns_equal(pat_a, pat_b) {
return false;
}
match (ch_a.len(), ch_b.len()) {
(1, 1) => same_match_chain(&ch_a[0], &ch_b[0]),
_ if children_are_all_decisions(ch_a) && children_are_all_decisions(ch_b) => true,
_ => false,
}
}
(Node::Decision(_), Node::Decision(_)) => true,
_ => false,
}
}
fn patterns_equal(a: &Pattern, b: &Pattern) -> bool {
match (a, b) {
(Pattern::Wildcard, Pattern::Wildcard) => return true,
(Pattern::Literal(va), Pattern::Literal(vb)) => return va == vb,
(Pattern::Prefix(va), Pattern::Prefix(vb)) => return va == vb,
_ => {}
}
let ja = serde_json::to_value(a);
let jb = serde_json::to_value(b);
match (ja, jb) {
(Ok(va), Ok(vb)) => va == vb,
_ => false,
}
}
fn children_are_all_decisions(children: &[Node]) -> bool {
children.iter().all(|n| matches!(n, Node::Decision(_)))
}
fn leaf_decision(node: &Node) -> Option<&Decision> {
match node {
Node::Decision(d) => Some(d),
Node::Condition { children, .. } => children.iter().find_map(leaf_decision),
}
}
fn replace_leaf_decision(node: &mut Node, new_decision: Option<&Decision>) {
let Some(new_decision) = new_decision else {
return;
};
match node {
Node::Decision(d) => *d = new_decision.clone(),
Node::Condition { children, .. } => {
for child in children.iter_mut() {
replace_leaf_decision(child, Some(new_decision));
}
}
}
}
pub fn build_exec_rule(bin: &str, args: &[&str], decision: Decision) -> Node {
let mut current = Node::Decision(decision);
for (i, arg) in args.iter().enumerate().rev() {
current = Node::Condition {
observe: Observable::PositionalArg((i + 1) as i32),
pattern: Pattern::Literal(crate::policy::match_tree::Value::Literal(
(*arg).to_string(),
)),
children: vec![current],
doc: None,
source: None,
terminal: false,
};
}
current = Node::Condition {
observe: Observable::PositionalArg(0),
pattern: Pattern::Literal(crate::policy::match_tree::Value::Literal(bin.into())),
children: vec![current],
doc: None,
source: None,
terminal: false,
};
Node::Condition {
observe: Observable::ToolName,
pattern: Pattern::Literal(crate::policy::match_tree::Value::Literal("Bash".into())),
children: vec![current],
doc: None,
source: None,
terminal: false,
}
}
pub fn build_tool_rule(tool_name: &str, decision: Decision) -> Node {
Node::Condition {
observe: Observable::ToolName,
pattern: Pattern::Literal(crate::policy::match_tree::Value::Literal(tool_name.into())),
children: vec![Node::Decision(decision)],
doc: None,
source: None,
terminal: false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::match_tree::*;
use std::collections::HashMap;
fn empty_manifest() -> PolicyManifest {
PolicyManifest {
includes: vec![],
policy: CompiledPolicy {
sandboxes: HashMap::new(),
tree: vec![],
default_effect: crate::policy::Effect::Deny,
default_sandbox: None,
},
}
}
#[test]
fn upsert_inserts_new_rule() {
let mut manifest = empty_manifest();
let node = build_exec_rule("grep", &[], Decision::Allow(None));
let result = upsert_rule(&mut manifest, node);
assert_eq!(result, UpsertResult::Inserted);
assert_eq!(manifest.policy.tree.len(), 1);
}
#[test]
fn upsert_replaces_same_chain() {
let mut manifest = empty_manifest();
let allow = build_exec_rule("grep", &[], Decision::Allow(None));
upsert_rule(&mut manifest, allow);
assert_eq!(manifest.policy.tree.len(), 1);
let deny = build_exec_rule("grep", &[], Decision::Deny);
let result = upsert_rule(&mut manifest, deny);
assert_eq!(result, UpsertResult::Replaced);
assert_eq!(manifest.policy.tree.len(), 1);
let leaf = leaf_decision(&manifest.policy.tree[0]);
assert!(matches!(leaf, Some(Decision::Deny)));
}
#[test]
fn upsert_different_bins_are_separate() {
let mut manifest = empty_manifest();
upsert_rule(
&mut manifest,
build_exec_rule("grep", &[], Decision::Allow(None)),
);
upsert_rule(&mut manifest, build_exec_rule("rm", &[], Decision::Deny));
let total_rules = count_leaf_decisions(&manifest.policy.tree);
assert_eq!(total_rules, 2);
}
#[test]
fn remove_existing_rule() {
let mut manifest = empty_manifest();
upsert_rule(
&mut manifest,
build_exec_rule("grep", &[], Decision::Allow(None)),
);
let target = build_exec_rule("grep", &[], Decision::Allow(None));
assert!(remove_rule(&mut manifest, &target));
assert!(manifest.policy.tree.is_empty());
}
#[test]
fn remove_nonexistent_returns_false() {
let mut manifest = empty_manifest();
upsert_rule(
&mut manifest,
build_exec_rule("grep", &[], Decision::Allow(None)),
);
let target = build_exec_rule("rm", &[], Decision::Allow(None));
assert!(!remove_rule(&mut manifest, &target));
}
#[test]
fn tool_rule_upsert() {
let mut manifest = empty_manifest();
let node = build_tool_rule("WebSearch", Decision::Deny);
let result = upsert_rule(&mut manifest, node);
assert_eq!(result, UpsertResult::Inserted);
let node2 = build_tool_rule("WebSearch", Decision::Allow(None));
let result2 = upsert_rule(&mut manifest, node2);
assert_eq!(result2, UpsertResult::Replaced);
let leaf = leaf_decision(&manifest.policy.tree[0]);
assert!(matches!(leaf, Some(Decision::Allow(None))));
}
#[test]
fn exec_rule_with_args() {
let mut manifest = empty_manifest();
let node = build_exec_rule("gh", &["pr", "create"], Decision::Allow(None));
upsert_rule(&mut manifest, node);
let total = count_leaf_decisions(&manifest.policy.tree);
assert_eq!(total, 1);
let deny = build_exec_rule("gh", &["pr", "create"], Decision::Deny);
let result = upsert_rule(&mut manifest, deny);
assert_eq!(result, UpsertResult::Replaced);
let leaf = leaf_decision(&manifest.policy.tree[0]);
assert!(matches!(leaf, Some(Decision::Deny)));
}
#[test]
fn exec_rule_different_args_are_separate() {
let mut manifest = empty_manifest();
upsert_rule(
&mut manifest,
build_exec_rule("gh", &["pr", "create"], Decision::Allow(None)),
);
upsert_rule(
&mut manifest,
build_exec_rule("gh", &["pr", "merge"], Decision::Deny),
);
let total = count_leaf_decisions(&manifest.policy.tree);
assert_eq!(total, 2);
}
fn count_leaf_decisions(nodes: &[Node]) -> usize {
nodes
.iter()
.map(|n| match n {
Node::Decision(_) => 1,
Node::Condition { children, .. } => count_leaf_decisions(children),
})
.sum()
}
}