use serde::{Deserialize, Serialize};
use std::default::Default;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum Decision {
Allow,
#[default]
Prompt,
Forbidden,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PrefixRule {
pub pattern: Vec<String>,
pub decision: Decision,
}
impl PrefixRule {
pub fn new(pattern: Vec<String>, decision: Decision) -> Self {
Self { pattern, decision }
}
pub fn matches(&self, command: &[String]) -> bool {
if command.len() < self.pattern.len() {
return false;
}
self.pattern
.iter()
.zip(command.iter())
.all(|(pattern, cmd)| pattern == cmd)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RuleMatch {
PrefixRuleMatch {
rule: PrefixRule,
decision: Decision,
},
HeuristicsRuleMatch { decision: Decision },
}
impl RuleMatch {
pub fn decision(&self) -> Decision {
match self {
Self::PrefixRuleMatch { decision, .. } => *decision,
Self::HeuristicsRuleMatch { decision } => *decision,
}
}
pub fn is_policy_match(&self) -> bool {
matches!(self, Self::PrefixRuleMatch { .. })
}
}
#[derive(Debug, Clone)]
pub struct PolicyEvaluation {
pub decision: Decision,
pub matched_rules: Vec<RuleMatch>,
}
#[derive(Debug, Clone, Default)]
pub struct Policy {
prefix_rules: Vec<PrefixRule>,
}
impl Policy {
pub fn empty() -> Self {
Self {
prefix_rules: Vec::new(),
}
}
pub fn add_prefix_rule(
&mut self,
pattern: &[String],
decision: Decision,
) -> anyhow::Result<()> {
self.prefix_rules
.push(PrefixRule::new(pattern.to_vec(), decision));
Ok(())
}
pub fn check(&self, command: &[String]) -> RuleMatch {
for rule in &self.prefix_rules {
if rule.matches(command) {
return RuleMatch::PrefixRuleMatch {
rule: rule.clone(),
decision: rule.decision,
};
}
}
RuleMatch::HeuristicsRuleMatch {
decision: Decision::Prompt,
}
}
pub fn check_multiple<'a, I, F>(&self, commands: I, heuristics_fallback: &F) -> PolicyEvaluation
where
I: Iterator<Item = &'a Vec<String>>,
F: Fn(&[String]) -> Decision,
{
let mut matched_rules = Vec::new();
let mut overall_decision = Decision::Allow;
for command in commands {
let rule_match = self.check(command);
let decision = match &rule_match {
RuleMatch::PrefixRuleMatch { decision, .. } => *decision,
RuleMatch::HeuristicsRuleMatch { .. } => heuristics_fallback(command),
};
overall_decision = match (overall_decision, decision) {
(Decision::Forbidden, _) | (_, Decision::Forbidden) => Decision::Forbidden,
(Decision::Prompt, _) | (_, Decision::Prompt) => Decision::Prompt,
(Decision::Allow, Decision::Allow) => Decision::Allow,
};
matched_rules.push(rule_match);
}
PolicyEvaluation {
decision: overall_decision,
matched_rules,
}
}
pub fn prefix_rules(&self) -> &[PrefixRule] {
&self.prefix_rules
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefix_rule_matching() {
let rule = PrefixRule::new(
vec!["cargo".to_string(), "build".to_string()],
Decision::Allow,
);
assert!(rule.matches(&["cargo".to_string(), "build".to_string()]));
assert!(rule.matches(&[
"cargo".to_string(),
"build".to_string(),
"--release".to_string()
]));
assert!(!rule.matches(&["cargo".to_string(), "test".to_string()]));
assert!(!rule.matches(&["cargo".to_string()]));
}
#[test]
fn test_policy_check() {
let mut policy = Policy::empty();
policy
.add_prefix_rule(&["cargo".to_string(), "build".to_string()], Decision::Allow)
.unwrap();
policy
.add_prefix_rule(&["rm".to_string()], Decision::Forbidden)
.unwrap();
let allow = policy.check(&["cargo".to_string(), "build".to_string()]);
assert_eq!(allow.decision(), Decision::Allow);
assert!(allow.is_policy_match());
let forbidden = policy.check(&["rm".to_string(), "-rf".to_string()]);
assert_eq!(forbidden.decision(), Decision::Forbidden);
let heuristics = policy.check(&["unknown".to_string()]);
assert!(!heuristics.is_policy_match());
}
#[test]
fn test_policy_evaluation() {
let mut policy = Policy::empty();
policy
.add_prefix_rule(&["echo".to_string()], Decision::Allow)
.unwrap();
policy
.add_prefix_rule(&["rm".to_string()], Decision::Forbidden)
.unwrap();
let commands = [
vec!["echo".to_string(), "hello".to_string()],
vec!["rm".to_string(), "-rf".to_string()],
];
let evaluation = policy.check_multiple(commands.iter(), &|_| Decision::Prompt);
assert_eq!(evaluation.decision, Decision::Forbidden);
}
}