Skip to main content

hiko_syntax/
constfold.rs

1use crate::ast::*;
2
3pub fn fold_program(mut program: Program) -> Program {
4    program.decls = program.decls.into_iter().map(fold_decl).collect();
5    program
6}
7
8fn fold_decl(decl: Decl) -> Decl {
9    let span = decl.span;
10    let kind = match decl.kind {
11        DeclKind::Val(pat, expr) => DeclKind::Val(pat, fold_expr(expr)),
12        DeclKind::ValRec(name, expr) => DeclKind::ValRec(name, fold_expr(expr)),
13        DeclKind::Fun(bindings) => DeclKind::Fun(
14            bindings
15                .into_iter()
16                .map(|mut b| {
17                    b.clauses = b
18                        .clauses
19                        .into_iter()
20                        .map(|mut c| {
21                            c.body = fold_expr(c.body);
22                            c
23                        })
24                        .collect();
25                    b
26                })
27                .collect(),
28        ),
29        DeclKind::Local(locals, body) => DeclKind::Local(
30            locals.into_iter().map(fold_decl).collect(),
31            body.into_iter().map(fold_decl).collect(),
32        ),
33        other => other,
34    };
35    Decl { kind, span }
36}
37
38fn fold_expr(expr: Expr) -> Expr {
39    let span = expr.span;
40    let kind = match expr.kind {
41        // Constant fold: if true/false
42        ExprKind::If(cond, then_br, else_br) => {
43            let cond = fold_expr(*cond);
44            let then_br = fold_expr(*then_br);
45            let else_br = fold_expr(*else_br);
46            match &cond.kind {
47                ExprKind::BoolLit(true) => return then_br,
48                ExprKind::BoolLit(false) => return else_br,
49                _ => ExprKind::If(Box::new(cond), Box::new(then_br), Box::new(else_br)),
50            }
51        }
52
53        // Constant fold: integer arithmetic
54        ExprKind::BinOp(op, lhs, rhs) => {
55            let lhs = fold_expr(*lhs);
56            let rhs = fold_expr(*rhs);
57            if let Some(result) = try_fold_binop(op, &lhs, &rhs) {
58                return Expr { kind: result, span };
59            }
60            ExprKind::BinOp(op, Box::new(lhs), Box::new(rhs))
61        }
62
63        // Constant fold: negation
64        ExprKind::UnaryNeg(e) => {
65            let e = fold_expr(*e);
66            match &e.kind {
67                ExprKind::IntLit(n) => match n.checked_neg() {
68                    Some(neg) => ExprKind::IntLit(neg),
69                    None => ExprKind::UnaryNeg(Box::new(e)),
70                },
71                ExprKind::FloatLit(f) => ExprKind::FloatLit(-f),
72                _ => ExprKind::UnaryNeg(Box::new(e)),
73            }
74        }
75
76        // Recurse into sub-expressions
77        ExprKind::App(f, arg) => ExprKind::App(Box::new(fold_expr(*f)), Box::new(fold_expr(*arg))),
78        ExprKind::Fn(pat, body) => ExprKind::Fn(pat, Box::new(fold_expr(*body))),
79        ExprKind::Let(decls, body) => ExprKind::Let(
80            decls.into_iter().map(fold_decl).collect(),
81            Box::new(fold_expr(*body)),
82        ),
83        ExprKind::Case(scrutinee, arms) => ExprKind::Case(
84            Box::new(fold_expr(*scrutinee)),
85            arms.into_iter()
86                .map(|(pat, body)| (pat, fold_expr(body)))
87                .collect(),
88        ),
89        ExprKind::Tuple(elems) => ExprKind::Tuple(elems.into_iter().map(fold_expr).collect()),
90        ExprKind::Cons(hd, tl) => {
91            ExprKind::Cons(Box::new(fold_expr(*hd)), Box::new(fold_expr(*tl)))
92        }
93        ExprKind::Ann(e, ty) => ExprKind::Ann(Box::new(fold_expr(*e)), ty),
94        ExprKind::Perform(name, arg) => ExprKind::Perform(name, Box::new(fold_expr(*arg))),
95        ExprKind::Handle {
96            body,
97            return_var,
98            return_body,
99            handlers,
100        } => ExprKind::Handle {
101            body: Box::new(fold_expr(*body)),
102            return_var,
103            return_body: Box::new(fold_expr(*return_body)),
104            handlers: handlers
105                .into_iter()
106                .map(|h| EffectHandler {
107                    body: fold_expr(h.body),
108                    ..h
109                })
110                .collect(),
111        },
112        ExprKind::Resume(cont, arg) => {
113            ExprKind::Resume(Box::new(fold_expr(*cont)), Box::new(fold_expr(*arg)))
114        }
115
116        // Leaves
117        other => other,
118    };
119    Expr { kind, span }
120}
121
122fn try_fold_binop(op: BinOp, lhs: &Expr, rhs: &Expr) -> Option<ExprKind> {
123    match (op, &lhs.kind, &rhs.kind) {
124        // Int arithmetic
125        (BinOp::AddInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
126            a.checked_add(*b).map(ExprKind::IntLit)
127        }
128        (BinOp::SubInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
129            a.checked_sub(*b).map(ExprKind::IntLit)
130        }
131        (BinOp::MulInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
132            a.checked_mul(*b).map(ExprKind::IntLit)
133        }
134        (BinOp::DivInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
135            a.checked_div(*b).map(ExprKind::IntLit)
136        }
137        (BinOp::ModInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
138            a.checked_rem(*b).map(ExprKind::IntLit)
139        }
140
141        // Float arithmetic
142        (BinOp::AddFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
143            Some(ExprKind::FloatLit(a + b))
144        }
145        (BinOp::SubFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
146            Some(ExprKind::FloatLit(a - b))
147        }
148        (BinOp::MulFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
149            Some(ExprKind::FloatLit(a * b))
150        }
151        (BinOp::DivFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
152            Some(ExprKind::FloatLit(a / b))
153        }
154
155        // String concat
156        (BinOp::ConcatStr, ExprKind::StringLit(a), ExprKind::StringLit(b)) => {
157            let mut s = a.clone();
158            s.push_str(b);
159            Some(ExprKind::StringLit(s))
160        }
161
162        // Int comparison
163        (BinOp::Eq, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a == b)),
164        (BinOp::Ne, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a != b)),
165        (BinOp::LtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a < b)),
166        (BinOp::GtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a > b)),
167        (BinOp::LeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a <= b)),
168        (BinOp::GeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a >= b)),
169
170        _ => None,
171    }
172}