1use crate::ast::*;
2
3pub fn fold_program(mut program: Program) -> Program {
4 program.decls = program.decls.into_iter().map(fold_decl).collect();
5 program
6}
7
8fn fold_decl(decl: Decl) -> Decl {
9 let span = decl.span;
10 let kind = match decl.kind {
11 DeclKind::Val(pat, expr) => DeclKind::Val(pat, fold_expr(expr)),
12 DeclKind::ValRec(name, expr) => DeclKind::ValRec(name, fold_expr(expr)),
13 DeclKind::Fun(bindings) => DeclKind::Fun(
14 bindings
15 .into_iter()
16 .map(|mut b| {
17 b.clauses = b
18 .clauses
19 .into_iter()
20 .map(|mut c| {
21 c.body = fold_expr(c.body);
22 c
23 })
24 .collect();
25 b
26 })
27 .collect(),
28 ),
29 DeclKind::Local(locals, body) => DeclKind::Local(
30 locals.into_iter().map(fold_decl).collect(),
31 body.into_iter().map(fold_decl).collect(),
32 ),
33 other => other,
34 };
35 Decl { kind, span }
36}
37
38fn fold_expr(expr: Expr) -> Expr {
39 let span = expr.span;
40 let kind = match expr.kind {
41 ExprKind::If(cond, then_br, else_br) => {
43 let cond = fold_expr(*cond);
44 let then_br = fold_expr(*then_br);
45 let else_br = fold_expr(*else_br);
46 match &cond.kind {
47 ExprKind::BoolLit(true) => return then_br,
48 ExprKind::BoolLit(false) => return else_br,
49 _ => ExprKind::If(Box::new(cond), Box::new(then_br), Box::new(else_br)),
50 }
51 }
52
53 ExprKind::BinOp(op, lhs, rhs) => {
55 let lhs = fold_expr(*lhs);
56 let rhs = fold_expr(*rhs);
57 if let Some(result) = try_fold_binop(op, &lhs, &rhs) {
58 return Expr { kind: result, span };
59 }
60 ExprKind::BinOp(op, Box::new(lhs), Box::new(rhs))
61 }
62
63 ExprKind::UnaryNeg(e) => {
65 let e = fold_expr(*e);
66 match &e.kind {
67 ExprKind::IntLit(n) => match n.checked_neg() {
68 Some(neg) => ExprKind::IntLit(neg),
69 None => ExprKind::UnaryNeg(Box::new(e)),
70 },
71 ExprKind::FloatLit(f) => ExprKind::FloatLit(-f),
72 _ => ExprKind::UnaryNeg(Box::new(e)),
73 }
74 }
75
76 ExprKind::App(f, arg) => ExprKind::App(Box::new(fold_expr(*f)), Box::new(fold_expr(*arg))),
78 ExprKind::Fn(pat, body) => ExprKind::Fn(pat, Box::new(fold_expr(*body))),
79 ExprKind::Let(decls, body) => ExprKind::Let(
80 decls.into_iter().map(fold_decl).collect(),
81 Box::new(fold_expr(*body)),
82 ),
83 ExprKind::Case(scrutinee, arms) => ExprKind::Case(
84 Box::new(fold_expr(*scrutinee)),
85 arms.into_iter()
86 .map(|(pat, body)| (pat, fold_expr(body)))
87 .collect(),
88 ),
89 ExprKind::Tuple(elems) => ExprKind::Tuple(elems.into_iter().map(fold_expr).collect()),
90 ExprKind::Cons(hd, tl) => {
91 ExprKind::Cons(Box::new(fold_expr(*hd)), Box::new(fold_expr(*tl)))
92 }
93 ExprKind::Ann(e, ty) => ExprKind::Ann(Box::new(fold_expr(*e)), ty),
94 ExprKind::Perform(name, arg) => ExprKind::Perform(name, Box::new(fold_expr(*arg))),
95 ExprKind::Handle {
96 body,
97 return_var,
98 return_body,
99 handlers,
100 } => ExprKind::Handle {
101 body: Box::new(fold_expr(*body)),
102 return_var,
103 return_body: Box::new(fold_expr(*return_body)),
104 handlers: handlers
105 .into_iter()
106 .map(|h| EffectHandler {
107 body: fold_expr(h.body),
108 ..h
109 })
110 .collect(),
111 },
112 ExprKind::Resume(cont, arg) => {
113 ExprKind::Resume(Box::new(fold_expr(*cont)), Box::new(fold_expr(*arg)))
114 }
115
116 other => other,
118 };
119 Expr { kind, span }
120}
121
122fn try_fold_binop(op: BinOp, lhs: &Expr, rhs: &Expr) -> Option<ExprKind> {
123 match (op, &lhs.kind, &rhs.kind) {
124 (BinOp::AddInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
126 a.checked_add(*b).map(ExprKind::IntLit)
127 }
128 (BinOp::SubInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
129 a.checked_sub(*b).map(ExprKind::IntLit)
130 }
131 (BinOp::MulInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
132 a.checked_mul(*b).map(ExprKind::IntLit)
133 }
134 (BinOp::DivInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
135 a.checked_div(*b).map(ExprKind::IntLit)
136 }
137 (BinOp::ModInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
138 a.checked_rem(*b).map(ExprKind::IntLit)
139 }
140
141 (BinOp::AddFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
143 Some(ExprKind::FloatLit(a + b))
144 }
145 (BinOp::SubFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
146 Some(ExprKind::FloatLit(a - b))
147 }
148 (BinOp::MulFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
149 Some(ExprKind::FloatLit(a * b))
150 }
151 (BinOp::DivFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
152 Some(ExprKind::FloatLit(a / b))
153 }
154
155 (BinOp::ConcatStr, ExprKind::StringLit(a), ExprKind::StringLit(b)) => {
157 let mut s = a.clone();
158 s.push_str(b);
159 Some(ExprKind::StringLit(s))
160 }
161
162 (BinOp::Eq, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a == b)),
164 (BinOp::Ne, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a != b)),
165 (BinOp::LtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a < b)),
166 (BinOp::GtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a > b)),
167 (BinOp::LeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a <= b)),
168 (BinOp::GeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a >= b)),
169
170 _ => None,
171 }
172}