use super::CoefficientField;
use super::coeff_expr::{CoeffExpr, CoeffExprKind};
use abstalg::{
AbelianGroup, CommuntativeMonoid, Domain, Field, Monoid, SemiRing, Semigroup, UnitaryRing,
};
use egg::{
AstSize, EGraph, Extractor, Id, PatternAst, RecExpr, Rewrite, Runner, Subst, Symbol, Var,
define_language, rewrite,
};
use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::{One, ToPrimitive, Zero};
use quote::ToTokens;
use std::{fmt, str::FromStr};
use syn::{Expr as SynExpr, Type, parse_quote};
define_language! {
pub enum EggScalarLang {
Rational(BigRational),
Symbol(Symbol),
Opaque(Symbol),
"neg" = Neg(Id),
"+" = Add([Id; 2]),
"*" = Mul([Id; 2]),
"/" = Div([Id; 2]),
"pow" = Pow([Id; 2]),
"mul_add" = MulAdd([Id; 3]),
"sin" = Sin(Id),
"cos" = Cos(Id),
"tan" = Tan(Id),
"sinh" = Sinh(Id),
"cosh" = Cosh(Id),
"tanh" = Tanh(Id),
"asin" = Asin(Id),
"acos" = Acos(Id),
"atan" = Atan(Id),
"asinh" = Asinh(Id),
"acosh" = Acosh(Id),
"atanh" = Atanh(Id),
"exp" = Exp(Id),
"ln" = Ln(Id),
}
}
#[allow(dead_code)]
#[derive(Clone, Debug)]
pub struct EggExpr {
coeff: CoeffExpr,
scalar_ty: Type,
}
impl fmt::Display for EggExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let syn_expr = CoeffExpr::to_syn(&self.coeff, Some(&self.scalar_ty));
write!(f, "{}", syn_expr.to_token_stream())
}
}
#[allow(dead_code)]
#[derive(Clone, Debug)]
pub struct EggField {
pub scalar_ty: Type,
pub rules: Vec<Rewrite<EggScalarLang, ()>>,
}
#[derive(Clone, Debug)]
pub struct RewritePass {
pub name: &'static str,
pub rules: Vec<Rewrite<EggScalarLang, ()>>,
}
impl RewritePass {
pub fn new(name: &'static str, rules: Vec<Rewrite<EggScalarLang, ()>>) -> Self {
Self { name, rules }
}
pub fn rule_count(&self) -> usize {
self.rules.len()
}
}
impl fmt::Display for RewritePass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RewritePass({}, {} rules)", self.name, self.rule_count())
}
}
fn flatten_passes(passes: &[RewritePass]) -> Vec<Rewrite<EggScalarLang, ()>> {
passes
.iter()
.flat_map(|pass| pass.rules.iter().cloned())
.collect()
}
pub fn default_rewrite_passes() -> Vec<RewritePass> {
vec![
algebraic_core_pass(),
numeric_folding_pass(),
structural_division_pass(),
pow_rules_pass(),
exp_ln_pass(),
trig_hyperbolic_pass(),
inverse_trig_pass(),
mul_add_pass(),
]
}
pub fn aggressive_rewrite_passes() -> Vec<RewritePass> {
vec![
algebraic_core_pass(),
numeric_folding_pass(),
structural_division_pass(),
pow_rules_pass(),
exp_ln_pass(),
trig_hyperbolic_pass(),
inverse_trig_pass(),
mul_add_pass(),
]
}
impl EggField {
pub fn new(scalar_ty: Type) -> Self {
Self::new_with_passes(scalar_ty, default_rewrite_passes())
}
pub fn new_with_passes(scalar_ty: Type, passes: Vec<RewritePass>) -> Self {
let rules = flatten_passes(&passes);
Self::new_with_rules(scalar_ty, rules)
}
pub fn new_with_rules(scalar_ty: Type, rules: Vec<Rewrite<EggScalarLang, ()>>) -> Self {
Self { scalar_ty, rules }
}
fn wrap_coeff(&self, coeff: CoeffExpr) -> EggExpr {
EggExpr {
coeff: self.simplify_with_egg(coeff),
scalar_ty: self.scalar_ty.clone(),
}
}
pub fn wrap_expr(&self, expr: SynExpr) -> EggExpr {
let coeff = CoeffExpr::from_syn(&expr);
let mut wrapped = self.wrap_coeff(coeff);
wrapped.coeff = self.simplify_with_egg(wrapped.coeff);
wrapped
}
pub fn wrap_numeric_expr(&self, expr: SynExpr) -> EggExpr {
self.wrap_expr(expr)
}
fn zero_literal(&self) -> EggExpr {
self.wrap_coeff(CoeffExpr::zero())
}
fn one_literal(&self) -> EggExpr {
self.wrap_coeff(CoeffExpr::one())
}
fn scale_by_usize(&self, num: usize, expr: &EggExpr) -> EggExpr {
match num {
0 => self.zero_literal(),
1 => expr.clone(),
_ => {
let factors = vec![
CoeffExpr::literal(BigRational::from_integer(BigInt::from(num as i128))),
expr.coeff.clone(),
];
self.wrap_coeff(CoeffExpr::new(CoeffExprKind::Product(factors)))
}
}
}
fn scale_by_isize(&self, num: isize, expr: &EggExpr) -> EggExpr {
match num {
0 => self.zero_literal(),
1 => expr.clone(),
-1 => self.wrap_coeff(CoeffExpr::neg(expr.coeff.clone())),
_ => {
let factors = vec![
CoeffExpr::literal(BigRational::from_integer(BigInt::from(num as i128))),
expr.coeff.clone(),
];
self.wrap_coeff(CoeffExpr::new(CoeffExprKind::Product(factors)))
}
}
}
fn simplify_with_egg(&self, coeff: CoeffExpr) -> CoeffExpr {
let rec = coeff_to_rec(&coeff);
let runner = Runner::default().with_expr(&rec).run(&self.rules);
let root = runner
.roots
.first()
.copied()
.expect("egg runner should retain root for initial expression");
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_cost, best) = extractor.find_best(root);
rec_to_coeff(&best)
}
}
impl Domain for EggField {
type Elem = EggExpr;
fn equals(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> bool {
elem1.coeff == elem2.coeff
}
fn contains(&self, _elem: &Self::Elem) -> bool {
true
}
}
impl CommuntativeMonoid for EggField {
fn zero(&self) -> Self::Elem {
self.zero_literal()
}
fn add(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let coeff = CoeffExpr::new(CoeffExprKind::Sum(vec![
elem1.coeff.clone(),
elem2.coeff.clone(),
]));
self.wrap_coeff(coeff)
}
fn is_zero(&self, elem: &Self::Elem) -> bool {
elem.coeff.is_zero()
}
fn add_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.add(elem1, elem2);
}
fn double(&self, elem: &mut Self::Elem) {
*elem = self.add(elem, elem);
}
fn times(&self, num: usize, elem: &Self::Elem) -> Self::Elem {
self.scale_by_usize(num, elem)
}
}
impl AbelianGroup for EggField {
fn neg(&self, elem: &Self::Elem) -> Self::Elem {
self.wrap_coeff(CoeffExpr::neg(elem.coeff.clone()))
}
fn neg_assign(&self, elem: &mut Self::Elem) {
*elem = self.neg(elem);
}
fn sub(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let coeff = CoeffExpr::new(CoeffExprKind::Sum(vec![
elem1.coeff.clone(),
CoeffExpr::neg(elem2.coeff.clone()),
]));
self.wrap_coeff(coeff)
}
fn sub_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.sub(elem1, elem2);
}
fn times(&self, num: isize, elem: &Self::Elem) -> Self::Elem {
self.scale_by_isize(num, elem)
}
}
impl Semigroup for EggField {
fn mul(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
let coeff = CoeffExpr::new(CoeffExprKind::Product(vec![
elem1.coeff.clone(),
elem2.coeff.clone(),
]));
self.wrap_coeff(coeff)
}
fn mul_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.mul(elem1, elem2);
}
fn square(&self, elem: &mut Self::Elem) {
*elem = self.mul(elem, elem);
}
}
impl Monoid for EggField {
fn one(&self) -> Self::Elem {
self.one_literal()
}
fn try_inv(&self, elem: &Self::Elem) -> Option<Self::Elem> {
if elem.coeff.is_zero() {
None
} else {
let coeff = CoeffExpr::quotient(CoeffExpr::one(), elem.coeff.clone());
Some(self.wrap_coeff(coeff))
}
}
}
impl SemiRing for EggField {}
impl UnitaryRing for EggField {
fn int(&self, elem: isize) -> Self::Elem {
self.wrap_coeff(CoeffExpr::literal(BigRational::from_integer(BigInt::from(
elem as i128,
))))
}
}
impl Field for EggField {
fn inv(&self, elem: &Self::Elem) -> Self::Elem {
self.try_inv(elem)
.expect("attempted to invert zero element in EggField")
}
fn div(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
if elem2.coeff.is_zero() {
panic!("attempted to divide by zero element in EggField");
}
let quotient = CoeffExpr::quotient(elem1.coeff.clone(), elem2.coeff.clone());
self.wrap_coeff(quotient)
}
}
impl CoefficientField for EggField {
fn embed_expr(&self, expr: SynExpr) -> syn::Result<<Self as Domain>::Elem> {
Ok(EggField::wrap_expr(self, expr))
}
fn to_expr(&self, elem: &<Self as Domain>::Elem) -> syn::Result<SynExpr> {
Ok(CoeffExpr::to_syn(&elem.coeff, Some(&self.scalar_ty)))
}
}
impl ToTokens for EggExpr {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let syn_expr = CoeffExpr::to_syn(&self.coeff, Some(&self.scalar_ty));
syn_expr.to_tokens(tokens);
}
}
fn coeff_to_rec(expr: &CoeffExpr) -> RecExpr<EggScalarLang> {
let mut nodes = Vec::new();
coeff_to_rec_inner(expr, &mut nodes);
RecExpr::from(nodes)
}
fn coeff_to_rec_inner(expr: &CoeffExpr, nodes: &mut Vec<EggScalarLang>) -> Id {
match expr.kind() {
CoeffExprKind::Literal(lit) => literal_to_lang(lit, nodes),
CoeffExprKind::Symbol(sym) => symbol_to_lang(sym, nodes),
CoeffExprKind::Sum(terms) => fold_sum(nodes, terms),
CoeffExprKind::Product(factors) => fold_product(nodes, factors),
CoeffExprKind::Neg(inner) => {
let child = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Neg(child))
}
CoeffExprKind::Quotient(numerator, denominator) => {
let lhs = coeff_to_rec_inner(numerator, nodes);
let rhs = coeff_to_rec_inner(denominator, nodes);
let minus_one = num_rational::BigRational::from_integer(num_bigint::BigInt::from(-1));
let exp_id = push_rational(nodes, minus_one);
let recip_id = push_node(nodes, EggScalarLang::Pow([rhs, exp_id]));
push_node(nodes, EggScalarLang::Mul([lhs, recip_id]))
}
CoeffExprKind::Pow(base, exponent) => {
let base_id = coeff_to_rec_inner(base, nodes);
let exponent_literal =
CoeffExpr::literal(BigRational::from_integer(BigInt::from(*exponent as i128)));
let exponent_id = coeff_to_rec_inner(&exponent_literal, nodes);
push_node(nodes, EggScalarLang::Pow([base_id, exponent_id]))
}
CoeffExprKind::Sin(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Sin(id))
}
CoeffExprKind::Asin(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Asin(id))
}
CoeffExprKind::Cos(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Cos(id))
}
CoeffExprKind::Acos(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Acos(id))
}
CoeffExprKind::Tan(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Tan(id))
}
CoeffExprKind::Atan(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Atan(id))
}
CoeffExprKind::Sinh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Sinh(id))
}
CoeffExprKind::Asinh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Asinh(id))
}
CoeffExprKind::Cosh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Cosh(id))
}
CoeffExprKind::Acosh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Acosh(id))
}
CoeffExprKind::Tanh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Tanh(id))
}
CoeffExprKind::Atanh(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Atanh(id))
}
CoeffExprKind::Exp(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Exp(id))
}
CoeffExprKind::Ln(inner) => {
let id = coeff_to_rec_inner(inner, nodes);
push_node(nodes, EggScalarLang::Ln(id))
}
CoeffExprKind::Sqrt(inner) => {
let base_id = coeff_to_rec_inner(inner, nodes);
let half = num_rational::BigRational::new(
num_bigint::BigInt::from(1),
num_bigint::BigInt::from(2),
);
let exp_id = push_rational(nodes, half);
push_node(nodes, EggScalarLang::Pow([base_id, exp_id]))
}
CoeffExprKind::MulAdd(lhs, rhs, addend) => {
let lhs_id = coeff_to_rec_inner(lhs, nodes);
let rhs_id = coeff_to_rec_inner(rhs, nodes);
let addend_id = coeff_to_rec_inner(addend, nodes);
push_node(nodes, EggScalarLang::MulAdd([lhs_id, rhs_id, addend_id]))
}
CoeffExprKind::Wedge(_) | CoeffExprKind::Opaque(_) => fallback_opaque(expr, nodes),
}
}
fn push_node(nodes: &mut Vec<EggScalarLang>, node: EggScalarLang) -> Id {
let id = Id::from(nodes.len());
nodes.push(node);
id
}
fn push_rational(nodes: &mut Vec<EggScalarLang>, value: BigRational) -> Id {
push_node(nodes, EggScalarLang::Rational(value))
}
fn fold_sum<'a, I>(nodes: &mut Vec<EggScalarLang>, terms: I) -> Id
where
I: IntoIterator<Item = &'a CoeffExpr>,
{
let mut iter = terms.into_iter();
let first = match iter.next() {
Some(expr) => coeff_to_rec_inner(expr, nodes),
None => return push_rational(nodes, BigRational::zero()),
};
iter.fold(first, |acc, expr| {
let rhs = coeff_to_rec_inner(expr, nodes);
push_node(nodes, EggScalarLang::Add([acc, rhs]))
})
}
fn fold_product<'a, I>(nodes: &mut Vec<EggScalarLang>, factors: I) -> Id
where
I: IntoIterator<Item = &'a CoeffExpr>,
{
let mut iter = factors.into_iter();
let first = match iter.next() {
Some(expr) => coeff_to_rec_inner(expr, nodes),
None => return push_rational(nodes, BigRational::one()),
};
iter.fold(first, |acc, expr| {
let rhs = coeff_to_rec_inner(expr, nodes);
push_node(nodes, EggScalarLang::Mul([acc, rhs]))
})
}
fn fallback_opaque(expr: &CoeffExpr, nodes: &mut Vec<EggScalarLang>) -> Id {
let syn_expr = CoeffExpr::to_syn(expr, None);
let repr = syn_expr.to_token_stream().to_string();
let sym = Symbol::from(repr.as_str());
push_node(nodes, EggScalarLang::Opaque(sym))
}
fn literal_to_lang(lit: &BigRational, nodes: &mut Vec<EggScalarLang>) -> Id {
push_rational(nodes, lit.clone())
}
fn symbol_to_lang(sym: &str, nodes: &mut Vec<EggScalarLang>) -> Id {
let symbol = Symbol::from(sym);
push_node(nodes, EggScalarLang::Symbol(symbol))
}
fn rec_to_coeff(rec: &RecExpr<EggScalarLang>) -> CoeffExpr {
let nodes = rec.as_ref();
if nodes.is_empty() {
return CoeffExpr::zero();
}
let mut cache: Vec<Option<CoeffExpr>> = vec![None; nodes.len()];
fn decode(i: usize, nodes: &[EggScalarLang], cache: &mut [Option<CoeffExpr>]) -> CoeffExpr {
if let Some(ref v) = cache[i] {
return v.clone();
}
let node = &nodes[i];
let expr = match node {
EggScalarLang::Rational(r) => CoeffExpr::literal(r.clone()),
EggScalarLang::Symbol(sym) => CoeffExpr::symbol(sym.to_string()),
EggScalarLang::Opaque(sym) => {
let repr = sym.to_string();
match syn::parse_str::<SynExpr>(&repr) {
Ok(expr) => CoeffExpr::opaque(expr),
Err(_) => CoeffExpr::symbol(repr),
}
}
EggScalarLang::Neg(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: NEG child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::neg(inner)
}
EggScalarLang::Sin(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: SIN child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::sin(inner)
}
EggScalarLang::Asin(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ASIN child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::asin(inner)
}
EggScalarLang::Cos(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: COS child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::cos(inner)
}
EggScalarLang::Acos(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ACOS child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::acos(inner)
}
EggScalarLang::Tan(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: TAN child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::tan(inner)
}
EggScalarLang::Atan(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ATAN child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::atan(inner)
}
EggScalarLang::Sinh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: SINH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::sinh(inner)
}
EggScalarLang::Asinh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ASINH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::asinh(inner)
}
EggScalarLang::Cosh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: COSH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::cosh(inner)
}
EggScalarLang::Acosh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ACOSH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::acosh(inner)
}
EggScalarLang::Tanh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: TANH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::tanh(inner)
}
EggScalarLang::Atanh(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ATANH child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::atanh(inner)
}
EggScalarLang::Exp(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: EXP child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::exp(inner)
}
EggScalarLang::Ln(child) => {
let child_idx = usize::from(*child);
if child_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: LN child_idx {} out of range (len={})",
child_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let inner = decode(child_idx, nodes, cache);
CoeffExpr::ln(inner)
}
EggScalarLang::Add([lhs, rhs]) => {
let lhs_idx = usize::from(*lhs);
let rhs_idx = usize::from(*rhs);
if lhs_idx >= nodes.len() || rhs_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: ADD child indices out of range lhs={} rhs={} len={}",
lhs_idx,
rhs_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let left = decode(lhs_idx, nodes, cache);
let right = decode(rhs_idx, nodes, cache);
CoeffExpr::sum(vec![left, right])
}
EggScalarLang::MulAdd([lhs, rhs, addend]) => {
let lhs_idx = usize::from(*lhs);
let rhs_idx = usize::from(*rhs);
let add_idx = usize::from(*addend);
if lhs_idx >= nodes.len() || rhs_idx >= nodes.len() || add_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: MULADD child indices out of range lhs={} rhs={} add={} len={}",
lhs_idx,
rhs_idx,
add_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let left = decode(lhs_idx, nodes, cache);
let right = decode(rhs_idx, nodes, cache);
let add = decode(add_idx, nodes, cache);
CoeffExpr::mul_add(left, right, add)
}
EggScalarLang::Pow([lhs, rhs]) => {
let lhs_idx = usize::from(*lhs);
let rhs_idx = usize::from(*rhs);
if lhs_idx >= nodes.len() || rhs_idx >= nodes.len() {
eprintln!(
"rec_to_coeff: POW child indices out of range lhs={} rhs={} len={}",
lhs_idx,
rhs_idx,
nodes.len()
);
return CoeffExpr::opaque(parse_quote!(0));
}
let base = decode(lhs_idx, nodes, cache);
let exponent = decode(rhs_idx, nodes, cache);
if let CoeffExprKind::Literal(r) = exponent.kind() {
if r.denom().is_one() {
if let Some(i) = r.numer().to_i32() {
if i == 2 {
if let CoeffExprKind::Sqrt(inner) = base.kind() {
return (**inner).clone();
}
return CoeffExpr::pow(base, 2);
}
if i > 0 {
return CoeffExpr::pow(base, i as u32);
} else {
let abs = (-i) as u32;
if abs == 1 {
return CoeffExpr::quotient(CoeffExpr::one(), base);
} else {
let positive = CoeffExpr::pow(base.clone(), abs);
return CoeffExpr::quotient(CoeffExpr::one(), positive);
}
}
}
} else if is_rational_half(r) {
if let CoeffExprKind::Pow(inner_base, exp) = base.kind() {
if *exp == 2 {
return (**inner_base).clone();
}
}
return CoeffExpr::sqrt(base);
}
}
pow_fallback(base, exponent)
}
EggScalarLang::Mul([lhs, rhs]) => {
let left = decode(usize::from(*lhs), nodes, cache);
let right = decode(usize::from(*rhs), nodes, cache);
CoeffExpr::product(vec![left, right])
}
EggScalarLang::Div([lhs, rhs]) => {
let left = decode(usize::from(*lhs), nodes, cache);
let right = decode(usize::from(*rhs), nodes, cache);
CoeffExpr::quotient(left, right)
}
};
cache[i] = Some(expr.clone());
expr
}
let root_idx = nodes.len() - 1;
decode(root_idx, nodes, &mut cache)
}
fn pow_fallback(base: CoeffExpr, exponent: CoeffExpr) -> CoeffExpr {
let base_syn = CoeffExpr::to_syn(&base, None);
let exp_syn = CoeffExpr::to_syn(&exponent, None);
let expr: SynExpr = parse_quote! { #base_syn.powf(#exp_syn) };
CoeffExpr::opaque(expr)
}
fn is_rational_half(value: &BigRational) -> bool {
value.numer() == &BigInt::from(1) && value.denom() == &BigInt::from(2)
}
fn extract_rational(egraph: &EGraph<EggScalarLang, ()>, id: Id) -> Option<BigRational> {
egraph[id].nodes.iter().find_map(|node| match node {
EggScalarLang::Rational(value) => Some(value.clone()),
_ => None,
})
}
#[derive(Clone, Copy)]
struct FoldAddRational {
a: Var,
b: Var,
}
impl Default for FoldAddRational {
fn default() -> Self {
Self {
a: Var::from_str("?a").expect("valid var"),
b: Var::from_str("?b").expect("valid var"),
}
}
}
impl egg::Applier<EggScalarLang, ()> for FoldAddRational {
fn apply_one(
&self,
egraph: &mut EGraph<EggScalarLang, ()>,
eclass: Id,
subst: &Subst,
_pattern: Option<&PatternAst<EggScalarLang>>,
_rule_name: Symbol,
) -> Vec<Id> {
let lhs_id = subst[self.a];
let rhs_id = subst[self.b];
if let (Some(lhs), Some(rhs)) = (
extract_rational(egraph, lhs_id),
extract_rational(egraph, rhs_id),
) {
let result = lhs + rhs;
let new_id = egraph.add(EggScalarLang::Rational(result));
egraph.union(eclass, new_id);
vec![new_id]
} else {
Vec::new()
}
}
fn vars(&self) -> Vec<Var> {
vec![self.a, self.b]
}
}
#[derive(Clone, Copy)]
struct FoldMulRational {
a: Var,
b: Var,
}
impl Default for FoldMulRational {
fn default() -> Self {
Self {
a: Var::from_str("?a").expect("valid var"),
b: Var::from_str("?b").expect("valid var"),
}
}
}
impl egg::Applier<EggScalarLang, ()> for FoldMulRational {
fn apply_one(
&self,
egraph: &mut EGraph<EggScalarLang, ()>,
eclass: Id,
subst: &Subst,
_pattern: Option<&PatternAst<EggScalarLang>>,
_rule_name: Symbol,
) -> Vec<Id> {
let lhs_id = subst[self.a];
let rhs_id = subst[self.b];
if let (Some(lhs), Some(rhs)) = (
extract_rational(egraph, lhs_id),
extract_rational(egraph, rhs_id),
) {
let result = lhs * rhs;
let new_id = egraph.add(EggScalarLang::Rational(result));
egraph.union(eclass, new_id);
vec![new_id]
} else {
Vec::new()
}
}
fn vars(&self) -> Vec<Var> {
vec![self.a, self.b]
}
}
pub fn scalar_rewrites() -> Vec<Rewrite<EggScalarLang, ()>> {
flatten_passes(&default_rewrite_passes())
}
pub fn aggressive_scalar_rewrites() -> Vec<Rewrite<EggScalarLang, ()>> {
flatten_passes(&aggressive_rewrite_passes())
}
fn algebraic_core_pass() -> RewritePass {
RewritePass::new(
"algebraic-core",
vec![
rewrite!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rewrite!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
rewrite!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
rewrite!("zero-add"; "(+ ?a 0)" => "?a"),
rewrite!("zero-mul"; "(* ?a 0)" => "0"),
rewrite!("one-mul"; "(* ?a 1)" => "?a"),
rewrite!("neg-neg"; "(neg (neg ?a))" => "?a"),
rewrite!("cancel-add-neg"; "(+ ?a (neg ?a))" => "0"),
rewrite!("mul-neg-one"; "(* ?a -1)" => "(neg ?a)"),
rewrite!("distrib-mul-add"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
],
)
}
fn numeric_folding_pass() -> RewritePass {
RewritePass::new(
"numeric-folding",
vec![
rewrite!("add-fold-rational"; "(+ ?a ?b)" => { FoldAddRational::default() }),
rewrite!("mul-fold-rational"; "(* ?a ?b)" => { FoldMulRational::default() }),
],
)
}
fn structural_division_pass() -> RewritePass {
RewritePass::new(
"division-structural",
vec![
rewrite!("mul-to-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)"),
rewrite!("neg-div-neg"; "(/ (neg ?x) (neg ?y))" => "(/ ?x ?y)"),
rewrite!("mul-div-cancel"; "(* ?a (/ ?b ?a))" => "?b"),
],
)
}
fn pow_rules_pass() -> RewritePass {
RewritePass::new(
"pow",
vec![
rewrite!("mul-pow-pow"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"),
rewrite!("pow-pow"; "(pow (pow ?a ?b) ?c)" => "(pow ?a (* ?b ?c))"),
rewrite!("mul-pow"; "(* (pow ?a ?b) ?a)" => "(pow ?a (+ ?b 1))"),
rewrite!("sq-pow"; "(* ?a ?a)" => "(pow ?a 2)"),
rewrite!("pow-zero"; "(pow ?a 0)" => "1"),
rewrite!("pow-one"; "(pow ?a 1)" => "?a"),
rewrite!("zero-pow"; "(pow 0 ?a)" => "0"),
rewrite!("one-pow"; "(pow 1 ?a)" => "1"),
rewrite!("pow-neg"; "(pow (neg ?a) ?b)" => { PowNegInteger::default() }),
],
)
}
#[derive(Clone, Copy)]
struct PowNegInteger {
a: Var,
n: Var,
}
impl Default for PowNegInteger {
fn default() -> Self {
Self {
a: "?a".parse().unwrap(),
n: "?b".parse().unwrap(),
}
}
}
impl egg::Applier<EggScalarLang, ()> for PowNegInteger {
fn apply_one(
&self,
egraph: &mut EGraph<EggScalarLang, ()>,
eclass: Id,
subst: &Subst,
_pattern: Option<&PatternAst<EggScalarLang>>,
_rule_name: Symbol,
) -> Vec<Id> {
let a_id = subst[self.a];
let n_id = subst[self.n];
if let Some(exp) = extract_rational(egraph, n_id) {
if exp.denom().is_one() {
let numer = exp.numer();
let two = BigInt::from(2);
if (numer % &two).is_zero() {
let new_id = egraph.add(EggScalarLang::Pow([a_id, n_id]));
egraph.union(eclass, new_id);
return vec![new_id];
} else {
let pow_id = egraph.add(EggScalarLang::Pow([a_id, n_id]));
let neg_id = egraph.add(EggScalarLang::Neg(pow_id));
egraph.union(eclass, neg_id);
return vec![neg_id];
}
}
}
Vec::new()
}
fn vars(&self) -> Vec<Var> {
vec![self.a, self.n]
}
}
fn exp_ln_pass() -> RewritePass {
RewritePass::new(
"exp-ln",
vec![
rewrite!("exp-zero"; "(exp 0)" => "1"),
rewrite!("ln-one"; "(ln 1)" => "0"),
rewrite!("exp-ln"; "(exp (ln ?a))" => "?a"),
rewrite!("ln-exp"; "(ln (exp ?a))" => "?a"),
],
)
}
fn trig_hyperbolic_pass() -> RewritePass {
RewritePass::new(
"trig-hyperbolic",
vec![
rewrite!("sin-zero"; "(sin 0)" => "0"),
rewrite!("cos-zero"; "(cos 0)" => "1"),
rewrite!("tan-zero"; "(tan 0)" => "0"),
rewrite!("sin-neg"; "(sin (neg ?a))" => "(neg (sin ?a))"),
rewrite!("cos-neg"; "(cos (neg ?a))" => "(cos ?a)"),
rewrite!("tan-neg"; "(tan (neg ?a))" => "(neg (tan ?a))"),
rewrite!("sinh-zero"; "(sinh 0)" => "0"),
rewrite!("cosh-zero"; "(cosh 0)" => "1"),
rewrite!("tanh-zero"; "(tanh 0)" => "0"),
rewrite!("sinh-neg"; "(sinh (neg ?a))" => "(neg (sinh ?a))"),
rewrite!("cosh-neg"; "(cosh (neg ?a))" => "(cosh ?a)"),
rewrite!("tanh-neg"; "(tanh (neg ?a))" => "(neg (tanh ?a))"),
rewrite!("sin2+cos2->1"; "(+ (pow (sin ?a) 2) (pow (cos ?a) 2))" => "1"),
rewrite!("cosh2-sinh2->1"; "(+ (pow (cosh ?a) 2) (neg (pow (sinh ?a) 2)))" => "1"),
rewrite!("tan->sin-div-cos"; "(tan ?a)" => "(/ (sin ?a) (cos ?a))"),
rewrite!("tanh->sinh-div-cosh"; "(tanh ?a)" => "(/ (sinh ?a) (cosh ?a))"),
],
)
}
fn inverse_trig_pass() -> RewritePass {
RewritePass::new(
"inverse-trig",
vec![
rewrite!("asin-neg"; "(asin (neg ?a))" => "(neg (asin ?a))"),
rewrite!("atan-neg"; "(atan (neg ?a))" => "(neg (atan ?a))"),
rewrite!("asinh-neg"; "(asinh (neg ?a))" => "(neg (asinh ?a))"),
rewrite!("atanh-neg"; "(atanh (neg ?a))" => "(neg (atanh ?a))"),
rewrite!("asin-zero"; "(asin 0)" => "0"),
rewrite!("acos-one"; "(acos 1)" => "0"),
rewrite!("atan-zero"; "(atan 0)" => "0"),
rewrite!("asinh-zero"; "(asinh 0)" => "0"),
rewrite!("atanh-zero"; "(atanh 0)" => "0"),
],
)
}
#[allow(dead_code)]
fn grobner_basis_pass() -> RewritePass {
RewritePass::new(
"grobner-basis",
vec![
],
)
}
fn mul_add_pass() -> RewritePass {
RewritePass::new(
"mul-add",
vec![
rewrite!("mul-add-forward"; "(+ (* ?a ?b) ?c)" => "(mul_add ?a ?b ?c)"),
rewrite!("mul-add-forward-comm"; "(+ ?c (* ?a ?b))" => "(mul_add ?a ?b ?c)"),
],
)
}
#[expect(unused)]
pub fn simplify_coeffs_globally(exprs: &[CoeffExpr]) -> Vec<CoeffExpr> {
eprintln!(
"simplify_coeffs_globally: called with {} exprs",
exprs.len()
);
for (i, e) in exprs.iter().enumerate() {
eprintln!(" expr[{}] = {}", i, e);
}
let mut nodes: Vec<EggScalarLang> = Vec::new();
let mut roots: Vec<Id> = Vec::new();
for expr in exprs.iter() {
let id = coeff_to_rec_inner(expr, &mut nodes);
roots.push(id);
}
let rec = RecExpr::from(nodes);
let runner = Runner::default().with_expr(&rec).run(&scalar_rewrites());
eprintln!(
"simplify_coeffs_globally: runner.roots.len() = {}",
runner.roots.len()
);
let extractor = Extractor::new(&runner.egraph, AstSize);
let node_to_eclass = runner
.egraph
.lookup_expr_ids(&rec)
.expect("egraph should contain the inserted RecExpr nodes");
eprintln!(
"simplify_coeffs_globally: node_to_eclass.len() = {} (nodes in rec)",
node_to_eclass.len()
);
let results: Vec<CoeffExpr> = roots
.into_iter()
.map(|rec_node_id| {
let idx = usize::from(rec_node_id);
let eclass = node_to_eclass[idx];
let (cost, best) = extractor.find_best(eclass);
eprintln!(
"simplify_coeffs_globally: extractor.find_best eclass={:?} cost={} nodes={}",
eclass,
cost,
best.as_ref().len()
);
eprintln!("simplify_coeffs_globally: extracted RecExpr=\n{:?}", best);
rec_to_coeff(&best)
})
.collect();
eprintln!(
"simplify_coeffs_globally: returning {} exprs",
results.len()
);
results
}
#[cfg(test)]
mod tests {
use super::*;
use num_bigint::BigInt;
use num_rational::BigRational;
use quote::quote;
use syn::{Expr as SynExpr, parse_quote};
fn rendered(expr: &EggExpr) -> SynExpr {
syn::parse_str(&expr.to_token_stream().to_string()).unwrap()
}
fn assert_render(expr: &EggExpr, expected: SynExpr) {
assert_eq!(rendered(expr), expected);
}
#[test]
fn egg_field_preserves_simple_arithmetic() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty.clone());
let two = field.int(2);
let five = field.int(5);
let sum = field.add(&two, &five);
let actual: SynExpr = syn::parse_str(&sum.to_token_stream().to_string()).unwrap();
let expected: SynExpr = parse_quote! { 7 as f32 };
assert_eq!(
quote! { #actual }.to_string(),
quote! { #expected }.to_string()
);
}
#[test]
fn egg_field_wrap_expr_round_trip() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let wrapped = field.wrap_expr(parse_quote! { (x + y) * 2 });
let tokens = wrapped.to_token_stream().to_string();
assert!(
tokens.contains("mul") && tokens.contains("add"),
"got: {}",
tokens
);
}
#[test]
fn egg_field_simplifies_to_zero() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { (a * b) - (a * b) });
let actual: SynExpr = syn::parse_str(&expr.to_token_stream().to_string()).unwrap();
let expected: SynExpr = parse_quote! { 0 as f32 };
assert_eq!(
quote! { #actual }.to_string(),
quote! { #expected }.to_string()
);
assert!(abstalg::CommuntativeMonoid::is_zero(&field, &expr));
}
#[test]
fn egg_field_rewrites_repeated_product_into_pow() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { x * x * x * x * x });
assert_render(&expr, parse_quote! { (x).powi(5) });
}
#[test]
fn egg_field_prefers_mul_add_for_sum_of_product() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { (x * y) + z });
assert_render(&expr, parse_quote! { (x).mul_add(y, z) });
}
#[test]
fn egg_field_prefers_mul_add_for_commuted_sum() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { z + (x * y) });
assert_render(&expr, parse_quote! { (x).mul_add(y, z) });
}
#[test]
fn egg_field_trig_exp_simplify_and_codegen() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty.clone());
let s0 = field.wrap_expr(parse_quote! { (0 as f32).sin() });
assert_render(&s0, parse_quote! { 0 as f32 });
let c0 = field.wrap_expr(parse_quote! { (0 as f32).cos() });
assert_render(&c0, parse_quote! { 1 as f32 });
let e0 = field.wrap_expr(parse_quote! { (0 as f32).exp() });
assert_render(&e0, parse_quote! { 1 as f32 });
let l1 = field.wrap_expr(parse_quote! { (1 as f32).ln() });
assert_render(&l1, parse_quote! { 0 as f32 });
let sneg = field.wrap_expr(parse_quote! { ( -x ).sin() });
let expected_sin: SynExpr = parse_quote! { -((x).sin()) };
let actual_sneg = rendered(&sneg);
let sneg_tokens = quote! { #actual_sneg }.to_string();
assert!(
sneg_tokens.contains("sin") && sneg_tokens.contains("-"),
"got: {}",
sneg_tokens
);
let sq0 = field.wrap_expr(parse_quote! { (0 as f32).sqrt() });
assert_render(&sq0, parse_quote! { 0 as f32 });
let sh0 = field.wrap_expr(parse_quote! { (0 as f32).sinh() });
assert_render(&sh0, parse_quote! { 0 as f32 });
let ch0 = field.wrap_expr(parse_quote! { (0 as f32).cosh() });
assert_render(&ch0, parse_quote! { 1 as f32 });
let th0 = field.wrap_expr(parse_quote! { (0 as f32).tanh() });
assert_render(&th0, parse_quote! { 0 as f32 });
let shneg = field.wrap_expr(parse_quote! { ( -x ).sinh() });
let actual_shneg = rendered(&shneg);
let shneg_tokens = quote! { #actual_shneg }.to_string();
assert!(
shneg_tokens.contains("sinh") && shneg_tokens.contains("-"),
"got: {}",
shneg_tokens
);
let thneg = field.wrap_expr(parse_quote! { ( -x ).tanh() });
let actual_thneg = rendered(&thneg);
let thneg_tokens = quote! { #actual_thneg }.to_string();
assert!(
thneg_tokens.contains("tanh") && thneg_tokens.contains("-"),
"got: {}",
thneg_tokens
);
}
#[test]
fn trig_pythagorean_sin_cos() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { (x).sin().powi(2) + (x).cos().powi(2) });
assert_render(&expr, parse_quote! { 1 as f32 });
}
#[test]
fn tan_to_sin_over_cos() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { (x).tan() });
let tokens = expr.to_token_stream().to_string();
assert!(
tokens.contains("tan") || (tokens.contains("sin") && tokens.contains("cos")),
"got: {}",
tokens
);
}
#[test]
fn egg_field_field_division() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let six = field.int(6);
let three = field.int(3);
let quotient = abstalg::Field::div(&field, &six, &three);
match quotient.coeff.kind() {
CoeffExprKind::Quotient(num, denom) => {
match num.kind() {
CoeffExprKind::Literal(r) => {
assert_eq!(r, &BigRational::from_integer(BigInt::from(6)));
}
other => panic!("expected literal numerator, got {other:?}"),
}
match denom.kind() {
CoeffExprKind::Literal(r) => {
assert_eq!(r, &BigRational::from_integer(BigInt::from(3)));
}
other => panic!("expected literal denominator, got {other:?}"),
}
}
other => panic!("expected quotient coefficient, got {other:?}"),
}
}
#[test]
fn egg_field_field_inverse() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty.clone());
let five = field.int(5);
let inv_five = abstalg::Field::inv(&field, &five);
match inv_five.coeff.kind() {
CoeffExprKind::Quotient(num, denom) => {
match num.kind() {
CoeffExprKind::Literal(r) => {
assert_eq!(r, &BigRational::from_integer(BigInt::from(1)));
}
other => panic!("expected literal numerator, got {other:?}"),
}
match denom.kind() {
CoeffExprKind::Literal(r) => {
assert_eq!(r, &BigRational::from_integer(BigInt::from(5)));
}
other => panic!("expected literal denominator, got {other:?}"),
}
}
other => panic!("expected quotient coefficient, got {other:?}"),
}
}
#[test]
#[should_panic(expected = "attempted to divide by zero element in EggField")]
fn egg_field_division_by_zero_panics() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let value = field.int(1);
let zero = abstalg::CommuntativeMonoid::zero(&field);
let _ = abstalg::Field::div(&field, &value, &zero);
}
#[test]
fn egg_field_pow_square_then_half_cancels() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { ((x).powi(2)).powf(0.5) });
let actual: SynExpr = syn::parse_str(&expr.to_token_stream().to_string()).unwrap();
let expected: SynExpr = parse_quote! { x };
assert_eq!(
quote! { #actual }.to_string(),
quote! { #expected }.to_string()
);
}
#[test]
fn egg_field_pow_half_then_square_cancels() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { ((x).sqrt()).powi(2) });
let actual: SynExpr = syn::parse_str(&expr.to_token_stream().to_string()).unwrap();
let expected: SynExpr = parse_quote! { x };
assert_eq!(
quote! { #actual }.to_string(),
quote! { #expected }.to_string()
);
}
#[test]
fn egg_field_pow_nested_integer_exponent_folds() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { ((x).powi(3)).powi(2) });
let tokens = expr.to_token_stream().to_string();
assert_eq!(tokens, quote! { (x).powi(6) }.to_string());
}
#[test]
fn egg_field_horner_simple_polynomial_to_mul_add() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { x * x * x + 2 * x * x + 3 * x + 4 });
let tokens = expr.to_token_stream().to_string();
assert!(tokens.contains("mul_add"), "got: {}", tokens);
}
#[test]
fn pow_neg_square_simplifies() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { -((ev.e2) * ((-((a._1).sinh())).powi(2))) });
let tokens = expr.to_token_stream().to_string();
assert!(
tokens.contains("mul")
&& tokens.contains("sinh")
&& tokens.contains("powi")
&& tokens.contains("-"),
"got: {}",
tokens
);
}
#[test]
fn pow_neg_odd_simplifies() {
let scalar_ty: Type = parse_quote!(f32);
let field = EggField::new(scalar_ty);
let expr = field.wrap_expr(parse_quote! { ((-x).powi(3)) });
let expected: SynExpr = parse_quote! { -((x).powi(3)) };
assert_render(&expr, expected);
}
}