use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
use serde::{Deserialize, Serialize};
use crate::error::{PolicyError, Result};
pub const MAX_EXPR_DEPTH: usize = 32;
pub const MAX_EXPR_LENGTH: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ContextExpr {
And(Box<ContextExpr>, Box<ContextExpr>),
Or(Box<ContextExpr>, Box<ContextExpr>),
Not(Box<ContextExpr>),
HasAttribute(String),
Compare {
key: String,
op: CompareOp,
value: String,
},
True,
False,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompareOp {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
}
impl fmt::Display for CompareOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CompareOp::Equal => write!(f, "=="),
CompareOp::NotEqual => write!(f, "!="),
CompareOp::LessThan => write!(f, "<"),
CompareOp::LessThanOrEqual => write!(f, "<="),
CompareOp::GreaterThan => write!(f, ">"),
CompareOp::GreaterThanOrEqual => write!(f, ">="),
}
}
}
impl ContextExpr {
pub fn evaluate(&self, context: &BTreeMap<String, String>, depth: usize) -> Result<bool> {
if depth > MAX_EXPR_DEPTH {
return Err(PolicyError::ExpressionTooDeep {
max: MAX_EXPR_DEPTH,
});
}
match self {
ContextExpr::True => Ok(true),
ContextExpr::False => Ok(false),
ContextExpr::And(left, right) => {
let left_result = left.evaluate(context, depth + 1)?;
if !left_result {
return Ok(false);
}
right.evaluate(context, depth + 1)
}
ContextExpr::Or(left, right) => {
let left_result = left.evaluate(context, depth + 1)?;
if left_result {
return Ok(true);
}
right.evaluate(context, depth + 1)
}
ContextExpr::Not(expr) => {
let result = expr.evaluate(context, depth + 1)?;
Ok(!result)
}
ContextExpr::HasAttribute(key) => Ok(context.contains_key(key)),
ContextExpr::Compare { key, op, value } => {
let actual = match context.get(key) {
Some(v) => v,
None => return Ok(false),
};
Ok(compare_values(actual, value, *op))
}
}
}
pub fn parse(input: &str) -> Result<Self> {
if input.len() > MAX_EXPR_LENGTH {
return Err(PolicyError::ExpressionTooLong {
max: MAX_EXPR_LENGTH,
length: input.len(),
});
}
let tokens = tokenize(input)?;
let mut parser = Parser::new(&tokens);
parser.parse_expr()
}
}
fn compare_values(left: &str, right: &str, op: CompareOp) -> bool {
match op {
CompareOp::Equal => left == right,
CompareOp::NotEqual => left != right,
CompareOp::LessThan => left < right,
CompareOp::LessThanOrEqual => left <= right,
CompareOp::GreaterThan => left > right,
CompareOp::GreaterThanOrEqual => left >= right,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Token {
And,
Or,
Not,
Has,
True,
False,
LeftParen,
RightParen,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Identifier(String),
StringLiteral(String),
}
fn tokenize(input: &str) -> Result<Vec<Token>> {
let mut tokens = Vec::new();
let mut chars = input.chars().peekable();
while let Some(&ch) = chars.peek() {
match ch {
' ' | '\t' | '\n' | '\r' => {
chars.next();
}
'(' => {
tokens.push(Token::LeftParen);
chars.next();
}
')' => {
tokens.push(Token::RightParen);
chars.next();
}
'=' => {
chars.next();
if chars.peek() == Some(&'=') {
chars.next();
tokens.push(Token::Equal);
} else {
return Err(PolicyError::InvalidExpression(
"Single '=' not allowed, use '=='".into(),
));
}
}
'!' => {
chars.next();
if chars.peek() == Some(&'=') {
chars.next();
tokens.push(Token::NotEqual);
} else {
return Err(PolicyError::InvalidExpression(
"Single '!' not allowed, use '!=' or 'NOT'".into(),
));
}
}
'<' => {
chars.next();
if chars.peek() == Some(&'=') {
chars.next();
tokens.push(Token::LessThanOrEqual);
} else {
tokens.push(Token::LessThan);
}
}
'>' => {
chars.next();
if chars.peek() == Some(&'=') {
chars.next();
tokens.push(Token::GreaterThanOrEqual);
} else {
tokens.push(Token::GreaterThan);
}
}
'"' => {
chars.next();
let mut value = String::new();
let mut escaped = false;
loop {
match chars.next() {
Some('\\') if !escaped => escaped = true,
Some('"') if !escaped => break,
Some(c) => {
value.push(c);
escaped = false;
}
None => {
return Err(PolicyError::InvalidExpression(
"Unterminated string literal".into(),
))
}
}
}
tokens.push(Token::StringLiteral(value));
}
c if c.is_alphabetic() || c == '_' => {
let mut ident = String::new();
while let Some(&ch) = chars.peek() {
if ch.is_alphanumeric() || ch == '_' {
ident.push(ch);
chars.next();
} else {
break;
}
}
let token = match ident.as_str() {
"AND" => Token::And,
"OR" => Token::Or,
"NOT" => Token::Not,
"HAS" => Token::Has,
"TRUE" => Token::True,
"FALSE" => Token::False,
_ => Token::Identifier(ident),
};
tokens.push(token);
}
_ => {
return Err(PolicyError::InvalidExpression(format!(
"Unexpected character: '{}'",
ch
)))
}
}
}
Ok(tokens)
}
struct Parser<'a> {
tokens: &'a [Token],
pos: usize,
}
impl<'a> Parser<'a> {
fn new(tokens: &'a [Token]) -> Self {
Self { tokens, pos: 0 }
}
fn current(&self) -> Option<&Token> {
self.tokens.get(self.pos)
}
fn advance(&mut self) -> Option<&Token> {
let token = self.tokens.get(self.pos);
self.pos += 1;
token
}
fn expect(&mut self, expected: Token) -> Result<()> {
match self.advance() {
Some(token) if token == &expected => Ok(()),
Some(token) => Err(PolicyError::InvalidExpression(format!(
"Expected {:?}, got {:?}",
expected, token
))),
None => Err(PolicyError::InvalidExpression(format!(
"Expected {:?}, got EOF",
expected
))),
}
}
fn parse_expr(&mut self) -> Result<ContextExpr> {
self.parse_or()
}
fn parse_or(&mut self) -> Result<ContextExpr> {
let mut left = self.parse_and()?;
while matches!(self.current(), Some(Token::Or)) {
self.advance();
let right = self.parse_and()?;
left = ContextExpr::Or(Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_and(&mut self) -> Result<ContextExpr> {
let mut left = self.parse_not()?;
while matches!(self.current(), Some(Token::And)) {
self.advance();
let right = self.parse_not()?;
left = ContextExpr::And(Box::new(left), Box::new(right));
}
Ok(left)
}
fn parse_not(&mut self) -> Result<ContextExpr> {
if matches!(self.current(), Some(Token::Not)) {
self.advance();
let expr = self.parse_primary()?;
Ok(ContextExpr::Not(Box::new(expr)))
} else {
self.parse_primary()
}
}
fn parse_primary(&mut self) -> Result<ContextExpr> {
match self.current() {
Some(Token::True) => {
self.advance();
Ok(ContextExpr::True)
}
Some(Token::False) => {
self.advance();
Ok(ContextExpr::False)
}
Some(Token::Has) => {
self.advance();
match self.advance() {
Some(Token::Identifier(key)) => Ok(ContextExpr::HasAttribute(key.clone())),
_ => Err(PolicyError::InvalidExpression(
"Expected identifier after HAS".into(),
)),
}
}
Some(Token::LeftParen) => {
self.advance();
let expr = self.parse_expr()?;
self.expect(Token::RightParen)?;
Ok(expr)
}
Some(Token::Identifier(key)) => {
let key = key.clone();
self.advance();
let op = match self.current() {
Some(Token::Equal) => CompareOp::Equal,
Some(Token::NotEqual) => CompareOp::NotEqual,
Some(Token::LessThan) => CompareOp::LessThan,
Some(Token::LessThanOrEqual) => CompareOp::LessThanOrEqual,
Some(Token::GreaterThan) => CompareOp::GreaterThan,
Some(Token::GreaterThanOrEqual) => CompareOp::GreaterThanOrEqual,
_ => {
return Err(PolicyError::InvalidExpression(
"Expected comparison operator".into(),
))
}
};
self.advance();
let value = match self.advance() {
Some(Token::StringLiteral(v)) => v.clone(),
Some(Token::Identifier(v)) => v.clone(), _ => {
return Err(PolicyError::InvalidExpression(
"Expected value after comparison operator".into(),
))
}
};
Ok(ContextExpr::Compare { key, op, value })
}
_ => Err(PolicyError::InvalidExpression("Expected expression".into())),
}
}
}
#[cfg(kani)]
mod kani_proofs {
use super::*;
#[kani::proof]
#[kani::unwind(10)]
fn proof_evaluate_never_panics() {
let expr = ContextExpr::True;
let mut ctx = BTreeMap::new();
ctx.insert(String::from("key"), String::from("value"));
let _ = expr.evaluate(&ctx, 0);
}
#[kani::proof]
#[kani::unwind(0)]
fn proof_depth_limit_enforced() {
let expr = ContextExpr::True;
let ctx = BTreeMap::new();
let result = expr.evaluate(&ctx, MAX_EXPR_DEPTH + 1);
kani::assert(result.is_err(), "Depth > MAX must always fail");
}
#[kani::proof]
#[kani::unwind(0)]
fn proof_depth_no_overflow() {
let depth = MAX_EXPR_DEPTH;
let new_depth = depth + 1;
kani::assert(new_depth <= MAX_EXPR_DEPTH + 1, "Depth increment safe");
}
}