Skip to main content

varpulis_parser/
optimize.rs

1//! AST-level constant folding optimization pass.
2//!
3//! Evaluates constant arithmetic expressions at compile time, reducing
4//! runtime computation. Applied after parsing, before the AST reaches
5//! the runtime engine.
6
7use varpulis_core::ast::*;
8use varpulis_core::span::Spanned;
9
10/// Fold constants in an entire program.
11pub fn fold_program(program: Program) -> Program {
12    Program {
13        statements: program
14            .statements
15            .into_iter()
16            .map(fold_spanned_stmt)
17            .collect(),
18    }
19}
20
21fn fold_spanned_stmt(s: Spanned<Stmt>) -> Spanned<Stmt> {
22    Spanned::new(fold_stmt(s.node), s.span)
23}
24
25fn fold_stmt(stmt: Stmt) -> Stmt {
26    match stmt {
27        Stmt::VarDecl {
28            mutable,
29            name,
30            ty,
31            value,
32        } => Stmt::VarDecl {
33            mutable,
34            name,
35            ty,
36            value: fold_expr(value),
37        },
38        Stmt::ConstDecl { name, ty, value } => Stmt::ConstDecl {
39            name,
40            ty,
41            value: fold_expr(value),
42        },
43        Stmt::FnDecl {
44            name,
45            params,
46            ret,
47            body,
48        } => Stmt::FnDecl {
49            name,
50            params,
51            ret,
52            body: body.into_iter().map(fold_spanned_stmt).collect(),
53        },
54        Stmt::StreamDecl {
55            name,
56            type_annotation,
57            source,
58            ops,
59            op_spans,
60        } => Stmt::StreamDecl {
61            name,
62            type_annotation,
63            source,
64            ops: ops.into_iter().map(fold_stream_op).collect(),
65            op_spans,
66        },
67        Stmt::If {
68            cond,
69            then_branch,
70            elif_branches,
71            else_branch,
72        } => Stmt::If {
73            cond: fold_expr(cond),
74            then_branch: then_branch.into_iter().map(fold_spanned_stmt).collect(),
75            elif_branches: elif_branches
76                .into_iter()
77                .map(|(c, b)| (fold_expr(c), b.into_iter().map(fold_spanned_stmt).collect()))
78                .collect(),
79            else_branch: else_branch.map(|b| b.into_iter().map(fold_spanned_stmt).collect()),
80        },
81        Stmt::For { var, iter, body } => Stmt::For {
82            var,
83            iter: fold_expr(iter),
84            body: body.into_iter().map(fold_spanned_stmt).collect(),
85        },
86        Stmt::While { cond, body } => Stmt::While {
87            cond: fold_expr(cond),
88            body: body.into_iter().map(fold_spanned_stmt).collect(),
89        },
90        Stmt::Return(Some(expr)) => Stmt::Return(Some(fold_expr(expr))),
91        Stmt::Expr(expr) => Stmt::Expr(fold_expr(expr)),
92        Stmt::Assignment { name, value } => Stmt::Assignment {
93            name,
94            value: fold_expr(value),
95        },
96        Stmt::Emit { event_type, fields } => Stmt::Emit {
97            event_type,
98            fields: fields.into_iter().map(fold_named_arg).collect(),
99        },
100        // Pass through unchanged
101        other => other,
102    }
103}
104
105fn fold_named_arg(arg: NamedArg) -> NamedArg {
106    NamedArg {
107        name: arg.name,
108        value: fold_expr(arg.value),
109    }
110}
111
112fn fold_stream_op(op: StreamOp) -> StreamOp {
113    match op {
114        StreamOp::Where(expr) => StreamOp::Where(fold_expr(expr)),
115        StreamOp::Filter(expr) => StreamOp::Filter(fold_expr(expr)),
116        StreamOp::Map(expr) => StreamOp::Map(fold_expr(expr)),
117        StreamOp::Process(expr) => StreamOp::Process(fold_expr(expr)),
118        StreamOp::OnError(expr) => StreamOp::OnError(fold_expr(expr)),
119        StreamOp::Having(expr) => StreamOp::Having(fold_expr(expr)),
120        StreamOp::PartitionBy(expr) => StreamOp::PartitionBy(fold_expr(expr)),
121        StreamOp::Limit(expr) => StreamOp::Limit(fold_expr(expr)),
122        StreamOp::Distinct(opt) => StreamOp::Distinct(opt.map(fold_expr)),
123        StreamOp::On(expr) => StreamOp::On(fold_expr(expr)),
124        StreamOp::Within(expr) => StreamOp::Within(fold_expr(expr)),
125        StreamOp::AllowedLateness(expr) => StreamOp::AllowedLateness(fold_expr(expr)),
126        StreamOp::ToExpr(expr) => StreamOp::ToExpr(fold_expr(expr)),
127        StreamOp::Emit {
128            output_type,
129            fields,
130            target_context,
131        } => StreamOp::Emit {
132            output_type,
133            fields: fields.into_iter().map(fold_named_arg).collect(),
134            target_context,
135        },
136        StreamOp::Print(exprs) => StreamOp::Print(exprs.into_iter().map(fold_expr).collect()),
137        StreamOp::Log(args) => StreamOp::Log(args.into_iter().map(fold_named_arg).collect()),
138        StreamOp::Tap(args) => StreamOp::Tap(args.into_iter().map(fold_named_arg).collect()),
139        // Pass through unchanged
140        other => other,
141    }
142}
143
144/// Recursively fold constant expressions.
145fn fold_expr(expr: Expr) -> Expr {
146    match expr {
147        // Recurse into binary expressions
148        Expr::Binary { op, left, right } => {
149            let left = fold_expr(*left);
150            let right = fold_expr(*right);
151            fold_binary(op, left, right)
152        }
153        // Recurse into unary expressions
154        Expr::Unary { op, expr } => {
155            let inner = fold_expr(*expr);
156            fold_unary(op, inner)
157        }
158        // Recurse into function call arguments
159        Expr::Call { func, args } => Expr::Call {
160            func: Box::new(fold_expr(*func)),
161            args: args.into_iter().map(fold_arg).collect(),
162        },
163        // Recurse into array elements
164        Expr::Array(elems) => Expr::Array(elems.into_iter().map(fold_expr).collect()),
165        // Recurse into map values
166        Expr::Map(entries) => Expr::Map(
167            entries
168                .into_iter()
169                .map(|(k, v)| (k, fold_expr(v)))
170                .collect(),
171        ),
172        // Recurse into lambda body
173        Expr::Lambda { params, body } => Expr::Lambda {
174            params,
175            body: Box::new(fold_expr(*body)),
176        },
177        // Recurse into if expression
178        Expr::If {
179            cond,
180            then_branch,
181            else_branch,
182        } => Expr::If {
183            cond: Box::new(fold_expr(*cond)),
184            then_branch: Box::new(fold_expr(*then_branch)),
185            else_branch: Box::new(fold_expr(*else_branch)),
186        },
187        // Recurse into coalesce
188        Expr::Coalesce { expr, default } => Expr::Coalesce {
189            expr: Box::new(fold_expr(*expr)),
190            default: Box::new(fold_expr(*default)),
191        },
192        // Recurse into range
193        Expr::Range {
194            start,
195            end,
196            inclusive,
197        } => Expr::Range {
198            start: Box::new(fold_expr(*start)),
199            end: Box::new(fold_expr(*end)),
200            inclusive,
201        },
202        // Recurse into member access
203        Expr::Member { expr, member } => Expr::Member {
204            expr: Box::new(fold_expr(*expr)),
205            member,
206        },
207        // Recurse into optional member access
208        Expr::OptionalMember { expr, member } => Expr::OptionalMember {
209            expr: Box::new(fold_expr(*expr)),
210            member,
211        },
212        // Recurse into index access
213        Expr::Index { expr, index } => Expr::Index {
214            expr: Box::new(fold_expr(*expr)),
215            index: Box::new(fold_expr(*index)),
216        },
217        // Recurse into slice
218        Expr::Slice { expr, start, end } => Expr::Slice {
219            expr: Box::new(fold_expr(*expr)),
220            start: start.map(|s| Box::new(fold_expr(*s))),
221            end: end.map(|e| Box::new(fold_expr(*e))),
222        },
223        // Recurse into block expression
224        Expr::Block { stmts, result } => Expr::Block {
225            stmts: stmts
226                .into_iter()
227                .map(|(name, ty, val, mutable)| (name, ty, fold_expr(val), mutable))
228                .collect(),
229            result: Box::new(fold_expr(*result)),
230        },
231        // Literals and identifiers are already folded
232        other => other,
233    }
234}
235
236fn fold_arg(arg: Arg) -> Arg {
237    match arg {
238        Arg::Positional(expr) => Arg::Positional(fold_expr(expr)),
239        Arg::Named(name, expr) => Arg::Named(name, fold_expr(expr)),
240    }
241}
242
243/// Try to fold a binary operation on two (possibly constant) operands.
244fn fold_binary(op: BinOp, left: Expr, right: Expr) -> Expr {
245    // First: try full constant folding (both operands are literals)
246    match (&op, &left, &right) {
247        // Int OP Int
248        (BinOp::Add, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_add(*b)),
249        (BinOp::Sub, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_sub(*b)),
250        (BinOp::Mul, Expr::Int(a), Expr::Int(b)) => return Expr::Int(a.wrapping_mul(*b)),
251        (BinOp::Div, Expr::Int(a), Expr::Int(b)) if *b != 0 => return Expr::Int(a / b),
252        (BinOp::Mod, Expr::Int(a), Expr::Int(b)) if *b != 0 => return Expr::Int(a % b),
253        (BinOp::Pow, Expr::Int(a), Expr::Int(b)) if *b >= 0 => {
254            return Expr::Int(a.wrapping_pow(*b as u32));
255        }
256
257        // Float OP Float
258        (BinOp::Add, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a + b),
259        (BinOp::Sub, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a - b),
260        (BinOp::Mul, Expr::Float(a), Expr::Float(b)) => return Expr::Float(a * b),
261        (BinOp::Div, Expr::Float(a), Expr::Float(b)) if *b != 0.0 => {
262            return Expr::Float(a / b);
263        }
264
265        _ => {}
266    }
267
268    // Second: identity folding (one operand is a known identity value)
269    match (&op, &left, &right) {
270        // x * 0 or 0 * x → 0 (integer only)
271        (BinOp::Mul, _, Expr::Int(0)) | (BinOp::Mul, Expr::Int(0), _) => {
272            return Expr::Int(0);
273        }
274        // x * 1 → x
275        (BinOp::Mul, _, Expr::Int(1)) => return left,
276        // 1 * x → x
277        (BinOp::Mul, Expr::Int(1), _) => return right,
278        // x + 0 → x
279        (BinOp::Add, _, Expr::Int(0)) => return left,
280        // 0 + x → x
281        (BinOp::Add, Expr::Int(0), _) => return right,
282        // x - 0 → x
283        (BinOp::Sub, _, Expr::Int(0)) => return left,
284        // x / 1 → x
285        (BinOp::Div, _, Expr::Int(1)) => return left,
286
287        _ => {}
288    }
289
290    // Not foldable — reconstruct
291    Expr::Binary {
292        op,
293        left: Box::new(left),
294        right: Box::new(right),
295    }
296}
297
298/// Try to fold a unary operation on a constant operand.
299fn fold_unary(op: UnaryOp, inner: Expr) -> Expr {
300    match (&op, &inner) {
301        (UnaryOp::Neg, Expr::Int(a)) => Expr::Int(-a),
302        (UnaryOp::Neg, Expr::Float(a)) => Expr::Float(-a),
303        _ => Expr::Unary {
304            op,
305            expr: Box::new(inner),
306        },
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    // Helper to build a binary expression
315    fn bin(op: BinOp, left: Expr, right: Expr) -> Expr {
316        Expr::Binary {
317            op,
318            left: Box::new(left),
319            right: Box::new(right),
320        }
321    }
322
323    fn unary(op: UnaryOp, expr: Expr) -> Expr {
324        Expr::Unary {
325            op,
326            expr: Box::new(expr),
327        }
328    }
329
330    #[test]
331    fn fold_int_addition() {
332        let expr = bin(BinOp::Add, Expr::Int(1), Expr::Int(2));
333        assert_eq!(fold_expr(expr), Expr::Int(3));
334    }
335
336    #[test]
337    fn fold_int_subtraction() {
338        let expr = bin(BinOp::Sub, Expr::Int(5), Expr::Int(3));
339        assert_eq!(fold_expr(expr), Expr::Int(2));
340    }
341
342    #[test]
343    fn fold_int_multiplication() {
344        let expr = bin(BinOp::Mul, Expr::Int(2), Expr::Int(3));
345        assert_eq!(fold_expr(expr), Expr::Int(6));
346    }
347
348    #[test]
349    fn fold_int_division() {
350        let expr = bin(BinOp::Div, Expr::Int(10), Expr::Int(3));
351        assert_eq!(fold_expr(expr), Expr::Int(3));
352    }
353
354    #[test]
355    fn fold_int_pow() {
356        let expr = bin(BinOp::Pow, Expr::Int(2), Expr::Int(10));
357        assert_eq!(fold_expr(expr), Expr::Int(1024));
358    }
359
360    #[test]
361    fn fold_float_arithmetic() {
362        let expr = bin(BinOp::Mul, Expr::Float(2.0), Expr::Float(3.0));
363        assert_eq!(fold_expr(expr), Expr::Float(6.0));
364    }
365
366    #[test]
367    fn fold_identity_mul_zero() {
368        // x * 0 → 0
369        let expr = bin(BinOp::Mul, Expr::Ident("x".into()), Expr::Int(0));
370        assert_eq!(fold_expr(expr), Expr::Int(0));
371    }
372
373    #[test]
374    fn fold_identity_mul_one() {
375        // x * 1 → x
376        let expr = bin(BinOp::Mul, Expr::Ident("x".into()), Expr::Int(1));
377        assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
378    }
379
380    #[test]
381    fn fold_identity_add_zero() {
382        // x + 0 → x
383        let expr = bin(BinOp::Add, Expr::Ident("x".into()), Expr::Int(0));
384        assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
385    }
386
387    #[test]
388    fn fold_nested_expression() {
389        // (2 + 3) * 4 → 20
390        let inner = bin(BinOp::Add, Expr::Int(2), Expr::Int(3));
391        let expr = bin(BinOp::Mul, inner, Expr::Int(4));
392        assert_eq!(fold_expr(expr), Expr::Int(20));
393    }
394
395    #[test]
396    fn fold_unary_neg_int() {
397        let expr = unary(UnaryOp::Neg, Expr::Int(5));
398        assert_eq!(fold_expr(expr), Expr::Int(-5));
399    }
400
401    #[test]
402    fn fold_unary_neg_float() {
403        let expr = unary(UnaryOp::Neg, Expr::Float(2.75));
404        assert_eq!(fold_expr(expr), Expr::Float(-2.75));
405    }
406
407    #[test]
408    fn preserves_non_constant() {
409        // x + 1 stays as Binary
410        let expr = bin(BinOp::Add, Expr::Ident("x".into()), Expr::Int(1));
411        let folded = fold_expr(expr.clone());
412        assert_eq!(folded, expr);
413    }
414
415    #[test]
416    fn fold_in_call_args() {
417        // f(2 * 3) → f(6)
418        let call = Expr::Call {
419            func: Box::new(Expr::Ident("f".into())),
420            args: vec![Arg::Positional(bin(BinOp::Mul, Expr::Int(2), Expr::Int(3)))],
421        };
422        let folded = fold_expr(call);
423        assert_eq!(
424            folded,
425            Expr::Call {
426                func: Box::new(Expr::Ident("f".into())),
427                args: vec![Arg::Positional(Expr::Int(6))],
428            }
429        );
430    }
431
432    #[test]
433    fn div_by_zero_not_folded() {
434        // 1 / 0 stays as Binary
435        let expr = bin(BinOp::Div, Expr::Int(1), Expr::Int(0));
436        let folded = fold_expr(expr);
437        assert!(matches!(folded, Expr::Binary { .. }));
438    }
439
440    #[test]
441    fn fold_identity_sub_zero() {
442        // x - 0 → x
443        let expr = bin(BinOp::Sub, Expr::Ident("x".into()), Expr::Int(0));
444        assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
445    }
446
447    #[test]
448    fn fold_identity_div_one() {
449        // x / 1 → x
450        let expr = bin(BinOp::Div, Expr::Ident("x".into()), Expr::Int(1));
451        assert_eq!(fold_expr(expr), Expr::Ident("x".into()));
452    }
453}