use crate::ast::Ast;
pub(crate) fn optimize(ast: Ast) -> Ast {
fixpoint(ast, 2) }
fn fixpoint(mut ast: Ast, max_rounds: usize) -> Ast {
for _ in 0..max_rounds {
let curr_dbg = ast.clone();
let next = fold(ast);
if next == curr_dbg {
return next;
}
ast = next;
}
ast
}
fn fold(ast: Ast) -> Ast {
match ast {
Ast::Num(_) | Ast::Var(_) => ast,
Ast::Neg(x) => {
let x = Box::new(fold(*x));
match *x {
Ast::Num(v) => Ast::Num(-v),
Ast::Neg(inner) => *inner,
other => Ast::Neg(Box::new(other)),
}
}
Ast::Not(x) => {
let x = fold(*x);
match x {
Ast::Num(v) => Ast::Num((v == 0.0) as i32 as f64),
other => Ast::Not(Box::new(other)),
}
}
Ast::Add(a, b) => fold_add(*a, *b),
Ast::Sub(a, b) => {
let a = fold(*a);
let b = fold(*b);
fold_add(a, Ast::Neg(Box::new(b)))
}
Ast::Mul(a, b) => fold_mul(*a, *b),
Ast::Div(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x / y),
(x, Ast::Num(1.0)) => x,
(x, y) => Ast::Div(Box::new(x), Box::new(y)),
}
}
Ast::Pow(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x.powf(y)),
(x, Ast::Num(1.0)) => x,
(Ast::Num(1.0), _) => Ast::Num(1.0),
(x, y) => Ast::Pow(Box::new(x), Box::new(y)),
}
}
Ast::Eq(a, b) => cmp_fold(
*a,
*b,
|x, y| (x == y) as i32 as f64,
|a, b| Ast::Eq(Box::new(a), Box::new(b)),
),
Ast::Ne(a, b) => cmp_fold(
*a,
*b,
|x, y| (x != y) as i32 as f64,
|a, b| Ast::Ne(Box::new(a), Box::new(b)),
),
Ast::Lt(a, b) => cmp_fold(
*a,
*b,
|x, y| (x < y) as i32 as f64,
|a, b| Ast::Lt(Box::new(a), Box::new(b)),
),
Ast::Le(a, b) => cmp_fold(
*a,
*b,
|x, y| (x <= y) as i32 as f64,
|a, b| Ast::Le(Box::new(a), Box::new(b)),
),
Ast::Gt(a, b) => cmp_fold(
*a,
*b,
|x, y| (x > y) as i32 as f64,
|a, b| Ast::Gt(Box::new(a), Box::new(b)),
),
Ast::Ge(a, b) => cmp_fold(
*a,
*b,
|x, y| (x >= y) as i32 as f64,
|a, b| Ast::Ge(Box::new(a), Box::new(b)),
),
Ast::And(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(((x != 0.0) && (y != 0.0)) as i32 as f64),
(Ast::Num(0.0), _) => Ast::Num(0.0),
(_, Ast::Num(0.0)) => Ast::Num(0.0),
(x, y) => Ast::And(Box::new(x), Box::new(y)),
}
}
Ast::Or(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(((x != 0.0) || (y != 0.0)) as i32 as f64),
(Ast::Num(x), _) if x != 0.0 => Ast::Num(1.0),
(_, Ast::Num(y)) if y != 0.0 => Ast::Num(1.0),
(x, y) => Ast::Or(Box::new(x), Box::new(y)),
}
}
Ast::If(c, t, e) => {
let c = fold(*c);
let t = fold(*t);
let e = fold(*e);
match c {
Ast::Num(x) if x >= 1.0 => t,
Ast::Num(0.0) => e,
c => Ast::If(Box::new(c), Box::new(t), Box::new(e)),
}
}
Ast::Ifs(args) => {
let mut current_args: Vec<Ast> = args.into_iter().map(|a| fold(*a)).collect();
loop {
if current_args.len() < 3 {
break;
}
match ¤t_args[0] {
Ast::Num(x) if *x >= 1.0 => {
return current_args.remove(1);
}
Ast::Num(x) if *x == 0.0 => {
current_args.drain(0..2);
continue;
}
_ => {
break;
}
}
}
if current_args.len() == 1 {
current_args.remove(0)
} else {
Ast::Ifs(current_args.into_iter().map(Box::new).collect())
}
}
Ast::Max(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x.max(y)),
(x, y) => Ast::Max(Box::new(x), Box::new(y)),
}
}
Ast::Min(a, b) => {
let a = fold(*a);
let b = fold(*b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x.min(y)),
(x, y) => Ast::Min(Box::new(x), Box::new(y)),
}
}
Ast::Call { name, args } => {
let args = args.into_iter().map(fold).collect();
Ast::Call { name, args }
}
}
}
fn cmp_fold<F, R>(a: Ast, b: Ast, eval: F, rebuild: R) -> Ast
where
F: Fn(f64, f64) -> f64,
R: Fn(Ast, Ast) -> Ast,
{
let a = fold(a);
let b = fold(b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(eval(x, y)),
(x, y) => rebuild(x, y),
}
}
fn fold_add(a: Ast, b: Ast) -> Ast {
let a = fold(a);
let b = fold(b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x + y),
(x, Ast::Num(0.0)) => x,
(Ast::Num(0.0), y) => y,
(x, y) => Ast::Add(Box::new(x), Box::new(y)),
}
}
fn fold_mul(a: Ast, b: Ast) -> Ast {
let a = fold(a);
let b = fold(b);
match (a, b) {
(Ast::Num(x), Ast::Num(y)) => Ast::Num(x * y),
(x, Ast::Num(1.0)) => x,
(Ast::Num(1.0), y) => y,
(x, y) => Ast::Mul(Box::new(x), Box::new(y)),
}
}