use cel::Program;
use std::collections::HashMap;
#[derive(Debug)]
pub enum CelError {
CompileError(String),
EvalError(String),
}
impl std::fmt::Display for CelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CelError::CompileError(msg) => write!(f, "CEL compile error: {}", msg),
CelError::EvalError(msg) => write!(f, "CEL eval error: {}", msg),
}
}
}
impl std::error::Error for CelError {}
#[derive(Debug, Clone, PartialEq)]
pub enum CelResult {
Met(bool),
MissingParameters(Vec<String>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Bool(bool),
Int(i64),
String(String),
List(Vec<Value>),
}
pub fn compile(expr: &str) -> Result<Program, CelError> {
Program::compile(expr).map_err(|e| CelError::CompileError(e.to_string()))
}
pub fn evaluate(
program: &Program,
context: &HashMap<String, Value>,
) -> Result<CelResult, CelError> {
let mut cel_context = cel::Context::default();
for (key, value) in context {
let cel_value = match value {
Value::Bool(b) => cel::Value::Bool(*b),
Value::Int(i) => cel::Value::Int(*i),
Value::String(s) => cel::Value::String(s.clone().into()),
Value::List(items) => {
let cel_items: Vec<cel::Value> = items
.iter()
.map(|v| match v {
Value::Bool(b) => cel::Value::Bool(*b),
Value::Int(i) => cel::Value::Int(*i),
Value::String(s) => cel::Value::String(s.clone().into()),
Value::List(_) => cel::Value::Null, })
.collect();
cel::Value::List(cel_items.into())
}
};
let _ = cel_context.add_variable(key, cel_value);
}
match program.execute(&cel_context) {
Ok(value) => {
match value {
cel::Value::Bool(b) => Ok(CelResult::Met(b)),
_ => Err(CelError::EvalError(format!(
"CEL expression must evaluate to boolean, got: {:?}",
value
))),
}
}
Err(e) => {
let err_msg = e.to_string();
let err_lower = err_msg.to_lowercase();
if err_lower.contains("undeclared") || err_lower.contains("not found") {
let missing = extract_missing_variable(&err_msg);
Ok(CelResult::MissingParameters(vec![missing]))
} else {
Err(CelError::EvalError(err_msg))
}
}
}
}
fn extract_missing_variable(err_msg: &str) -> String {
if let Some(start) = err_msg.find('\'')
&& let Some(end) = err_msg[start + 1..].find('\'')
{
return err_msg[start + 1..start + 1 + end].to_string();
}
if let Some(idx) = err_msg.find("undeclared") {
let rest = &err_msg[idx..];
if let Some(word_start) = rest.rfind(' ') {
return rest[word_start + 1..].trim().to_string();
}
}
"unknown".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_valid_expression() {
let result = compile("x == 42");
assert!(result.is_ok());
}
#[test]
fn test_compile_invalid_expression() {
let result = compile("x ==");
assert!(result.is_err());
}
#[test]
fn test_eval_true() {
let program = compile("x == 42").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(42));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_false() {
let program = compile("x == 42").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(99));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(false));
}
#[test]
fn test_eval_missing_params() {
let program = compile("x == 42").unwrap();
let context = HashMap::new(); let result = evaluate(&program, &context).unwrap();
match result {
CelResult::MissingParameters(params) => {
assert!(!params.is_empty());
}
_ => panic!("Expected MissingParameters"),
}
}
#[test]
fn test_eval_string_comparison() {
let program = compile("name == \"alice\"").unwrap();
let mut context = HashMap::new();
context.insert(
"name".to_string(),
Value::String("alice".to_string().into()),
);
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_list_contains() {
let program = compile("x in [1, 2, 3]").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(2));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_boolean_logic() {
let program = compile("x > 0 && y < 10").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(5));
context.insert("y".to_string(), Value::Int(3));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_boolean_logic_edge() {
let program = compile("x > 0 && y < 10").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(5));
context.insert("y".to_string(), Value::Int(8));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_nested_logic() {
let program = compile("(x > 0 && y < 10) || z == true").unwrap();
let mut context = HashMap::new();
context.insert("x".to_string(), Value::Int(-1));
context.insert("y".to_string(), Value::Int(5));
context.insert("z".to_string(), Value::Bool(true));
let result = evaluate(&program, &context).unwrap();
assert_eq!(result, CelResult::Met(true));
}
#[test]
fn test_eval_empty_expression() {
let result = compile("");
assert!(result.is_err());
}
}