use crate::error::PolicyError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Expr {
Claim(String),
Not(Box<Expr>),
And(Vec<Expr>),
Or(Vec<Expr>),
}
impl Expr {
pub fn and(parts: Vec<Expr>) -> Expr {
let mut flat = Vec::with_capacity(parts.len());
for p in parts {
match p {
Expr::And(children) => flat.extend(children),
other => flat.push(other),
}
}
if flat.len() == 1 {
flat.into_iter().next().unwrap()
} else {
Expr::And(flat)
}
}
pub fn or(parts: Vec<Expr>) -> Expr {
let mut flat = Vec::with_capacity(parts.len());
for p in parts {
match p {
Expr::Or(children) => flat.extend(children),
other => flat.push(other),
}
}
if flat.len() == 1 {
flat.into_iter().next().unwrap()
} else {
Expr::Or(flat)
}
}
#[allow(clippy::should_implement_trait)]
pub fn not(inner: Expr) -> Expr {
Expr::Not(Box::new(inner))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tri {
True,
False,
Unknown,
}
impl std::ops::Not for Tri {
type Output = Tri;
fn not(self) -> Self {
match self {
Tri::True => Tri::False,
Tri::False => Tri::True,
Tri::Unknown => Tri::Unknown,
}
}
}
pub fn evaluate<F>(expr: &Expr, lookup: &F) -> Tri
where
F: Fn(&str) -> Tri,
{
match expr {
Expr::Claim(name) => lookup(name),
Expr::Not(inner) => !evaluate(inner, lookup),
Expr::And(children) => {
let mut all_true = true;
for c in children {
match evaluate(c, lookup) {
Tri::False => return Tri::False,
Tri::True => {}
Tri::Unknown => all_true = false,
}
}
if all_true { Tri::True } else { Tri::Unknown }
}
Expr::Or(children) => {
let mut all_false = true;
for c in children {
match evaluate(c, lookup) {
Tri::True => return Tri::True,
Tri::False => {}
Tri::Unknown => all_false = false,
}
}
if all_false { Tri::False } else { Tri::Unknown }
}
}
}
#[derive(Debug, PartialEq, Eq)]
enum Token {
Ident(String),
And,
Or,
Not,
Implies,
LParen,
RParen,
}
fn tokenize(input: &str) -> Result<Vec<Token>, PolicyError> {
let mut out = Vec::new();
let mut chars = input.chars().peekable();
while let Some(&c) = chars.peek() {
if c.is_whitespace() {
chars.next();
continue;
}
if c == '(' {
chars.next();
out.push(Token::LParen);
continue;
}
if c == ')' {
chars.next();
out.push(Token::RParen);
continue;
}
if is_ident_start(c) {
let mut s = String::new();
while let Some(&c) = chars.peek() {
if is_ident_continue(c) {
s.push(c);
chars.next();
} else {
break;
}
}
out.push(match s.to_ascii_lowercase().as_str() {
"and" => Token::And,
"or" => Token::Or,
"not" => Token::Not,
"implies" => Token::Implies,
_ => Token::Ident(s),
});
continue;
}
return Err(PolicyError::ExprParse(format!(
"unexpected character {c:?} in expression"
)));
}
Ok(out)
}
fn is_ident_start(c: char) -> bool {
c.is_ascii_alphabetic() || c == '_'
}
fn is_ident_continue(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '_' || c == '-'
}
struct Parser {
tokens: std::vec::IntoIter<Token>,
peeked: Option<Token>,
}
impl Parser {
fn new(tokens: Vec<Token>) -> Self {
Self {
tokens: tokens.into_iter(),
peeked: None,
}
}
fn peek(&mut self) -> Option<&Token> {
if self.peeked.is_none() {
self.peeked = self.tokens.next();
}
self.peeked.as_ref()
}
fn consume(&mut self) -> Option<Token> {
if let Some(t) = self.peeked.take() {
return Some(t);
}
self.tokens.next()
}
fn parse_expr(&mut self) -> Result<Expr, PolicyError> {
self.parse_implies()
}
fn parse_implies(&mut self) -> Result<Expr, PolicyError> {
let left = self.parse_or()?;
if matches!(self.peek(), Some(Token::Implies)) {
self.consume();
let right = self.parse_implies()?;
return Ok(Expr::or(vec![Expr::not(left), right]));
}
Ok(left)
}
fn parse_or(&mut self) -> Result<Expr, PolicyError> {
let mut parts = vec![self.parse_and()?];
while matches!(self.peek(), Some(Token::Or)) {
self.consume();
parts.push(self.parse_and()?);
}
Ok(Expr::or(parts))
}
fn parse_and(&mut self) -> Result<Expr, PolicyError> {
let mut parts = vec![self.parse_not()?];
while matches!(self.peek(), Some(Token::And)) {
self.consume();
parts.push(self.parse_not()?);
}
Ok(Expr::and(parts))
}
fn parse_not(&mut self) -> Result<Expr, PolicyError> {
if matches!(self.peek(), Some(Token::Not)) {
self.consume();
let inner = self.parse_not()?;
return Ok(Expr::not(inner));
}
self.parse_atom()
}
fn parse_atom(&mut self) -> Result<Expr, PolicyError> {
match self.consume() {
Some(Token::Ident(s)) => Ok(Expr::Claim(s)),
Some(Token::LParen) => {
let inner = self.parse_expr()?;
match self.consume() {
Some(Token::RParen) => Ok(inner),
_ => Err(PolicyError::ExprParse("missing closing ')'".into())),
}
}
Some(t) => Err(PolicyError::ExprParse(format!(
"unexpected token {t:?}; expected claim or '('"
))),
None => Err(PolicyError::ExprParse(
"unexpected end of expression".into(),
)),
}
}
}
pub fn parse(input: &str) -> Result<Expr, PolicyError> {
let tokens = tokenize(input)?;
let mut p = Parser::new(tokens);
let expr = p.parse_expr()?;
if p.peek().is_some() {
return Err(PolicyError::ExprParse(format!(
"trailing tokens after expression: {:?}",
p.consume()
)));
}
Ok(expr)
}
#[cfg(test)]
mod tests {
use super::*;
fn ev(expr: &str, claims: &[(&str, bool)]) -> Tri {
let e = parse(expr).unwrap();
let lookup = |name: &str| match claims.iter().find(|(n, _)| *n == name) {
Some((_, true)) => Tri::True,
Some((_, false)) => Tri::False,
None => Tri::Unknown,
};
evaluate(&e, &lookup)
}
#[test]
fn single_claim() {
assert_eq!(ev("safe-to-deploy", &[("safe-to-deploy", true)]), Tri::True);
assert_eq!(
ev("safe-to-deploy", &[("safe-to-deploy", false)]),
Tri::False
);
assert_eq!(ev("safe-to-deploy", &[]), Tri::Unknown);
}
#[test]
fn and_basic() {
assert_eq!(ev("a and b", &[("a", true), ("b", true)]), Tri::True);
assert_eq!(ev("a and b", &[("a", true), ("b", false)]), Tri::False);
assert_eq!(ev("a and b", &[("a", true)]), Tri::Unknown);
}
#[test]
fn and_short_circuits_on_false() {
assert_eq!(ev("a and b", &[("a", false)]), Tri::False);
}
#[test]
fn or_basic() {
assert_eq!(ev("a or b", &[("a", false), ("b", true)]), Tri::True);
assert_eq!(ev("a or b", &[("a", false), ("b", false)]), Tri::False);
assert_eq!(ev("a or b", &[("a", false)]), Tri::Unknown);
}
#[test]
fn or_short_circuits_on_true() {
assert_eq!(ev("a or b", &[("a", true)]), Tri::True);
}
#[test]
fn not_inverts() {
assert_eq!(ev("not a", &[("a", true)]), Tri::False);
assert_eq!(ev("not a", &[("a", false)]), Tri::True);
assert_eq!(ev("not a", &[]), Tri::Unknown);
}
#[test]
fn precedence_not_binds_tightest() {
assert_eq!(ev("not a and b", &[("a", false), ("b", true)]), Tri::True);
assert_eq!(ev("not a and b", &[("a", true), ("b", true)]), Tri::False);
}
#[test]
fn precedence_and_binds_tighter_than_or() {
assert_eq!(
ev("a or b and c", &[("a", false), ("b", true), ("c", true)]),
Tri::True
);
assert_eq!(
ev("a or b and c", &[("a", false), ("b", true), ("c", false)]),
Tri::False
);
}
#[test]
fn parens_override_precedence() {
assert_eq!(
ev("(a or b) and c", &[("a", true), ("c", false)]),
Tri::False
);
}
#[test]
fn nested_not_and_or() {
assert_eq!(
ev(
"not (a and b) or c",
&[("a", true), ("b", true), ("c", false)]
),
Tri::False
);
assert_eq!(
ev("not (a and b) or c", &[("a", true), ("b", false)]),
Tri::True
);
}
#[test]
fn case_insensitive_keywords() {
assert_eq!(ev("a AND b", &[("a", true), ("b", true)]), Tri::True);
assert_eq!(ev("NOT a", &[("a", true)]), Tri::False);
assert_eq!(ev("a IMPLIES b", &[("a", true), ("b", true)]), Tri::True);
}
#[test]
fn implies_truth_table() {
assert_eq!(ev("a implies b", &[("a", true), ("b", true)]), Tri::True);
assert_eq!(ev("a implies b", &[("a", true), ("b", false)]), Tri::False);
assert_eq!(ev("a implies b", &[("a", false), ("b", false)]), Tri::True);
assert_eq!(ev("a implies b", &[("a", false)]), Tri::True);
assert_eq!(ev("a implies b", &[("a", true)]), Tri::Unknown);
assert_eq!(ev("a implies b", &[("b", true)]), Tri::True);
assert_eq!(ev("a implies b", &[("b", false)]), Tri::Unknown);
}
#[test]
fn implies_lower_than_or_and_and() {
assert_eq!(
ev(
"a implies b and c",
&[("a", true), ("b", true), ("c", false)]
),
Tri::False
);
assert_eq!(
ev(
"a or b implies c",
&[("a", true), ("b", false), ("c", false)]
),
Tri::False
);
assert_eq!(
ev(
"a or b implies c",
&[("a", false), ("b", false), ("c", false)]
),
Tri::True
);
}
#[test]
fn implies_right_associative() {
assert_eq!(
ev(
"a implies b implies c",
&[("a", false), ("b", true), ("c", false)]
),
Tri::True
);
}
#[test]
fn implies_desugars_to_or_not() {
let lhs = parse("a implies b").unwrap();
let rhs = parse("(not a) or b").unwrap();
assert_eq!(lhs, rhs);
}
#[test]
fn nary_and_chain_flattens() {
let e = parse("a and b and c").unwrap();
let Expr::And(children) = e else {
panic!("expected top-level And");
};
assert_eq!(children.len(), 3);
assert_eq!(children[0], Expr::Claim("a".into()));
assert_eq!(children[1], Expr::Claim("b".into()));
assert_eq!(children[2], Expr::Claim("c".into()));
}
#[test]
fn nary_or_chain_flattens() {
let e = parse("a or b or c").unwrap();
let Expr::Or(children) = e else {
panic!("expected top-level Or");
};
assert_eq!(children.len(), 3);
}
#[test]
fn parens_with_same_op_get_spliced() {
let e = parse("(a or b) or c").unwrap();
let Expr::Or(children) = e else {
panic!("expected flat Or");
};
assert_eq!(children.len(), 3);
let e = parse("a and (b and c)").unwrap();
let Expr::And(children) = e else {
panic!("expected flat And");
};
assert_eq!(children.len(), 3);
}
#[test]
fn implies_with_or_rhs_splices_into_one_or() {
let e = parse("a implies b or c").unwrap();
let Expr::Or(children) = e else {
panic!("expected top-level Or");
};
assert_eq!(children.len(), 3);
assert_eq!(children[0], Expr::not(Expr::Claim("a".into())));
assert_eq!(children[1], Expr::Claim("b".into()));
assert_eq!(children[2], Expr::Claim("c".into()));
}
#[test]
fn mixed_op_nesting_is_preserved() {
let e = parse("a and (b or c) and d").unwrap();
let Expr::And(children) = e else {
panic!("expected top-level And");
};
assert_eq!(children.len(), 3);
assert!(matches!(&children[1], Expr::Or(inner) if inner.len() == 2));
}
#[test]
fn single_element_collapses() {
let x = Expr::Claim("x".into());
assert_eq!(Expr::and(vec![x.clone()]), x);
assert_eq!(Expr::or(vec![x.clone()]), x);
}
#[test]
fn parse_errors() {
assert!(parse("a and").is_err());
assert!(parse("(a or b").is_err());
assert!(parse("and a").is_err());
assert!(parse("a b").is_err()); assert!(parse("").is_err());
assert!(parse("a #").is_err()); assert!(parse("a implies").is_err()); assert!(parse("implies b").is_err()); }
}