Skip to main content

authz_core/
cel.rs

1//! CEL condition compilation and evaluation.
2
3use cel::Program;
4use std::collections::HashMap;
5
6#[derive(Debug)]
7pub enum CelError {
8    CompileError(String),
9    EvalError(String),
10}
11
12impl std::fmt::Display for CelError {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        match self {
15            CelError::CompileError(msg) => write!(f, "CEL compile error: {}", msg),
16            CelError::EvalError(msg) => write!(f, "CEL eval error: {}", msg),
17        }
18    }
19}
20
21impl std::error::Error for CelError {}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum CelResult {
25    Met(bool),
26    MissingParameters(Vec<String>),
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum Value {
31    Bool(bool),
32    Int(i64),
33    String(String),
34    List(Vec<Value>),
35}
36
37/// Compile a CEL expression into a Program.
38pub fn compile(expr: &str) -> Result<Program, CelError> {
39    Program::compile(expr).map_err(|e| CelError::CompileError(e.to_string()))
40}
41
42/// Evaluate a compiled CEL program with the given context.
43/// Returns CelResult::Met(bool) if evaluation succeeds,
44/// or CelResult::MissingParameters if required parameters are missing.
45pub fn evaluate(
46    program: &Program,
47    context: &HashMap<String, Value>,
48) -> Result<CelResult, CelError> {
49    // Convert our Value type to cel::Value
50    let mut cel_context = cel::Context::default();
51    for (key, value) in context {
52        let cel_value = match value {
53            Value::Bool(b) => cel::Value::Bool(*b),
54            Value::Int(i) => cel::Value::Int(*i),
55            Value::String(s) => cel::Value::String(s.clone().into()),
56            Value::List(items) => {
57                let cel_items: Vec<cel::Value> = items
58                    .iter()
59                    .map(|v| match v {
60                        Value::Bool(b) => cel::Value::Bool(*b),
61                        Value::Int(i) => cel::Value::Int(*i),
62                        Value::String(s) => cel::Value::String(s.clone().into()),
63                        Value::List(_) => cel::Value::Null, // nested lists not supported
64                    })
65                    .collect();
66                cel::Value::List(cel_items.into())
67            }
68        };
69        let _ = cel_context.add_variable(key, cel_value);
70    }
71
72    // Execute the program
73    match program.execute(&cel_context) {
74        Ok(value) => {
75            // Convert result to boolean
76            match value {
77                cel::Value::Bool(b) => Ok(CelResult::Met(b)),
78                _ => Err(CelError::EvalError(format!(
79                    "CEL expression must evaluate to boolean, got: {:?}",
80                    value
81                ))),
82            }
83        }
84        Err(e) => {
85            let err_msg = e.to_string();
86            // Check if error is due to missing variable (case-insensitive check)
87            let err_lower = err_msg.to_lowercase();
88            if err_lower.contains("undeclared") || err_lower.contains("not found") {
89                // Extract variable name from error message
90                let missing = extract_missing_variable(&err_msg);
91                Ok(CelResult::MissingParameters(vec![missing]))
92            } else {
93                Err(CelError::EvalError(err_msg))
94            }
95        }
96    }
97}
98
99fn extract_missing_variable(err_msg: &str) -> String {
100    // Try to extract variable name from error message
101    // Example: "undeclared reference to 'x'"
102    if let Some(start) = err_msg.find('\'')
103        && let Some(end) = err_msg[start + 1..].find('\'')
104    {
105        return err_msg[start + 1..start + 1 + end].to_string();
106    }
107    // Try without quotes
108    if let Some(idx) = err_msg.find("undeclared") {
109        let rest = &err_msg[idx..];
110        if let Some(word_start) = rest.rfind(' ') {
111            return rest[word_start + 1..].trim().to_string();
112        }
113    }
114    "unknown".to_string()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_compile_valid_expression() {
123        let result = compile("x == 42");
124        assert!(result.is_ok());
125    }
126
127    #[test]
128    fn test_compile_invalid_expression() {
129        let result = compile("x ==");
130        assert!(result.is_err());
131    }
132
133    #[test]
134    fn test_eval_true() {
135        let program = compile("x == 42").unwrap();
136        let mut context = HashMap::new();
137        context.insert("x".to_string(), Value::Int(42));
138        let result = evaluate(&program, &context).unwrap();
139        assert_eq!(result, CelResult::Met(true));
140    }
141
142    #[test]
143    fn test_eval_false() {
144        let program = compile("x == 42").unwrap();
145        let mut context = HashMap::new();
146        context.insert("x".to_string(), Value::Int(99));
147        let result = evaluate(&program, &context).unwrap();
148        assert_eq!(result, CelResult::Met(false));
149    }
150
151    #[test]
152    fn test_eval_missing_params() {
153        let program = compile("x == 42").unwrap();
154        let context = HashMap::new(); // Empty context
155        let result = evaluate(&program, &context).unwrap();
156        match result {
157            CelResult::MissingParameters(params) => {
158                assert!(!params.is_empty());
159            }
160            _ => panic!("Expected MissingParameters"),
161        }
162    }
163
164    #[test]
165    fn test_eval_string_comparison() {
166        let program = compile("name == \"alice\"").unwrap();
167        let mut context = HashMap::new();
168        context.insert(
169            "name".to_string(),
170            Value::String("alice".to_string().into()),
171        );
172        let result = evaluate(&program, &context).unwrap();
173        assert_eq!(result, CelResult::Met(true));
174    }
175
176    #[test]
177    fn test_eval_list_contains() {
178        let program = compile("x in [1, 2, 3]").unwrap();
179        let mut context = HashMap::new();
180        context.insert("x".to_string(), Value::Int(2));
181        let result = evaluate(&program, &context).unwrap();
182        assert_eq!(result, CelResult::Met(true));
183    }
184
185    #[test]
186    fn test_eval_boolean_logic() {
187        let program = compile("x > 0 && y < 10").unwrap();
188        let mut context = HashMap::new();
189        context.insert("x".to_string(), Value::Int(5));
190        context.insert("y".to_string(), Value::Int(3));
191        let result = evaluate(&program, &context).unwrap();
192        assert_eq!(result, CelResult::Met(true));
193    }
194
195    #[test]
196    fn test_eval_boolean_logic_edge() {
197        let program = compile("x > 0 && y < 10").unwrap();
198        let mut context = HashMap::new();
199        context.insert("x".to_string(), Value::Int(5));
200        context.insert("y".to_string(), Value::Int(8));
201        let result = evaluate(&program, &context).unwrap();
202        assert_eq!(result, CelResult::Met(true));
203    }
204
205    #[test]
206    fn test_eval_nested_logic() {
207        let program = compile("(x > 0 && y < 10) || z == true").unwrap();
208        let mut context = HashMap::new();
209        context.insert("x".to_string(), Value::Int(-1));
210        context.insert("y".to_string(), Value::Int(5));
211        context.insert("z".to_string(), Value::Bool(true));
212        let result = evaluate(&program, &context).unwrap();
213        assert_eq!(result, CelResult::Met(true));
214    }
215
216    #[test]
217    fn test_eval_empty_expression() {
218        let result = compile("");
219        assert!(result.is_err());
220    }
221}