Skip to main content

cel_core/
unparser.rs

1//! CEL expression unparser (AST to source text).
2//!
3//! This module converts a CEL AST back into source text. The output is
4//! semantically equivalent to the original expression but may differ in
5//! formatting (whitespace, parenthesization, etc.).
6//!
7//! # Example
8//!
9//! ```
10//! use cel_core::ast_to_string;
11//! use cel_core::parse;
12//!
13//! let ast = parse("x + 1").ast.unwrap();
14//! let source = ast_to_string(&ast);
15//! assert_eq!(source, "x + 1");
16//! ```
17
18use crate::types::{
19    BinaryOp, ComprehensionData, Expr, ListElement, MapEntry, SpannedExpr, StructField, UnaryOp,
20};
21
22/// Convert a CEL AST to source text.
23///
24/// The output is a valid CEL expression that is semantically equivalent
25/// to the input AST. Formatting may differ from the original source.
26pub fn ast_to_string(expr: &SpannedExpr) -> String {
27    unparse(&expr.node)
28}
29
30/// Returns the precedence of a binary operator (higher = binds tighter).
31fn precedence(op: BinaryOp) -> u8 {
32    match op {
33        BinaryOp::Or => 1,
34        BinaryOp::And => 2,
35        BinaryOp::Eq
36        | BinaryOp::Ne
37        | BinaryOp::Lt
38        | BinaryOp::Le
39        | BinaryOp::Gt
40        | BinaryOp::Ge
41        | BinaryOp::In => 3,
42        BinaryOp::Add | BinaryOp::Sub => 4,
43        BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => 5,
44    }
45}
46
47/// Returns the operator symbol for a binary operator.
48fn binary_op_symbol(op: BinaryOp) -> &'static str {
49    match op {
50        BinaryOp::Add => "+",
51        BinaryOp::Sub => "-",
52        BinaryOp::Mul => "*",
53        BinaryOp::Div => "/",
54        BinaryOp::Mod => "%",
55        BinaryOp::Eq => "==",
56        BinaryOp::Ne => "!=",
57        BinaryOp::Lt => "<",
58        BinaryOp::Le => "<=",
59        BinaryOp::Gt => ">",
60        BinaryOp::Ge => ">=",
61        BinaryOp::In => "in",
62        BinaryOp::And => "&&",
63        BinaryOp::Or => "||",
64    }
65}
66
67/// Returns the operator symbol for a unary operator.
68fn unary_op_symbol(op: UnaryOp) -> &'static str {
69    match op {
70        UnaryOp::Neg => "-",
71        UnaryOp::Not => "!",
72    }
73}
74
75/// Unparse an expression to a string.
76fn unparse(expr: &Expr) -> String {
77    match expr {
78        // Literals
79        Expr::Null => "null".to_string(),
80        Expr::Bool(b) => b.to_string(),
81        Expr::Int(n) => n.to_string(),
82        Expr::UInt(n) => format!("{}u", n),
83        Expr::Float(f) => format_float(*f),
84        Expr::String(s) => format!("\"{}\"", escape_string(s)),
85        Expr::Bytes(b) => format!("b\"{}\"", escape_bytes(b)),
86
87        // Identifiers
88        Expr::Ident(name) => name.clone(),
89        Expr::RootIdent(name) => format!(".{}", name),
90
91        // Collections
92        Expr::List(elements) => unparse_list(elements),
93        Expr::Map(entries) => unparse_map(entries),
94
95        // Operations
96        Expr::Unary { op, expr } => unparse_unary(*op, expr),
97        Expr::Binary { op, left, right } => unparse_binary(*op, left, right),
98        Expr::Ternary {
99            cond,
100            then_expr,
101            else_expr,
102        } => {
103            format!(
104                "{} ? {} : {}",
105                unparse_with_parens_if_needed(&cond.node, Some(0)),
106                unparse(&then_expr.node),
107                unparse(&else_expr.node)
108            )
109        }
110
111        // Access
112        Expr::Member {
113            expr,
114            field,
115            optional,
116        } => {
117            let op = if *optional { ".?" } else { "." };
118            format!("{}{}{}", unparse_primary(&expr.node), op, field)
119        }
120        Expr::Index {
121            expr,
122            index,
123            optional,
124        } => {
125            let brackets = if *optional { "[?" } else { "[" };
126            format!(
127                "{}{}{}]",
128                unparse_primary(&expr.node),
129                brackets,
130                unparse(&index.node)
131            )
132        }
133        Expr::Call { expr, args } => unparse_call(expr, args),
134        Expr::Struct { type_name, fields } => unparse_struct(type_name, fields),
135
136        // Macro expansions - unparse back to macro syntax where possible
137        Expr::Comprehension(comp) => unparse_comprehension(comp),
138        Expr::MemberTestOnly { expr, field } => {
139            format!("has({}.{})", unparse(&expr.node), field)
140        }
141        Expr::Bind {
142            var_name,
143            init,
144            body,
145        } => {
146            format!(
147                "cel.bind({}, {}, {})",
148                var_name,
149                unparse(&init.node),
150                unparse(&body.node)
151            )
152        }
153
154        Expr::Error => "<error>".to_string(),
155    }
156}
157
158/// Format a float, ensuring it always has a decimal point or exponent.
159fn format_float(f: f64) -> String {
160    if f.is_nan() {
161        return "double(\"NaN\")".to_string();
162    }
163    if f.is_infinite() {
164        return if f.is_sign_positive() {
165            "double(\"Infinity\")".to_string()
166        } else {
167            "double(\"-Infinity\")".to_string()
168        };
169    }
170
171    let s = f.to_string();
172    // Ensure we have a decimal point or exponent to distinguish from int
173    if s.contains('.') || s.contains('e') || s.contains('E') {
174        s
175    } else {
176        format!("{}.0", s)
177    }
178}
179
180/// Escape a string for CEL output.
181fn escape_string(s: &str) -> String {
182    let mut result = String::with_capacity(s.len());
183    for c in s.chars() {
184        match c {
185            '\\' => result.push_str("\\\\"),
186            '"' => result.push_str("\\\""),
187            '\n' => result.push_str("\\n"),
188            '\r' => result.push_str("\\r"),
189            '\t' => result.push_str("\\t"),
190            c if c.is_control() => {
191                // Use Unicode escape for other control characters
192                result.push_str(&format!("\\u{:04x}", c as u32));
193            }
194            c => result.push(c),
195        }
196    }
197    result
198}
199
200/// Escape bytes for CEL output.
201fn escape_bytes(bytes: &[u8]) -> String {
202    let mut result = String::with_capacity(bytes.len() * 2);
203    for &b in bytes {
204        match b {
205            b'\\' => result.push_str("\\\\"),
206            b'"' => result.push_str("\\\""),
207            b'\n' => result.push_str("\\n"),
208            b'\r' => result.push_str("\\r"),
209            b'\t' => result.push_str("\\t"),
210            b if b.is_ascii_graphic() || b == b' ' => result.push(b as char),
211            b => result.push_str(&format!("\\x{:02x}", b)),
212        }
213    }
214    result
215}
216
217/// Unparse a list expression.
218fn unparse_list(elements: &[ListElement]) -> String {
219    let items: Vec<String> = elements
220        .iter()
221        .map(|elem| {
222            if elem.optional {
223                format!("?{}", unparse(&elem.expr.node))
224            } else {
225                unparse(&elem.expr.node)
226            }
227        })
228        .collect();
229    format!("[{}]", items.join(", "))
230}
231
232/// Unparse a map expression.
233fn unparse_map(entries: &[MapEntry]) -> String {
234    let items: Vec<String> = entries
235        .iter()
236        .map(|entry| {
237            let key = unparse(&entry.key.node);
238            let value = unparse(&entry.value.node);
239            if entry.optional {
240                format!("?{}: {}", key, value)
241            } else {
242                format!("{}: {}", key, value)
243            }
244        })
245        .collect();
246    format!("{{{}}}", items.join(", "))
247}
248
249/// Unparse a unary expression.
250fn unparse_unary(op: UnaryOp, expr: &SpannedExpr) -> String {
251    let op_str = unary_op_symbol(op);
252    // Need parens around binary expressions and ternaries
253    match &expr.node {
254        Expr::Binary { .. } | Expr::Ternary { .. } => {
255            format!("{}({})", op_str, unparse(&expr.node))
256        }
257        _ => format!("{}{}", op_str, unparse(&expr.node)),
258    }
259}
260
261/// Unparse a binary expression with proper precedence handling.
262fn unparse_binary(op: BinaryOp, left: &SpannedExpr, right: &SpannedExpr) -> String {
263    let op_prec = precedence(op);
264    let left_str = unparse_with_parens_if_needed(&left.node, Some(op_prec));
265    let right_str = unparse_with_parens_if_needed(&right.node, Some(op_prec));
266    format!("{} {} {}", left_str, binary_op_symbol(op), right_str)
267}
268
269/// Unparse an expression, adding parentheses if needed based on precedence.
270fn unparse_with_parens_if_needed(expr: &Expr, parent_prec: Option<u8>) -> String {
271    match expr {
272        Expr::Binary { op, .. } => {
273            let expr_prec = precedence(*op);
274            if let Some(p) = parent_prec {
275                if expr_prec < p {
276                    return format!("({})", unparse(expr));
277                }
278            }
279            unparse(expr)
280        }
281        Expr::Ternary { .. } => {
282            // Ternary always needs parens when nested in binary
283            if parent_prec.is_some() {
284                format!("({})", unparse(expr))
285            } else {
286                unparse(expr)
287            }
288        }
289        _ => unparse(expr),
290    }
291}
292
293/// Unparse a primary expression (may need parens for member/index access).
294fn unparse_primary(expr: &Expr) -> String {
295    match expr {
296        Expr::Binary { .. } | Expr::Ternary { .. } | Expr::Unary { .. } => {
297            format!("({})", unparse(expr))
298        }
299        _ => unparse(expr),
300    }
301}
302
303/// Unparse a function call.
304fn unparse_call(expr: &SpannedExpr, args: &[SpannedExpr]) -> String {
305    let args_str: Vec<String> = args.iter().map(|a| unparse(&a.node)).collect();
306
307    // Check if this is a method call (expr is Member) or function call (expr is Ident)
308    match &expr.node {
309        Expr::Ident(name) => {
310            // Global function call
311            format!("{}({})", name, args_str.join(", "))
312        }
313        Expr::Member {
314            expr: receiver,
315            field,
316            optional: false,
317        } => {
318            // Method call: receiver.method(args)
319            format!(
320                "{}.{}({})",
321                unparse_primary(&receiver.node),
322                field,
323                args_str.join(", ")
324            )
325        }
326        Expr::Member {
327            expr: receiver,
328            field,
329            optional: true,
330        } => {
331            // Optional method call: receiver.?method(args)
332            format!(
333                "{}.?{}({})",
334                unparse_primary(&receiver.node),
335                field,
336                args_str.join(", ")
337            )
338        }
339        _ => {
340            // Generic callable expression
341            format!("{}({})", unparse_primary(&expr.node), args_str.join(", "))
342        }
343    }
344}
345
346/// Unparse a struct literal.
347fn unparse_struct(type_name: &SpannedExpr, fields: &[StructField]) -> String {
348    let type_str = unparse(&type_name.node);
349    let fields_str: Vec<String> = fields
350        .iter()
351        .map(|f| {
352            if f.optional {
353                format!("?{}: {}", f.name, unparse(&f.value.node))
354            } else {
355                format!("{}: {}", f.name, unparse(&f.value.node))
356            }
357        })
358        .collect();
359    format!("{}{{{}}}", type_str, fields_str.join(", "))
360}
361
362/// Unparse a comprehension back to macro syntax.
363///
364/// This attempts to recognize common macro patterns and unparse them appropriately.
365/// For unrecognized patterns, it falls back to a generic representation.
366fn unparse_comprehension(comp: &ComprehensionData) -> String {
367    // Try to recognize common macro patterns
368
369    // Check for `all` pattern: accu_init=true, loop_step=accu && condition, result=accu
370    if matches!(&comp.accu_init.node, Expr::Bool(true)) {
371        if let Expr::Binary {
372            op: BinaryOp::And,
373            left,
374            right,
375        } = &comp.loop_step.node
376        {
377            if matches!(&left.node, Expr::Ident(name) if name == &comp.accu_var)
378                && matches!(&comp.result.node, Expr::Ident(name) if name == &comp.accu_var)
379            {
380                let range_str = unparse(&comp.iter_range.node);
381                let condition_str = unparse(&right.node);
382                return format!("{}.all({}, {})", range_str, comp.iter_var, condition_str);
383            }
384        }
385    }
386
387    // Check for `exists` pattern: accu_init=false, loop_step=accu || condition, result=accu
388    if matches!(&comp.accu_init.node, Expr::Bool(false)) {
389        if let Expr::Binary {
390            op: BinaryOp::Or,
391            left,
392            right,
393        } = &comp.loop_step.node
394        {
395            if matches!(&left.node, Expr::Ident(name) if name == &comp.accu_var)
396                && matches!(&comp.result.node, Expr::Ident(name) if name == &comp.accu_var)
397            {
398                let range_str = unparse(&comp.iter_range.node);
399                let condition_str = unparse(&right.node);
400                return format!("{}.exists({}, {})", range_str, comp.iter_var, condition_str);
401            }
402        }
403    }
404
405    // Check for `map` pattern: accu_init=[], loop_step=accu + [transform], result=accu
406    if matches!(&comp.accu_init.node, Expr::List(elems) if elems.is_empty()) {
407        if let Expr::Binary {
408            op: BinaryOp::Add,
409            left,
410            right,
411        } = &comp.loop_step.node
412        {
413            if matches!(&left.node, Expr::Ident(name) if name == &comp.accu_var) {
414                if let Expr::List(elems) = &right.node {
415                    if elems.len() == 1
416                        && !elems[0].optional
417                        && matches!(&comp.result.node, Expr::Ident(name) if name == &comp.accu_var)
418                    {
419                        let range_str = unparse(&comp.iter_range.node);
420                        let transform_str = unparse(&elems[0].expr.node);
421                        return format!("{}.map({}, {})", range_str, comp.iter_var, transform_str);
422                    }
423                }
424            }
425        }
426    }
427
428    // Check for `filter` pattern: accu_init=[], loop_step=accu + ([iter_var] if condition else []), result=accu
429    // This is more complex, so we'll use a simplified check
430    if matches!(&comp.accu_init.node, Expr::List(elems) if elems.is_empty()) {
431        if let Expr::Ternary {
432            cond,
433            then_expr,
434            else_expr,
435        } = &comp.loop_step.node
436        {
437            if let Expr::List(then_elems) = &then_expr.node {
438                if then_elems.len() == 1 {
439                    if let Expr::Ident(elem_name) = &then_elems[0].expr.node {
440                        if elem_name == &comp.iter_var {
441                            if let Expr::List(else_elems) = &else_expr.node {
442                                if else_elems.is_empty()
443                                    && matches!(&comp.result.node, Expr::Ident(name) if name == &comp.accu_var)
444                                {
445                                    let range_str = unparse(&comp.iter_range.node);
446                                    let condition_str = unparse(&cond.node);
447                                    return format!(
448                                        "{}.filter({}, {})",
449                                        range_str, comp.iter_var, condition_str
450                                    );
451                                }
452                            }
453                        }
454                    }
455                }
456            }
457        }
458    }
459
460    // Fallback: generic comprehension representation
461    // This is not valid CEL syntax but provides a readable representation
462    let iter_vars = if comp.iter_var2.is_empty() {
463        comp.iter_var.to_string()
464    } else {
465        format!("{}, {}", comp.iter_var, comp.iter_var2)
466    };
467
468    format!(
469        "__comprehension__({}, {}, {}, {}, {}, {}, {})",
470        unparse(&comp.iter_range.node),
471        iter_vars,
472        comp.accu_var,
473        unparse(&comp.accu_init.node),
474        unparse(&comp.loop_condition.node),
475        unparse(&comp.loop_step.node),
476        unparse(&comp.result.node)
477    )
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::parser::parse;
484
485    fn roundtrip(source: &str) -> String {
486        let ast = parse(source).ast.expect("parse failed");
487        ast_to_string(&ast)
488    }
489
490    #[test]
491    fn test_literals() {
492        assert_eq!(roundtrip("null"), "null");
493        assert_eq!(roundtrip("true"), "true");
494        assert_eq!(roundtrip("false"), "false");
495        assert_eq!(roundtrip("42"), "42");
496        assert_eq!(roundtrip("42u"), "42u");
497        assert_eq!(roundtrip("3.14"), "3.14");
498        assert_eq!(roundtrip("\"hello\""), "\"hello\"");
499        assert_eq!(roundtrip("b\"bytes\""), "b\"bytes\"");
500    }
501
502    #[test]
503    fn test_string_escaping() {
504        assert_eq!(roundtrip(r#""hello\nworld""#), "\"hello\\nworld\"");
505        assert_eq!(roundtrip(r#""tab\there""#), "\"tab\\there\"");
506        assert_eq!(roundtrip(r#""quote\"here""#), "\"quote\\\"here\"");
507    }
508
509    #[test]
510    fn test_identifiers() {
511        assert_eq!(roundtrip("x"), "x");
512        assert_eq!(roundtrip("foo_bar"), "foo_bar");
513    }
514
515    #[test]
516    fn test_collections() {
517        assert_eq!(roundtrip("[]"), "[]");
518        assert_eq!(roundtrip("[1, 2, 3]"), "[1, 2, 3]");
519        assert_eq!(roundtrip("{}"), "{}");
520        assert_eq!(roundtrip("{\"a\": 1, \"b\": 2}"), "{\"a\": 1, \"b\": 2}");
521    }
522
523    #[test]
524    fn test_unary() {
525        assert_eq!(roundtrip("-x"), "-x");
526        assert_eq!(roundtrip("!x"), "!x");
527        assert_eq!(roundtrip("--x"), "--x");
528    }
529
530    #[test]
531    fn test_binary() {
532        assert_eq!(roundtrip("x + y"), "x + y");
533        assert_eq!(roundtrip("x - y"), "x - y");
534        assert_eq!(roundtrip("x * y"), "x * y");
535        assert_eq!(roundtrip("x / y"), "x / y");
536        assert_eq!(roundtrip("x % y"), "x % y");
537        assert_eq!(roundtrip("x == y"), "x == y");
538        assert_eq!(roundtrip("x != y"), "x != y");
539        assert_eq!(roundtrip("x < y"), "x < y");
540        assert_eq!(roundtrip("x <= y"), "x <= y");
541        assert_eq!(roundtrip("x > y"), "x > y");
542        assert_eq!(roundtrip("x >= y"), "x >= y");
543        assert_eq!(roundtrip("x in y"), "x in y");
544        assert_eq!(roundtrip("x && y"), "x && y");
545        assert_eq!(roundtrip("x || y"), "x || y");
546    }
547
548    #[test]
549    fn test_precedence() {
550        // Ensure parentheses are added where needed
551        assert_eq!(roundtrip("(x + y) * z"), "(x + y) * z");
552        assert_eq!(roundtrip("x * (y + z)"), "x * (y + z)");
553        // Same precedence doesn't need parens
554        assert_eq!(roundtrip("x + y + z"), "x + y + z");
555    }
556
557    #[test]
558    fn test_ternary() {
559        assert_eq!(roundtrip("x ? y : z"), "x ? y : z");
560        assert_eq!(roundtrip("a ? b ? c : d : e"), "a ? b ? c : d : e");
561    }
562
563    #[test]
564    fn test_member_access() {
565        assert_eq!(roundtrip("x.y"), "x.y");
566        assert_eq!(roundtrip("x.y.z"), "x.y.z");
567    }
568
569    #[test]
570    fn test_index_access() {
571        assert_eq!(roundtrip("x[0]"), "x[0]");
572        assert_eq!(roundtrip("x[\"key\"]"), "x[\"key\"]");
573    }
574
575    #[test]
576    fn test_function_calls() {
577        assert_eq!(roundtrip("f()"), "f()");
578        assert_eq!(roundtrip("f(x)"), "f(x)");
579        assert_eq!(roundtrip("f(x, y, z)"), "f(x, y, z)");
580    }
581
582    #[test]
583    fn test_method_calls() {
584        assert_eq!(roundtrip("x.size()"), "x.size()");
585        assert_eq!(roundtrip("x.contains(y)"), "x.contains(y)");
586        assert_eq!(
587            roundtrip("\"hello\".startsWith(\"h\")"),
588            "\"hello\".startsWith(\"h\")"
589        );
590    }
591
592    #[test]
593    fn test_complex_expression() {
594        assert_eq!(
595            roundtrip("x > 0 && y < 10 || z == \"test\""),
596            "x > 0 && y < 10 || z == \"test\""
597        );
598    }
599
600    #[test]
601    fn test_has_macro() {
602        // has() macro expands to MemberTestOnly
603        assert_eq!(roundtrip("has(x.y)"), "has(x.y)");
604    }
605}