use crate::tree::EmlTree;
pub struct Canonical;
impl Canonical {
pub fn exp(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
EmlTree::eml(x, &one)
}
pub fn ln(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let inner = EmlTree::eml(&one, x); let middle = EmlTree::eml(&inner, &one); EmlTree::eml(&one, &middle) }
pub fn euler() -> EmlTree {
let one = EmlTree::one();
EmlTree::eml(&one, &one)
}
pub fn neg(x: &EmlTree) -> EmlTree {
let e_minus_x = Self::e_minus(x);
let ln_e_minus_x = Self::ln(&e_minus_x);
let exp_e = Self::exp(&Self::euler());
EmlTree::eml(&ln_e_minus_x, &exp_e)
}
pub fn add(x: &EmlTree, y: &EmlTree) -> EmlTree {
Self::sub(x, &Self::neg(y))
}
pub fn sub(x: &EmlTree, y: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let ln_x = Self::ln(x);
let exp_y = EmlTree::eml(y, &one); EmlTree::eml(&ln_x, &exp_y)
}
pub fn mul(x: &EmlTree, y: &EmlTree) -> EmlTree {
let ln_x = Self::ln(x);
let ln_y = Self::ln(y);
let sum = Self::add(&ln_x, &ln_y);
Self::exp(&sum)
}
pub fn div(x: &EmlTree, y: &EmlTree) -> EmlTree {
let ln_x = Self::ln(x);
let ln_y = Self::ln(y);
let diff = Self::sub(&ln_x, &ln_y);
Self::exp(&diff)
}
pub fn pow(x: &EmlTree, y: &EmlTree) -> EmlTree {
let ln_x = Self::ln(x);
let y_ln_x = Self::mul(y, &ln_x);
Self::exp(&y_ln_x)
}
pub fn pi() -> EmlTree {
let one = EmlTree::one();
let neg_one = Self::neg(&one);
Self::ln(&neg_one) }
pub fn sin(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let neg_one = Self::neg(&one);
let half = Self::reciprocal(&Self::add(&one, &one));
let ln_neg_one = Self::ln(&neg_one);
let half_ln_neg_one = Self::mul(&half, &ln_neg_one);
let i_val = Self::exp(&half_ln_neg_one);
let ix = Self::mul(&i_val, x);
let exp_ix = Self::exp(&ix);
let neg_ix = Self::neg(&ix);
let exp_neg_ix = Self::exp(&neg_ix);
let diff = Self::sub(&exp_ix, &exp_neg_ix);
let two = Self::add(&one, &one);
let two_i = Self::mul(&two, &i_val);
Self::div(&diff, &two_i)
}
pub fn cos(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let neg_one = Self::neg(&one);
let half = Self::reciprocal(&Self::add(&one, &one));
let ln_neg_one = Self::ln(&neg_one);
let half_ln_neg_one = Self::mul(&half, &ln_neg_one);
let i_val = Self::exp(&half_ln_neg_one);
let ix = Self::mul(&i_val, x);
let exp_ix = Self::exp(&ix);
let neg_ix = Self::neg(&ix);
let exp_neg_ix = Self::exp(&neg_ix);
let sum = Self::add(&exp_ix, &exp_neg_ix);
let two = Self::add(&one, &one);
Self::div(&sum, &two)
}
pub fn tan(x: &EmlTree) -> EmlTree {
Self::div(&Self::sin(x), &Self::cos(x))
}
pub fn arcsin(x: &EmlTree) -> EmlTree {
let i = Self::imag_unit();
let one = EmlTree::one();
let ix = Self::mul(&i, x);
let x_sq = Self::square(x);
let one_minus_x_sq = Self::sub(&one, &x_sq);
let sqrt_part = Self::sqrt(&one_minus_x_sq);
Self::neg(&Self::mul(&i, &Self::ln(&Self::add(&ix, &sqrt_part))))
}
pub fn arccos(x: &EmlTree) -> EmlTree {
let i = Self::imag_unit();
let one = EmlTree::one();
let x_sq = Self::square(x);
let one_minus_x_sq = Self::sub(&one, &x_sq);
let sqrt_part = Self::sqrt(&one_minus_x_sq);
let i_sqrt = Self::mul(&i, &sqrt_part);
Self::neg(&Self::mul(&i, &Self::ln(&Self::add(x, &i_sqrt))))
}
pub fn arctan(x: &EmlTree) -> EmlTree {
let i = Self::imag_unit();
let one = EmlTree::one();
let two = Self::nat(2);
let ix = Self::mul(&i, x);
let numerator = Self::add(&one, &ix);
let denominator = Self::sub(&one, &ix);
let neg_i_half = Self::neg(&Self::mul(&i, &Self::reciprocal(&two)));
Self::mul(&neg_i_half, &Self::ln(&Self::div(&numerator, &denominator)))
}
pub fn sinh(x: &EmlTree) -> EmlTree {
let exp_x = Self::exp(x);
let exp_neg_x = Self::exp(&Self::neg(x));
Self::div(&Self::sub(&exp_x, &exp_neg_x), &Self::nat(2))
}
pub fn cosh(x: &EmlTree) -> EmlTree {
let exp_x = Self::exp(x);
let exp_neg_x = Self::exp(&Self::neg(x));
Self::div(&Self::add(&exp_x, &exp_neg_x), &Self::nat(2))
}
pub fn tanh(x: &EmlTree) -> EmlTree {
Self::div(&Self::sinh(x), &Self::cosh(x))
}
pub fn arcsinh(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let x_sq = Self::square(x);
Self::ln(&Self::add(x, &Self::sqrt(&Self::add(&x_sq, &one))))
}
pub fn arccosh(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let x_sq = Self::square(x);
Self::ln(&Self::add(x, &Self::sqrt(&Self::sub(&x_sq, &one))))
}
pub fn arctanh(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let two = Self::nat(2);
let half = Self::reciprocal(&two);
let numerator = Self::add(&one, x);
let denominator = Self::sub(&one, x);
Self::mul(&half, &Self::ln(&Self::div(&numerator, &denominator)))
}
pub fn square(x: &EmlTree) -> EmlTree {
Self::pow(x, &Self::nat(2))
}
pub fn sqrt(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let two = Self::add(&one, &one);
let half = Self::reciprocal(&two);
Self::pow(x, &half)
}
pub fn abs(x: &EmlTree) -> EmlTree {
Self::sqrt(&Self::square(x))
}
pub fn neg_one() -> EmlTree {
Self::neg(&EmlTree::one())
}
pub fn neg_two() -> EmlTree {
Self::neg(&Self::nat(2))
}
pub fn imag_unit() -> EmlTree {
let two = Self::nat(2);
let half = Self::reciprocal(&two);
let ln_neg_one = Self::ln(&Self::neg_one()); Self::exp(&Self::mul(&half, &ln_neg_one)) }
fn e_minus(x: &EmlTree) -> EmlTree {
let one = EmlTree::one();
let exp_x = EmlTree::eml(x, &one);
EmlTree::eml(&one, &exp_x)
}
pub fn reciprocal(x: &EmlTree) -> EmlTree {
let ln_x = Self::ln(x);
let neg_ln_x = Self::neg(&ln_x);
Self::exp(&neg_ln_x)
}
pub fn nat(n: u64) -> EmlTree {
assert!(n >= 1, "nat(0) not supported; use ln(1) for zero");
let one = EmlTree::one();
if n == 1 {
return one;
}
let mut result = one.clone();
for _ in 1..n {
result = Self::add(&result, &one);
}
result
}
pub fn zero() -> EmlTree {
Self::ln(&EmlTree::one())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::EvalCtx;
#[test]
fn test_exp_construction() {
let x = EmlTree::var(0);
let exp_x = Canonical::exp(&x);
assert_eq!(exp_x.depth(), 1);
let ctx = EvalCtx::new(&[2.0]);
let result = exp_x.eval_real(&ctx).expect("exp(x) eval should succeed");
assert!((result - 2.0_f64.exp()).abs() < 1e-10);
}
#[test]
fn test_euler_construction() {
let e = Canonical::euler();
assert_eq!(e.depth(), 1);
let ctx = EvalCtx::new(&[]);
let result = e
.eval_real(&ctx)
.expect("euler constant eval should succeed");
assert!((result - std::f64::consts::E).abs() < 1e-10);
}
#[test]
fn test_ln_construction() {
let x = EmlTree::var(0);
let ln_x = Canonical::ln(&x);
assert_eq!(ln_x.depth(), 3);
let ctx = EvalCtx::new(&[std::f64::consts::E]);
let result = ln_x.eval_real(&ctx).expect("ln(x) eval should succeed");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn test_ln_of_one() {
let one = EmlTree::one();
let ln_one = Canonical::ln(&one);
let ctx = EvalCtx::new(&[]);
let result = ln_one.eval_real(&ctx).expect("ln(1) eval should succeed");
assert!(result.abs() < 1e-10);
}
#[test]
fn test_e_minus_x() {
let x = EmlTree::var(0);
let emx = Canonical::e_minus(&x);
let ctx = EvalCtx::new(&[1.0]);
let result = emx.eval_real(&ctx).expect("e_minus(x) eval should succeed");
assert!((result - (std::f64::consts::E - 1.0)).abs() < 1e-10);
}
#[test]
fn test_neg() {
let x = EmlTree::var(0);
let neg_x = Canonical::neg(&x);
let ctx = EvalCtx::new(&[3.0]);
let result = neg_x.eval_real(&ctx).expect("neg(x) eval should succeed");
assert!((result - (-3.0)).abs() < 1e-8);
}
#[test]
fn test_sub() {
let x = EmlTree::var(0);
let y = EmlTree::var(1);
let diff = Canonical::sub(&x, &y);
let ctx = EvalCtx::new(&[5.0, 3.0]);
let result = diff.eval_real(&ctx).expect("sub(x,y) eval should succeed");
assert!((result - 2.0).abs() < 1e-8);
}
#[test]
fn test_add() {
let x = EmlTree::var(0);
let y = EmlTree::var(1);
let sum = Canonical::add(&x, &y);
let ctx = EvalCtx::new(&[2.0, 3.0]);
let result = sum.eval_real(&ctx).expect("add(x,y) eval should succeed");
assert!((result - 5.0).abs() < 1e-6);
}
#[test]
fn test_mul() {
let x = EmlTree::var(0);
let y = EmlTree::var(1);
let prod = Canonical::mul(&x, &y);
let ctx = EvalCtx::new(&[3.0, 4.0]);
let result = prod.eval_real(&ctx).expect("mul(x,y) eval should succeed");
assert!((result - 12.0).abs() < 1e-4);
}
#[test]
fn test_div() {
let x = EmlTree::var(0);
let y = EmlTree::var(1);
let quot = Canonical::div(&x, &y);
let ctx = EvalCtx::new(&[10.0, 2.0]);
let result = quot.eval_real(&ctx).expect("div(x,y) eval should succeed");
assert!((result - 5.0).abs() < 1e-4);
}
#[test]
fn test_pow() {
let x = EmlTree::var(0);
let y = EmlTree::var(1);
let p = Canonical::pow(&x, &y);
let ctx = EvalCtx::new(&[2.0, 3.0]);
let result = p.eval_real(&ctx).expect("pow(x,y) eval should succeed");
assert!((result - 8.0).abs() < 1e-4);
}
#[test]
fn test_reciprocal() {
let x = EmlTree::var(0);
let recip = Canonical::reciprocal(&x);
let ctx = EvalCtx::new(&[4.0]);
let result = recip
.eval_real(&ctx)
.expect("reciprocal(x) eval should succeed");
assert!((result - 0.25).abs() < 1e-8);
}
#[test]
fn test_zero() {
let z = Canonical::zero();
let ctx = EvalCtx::new(&[]);
let result = z
.eval_real(&ctx)
.expect("zero constant eval should succeed");
assert!(result.abs() < 1e-10);
}
#[test]
fn test_nat() {
for n in 1..=5u64 {
let tree = Canonical::nat(n);
let ctx = EvalCtx::new(&[]);
let result = tree.eval_real(&ctx).expect("nat(n) eval should succeed");
assert!(
(result - n as f64).abs() < 0.1,
"nat({n}) = {result}, expected {n}"
);
}
}
#[test]
fn test_sqrt() {
let x = EmlTree::var(0);
let sqrt_x = Canonical::sqrt(&x);
let ctx = EvalCtx::new(&[4.0]);
let result = sqrt_x.eval_real(&ctx).expect("sqrt(x) eval should succeed");
assert!((result - 2.0).abs() < 1e-2);
}
#[test]
fn test_abs_positive() {
let x = EmlTree::var(0);
let abs_x = Canonical::abs(&x);
let ctx = EvalCtx::new(&[3.0]);
let result = abs_x.eval_real(&ctx).expect("abs(x) eval should succeed");
assert!((result - 3.0).abs() < 1e-2);
}
#[test]
fn test_square() {
let x = EmlTree::var(0);
let x_sq = Canonical::square(&x);
for &val in &[2.0, 3.0, 0.5] {
let ctx = EvalCtx::new(&[val]);
let result = x_sq.eval_real(&ctx).expect("square(x) eval should succeed");
assert!(
(result - val * val).abs() < 1e-2,
"square({val}) = {result}, expected {}",
val * val
);
}
}
#[test]
fn test_neg_one() {
let tree = Canonical::neg_one();
let ctx = EvalCtx::new(&[]);
let result = tree.eval_real(&ctx).expect("neg_one eval should succeed");
assert!((result - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_neg_two() {
let tree = Canonical::neg_two();
let ctx = EvalCtx::new(&[]);
let result = tree.eval_real(&ctx).expect("neg_two eval should succeed");
assert!((result - (-2.0)).abs() < 0.1);
}
#[test]
fn test_imag_unit() {
let i_tree = Canonical::imag_unit();
let ctx = EvalCtx::new(&[]);
assert!(
i_tree.eval_real(&ctx).is_err(),
"imag_unit should fail in real mode"
);
let result = i_tree
.eval_complex(&[])
.expect("imag unit complex eval should succeed");
assert!(
result.re.abs() < 1e-4,
"Re(i) should be ~0, got {}",
result.re
);
assert!(
(result.im - 1.0).abs() < 1e-4,
"Im(i) should be ~1, got {}",
result.im
);
}
#[test]
fn test_tan() {
let x = EmlTree::var(0);
let tan_x = Canonical::tan(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = tan_x.eval_real(&ctx);
if let Ok(val) = result {
assert!(val.abs() < 0.1, "tan(0) should be ~0, got {val}");
}
}
#[test]
fn test_sinh() {
let x = EmlTree::var(0);
let sinh_x = Canonical::sinh(&x);
for &val in &[0.0, 1.0] {
let ctx = EvalCtx::new(&[val]);
let result = sinh_x.eval_real(&ctx).expect("sinh(x) eval should succeed");
assert!(
(result - val.sinh()).abs() < 0.1,
"sinh({val}) = {result}, expected {}",
val.sinh()
);
}
}
#[test]
fn test_cosh() {
let x = EmlTree::var(0);
let cosh_x = Canonical::cosh(&x);
for &val in &[0.0, 1.0] {
let ctx = EvalCtx::new(&[val]);
let result = cosh_x.eval_real(&ctx).expect("cosh(x) eval should succeed");
assert!(
(result - val.cosh()).abs() < 0.1,
"cosh({val}) = {result}, expected {}",
val.cosh()
);
}
}
#[test]
fn test_tanh() {
let x = EmlTree::var(0);
let tanh_x = Canonical::tanh(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = tanh_x.eval_real(&ctx);
if let Ok(val) = result {
assert!(val.abs() < 0.1, "tanh(0) should be ~0, got {val}");
}
}
#[test]
fn test_arcsinh() {
let x = EmlTree::var(0);
let asinh_x = Canonical::arcsinh(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = asinh_x
.eval_real(&ctx)
.expect("arcsinh(0) eval should succeed");
assert!(result.abs() < 0.1, "arcsinh(0) = {result}, expected 0");
}
#[test]
fn test_arctanh() {
let x = EmlTree::var(0);
let atanh_x = Canonical::arctanh(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = atanh_x
.eval_real(&ctx)
.expect("arctanh(0) eval should succeed");
assert!(result.abs() < 0.1, "arctanh(0) = {result}, expected 0");
}
#[test]
fn test_arctan() {
let x = EmlTree::var(0);
let atan_x = Canonical::arctan(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = atan_x.eval_real(&ctx);
if let Ok(val) = result {
assert!(val.abs() < 0.1, "arctan(0) should be ~0, got {val}");
}
}
#[test]
fn test_arcsin() {
let x = EmlTree::var(0);
let asin_x = Canonical::arcsin(&x);
let ctx = EvalCtx::new(&[0.0]);
let result = asin_x.eval_real(&ctx);
if let Ok(val) = result {
assert!(val.abs() < 0.1, "arcsin(0) should be ~0, got {val}");
}
}
#[test]
fn test_arccos() {
let x = EmlTree::var(0);
let acos_x = Canonical::arccos(&x);
let ctx = EvalCtx::new(&[1.0]);
let result = acos_x.eval_real(&ctx);
if let Ok(val) = result {
assert!(val.abs() < 0.2, "arccos(1) should be ~0, got {val}");
}
}
}