use crate::error::EmlError;
use crate::eval::EvalCtx;
use crate::named_const::NamedConst;
use crate::tree::{EmlNode, EmlTree};
use std::fmt;
use std::sync::OnceLock;
const WILDCARD_VAR: usize = usize::MAX / 2;
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
pub enum LoweredOp {
Const(f64),
Var(usize),
Add(Box<LoweredOp>, Box<LoweredOp>),
Sub(Box<LoweredOp>, Box<LoweredOp>),
Mul(Box<LoweredOp>, Box<LoweredOp>),
Div(Box<LoweredOp>, Box<LoweredOp>),
Exp(Box<LoweredOp>),
Ln(Box<LoweredOp>),
Sin(Box<LoweredOp>),
Cos(Box<LoweredOp>),
Pow(Box<LoweredOp>, Box<LoweredOp>),
Neg(Box<LoweredOp>),
Tan(Box<LoweredOp>),
Sinh(Box<LoweredOp>),
Cosh(Box<LoweredOp>),
Tanh(Box<LoweredOp>),
Arcsin(Box<LoweredOp>),
Arccos(Box<LoweredOp>),
Arctan(Box<LoweredOp>),
Arcsinh(Box<LoweredOp>),
Arccosh(Box<LoweredOp>),
Arctanh(Box<LoweredOp>),
NamedConst(NamedConst),
}
#[derive(Clone, Debug, PartialEq)]
pub enum OxiOp {
Const(f64),
Var(usize),
Add,
Sub,
Mul,
Div,
Neg,
Exp,
Ln,
Sin,
Cos,
Pow,
Tan,
Sinh,
Cosh,
Tanh,
Arcsin,
Arccos,
Arctan,
Arcsinh,
Arccosh,
Arctanh,
}
pub use crate::lower_interval::IntervalLO;
impl EmlTree {
pub fn lower(&self) -> LoweredOp {
lower_node(&self.root)
}
pub fn eval_real_lowered(&self, ctx: &EvalCtx) -> Result<f64, EmlError> {
let lowered = self.lower().simplify();
let ops = lowered.to_oxiblas_ops();
let result = LoweredOp::eval_ops(&ops, ctx.as_slice());
if result.is_nan() {
return Err(EmlError::NanEncountered);
}
Ok(result)
}
}
fn lower_node(node: &EmlNode) -> LoweredOp {
match node {
EmlNode::One => LoweredOp::Const(1.0),
EmlNode::Var(i) => LoweredOp::Var(*i),
EmlNode::Eml { left, right } => {
if let Some(inner) = match_sin_structure(node) {
return LoweredOp::Sin(Box::new(lower_node(&inner)));
}
if let Some(inner) = match_cos_structure(node) {
return LoweredOp::Cos(Box::new(lower_node(&inner)));
}
if matches!(right.as_ref(), EmlNode::One) {
if let Some(inner) = match_ln_structure(left) {
return lower_node(&inner);
}
return LoweredOp::Exp(Box::new(lower_node(left)));
}
if matches!(left.as_ref(), EmlNode::One) && matches!(right.as_ref(), EmlNode::One) {
return LoweredOp::Const(std::f64::consts::E);
}
if matches!(left.as_ref(), EmlNode::One) {
if let Some(inner) = match_ln_of_right(right) {
return LoweredOp::Ln(Box::new(lower_node(&inner)));
}
}
if matches!(left.as_ref(), EmlNode::One) {
if let EmlNode::Eml {
left: inner_l,
right: inner_r,
} = right.as_ref()
{
if matches!(inner_r.as_ref(), EmlNode::One) {
let x_lowered = lower_node(inner_l);
return LoweredOp::Sub(
Box::new(LoweredOp::Const(std::f64::consts::E)),
Box::new(x_lowered),
);
}
}
}
if let Some(x_inner) = match_ln_structure(left) {
if let EmlNode::Eml {
left: y_node,
right: y_one,
} = right.as_ref()
{
if matches!(y_one.as_ref(), EmlNode::One) {
return LoweredOp::Sub(
Box::new(lower_node(&x_inner)),
Box::new(lower_node(y_node)),
);
}
}
}
let left_lowered = lower_node(left);
let right_lowered = lower_node(right);
LoweredOp::Sub(
Box::new(LoweredOp::Exp(Box::new(left_lowered))),
Box::new(LoweredOp::Ln(Box::new(right_lowered))),
)
}
}
}
fn match_ln_structure(node: &EmlNode) -> Option<EmlNode> {
if let EmlNode::Eml { left, right } = node {
if !matches!(left.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: mid_l,
right: mid_r,
} = right.as_ref()
{
if !matches!(mid_r.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: inner_l,
right: inner_r,
} = mid_l.as_ref()
{
if matches!(inner_l.as_ref(), EmlNode::One) {
return Some(inner_r.as_ref().clone());
}
}
}
}
None
}
fn match_ln_of_right(right: &EmlNode) -> Option<EmlNode> {
if let EmlNode::Eml {
left: mid_l,
right: mid_r,
} = right
{
if !matches!(mid_r.as_ref(), EmlNode::One) {
return None;
}
if let EmlNode::Eml {
left: inner_l,
right: inner_r,
} = mid_l.as_ref()
{
if matches!(inner_l.as_ref(), EmlNode::One) {
return Some(inner_r.as_ref().clone());
}
}
}
None
}
fn sin_template() -> &'static EmlNode {
static TEMPLATE: OnceLock<EmlNode> = OnceLock::new();
TEMPLATE.get_or_init(|| {
let placeholder = EmlTree::var(WILDCARD_VAR);
let tree = crate::canonical::Canonical::sin(&placeholder);
(*tree.root).clone()
})
}
fn cos_template() -> &'static EmlNode {
static TEMPLATE: OnceLock<EmlNode> = OnceLock::new();
TEMPLATE.get_or_init(|| {
let placeholder = EmlTree::var(WILDCARD_VAR);
let tree = crate::canonical::Canonical::cos(&placeholder);
(*tree.root).clone()
})
}
fn unify_with_wildcard<'a>(
candidate: &'a EmlNode,
template: &EmlNode,
captured: &mut Option<&'a EmlNode>,
) -> bool {
if let EmlNode::Var(idx) = template {
if *idx == WILDCARD_VAR {
match captured {
None => {
*captured = Some(candidate);
return true;
}
Some(prev) => {
return nodes_structurally_equal(prev, candidate);
}
}
}
}
match (candidate, template) {
(EmlNode::One, EmlNode::One) => true,
(EmlNode::Var(a), EmlNode::Var(b)) => a == b,
(
EmlNode::Eml {
left: la,
right: ra,
},
EmlNode::Eml {
left: lb,
right: rb,
},
) => {
unify_with_wildcard(la.as_ref(), lb.as_ref(), captured)
&& unify_with_wildcard(ra.as_ref(), rb.as_ref(), captured)
}
_ => false,
}
}
fn nodes_structurally_equal(a: &EmlNode, b: &EmlNode) -> bool {
match (a, b) {
(EmlNode::One, EmlNode::One) => true,
(EmlNode::Var(i), EmlNode::Var(j)) => i == j,
(
EmlNode::Eml {
left: la,
right: ra,
},
EmlNode::Eml {
left: lb,
right: rb,
},
) => {
nodes_structurally_equal(la.as_ref(), lb.as_ref())
&& nodes_structurally_equal(ra.as_ref(), rb.as_ref())
}
_ => false,
}
}
fn match_sin_structure(node: &EmlNode) -> Option<EmlNode> {
let mut captured: Option<&EmlNode> = None;
if unify_with_wildcard(node, sin_template(), &mut captured) {
captured.cloned()
} else {
None
}
}
fn match_cos_structure(node: &EmlNode) -> Option<EmlNode> {
let mut captured: Option<&EmlNode> = None;
if unify_with_wildcard(node, cos_template(), &mut captured) {
captured.cloned()
} else {
None
}
}
impl LoweredOp {
pub fn to_oxiblas_ops(&self) -> Vec<OxiOp> {
let mut ops = Vec::new();
self.collect_ops(&mut ops);
ops
}
fn collect_ops(&self, ops: &mut Vec<OxiOp>) {
match self {
Self::Const(c) => ops.push(OxiOp::Const(*c)),
Self::NamedConst(nc) => ops.push(OxiOp::Const(nc.value())),
Self::Var(i) => ops.push(OxiOp::Var(*i)),
Self::Add(a, b) => {
a.collect_ops(ops);
b.collect_ops(ops);
ops.push(OxiOp::Add);
}
Self::Sub(a, b) => {
a.collect_ops(ops);
b.collect_ops(ops);
ops.push(OxiOp::Sub);
}
Self::Mul(a, b) => {
a.collect_ops(ops);
b.collect_ops(ops);
ops.push(OxiOp::Mul);
}
Self::Div(a, b) => {
a.collect_ops(ops);
b.collect_ops(ops);
ops.push(OxiOp::Div);
}
Self::Exp(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Exp);
}
Self::Ln(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Ln);
}
Self::Sin(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Sin);
}
Self::Cos(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Cos);
}
Self::Pow(a, b) => {
a.collect_ops(ops);
b.collect_ops(ops);
ops.push(OxiOp::Pow);
}
Self::Neg(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Neg);
}
Self::Tan(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Tan);
}
Self::Sinh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Sinh);
}
Self::Cosh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Cosh);
}
Self::Tanh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Tanh);
}
Self::Arcsin(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arcsin);
}
Self::Arccos(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arccos);
}
Self::Arctan(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arctan);
}
Self::Arcsinh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arcsinh);
}
Self::Arccosh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arccosh);
}
Self::Arctanh(a) => {
a.collect_ops(ops);
ops.push(OxiOp::Arctanh);
}
}
}
pub fn eval_ops(ops: &[OxiOp], vars: &[f64]) -> f64 {
let mut stack: Vec<f64> = Vec::with_capacity(ops.len());
for op in ops {
match op {
OxiOp::Const(c) => stack.push(*c),
OxiOp::Var(i) => {
stack.push(vars.get(*i).copied().unwrap_or(f64::NAN));
}
OxiOp::Add => {
let b = stack.pop().unwrap_or(f64::NAN);
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a + b);
}
OxiOp::Sub => {
let b = stack.pop().unwrap_or(f64::NAN);
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a - b);
}
OxiOp::Mul => {
let b = stack.pop().unwrap_or(f64::NAN);
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a * b);
}
OxiOp::Div => {
let b = stack.pop().unwrap_or(f64::NAN);
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a / b);
}
OxiOp::Neg => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(-a);
}
OxiOp::Exp => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.exp());
}
OxiOp::Ln => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.ln());
}
OxiOp::Sin => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.sin());
}
OxiOp::Cos => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.cos());
}
OxiOp::Pow => {
let b = stack.pop().unwrap_or(f64::NAN);
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.powf(b));
}
OxiOp::Tan => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.tan());
}
OxiOp::Sinh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.sinh());
}
OxiOp::Cosh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.cosh());
}
OxiOp::Tanh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.tanh());
}
OxiOp::Arcsin => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.asin());
}
OxiOp::Arccos => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.acos());
}
OxiOp::Arctan => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.atan());
}
OxiOp::Arcsinh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.asinh());
}
OxiOp::Arccosh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.acosh());
}
OxiOp::Arctanh => {
let a = stack.pop().unwrap_or(f64::NAN);
stack.push(a.atanh());
}
}
}
stack.pop().unwrap_or(f64::NAN)
}
pub fn eval_batch(&self, data: &[Vec<f64>]) -> Vec<f64> {
let ops = self.to_oxiblas_ops();
#[cfg(feature = "simd")]
{
crate::simd_eval::eval_batch_simd(&ops, data)
}
#[cfg(not(feature = "simd"))]
{
Self::eval_batch_scalar_from_ops(&ops, data)
}
}
pub fn eval_batch_scalar_from_ops(ops: &[OxiOp], data: &[Vec<f64>]) -> Vec<f64> {
data.iter().map(|row| Self::eval_ops(ops, row)).collect()
}
pub fn eval_batch_scalar(&self, data: &[Vec<f64>]) -> Vec<f64> {
let ops = self.to_oxiblas_ops();
Self::eval_batch_scalar_from_ops(&ops, data)
}
pub fn structural_hash<H: std::hash::Hasher>(&self, state: &mut H) {
use std::hash::Hash;
match self {
Self::Const(c) => {
0u8.hash(state);
c.to_bits().hash(state);
}
Self::NamedConst(nc) => {
0u8.hash(state);
nc.value().to_bits().hash(state);
}
Self::Var(i) => {
1u8.hash(state);
i.hash(state);
}
Self::Add(a, b) => {
a.structural_hash(state);
b.structural_hash(state);
2u8.hash(state);
}
Self::Sub(a, b) => {
a.structural_hash(state);
b.structural_hash(state);
3u8.hash(state);
}
Self::Mul(a, b) => {
a.structural_hash(state);
b.structural_hash(state);
4u8.hash(state);
}
Self::Div(a, b) => {
a.structural_hash(state);
b.structural_hash(state);
5u8.hash(state);
}
Self::Exp(a) => {
a.structural_hash(state);
6u8.hash(state);
}
Self::Ln(a) => {
a.structural_hash(state);
7u8.hash(state);
}
Self::Sin(a) => {
a.structural_hash(state);
8u8.hash(state);
}
Self::Cos(a) => {
a.structural_hash(state);
9u8.hash(state);
}
Self::Pow(a, b) => {
a.structural_hash(state);
b.structural_hash(state);
10u8.hash(state);
}
Self::Neg(a) => {
a.structural_hash(state);
11u8.hash(state);
}
Self::Tan(a) => {
a.structural_hash(state);
12u8.hash(state);
}
Self::Sinh(a) => {
a.structural_hash(state);
13u8.hash(state);
}
Self::Cosh(a) => {
a.structural_hash(state);
14u8.hash(state);
}
Self::Tanh(a) => {
a.structural_hash(state);
15u8.hash(state);
}
Self::Arcsin(a) => {
a.structural_hash(state);
16u8.hash(state);
}
Self::Arccos(a) => {
a.structural_hash(state);
17u8.hash(state);
}
Self::Arctan(a) => {
a.structural_hash(state);
18u8.hash(state);
}
Self::Arcsinh(a) => {
a.structural_hash(state);
19u8.hash(state);
}
Self::Arccosh(a) => {
a.structural_hash(state);
20u8.hash(state);
}
Self::Arctanh(a) => {
a.structural_hash(state);
21u8.hash(state);
}
}
}
pub fn to_pretty(&self) -> String {
format!("{self}")
}
pub fn to_latex(&self) -> String {
fn render(op: &LoweredOp, top_level: bool) -> String {
match op {
LoweredOp::NamedConst(nc) => nc.to_latex().to_string(),
LoweredOp::Const(c) => {
if (*c - std::f64::consts::E).abs() < 1e-15 {
"e".to_string()
} else if (*c - std::f64::consts::PI).abs() < 1e-15 {
r"\pi".to_string()
} else if (*c - std::f64::consts::TAU).abs() < 1e-15 {
r"2\pi".to_string()
} else if (*c - (-1.0_f64)).abs() < 1e-15 {
"-1".to_string()
} else if (c - c.round()).abs() < 1e-10 && c.abs() < 1e15 {
format!("{}", *c as i64)
} else {
format!("{c:.6}")
}
}
LoweredOp::Var(i) => format!("x_{{{i}}}"),
LoweredOp::Add(a, b) => {
let inner = format!("{} + {}", render(a, false), render(b, false));
if top_level {
inner
} else {
format!("({inner})")
}
}
LoweredOp::Sub(a, b) => {
let inner = format!("{} - {}", render(a, false), render(b, false));
if top_level {
inner
} else {
format!("({inner})")
}
}
LoweredOp::Mul(a, b) => {
let inner = format!(r"{} \cdot {}", render(a, false), render(b, false));
if top_level {
inner
} else {
format!("({inner})")
}
}
LoweredOp::Div(a, b) => {
format!(r"\frac{{{}}}{{{}}}", render(a, true), render(b, true))
}
LoweredOp::Exp(a) => {
let arg = render(a, true);
format!("e^{{{arg}}}")
}
LoweredOp::Ln(a) => {
format!(r"\ln\left({}\right)", render(a, true))
}
LoweredOp::Sin(a) => {
format!(r"\sin\left({}\right)", render(a, true))
}
LoweredOp::Cos(a) => {
format!(r"\cos\left({}\right)", render(a, true))
}
LoweredOp::Pow(base, exp) => {
let b = render(base, false);
let e = render(exp, true);
format!("{b}^{{{e}}}")
}
LoweredOp::Neg(a) => {
let inner = render(a, false);
format!("-{inner}")
}
LoweredOp::Tan(a) => {
format!(r"\tan{{{}}}", render(a, true))
}
LoweredOp::Sinh(a) => {
format!(r"\sinh{{{}}}", render(a, true))
}
LoweredOp::Cosh(a) => {
format!(r"\cosh{{{}}}", render(a, true))
}
LoweredOp::Tanh(a) => {
format!(r"\tanh{{{}}}", render(a, true))
}
LoweredOp::Arcsin(a) => {
format!(r"\arcsin{{{}}}", render(a, true))
}
LoweredOp::Arccos(a) => {
format!(r"\arccos{{{}}}", render(a, true))
}
LoweredOp::Arctan(a) => {
format!(r"\arctan{{{}}}", render(a, true))
}
LoweredOp::Arcsinh(a) => {
format!(r"\operatorname{{arcsinh}}{{{}}}", render(a, true))
}
LoweredOp::Arccosh(a) => {
format!(r"\operatorname{{arccosh}}{{{}}}", render(a, true))
}
LoweredOp::Arctanh(a) => {
format!(r"\operatorname{{arctanh}}{{{}}}", render(a, true))
}
}
}
render(self, true)
}
pub fn eval(&self, vars: &[f64]) -> f64 {
match self {
Self::Const(c) => *c,
Self::NamedConst(nc) => nc.value(),
Self::Var(i) => vars[*i],
Self::Add(a, b) => a.eval(vars) + b.eval(vars),
Self::Sub(a, b) => a.eval(vars) - b.eval(vars),
Self::Mul(a, b) => a.eval(vars) * b.eval(vars),
Self::Div(a, b) => a.eval(vars) / b.eval(vars),
Self::Exp(a) => a.eval(vars).exp(),
Self::Ln(a) => a.eval(vars).ln(),
Self::Sin(a) => a.eval(vars).sin(),
Self::Cos(a) => a.eval(vars).cos(),
Self::Pow(a, b) => a.eval(vars).powf(b.eval(vars)),
Self::Neg(a) => -a.eval(vars),
Self::Tan(a) => a.eval(vars).tan(),
Self::Sinh(a) => a.eval(vars).sinh(),
Self::Cosh(a) => a.eval(vars).cosh(),
Self::Tanh(a) => a.eval(vars).tanh(),
Self::Arcsin(a) => a.eval(vars).asin(),
Self::Arccos(a) => a.eval(vars).acos(),
Self::Arctan(a) => a.eval(vars).atan(),
Self::Arcsinh(a) => a.eval(vars).asinh(),
Self::Arccosh(a) => a.eval(vars).acosh(),
Self::Arctanh(a) => a.eval(vars).atanh(),
}
}
}
impl fmt::Display for LoweredOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NamedConst(nc) => write!(f, "{}", nc.to_pretty()),
Self::Const(c) => {
if (*c - std::f64::consts::E).abs() < 1e-15 {
write!(f, "e")
} else if (*c - std::f64::consts::PI).abs() < 1e-15 {
write!(f, "Ï€")
} else if (c - c.round()).abs() < 1e-10 && c.abs() < 1e15 {
write!(f, "{}", *c as i64)
} else {
write!(f, "{c:.6}")
}
}
Self::Var(i) => write!(f, "x{i}"),
Self::Add(a, b) => write!(f, "({a} + {b})"),
Self::Sub(a, b) => write!(f, "({a} - {b})"),
Self::Mul(a, b) => write!(f, "({a} * {b})"),
Self::Div(a, b) => write!(f, "({a} / {b})"),
Self::Exp(a) => write!(f, "exp({a})"),
Self::Ln(a) => write!(f, "ln({a})"),
Self::Sin(a) => write!(f, "sin({a})"),
Self::Cos(a) => write!(f, "cos({a})"),
Self::Pow(a, b) => write!(f, "({a})^({b})"),
Self::Neg(a) => write!(f, "-{a}"),
Self::Tan(a) => write!(f, "tan({a})"),
Self::Sinh(a) => write!(f, "sinh({a})"),
Self::Cosh(a) => write!(f, "cosh({a})"),
Self::Tanh(a) => write!(f, "tanh({a})"),
Self::Arcsin(a) => write!(f, "arcsin({a})"),
Self::Arccos(a) => write!(f, "arccos({a})"),
Self::Arctan(a) => write!(f, "arctan({a})"),
Self::Arcsinh(a) => write!(f, "arcsinh({a})"),
Self::Arccosh(a) => write!(f, "arccosh({a})"),
Self::Arctanh(a) => write!(f, "arctanh({a})"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lower_one() {
let t = EmlTree::one();
let lowered = t.lower();
assert_eq!(lowered, LoweredOp::Const(1.0));
}
#[test]
fn test_lower_var() {
let t = EmlTree::var(0);
let lowered = t.lower();
assert_eq!(lowered, LoweredOp::Var(0));
}
#[test]
fn test_lower_exp() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let lowered = exp_x.lower();
assert_eq!(lowered, LoweredOp::Exp(Box::new(LoweredOp::Var(0))));
}
#[test]
fn test_lower_e_minus_x() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let exp_x = EmlTree::eml(&x, &one);
let e_minus_x = EmlTree::eml(&one, &exp_x);
let lowered = e_minus_x.lower();
assert_eq!(
lowered,
LoweredOp::Sub(
Box::new(LoweredOp::Const(std::f64::consts::E)),
Box::new(LoweredOp::Var(0)),
)
);
}
#[test]
fn test_lower_ln() {
let x = EmlTree::var(0);
let one = EmlTree::one();
let inner = EmlTree::eml(&one, &x); let middle = EmlTree::eml(&inner, &one); let ln_x = EmlTree::eml(&one, &middle); let lowered = ln_x.lower();
assert_eq!(lowered, LoweredOp::Ln(Box::new(LoweredOp::Var(0))));
}
#[test]
fn test_lowered_eval() {
let op = LoweredOp::Add(Box::new(LoweredOp::Var(0)), Box::new(LoweredOp::Const(3.0)));
assert!((op.eval(&[2.0]) - 5.0).abs() < 1e-15);
}
#[test]
fn test_pretty_print() {
let op = LoweredOp::Mul(Box::new(LoweredOp::Var(0)), Box::new(LoweredOp::Var(1)));
assert_eq!(op.to_pretty(), "(x0 * x1)");
}
#[test]
fn test_simplify_exp_ln() {
let op = LoweredOp::Exp(Box::new(LoweredOp::Ln(Box::new(LoweredOp::Var(0)))));
let simplified = op.simplify();
assert_eq!(simplified, LoweredOp::Var(0));
}
#[test]
fn test_simplify_constants() {
let op = LoweredOp::Add(
Box::new(LoweredOp::Const(2.0)),
Box::new(LoweredOp::Const(3.0)),
);
let simplified = op.simplify();
assert_eq!(simplified, LoweredOp::Const(5.0));
}
#[test]
fn test_to_oxiblas_ops_roundtrip() {
use crate::Canonical;
let x = crate::tree::EmlTree::var(0);
let exp_x = Canonical::exp(&x);
let lowered = exp_x.lower();
let ops = lowered.to_oxiblas_ops();
let result = LoweredOp::eval_ops(&ops, &[1.5_f64]);
assert!(
(result - 1.5_f64.exp()).abs() < 1e-12,
"exp roundtrip failed: {result}"
);
let ln_x = Canonical::ln(&x);
let lowered_ln = ln_x.lower();
let ops_ln = lowered_ln.to_oxiblas_ops();
let result_ln = LoweredOp::eval_ops(&ops_ln, &[2.0_f64]);
assert!(
(result_ln - 2.0_f64.ln()).abs() < 1e-12,
"ln roundtrip failed: {result_ln}"
);
let lowered_sin = LoweredOp::Sin(Box::new(LoweredOp::Var(0)));
let ops_sin = lowered_sin.to_oxiblas_ops();
let result_sin = LoweredOp::eval_ops(&ops_sin, &[std::f64::consts::PI / 6.0]);
assert!(
(result_sin - 0.5_f64).abs() < 1e-9,
"sin roundtrip failed: {result_sin}"
);
}
#[test]
fn test_eval_batch_scalar_matches_eval() {
use crate::Canonical;
let x = crate::tree::EmlTree::var(0);
let exp_x = Canonical::exp(&x);
let lowered = exp_x.lower();
let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64 * 0.05]).collect();
let batch_results = lowered.eval_batch_scalar(&data);
assert_eq!(batch_results.len(), 100);
for (row, result) in data.iter().zip(batch_results.iter()) {
let expected = lowered.eval(row);
assert!(
(result - expected).abs() < 1e-12,
"mismatch at x={}: got {result}, expected {expected}",
row[0]
);
}
}
#[test]
fn test_structural_hash_differs() {
use crate::Canonical;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let x = crate::tree::EmlTree::var(0);
let exp_x = Canonical::exp(&x).lower().simplify();
let ln_x = Canonical::ln(&x).lower().simplify();
let mut h1 = DefaultHasher::new();
exp_x.structural_hash(&mut h1);
let mut h2 = DefaultHasher::new();
ln_x.structural_hash(&mut h2);
assert_ne!(
h1.finish(),
h2.finish(),
"exp and ln should have different structural hashes"
);
}
#[test]
fn test_structural_hash_same_for_equiv() {
use crate::Canonical;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let x = crate::tree::EmlTree::var(0);
let exp_x1 = Canonical::exp(&x).lower().simplify();
let exp_x2 = Canonical::exp(&x).lower().simplify();
let mut h1 = DefaultHasher::new();
exp_x1.structural_hash(&mut h1);
let mut h2 = DefaultHasher::new();
exp_x2.structural_hash(&mut h2);
assert_eq!(
h1.finish(),
h2.finish(),
"identical trees should have the same structural hash"
);
}
#[test]
fn latex_var() {
assert_eq!(LoweredOp::Var(0).to_latex(), "x_{0}");
assert_eq!(LoweredOp::Var(3).to_latex(), "x_{3}");
}
#[test]
fn latex_const_pi() {
assert_eq!(LoweredOp::Const(std::f64::consts::PI).to_latex(), r"\pi");
}
#[test]
fn latex_const_e() {
assert_eq!(LoweredOp::Const(std::f64::consts::E).to_latex(), "e");
}
#[test]
fn latex_const_integer() {
assert_eq!(LoweredOp::Const(2.0).to_latex(), "2");
assert_eq!(LoweredOp::Const(-1.0).to_latex(), "-1");
}
#[test]
fn latex_div() {
let op = LoweredOp::Div(Box::new(LoweredOp::Const(1.0)), Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), r"\frac{1}{x_{0}}");
}
#[test]
fn latex_exp() {
let op = LoweredOp::Exp(Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), r"e^{x_{0}}");
}
#[test]
fn latex_ln() {
let op = LoweredOp::Ln(Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), r"\ln\left(x_{0}\right)");
}
#[test]
fn latex_sin_cos() {
let op = LoweredOp::Sin(Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), r"\sin\left(x_{0}\right)");
let op2 = LoweredOp::Cos(Box::new(LoweredOp::Var(0)));
assert_eq!(op2.to_latex(), r"\cos\left(x_{0}\right)");
}
#[test]
fn latex_pow() {
let op = LoweredOp::Pow(Box::new(LoweredOp::Var(0)), Box::new(LoweredOp::Const(2.0)));
assert_eq!(op.to_latex(), "x_{0}^{2}");
}
#[test]
fn latex_neg() {
let op = LoweredOp::Neg(Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), "-x_{0}");
}
#[test]
fn latex_mul() {
let op = LoweredOp::Mul(Box::new(LoweredOp::Const(2.0)), Box::new(LoweredOp::Var(0)));
assert_eq!(op.to_latex(), r"2 \cdot x_{0}");
}
#[test]
fn latex_composite() {
let op = LoweredOp::Div(
Box::new(LoweredOp::Sin(Box::new(LoweredOp::Var(0)))),
Box::new(LoweredOp::Cos(Box::new(LoweredOp::Var(0)))),
);
let latex = op.to_latex();
assert!(latex.contains(r"\frac"));
assert!(latex.contains(r"\sin"));
assert!(latex.contains(r"\cos"));
}
}