1use cel::Program;
4use std::collections::HashMap;
5
6#[derive(Debug)]
7pub enum CelError {
8 CompileError(String),
9 EvalError(String),
10}
11
12impl std::fmt::Display for CelError {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 match self {
15 CelError::CompileError(msg) => write!(f, "CEL compile error: {}", msg),
16 CelError::EvalError(msg) => write!(f, "CEL eval error: {}", msg),
17 }
18 }
19}
20
21impl std::error::Error for CelError {}
22
23#[derive(Debug, Clone, PartialEq)]
24pub enum CelResult {
25 Met(bool),
26 MissingParameters(Vec<String>),
27}
28
29#[derive(Debug, Clone, PartialEq)]
30pub enum Value {
31 Bool(bool),
32 Int(i64),
33 String(String),
34 List(Vec<Value>),
35}
36
37pub fn compile(expr: &str) -> Result<Program, CelError> {
39 Program::compile(expr).map_err(|e| CelError::CompileError(e.to_string()))
40}
41
42pub fn evaluate(
46 program: &Program,
47 context: &HashMap<String, Value>,
48) -> Result<CelResult, CelError> {
49 let mut cel_context = cel::Context::default();
51 for (key, value) in context {
52 let cel_value = match value {
53 Value::Bool(b) => cel::Value::Bool(*b),
54 Value::Int(i) => cel::Value::Int(*i),
55 Value::String(s) => cel::Value::String(s.clone().into()),
56 Value::List(items) => {
57 let cel_items: Vec<cel::Value> = items
58 .iter()
59 .map(|v| match v {
60 Value::Bool(b) => cel::Value::Bool(*b),
61 Value::Int(i) => cel::Value::Int(*i),
62 Value::String(s) => cel::Value::String(s.clone().into()),
63 Value::List(_) => cel::Value::Null, })
65 .collect();
66 cel::Value::List(cel_items.into())
67 }
68 };
69 let _ = cel_context.add_variable(key, cel_value);
70 }
71
72 match program.execute(&cel_context) {
74 Ok(value) => {
75 match value {
77 cel::Value::Bool(b) => Ok(CelResult::Met(b)),
78 _ => Err(CelError::EvalError(format!(
79 "CEL expression must evaluate to boolean, got: {:?}",
80 value
81 ))),
82 }
83 }
84 Err(e) => {
85 let err_msg = e.to_string();
86 let err_lower = err_msg.to_lowercase();
88 if err_lower.contains("undeclared") || err_lower.contains("not found") {
89 let missing = extract_missing_variable(&err_msg);
91 Ok(CelResult::MissingParameters(vec![missing]))
92 } else {
93 Err(CelError::EvalError(err_msg))
94 }
95 }
96 }
97}
98
99fn extract_missing_variable(err_msg: &str) -> String {
100 if let Some(start) = err_msg.find('\'')
103 && let Some(end) = err_msg[start + 1..].find('\'')
104 {
105 return err_msg[start + 1..start + 1 + end].to_string();
106 }
107 if let Some(idx) = err_msg.find("undeclared") {
109 let rest = &err_msg[idx..];
110 if let Some(word_start) = rest.rfind(' ') {
111 return rest[word_start + 1..].trim().to_string();
112 }
113 }
114 "unknown".to_string()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_compile_valid_expression() {
123 let result = compile("x == 42");
124 assert!(result.is_ok());
125 }
126
127 #[test]
128 fn test_compile_invalid_expression() {
129 let result = compile("x ==");
130 assert!(result.is_err());
131 }
132
133 #[test]
134 fn test_eval_true() {
135 let program = compile("x == 42").unwrap();
136 let mut context = HashMap::new();
137 context.insert("x".to_string(), Value::Int(42));
138 let result = evaluate(&program, &context).unwrap();
139 assert_eq!(result, CelResult::Met(true));
140 }
141
142 #[test]
143 fn test_eval_false() {
144 let program = compile("x == 42").unwrap();
145 let mut context = HashMap::new();
146 context.insert("x".to_string(), Value::Int(99));
147 let result = evaluate(&program, &context).unwrap();
148 assert_eq!(result, CelResult::Met(false));
149 }
150
151 #[test]
152 fn test_eval_missing_params() {
153 let program = compile("x == 42").unwrap();
154 let context = HashMap::new(); let result = evaluate(&program, &context).unwrap();
156 match result {
157 CelResult::MissingParameters(params) => {
158 assert!(!params.is_empty());
159 }
160 _ => panic!("Expected MissingParameters"),
161 }
162 }
163
164 #[test]
165 fn test_eval_string_comparison() {
166 let program = compile("name == \"alice\"").unwrap();
167 let mut context = HashMap::new();
168 context.insert(
169 "name".to_string(),
170 Value::String("alice".to_string().into()),
171 );
172 let result = evaluate(&program, &context).unwrap();
173 assert_eq!(result, CelResult::Met(true));
174 }
175
176 #[test]
177 fn test_eval_list_contains() {
178 let program = compile("x in [1, 2, 3]").unwrap();
179 let mut context = HashMap::new();
180 context.insert("x".to_string(), Value::Int(2));
181 let result = evaluate(&program, &context).unwrap();
182 assert_eq!(result, CelResult::Met(true));
183 }
184
185 #[test]
186 fn test_eval_boolean_logic() {
187 let program = compile("x > 0 && y < 10").unwrap();
188 let mut context = HashMap::new();
189 context.insert("x".to_string(), Value::Int(5));
190 context.insert("y".to_string(), Value::Int(3));
191 let result = evaluate(&program, &context).unwrap();
192 assert_eq!(result, CelResult::Met(true));
193 }
194
195 #[test]
196 fn test_eval_boolean_logic_edge() {
197 let program = compile("x > 0 && y < 10").unwrap();
198 let mut context = HashMap::new();
199 context.insert("x".to_string(), Value::Int(5));
200 context.insert("y".to_string(), Value::Int(8));
201 let result = evaluate(&program, &context).unwrap();
202 assert_eq!(result, CelResult::Met(true));
203 }
204
205 #[test]
206 fn test_eval_nested_logic() {
207 let program = compile("(x > 0 && y < 10) || z == true").unwrap();
208 let mut context = HashMap::new();
209 context.insert("x".to_string(), Value::Int(-1));
210 context.insert("y".to_string(), Value::Int(5));
211 context.insert("z".to_string(), Value::Bool(true));
212 let result = evaluate(&program, &context).unwrap();
213 assert_eq!(result, CelResult::Met(true));
214 }
215
216 #[test]
217 fn test_eval_empty_expression() {
218 let result = compile("");
219 assert!(result.is_err());
220 }
221}