use crate::ast::*;
pub fn fold_program(mut program: Program) -> Program {
program.decls = program.decls.into_iter().map(fold_decl).collect();
program
}
fn fold_decl(decl: Decl) -> Decl {
let span = decl.span;
let kind = match decl.kind {
DeclKind::Val(pat, expr) => DeclKind::Val(pat, fold_expr(expr)),
DeclKind::ValRec(name, expr) => DeclKind::ValRec(name, fold_expr(expr)),
DeclKind::Fun(bindings) => DeclKind::Fun(
bindings
.into_iter()
.map(|mut b| {
b.clauses = b
.clauses
.into_iter()
.map(|mut c| {
c.body = fold_expr(c.body);
c
})
.collect();
b
})
.collect(),
),
DeclKind::Local(locals, body) => DeclKind::Local(
locals.into_iter().map(fold_decl).collect(),
body.into_iter().map(fold_decl).collect(),
),
other => other,
};
Decl { kind, span }
}
fn fold_expr(expr: Expr) -> Expr {
let span = expr.span;
let kind = match expr.kind {
ExprKind::If(cond, then_br, else_br) => {
let cond = fold_expr(*cond);
let then_br = fold_expr(*then_br);
let else_br = fold_expr(*else_br);
match &cond.kind {
ExprKind::BoolLit(true) => return then_br,
ExprKind::BoolLit(false) => return else_br,
_ => ExprKind::If(Box::new(cond), Box::new(then_br), Box::new(else_br)),
}
}
ExprKind::BinOp(op, lhs, rhs) => {
let lhs = fold_expr(*lhs);
let rhs = fold_expr(*rhs);
if let Some(result) = try_fold_binop(op, &lhs, &rhs) {
return Expr { kind: result, span };
}
ExprKind::BinOp(op, Box::new(lhs), Box::new(rhs))
}
ExprKind::UnaryNeg(e) => {
let e = fold_expr(*e);
match &e.kind {
ExprKind::IntLit(n) => match n.checked_neg() {
Some(neg) => ExprKind::IntLit(neg),
None => ExprKind::UnaryNeg(Box::new(e)),
},
ExprKind::FloatLit(f) => ExprKind::FloatLit(-f),
_ => ExprKind::UnaryNeg(Box::new(e)),
}
}
ExprKind::App(f, arg) => ExprKind::App(Box::new(fold_expr(*f)), Box::new(fold_expr(*arg))),
ExprKind::Fn(pat, body) => ExprKind::Fn(pat, Box::new(fold_expr(*body))),
ExprKind::Let(decls, body) => ExprKind::Let(
decls.into_iter().map(fold_decl).collect(),
Box::new(fold_expr(*body)),
),
ExprKind::Case(scrutinee, arms) => ExprKind::Case(
Box::new(fold_expr(*scrutinee)),
arms.into_iter()
.map(|(pat, body)| (pat, fold_expr(body)))
.collect(),
),
ExprKind::Tuple(elems) => ExprKind::Tuple(elems.into_iter().map(fold_expr).collect()),
ExprKind::Cons(hd, tl) => {
ExprKind::Cons(Box::new(fold_expr(*hd)), Box::new(fold_expr(*tl)))
}
ExprKind::Ann(e, ty) => ExprKind::Ann(Box::new(fold_expr(*e)), ty),
ExprKind::Perform(name, arg) => ExprKind::Perform(name, Box::new(fold_expr(*arg))),
ExprKind::Handle {
body,
return_var,
return_body,
handlers,
} => ExprKind::Handle {
body: Box::new(fold_expr(*body)),
return_var,
return_body: Box::new(fold_expr(*return_body)),
handlers: handlers
.into_iter()
.map(|h| EffectHandler {
body: fold_expr(h.body),
..h
})
.collect(),
},
ExprKind::Resume(cont, arg) => {
ExprKind::Resume(Box::new(fold_expr(*cont)), Box::new(fold_expr(*arg)))
}
other => other,
};
Expr { kind, span }
}
fn try_fold_binop(op: BinOp, lhs: &Expr, rhs: &Expr) -> Option<ExprKind> {
match (op, &lhs.kind, &rhs.kind) {
(BinOp::AddInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
a.checked_add(*b).map(ExprKind::IntLit)
}
(BinOp::SubInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
a.checked_sub(*b).map(ExprKind::IntLit)
}
(BinOp::MulInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => {
a.checked_mul(*b).map(ExprKind::IntLit)
}
(BinOp::DivInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
a.checked_div(*b).map(ExprKind::IntLit)
}
(BinOp::ModInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) if *b != 0 => {
a.checked_rem(*b).map(ExprKind::IntLit)
}
(BinOp::AddFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
Some(ExprKind::FloatLit(a + b))
}
(BinOp::SubFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
Some(ExprKind::FloatLit(a - b))
}
(BinOp::MulFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
Some(ExprKind::FloatLit(a * b))
}
(BinOp::DivFloat, ExprKind::FloatLit(a), ExprKind::FloatLit(b)) => {
Some(ExprKind::FloatLit(a / b))
}
(BinOp::ConcatStr, ExprKind::StringLit(a), ExprKind::StringLit(b)) => {
let mut s = a.clone();
s.push_str(b);
Some(ExprKind::StringLit(s))
}
(BinOp::Eq, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a == b)),
(BinOp::Ne, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a != b)),
(BinOp::LtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a < b)),
(BinOp::GtInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a > b)),
(BinOp::LeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a <= b)),
(BinOp::GeInt, ExprKind::IntLit(a), ExprKind::IntLit(b)) => Some(ExprKind::BoolLit(a >= b)),
_ => None,
}
}