use super::*;
use crate::array::ArrayImpl;
use crate::types::DataValue;
#[rustfmt::skip]
pub fn rules() -> Vec<Rewrite> { vec![
rw!("add-zero"; "(+ ?a 0)" => "?a"),
rw!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rw!("add-assoc"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rw!("add-same"; "(+ ?a ?a)" => "(* ?a 2)"),
rw!("add-neg"; "(+ ?a (- ?b))" => "(- ?a ?b)"),
rw!("mul-zero"; "(* ?a 0)" => "0"),
rw!("mul-one"; "(* ?a 1)" => "?a"),
rw!("mul-minus"; "(* ?a -1)" => "(- ?a)"),
rw!("mul-comm"; "(* ?a ?b)" => "(* ?b ?a)"),
rw!("mul-assoc"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
rw!("neg-neg"; "(- (- ?a))" => "?a"),
rw!("neg-sub"; "(- (- ?a ?b))" => "(- ?b ?a)"),
rw!("sub-zero"; "(- ?a 0)" => "?a"),
rw!("zero-sub"; "(- 0 ?a)" => "(- ?a)"),
rw!("sub-cancel"; "(- ?a ?a)" => "0"),
rw!("div-cancel"; "(/ ?a ?a)" => "1" if is_not_zero("?a")),
rw!("mul-add-distri"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rw!("mul-add-factor"; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
rw!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1" if is_not_zero("?x")),
rw!("eq-eq"; "(= ?a ?a)" => "true"),
rw!("ne-eq"; "(<> ?a ?a)" => "false"),
rw!("gt-eq"; "(> ?a ?a)" => "false"),
rw!("lt-eq"; "(< ?a ?a)" => "false"),
rw!("ge-eq"; "(>= ?a ?a)" => "true"),
rw!("le-eq"; "(<= ?a ?a)" => "true"),
rw!("eq-comm"; "(= ?a ?b)" => "(= ?b ?a)"),
rw!("ne-comm"; "(<> ?a ?b)" => "(<> ?b ?a)"),
rw!("gt-comm"; "(> ?a ?b)" => "(< ?b ?a)"),
rw!("lt-comm"; "(< ?a ?b)" => "(> ?b ?a)"),
rw!("ge-comm"; "(>= ?a ?b)" => "(<= ?b ?a)"),
rw!("le-comm"; "(<= ?a ?b)" => "(>= ?b ?a)"),
rw!("eq-add"; "(= (+ ?a ?b) ?c)" => "(= ?a (- ?c ?b))"),
rw!("ne-add"; "(<> (+ ?a ?b) ?c)" => "(<> ?a (- ?c ?b))"),
rw!("gt-add"; "(> (+ ?a ?b) ?c)" => "(> ?a (- ?c ?b))"),
rw!("lt-add"; "(< (+ ?a ?b) ?c)" => "(< ?a (- ?c ?b))"),
rw!("ge-add"; "(>= (+ ?a ?b) ?c)" => "(>= ?a (- ?c ?b))"),
rw!("le-add"; "(<= (+ ?a ?b) ?c)" => "(<= ?a (- ?c ?b))"),
rw!("eq-trans"; "(and (= ?a ?b) (= ?b ?c))" => "(and (= ?a ?b) (= ?a ?c))"),
rw!("not-eq"; "(not (= ?a ?b))" => "(<> ?a ?b)"),
rw!("not-ne"; "(not (<> ?a ?b))" => "(= ?a ?b)"),
rw!("not-gt"; "(not (> ?a ?b))" => "(<= ?a ?b)"),
rw!("not-ge"; "(not (>= ?a ?b))" => "(< ?a ?b)"),
rw!("not-lt"; "(not (< ?a ?b))" => "(>= ?a ?b)"),
rw!("not-le"; "(not (<= ?a ?b))" => "(> ?a ?b)"),
rw!("not-and"; "(not (and ?a ?b))" => "(or (not ?a) (not ?b))"),
rw!("not-or"; "(not (or ?a ?b))" => "(and (not ?a) (not ?b))"),
rw!("not-not"; "(not (not ?a))" => "?a"),
rw!("and-false"; "(and false ?a)" => "false"),
rw!("and-true"; "(and true ?a)" => "?a"),
rw!("and-null"; "(and null ?a)" => "null"),
rw!("and-comm"; "(and ?a ?b)" => "(and ?b ?a)"),
rw!("and-assoc"; "(and ?a (and ?b ?c))" => "(and (and ?a ?b) ?c)"),
rw!("or-false"; "(or false ?a)" => "?a"),
rw!("or-true"; "(or true ?a)" => "true"),
rw!("or-null"; "(or null ?a)" => "null"),
rw!("or-comm"; "(or ?a ?b)" => "(or ?b ?a)"),
rw!("or-assoc"; "(or ?a (or ?b ?c))" => "(or (or ?a ?b) ?c)"),
rw!("if-false"; "(if false ?then ?else)" => "?else"),
rw!("if-true"; "(if true ?then ?else)" => "?then"),
rw!("if-not"; "(if (not ?cond) ?then ?else)" => "(if ?cond ?else ?then)"),
rw!("avg"; "(avg ?a)" => "(/ (sum ?a) (count ?a))"),
]}
pub type ConstValue = Option<DataValue>;
pub fn eval_constant(egraph: &EGraph, enode: &Expr) -> ConstValue {
use Expr::*;
let x = |i: Id| egraph[i].data.constant.as_ref();
if let Constant(v) = enode {
Some(v.clone())
} else if let Nested(e) = enode {
Some(x(*e)?.clone())
} else if let Some((op, a, b)) = enode.binary_op() {
let (a, b) = (x(a)?, x(b)?);
if a.is_null() || b.is_null() {
return Some(DataValue::Null);
}
let array_a = ArrayImpl::from(a);
let array_b = ArrayImpl::from(b);
Some(array_a.binary_op(&op, &array_b).ok()?.get(0))
} else if let Some((op, a)) = enode.unary_op() {
let a = x(a)?;
if a.is_null() {
return Some(DataValue::Null);
}
let array_a = ArrayImpl::from(a);
Some(array_a.unary_op(&op).ok()?.get(0))
} else if let &IsNull(a) = enode {
Some(DataValue::Bool(x(a)?.is_null()))
} else if let &Cast([ty, a]) = enode {
let a = x(a)?;
if a.is_null() {
return Some(DataValue::Null);
}
let ty = egraph[ty].nodes[0].as_type();
a.cast(ty).ok()
} else if let &Max(a) | &Min(a) | &Avg(a) | &First(a) | &Last(a) = enode {
x(a).cloned()
} else {
None
}
}
pub fn union_constant(egraph: &mut EGraph, id: Id) {
if let Some(val) = &egraph[id].data.constant {
let added = egraph.add(Expr::Constant(val.clone()));
egraph.union(id, added);
egraph[id].nodes.retain(|n| n.is_leaf());
}
}
fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
value_is(var, |v| !v.is_zero())
}
fn value_is(v: &str, f: impl Fn(&DataValue) -> bool) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let v = var(v);
move |egraph, _, subst| {
if let Some(n) = &egraph[subst[v]].data.constant {
f(n)
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::rules;
egg::test_fn! {
and_eq_const,
rules(),
"(and (= a 1) (= a b))" => "(and (= a 1) (= b 1))",
}
egg::test_fn! {
#[cfg_attr(feature = "simd", ignore)] constant_folding,
rules(),
"(* (- (+ 1 2) 4) (/ 6 2))" => "-3",
}
egg::test_fn! {
constant_moving,
rules(),
"(> (+ 100 a) 300)" => "(> a 200)",
}
egg::test_fn! {
constant_type_cast,
rules(),
"(cast BOOLEAN 1)" => "true",
}
}