Skip to main content

cel_core/
unparser.rs

1//! CEL expression unparser (AST to source text).
2//!
3//! This module converts a CEL AST back into source text. The output is
4//! semantically equivalent to the original expression but may differ in
5//! formatting (whitespace, parenthesization, etc.).
6//!
7//! # Example
8//!
9//! ```
10//! use cel_core::unparser::ast_to_string;
11//! use cel_core::parser::parse;
12//!
13//! let ast = parse("x + 1").ast.unwrap();
14//! let source = ast_to_string(&ast);
15//! assert_eq!(source, "x + 1");
16//! ```
17
18use crate::types::{BinaryOp, Expr, ListElement, MapEntry, SpannedExpr, StructField, UnaryOp};
19
20/// Convert a CEL AST to source text.
21///
22/// The output is a valid CEL expression that is semantically equivalent
23/// to the input AST. Formatting may differ from the original source.
24pub fn ast_to_string(expr: &SpannedExpr) -> String {
25    unparse(&expr.node)
26}
27
28/// Returns the precedence of a binary operator (higher = binds tighter).
29fn precedence(op: BinaryOp) -> u8 {
30    match op {
31        BinaryOp::Or => 1,
32        BinaryOp::And => 2,
33        BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge | BinaryOp::In => 3,
34        BinaryOp::Add | BinaryOp::Sub => 4,
35        BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => 5,
36    }
37}
38
39/// Returns the operator symbol for a binary operator.
40fn binary_op_symbol(op: BinaryOp) -> &'static str {
41    match op {
42        BinaryOp::Add => "+",
43        BinaryOp::Sub => "-",
44        BinaryOp::Mul => "*",
45        BinaryOp::Div => "/",
46        BinaryOp::Mod => "%",
47        BinaryOp::Eq => "==",
48        BinaryOp::Ne => "!=",
49        BinaryOp::Lt => "<",
50        BinaryOp::Le => "<=",
51        BinaryOp::Gt => ">",
52        BinaryOp::Ge => ">=",
53        BinaryOp::In => "in",
54        BinaryOp::And => "&&",
55        BinaryOp::Or => "||",
56    }
57}
58
59/// Returns the operator symbol for a unary operator.
60fn unary_op_symbol(op: UnaryOp) -> &'static str {
61    match op {
62        UnaryOp::Neg => "-",
63        UnaryOp::Not => "!",
64    }
65}
66
67/// Unparse an expression to a string.
68fn unparse(expr: &Expr) -> String {
69    match expr {
70        // Literals
71        Expr::Null => "null".to_string(),
72        Expr::Bool(b) => b.to_string(),
73        Expr::Int(n) => n.to_string(),
74        Expr::UInt(n) => format!("{}u", n),
75        Expr::Float(f) => format_float(*f),
76        Expr::String(s) => format!("\"{}\"", escape_string(s)),
77        Expr::Bytes(b) => format!("b\"{}\"", escape_bytes(b)),
78
79        // Identifiers
80        Expr::Ident(name) => name.clone(),
81        Expr::RootIdent(name) => format!(".{}", name),
82
83        // Collections
84        Expr::List(elements) => unparse_list(elements),
85        Expr::Map(entries) => unparse_map(entries),
86
87        // Operations
88        Expr::Unary { op, expr } => unparse_unary(*op, expr),
89        Expr::Binary { op, left, right } => unparse_binary(*op, left, right),
90        Expr::Ternary { cond, then_expr, else_expr } => {
91            format!(
92                "{} ? {} : {}",
93                unparse_with_parens_if_needed(&cond.node, Some(0)),
94                unparse(&then_expr.node),
95                unparse(&else_expr.node)
96            )
97        }
98
99        // Access
100        Expr::Member { expr, field, optional } => {
101            let op = if *optional { ".?" } else { "." };
102            format!("{}{}{}", unparse_primary(&expr.node), op, field)
103        }
104        Expr::Index { expr, index, optional } => {
105            let brackets = if *optional { "[?" } else { "[" };
106            format!("{}{}{}]", unparse_primary(&expr.node), brackets, unparse(&index.node))
107        }
108        Expr::Call { expr, args } => unparse_call(expr, args),
109        Expr::Struct { type_name, fields } => unparse_struct(type_name, fields),
110
111        // Macro expansions - unparse back to macro syntax where possible
112        Expr::Comprehension { iter_var, iter_var2, iter_range, accu_var, accu_init, loop_condition, loop_step, result } => {
113            unparse_comprehension(iter_var, iter_var2, iter_range, accu_var, accu_init, loop_condition, loop_step, result)
114        }
115        Expr::MemberTestOnly { expr, field } => {
116            format!("has({}.{})", unparse(&expr.node), field)
117        }
118        Expr::Bind { var_name, init, body } => {
119            format!("cel.bind({}, {}, {})", var_name, unparse(&init.node), unparse(&body.node))
120        }
121
122        Expr::Error => "<error>".to_string(),
123    }
124}
125
126/// Format a float, ensuring it always has a decimal point or exponent.
127fn format_float(f: f64) -> String {
128    if f.is_nan() {
129        return "double(\"NaN\")".to_string();
130    }
131    if f.is_infinite() {
132        return if f.is_sign_positive() {
133            "double(\"Infinity\")".to_string()
134        } else {
135            "double(\"-Infinity\")".to_string()
136        };
137    }
138
139    let s = f.to_string();
140    // Ensure we have a decimal point or exponent to distinguish from int
141    if s.contains('.') || s.contains('e') || s.contains('E') {
142        s
143    } else {
144        format!("{}.0", s)
145    }
146}
147
148/// Escape a string for CEL output.
149fn escape_string(s: &str) -> String {
150    let mut result = String::with_capacity(s.len());
151    for c in s.chars() {
152        match c {
153            '\\' => result.push_str("\\\\"),
154            '"' => result.push_str("\\\""),
155            '\n' => result.push_str("\\n"),
156            '\r' => result.push_str("\\r"),
157            '\t' => result.push_str("\\t"),
158            c if c.is_control() => {
159                // Use Unicode escape for other control characters
160                result.push_str(&format!("\\u{:04x}", c as u32));
161            }
162            c => result.push(c),
163        }
164    }
165    result
166}
167
168/// Escape bytes for CEL output.
169fn escape_bytes(bytes: &[u8]) -> String {
170    let mut result = String::with_capacity(bytes.len() * 2);
171    for &b in bytes {
172        match b {
173            b'\\' => result.push_str("\\\\"),
174            b'"' => result.push_str("\\\""),
175            b'\n' => result.push_str("\\n"),
176            b'\r' => result.push_str("\\r"),
177            b'\t' => result.push_str("\\t"),
178            b if b.is_ascii_graphic() || b == b' ' => result.push(b as char),
179            b => result.push_str(&format!("\\x{:02x}", b)),
180        }
181    }
182    result
183}
184
185/// Unparse a list expression.
186fn unparse_list(elements: &[ListElement]) -> String {
187    let items: Vec<String> = elements
188        .iter()
189        .map(|elem| {
190            if elem.optional {
191                format!("?{}", unparse(&elem.expr.node))
192            } else {
193                unparse(&elem.expr.node)
194            }
195        })
196        .collect();
197    format!("[{}]", items.join(", "))
198}
199
200/// Unparse a map expression.
201fn unparse_map(entries: &[MapEntry]) -> String {
202    let items: Vec<String> = entries
203        .iter()
204        .map(|entry| {
205            let key = unparse(&entry.key.node);
206            let value = unparse(&entry.value.node);
207            if entry.optional {
208                format!("?{}: {}", key, value)
209            } else {
210                format!("{}: {}", key, value)
211            }
212        })
213        .collect();
214    format!("{{{}}}", items.join(", "))
215}
216
217/// Unparse a unary expression.
218fn unparse_unary(op: UnaryOp, expr: &SpannedExpr) -> String {
219    let op_str = unary_op_symbol(op);
220    // Need parens around binary expressions and ternaries
221    match &expr.node {
222        Expr::Binary { .. } | Expr::Ternary { .. } => {
223            format!("{}({})", op_str, unparse(&expr.node))
224        }
225        _ => format!("{}{}", op_str, unparse(&expr.node)),
226    }
227}
228
229/// Unparse a binary expression with proper precedence handling.
230fn unparse_binary(op: BinaryOp, left: &SpannedExpr, right: &SpannedExpr) -> String {
231    let op_prec = precedence(op);
232    let left_str = unparse_with_parens_if_needed(&left.node, Some(op_prec));
233    let right_str = unparse_with_parens_if_needed(&right.node, Some(op_prec));
234    format!("{} {} {}", left_str, binary_op_symbol(op), right_str)
235}
236
237/// Unparse an expression, adding parentheses if needed based on precedence.
238fn unparse_with_parens_if_needed(expr: &Expr, parent_prec: Option<u8>) -> String {
239    match expr {
240        Expr::Binary { op, .. } => {
241            let expr_prec = precedence(*op);
242            if let Some(p) = parent_prec {
243                if expr_prec < p {
244                    return format!("({})", unparse(expr));
245                }
246            }
247            unparse(expr)
248        }
249        Expr::Ternary { .. } => {
250            // Ternary always needs parens when nested in binary
251            if parent_prec.is_some() {
252                format!("({})", unparse(expr))
253            } else {
254                unparse(expr)
255            }
256        }
257        _ => unparse(expr),
258    }
259}
260
261/// Unparse a primary expression (may need parens for member/index access).
262fn unparse_primary(expr: &Expr) -> String {
263    match expr {
264        Expr::Binary { .. } | Expr::Ternary { .. } | Expr::Unary { .. } => {
265            format!("({})", unparse(expr))
266        }
267        _ => unparse(expr),
268    }
269}
270
271/// Unparse a function call.
272fn unparse_call(expr: &SpannedExpr, args: &[SpannedExpr]) -> String {
273    let args_str: Vec<String> = args.iter().map(|a| unparse(&a.node)).collect();
274
275    // Check if this is a method call (expr is Member) or function call (expr is Ident)
276    match &expr.node {
277        Expr::Ident(name) => {
278            // Global function call
279            format!("{}({})", name, args_str.join(", "))
280        }
281        Expr::Member { expr: receiver, field, optional: false } => {
282            // Method call: receiver.method(args)
283            format!("{}.{}({})", unparse_primary(&receiver.node), field, args_str.join(", "))
284        }
285        Expr::Member { expr: receiver, field, optional: true } => {
286            // Optional method call: receiver.?method(args)
287            format!("{}.?{}({})", unparse_primary(&receiver.node), field, args_str.join(", "))
288        }
289        _ => {
290            // Generic callable expression
291            format!("{}({})", unparse_primary(&expr.node), args_str.join(", "))
292        }
293    }
294}
295
296/// Unparse a struct literal.
297fn unparse_struct(type_name: &SpannedExpr, fields: &[StructField]) -> String {
298    let type_str = unparse(&type_name.node);
299    let fields_str: Vec<String> = fields
300        .iter()
301        .map(|f| {
302            if f.optional {
303                format!("?{}: {}", f.name, unparse(&f.value.node))
304            } else {
305                format!("{}: {}", f.name, unparse(&f.value.node))
306            }
307        })
308        .collect();
309    format!("{}{{{}}}", type_str, fields_str.join(", "))
310}
311
312/// Unparse a comprehension back to macro syntax.
313///
314/// This attempts to recognize common macro patterns and unparse them appropriately.
315/// For unrecognized patterns, it falls back to a generic representation.
316fn unparse_comprehension(
317    iter_var: &str,
318    iter_var2: &str,
319    iter_range: &SpannedExpr,
320    accu_var: &str,
321    accu_init: &SpannedExpr,
322    loop_condition: &SpannedExpr,
323    loop_step: &SpannedExpr,
324    result: &SpannedExpr,
325) -> String {
326    // Try to recognize common macro patterns
327
328    // Check for `all` pattern: accu_init=true, loop_step=accu && condition, result=accu
329    if matches!(&accu_init.node, Expr::Bool(true)) {
330        if let Expr::Binary { op: BinaryOp::And, left, right } = &loop_step.node {
331            if matches!(&left.node, Expr::Ident(name) if name == accu_var) {
332                if matches!(&result.node, Expr::Ident(name) if name == accu_var) {
333                    let range_str = unparse(&iter_range.node);
334                    let condition_str = unparse(&right.node);
335                    return format!("{}.all({}, {})", range_str, iter_var, condition_str);
336                }
337            }
338        }
339    }
340
341    // Check for `exists` pattern: accu_init=false, loop_step=accu || condition, result=accu
342    if matches!(&accu_init.node, Expr::Bool(false)) {
343        if let Expr::Binary { op: BinaryOp::Or, left, right } = &loop_step.node {
344            if matches!(&left.node, Expr::Ident(name) if name == accu_var) {
345                if matches!(&result.node, Expr::Ident(name) if name == accu_var) {
346                    let range_str = unparse(&iter_range.node);
347                    let condition_str = unparse(&right.node);
348                    return format!("{}.exists({}, {})", range_str, iter_var, condition_str);
349                }
350            }
351        }
352    }
353
354    // Check for `map` pattern: accu_init=[], loop_step=accu + [transform], result=accu
355    if matches!(&accu_init.node, Expr::List(elems) if elems.is_empty()) {
356        if let Expr::Binary { op: BinaryOp::Add, left, right } = &loop_step.node {
357            if matches!(&left.node, Expr::Ident(name) if name == accu_var) {
358                if let Expr::List(elems) = &right.node {
359                    if elems.len() == 1 && !elems[0].optional {
360                        if matches!(&result.node, Expr::Ident(name) if name == accu_var) {
361                            let range_str = unparse(&iter_range.node);
362                            let transform_str = unparse(&elems[0].expr.node);
363                            return format!("{}.map({}, {})", range_str, iter_var, transform_str);
364                        }
365                    }
366                }
367            }
368        }
369    }
370
371    // Check for `filter` pattern: accu_init=[], loop_step=accu + ([iter_var] if condition else []), result=accu
372    // This is more complex, so we'll use a simplified check
373    if matches!(&accu_init.node, Expr::List(elems) if elems.is_empty()) {
374        if let Expr::Ternary { cond, then_expr, else_expr } = &loop_step.node {
375            if let Expr::List(then_elems) = &then_expr.node {
376                if then_elems.len() == 1 {
377                    if let Expr::Ident(elem_name) = &then_elems[0].expr.node {
378                        if elem_name == iter_var {
379                            if let Expr::List(else_elems) = &else_expr.node {
380                                if else_elems.is_empty() {
381                                    if matches!(&result.node, Expr::Ident(name) if name == accu_var) {
382                                        let range_str = unparse(&iter_range.node);
383                                        let condition_str = unparse(&cond.node);
384                                        return format!("{}.filter({}, {})", range_str, iter_var, condition_str);
385                                    }
386                                }
387                            }
388                        }
389                    }
390                }
391            }
392        }
393    }
394
395    // Fallback: generic comprehension representation
396    // This is not valid CEL syntax but provides a readable representation
397    let iter_vars = if iter_var2.is_empty() {
398        iter_var.to_string()
399    } else {
400        format!("{}, {}", iter_var, iter_var2)
401    };
402
403    format!(
404        "__comprehension__({}, {}, {}, {}, {}, {}, {})",
405        unparse(&iter_range.node),
406        iter_vars,
407        accu_var,
408        unparse(&accu_init.node),
409        unparse(&loop_condition.node),
410        unparse(&loop_step.node),
411        unparse(&result.node)
412    )
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use crate::parser::parse;
419
420    fn roundtrip(source: &str) -> String {
421        let ast = parse(source).ast.expect("parse failed");
422        ast_to_string(&ast)
423    }
424
425    #[test]
426    fn test_literals() {
427        assert_eq!(roundtrip("null"), "null");
428        assert_eq!(roundtrip("true"), "true");
429        assert_eq!(roundtrip("false"), "false");
430        assert_eq!(roundtrip("42"), "42");
431        assert_eq!(roundtrip("42u"), "42u");
432        assert_eq!(roundtrip("3.14"), "3.14");
433        assert_eq!(roundtrip("\"hello\""), "\"hello\"");
434        assert_eq!(roundtrip("b\"bytes\""), "b\"bytes\"");
435    }
436
437    #[test]
438    fn test_string_escaping() {
439        assert_eq!(roundtrip(r#""hello\nworld""#), "\"hello\\nworld\"");
440        assert_eq!(roundtrip(r#""tab\there""#), "\"tab\\there\"");
441        assert_eq!(roundtrip(r#""quote\"here""#), "\"quote\\\"here\"");
442    }
443
444    #[test]
445    fn test_identifiers() {
446        assert_eq!(roundtrip("x"), "x");
447        assert_eq!(roundtrip("foo_bar"), "foo_bar");
448    }
449
450    #[test]
451    fn test_collections() {
452        assert_eq!(roundtrip("[]"), "[]");
453        assert_eq!(roundtrip("[1, 2, 3]"), "[1, 2, 3]");
454        assert_eq!(roundtrip("{}"), "{}");
455        assert_eq!(roundtrip("{\"a\": 1, \"b\": 2}"), "{\"a\": 1, \"b\": 2}");
456    }
457
458    #[test]
459    fn test_unary() {
460        assert_eq!(roundtrip("-x"), "-x");
461        assert_eq!(roundtrip("!x"), "!x");
462        assert_eq!(roundtrip("--x"), "--x");
463    }
464
465    #[test]
466    fn test_binary() {
467        assert_eq!(roundtrip("x + y"), "x + y");
468        assert_eq!(roundtrip("x - y"), "x - y");
469        assert_eq!(roundtrip("x * y"), "x * y");
470        assert_eq!(roundtrip("x / y"), "x / y");
471        assert_eq!(roundtrip("x % y"), "x % y");
472        assert_eq!(roundtrip("x == y"), "x == y");
473        assert_eq!(roundtrip("x != y"), "x != y");
474        assert_eq!(roundtrip("x < y"), "x < y");
475        assert_eq!(roundtrip("x <= y"), "x <= y");
476        assert_eq!(roundtrip("x > y"), "x > y");
477        assert_eq!(roundtrip("x >= y"), "x >= y");
478        assert_eq!(roundtrip("x in y"), "x in y");
479        assert_eq!(roundtrip("x && y"), "x && y");
480        assert_eq!(roundtrip("x || y"), "x || y");
481    }
482
483    #[test]
484    fn test_precedence() {
485        // Ensure parentheses are added where needed
486        assert_eq!(roundtrip("(x + y) * z"), "(x + y) * z");
487        assert_eq!(roundtrip("x * (y + z)"), "x * (y + z)");
488        // Same precedence doesn't need parens
489        assert_eq!(roundtrip("x + y + z"), "x + y + z");
490    }
491
492    #[test]
493    fn test_ternary() {
494        assert_eq!(roundtrip("x ? y : z"), "x ? y : z");
495        assert_eq!(roundtrip("a ? b ? c : d : e"), "a ? b ? c : d : e");
496    }
497
498    #[test]
499    fn test_member_access() {
500        assert_eq!(roundtrip("x.y"), "x.y");
501        assert_eq!(roundtrip("x.y.z"), "x.y.z");
502    }
503
504    #[test]
505    fn test_index_access() {
506        assert_eq!(roundtrip("x[0]"), "x[0]");
507        assert_eq!(roundtrip("x[\"key\"]"), "x[\"key\"]");
508    }
509
510    #[test]
511    fn test_function_calls() {
512        assert_eq!(roundtrip("f()"), "f()");
513        assert_eq!(roundtrip("f(x)"), "f(x)");
514        assert_eq!(roundtrip("f(x, y, z)"), "f(x, y, z)");
515    }
516
517    #[test]
518    fn test_method_calls() {
519        assert_eq!(roundtrip("x.size()"), "x.size()");
520        assert_eq!(roundtrip("x.contains(y)"), "x.contains(y)");
521        assert_eq!(roundtrip("\"hello\".startsWith(\"h\")"), "\"hello\".startsWith(\"h\")");
522    }
523
524    #[test]
525    fn test_complex_expression() {
526        assert_eq!(
527            roundtrip("x > 0 && y < 10 || z == \"test\""),
528            "x > 0 && y < 10 || z == \"test\""
529        );
530    }
531
532    #[test]
533    fn test_has_macro() {
534        // has() macro expands to MemberTestOnly
535        assert_eq!(roundtrip("has(x.y)"), "has(x.y)");
536    }
537}