use std::collections::HashMap;
use num_traits::ToPrimitive;
use crate::symbolic::core::Expr;
pub fn eval_expr<S: ::std::hash::BuildHasher>(
expr: &Expr,
vars: &HashMap<String, f64, S>,
) -> Result<f64, String> {
match expr {
| Expr::Dag(node) => {
let converted_expr = node
.to_expr()
.map_err(|e| format!("Invalid DAG node: {e}"))?;
eval_expr(&converted_expr, vars)
},
| Expr::Constant(c) => Ok(*c),
| Expr::BigInt(i) => {
i.to_f64()
.ok_or_else(|| "BigInt overflow during evaluation".to_string())
},
| Expr::Rational(r) => {
r.to_f64()
.ok_or_else(|| "Rational overflow during evaluation".to_string())
},
| Expr::Variable(v) => {
vars.get(v)
.copied()
.ok_or_else(|| format!("Unknown variable: '{v}'"))
},
| Expr::Add(a, b) => Ok(eval_expr(a, vars)? + eval_expr(b, vars)?),
| Expr::AddList(list) => {
let mut sum = 0.0;
for e in list {
sum += eval_expr(e, vars)?;
}
Ok(sum)
},
| Expr::Sub(a, b) => Ok(eval_expr(a, vars)? - eval_expr(b, vars)?),
| Expr::Mul(a, b) => Ok(eval_expr(a, vars)? * eval_expr(b, vars)?),
| Expr::MulList(list) => {
let mut prod = 1.0;
for e in list {
prod *= eval_expr(e, vars)?;
}
Ok(prod)
},
| Expr::Div(a, b) => {
let den = eval_expr(b, vars)?;
if den == 0.0 {
return Err("Division by zero".to_string());
}
Ok(eval_expr(a, vars)? / den)
},
| Expr::Neg(a) => Ok(-eval_expr(a, vars)?),
| Expr::Power(b, e) => {
let base = eval_expr(b, vars)?;
let exp = eval_expr(e, vars)?;
if base == 0.0 && exp < 0.0 {
return Err("Undefined operation: 0^negative power".to_string());
}
if base < 0.0 && exp.fract() != 0.0 {
return Err("Complex result: negative base raised to non-integer power".to_string());
}
Ok(base.powf(exp))
},
| Expr::Abs(a) => Ok(eval_expr(a, vars)?.abs()),
| Expr::Sqrt(a) => {
let val = eval_expr(a, vars)?;
if val < 0.0 {
return Err("Square root of negative number".to_string());
}
Ok(val.sqrt())
},
| Expr::Sin(a) => Ok(eval_expr(a, vars)?.sin()),
| Expr::Cos(a) => Ok(eval_expr(a, vars)?.cos()),
| Expr::Tan(a) => Ok(eval_expr(a, vars)?.tan()),
| Expr::Sec(a) => Ok(1.0 / eval_expr(a, vars)?.cos()),
| Expr::Csc(a) => Ok(1.0 / eval_expr(a, vars)?.sin()),
| Expr::Cot(a) => Ok(1.0 / eval_expr(a, vars)?.tan()),
| Expr::ArcSin(a) => {
let val = eval_expr(a, vars)?;
if !(-1.0..=1.0).contains(&val) {
return Err("Inverse sine argument out of domain [-1, 1]".to_string());
}
Ok(val.asin())
},
| Expr::ArcCos(a) => {
let val = eval_expr(a, vars)?;
if !(-1.0..=1.0).contains(&val) {
return Err("Inverse cosine argument out of domain [-1, 1]".to_string());
}
Ok(val.acos())
},
| Expr::ArcTan(a) => Ok(eval_expr(a, vars)?.atan()),
| Expr::Atan2(y, x) => Ok(eval_expr(y, vars)?.atan2(eval_expr(x, vars)?)),
| Expr::ArcSec(a) => {
let val = eval_expr(a, vars)?;
if val.abs() < 1.0 {
return Err(
"Inverse secant argument out of domain (-inf, -1] U [1, inf)".to_string(),
);
}
Ok((1.0 / val).acos())
},
| Expr::ArcCsc(a) => {
let val = eval_expr(a, vars)?;
if val.abs() < 1.0 {
return Err(
"Inverse cosecant argument out of domain (-inf, -1] U [1, inf)".to_string(),
);
}
Ok((1.0 / val).asin())
},
| Expr::ArcCot(a) => Ok((1.0 / eval_expr(a, vars)?).atan()),
| Expr::Sinh(a) => Ok(eval_expr(a, vars)?.sinh()),
| Expr::Cosh(a) => Ok(eval_expr(a, vars)?.cosh()),
| Expr::Tanh(a) => Ok(eval_expr(a, vars)?.tanh()),
| Expr::Sech(a) => Ok(1.0 / eval_expr(a, vars)?.cosh()),
| Expr::Csch(a) => Ok(1.0 / eval_expr(a, vars)?.sinh()),
| Expr::Coth(a) => Ok(1.0 / eval_expr(a, vars)?.tanh()),
| Expr::ArcSinh(a) => Ok(eval_expr(a, vars)?.asinh()),
| Expr::ArcCosh(a) => {
let val = eval_expr(a, vars)?;
if val < 1.0 {
return Err("Inverse hyperbolic cosine argument < 1".to_string());
}
Ok(val.acosh())
},
| Expr::ArcTanh(a) => {
let val = eval_expr(a, vars)?;
if val <= -1.0 || val >= 1.0 {
return Err("Inverse hyperbolic tangent argument out of domain (-1, 1)".to_string());
}
Ok(val.atanh())
},
| Expr::Exp(a) => Ok(eval_expr(a, vars)?.exp()),
| Expr::Log(a) => {
let val = eval_expr(a, vars)?;
if val <= 0.0 {
return Err("Logarithm of non-positive number".to_string());
}
Ok(val.ln())
},
| Expr::LogBase(b, a) => {
let base = eval_expr(b, vars)?;
let val = eval_expr(a, vars)?;
if base <= 0.0 || (base - 1.0).abs() < f64::EPSILON {
return Err("Invalid logarithm base".to_string());
}
if val <= 0.0 {
return Err("Logarithm of non-positive number".to_string());
}
Ok(val.log(base))
},
| Expr::Pi => Ok(std::f64::consts::PI),
| Expr::E => Ok(std::f64::consts::E),
| Expr::Infinity => Ok(f64::INFINITY),
| Expr::NegativeInfinity => Ok(f64::NEG_INFINITY),
| Expr::Floor(a) => Ok(eval_expr(a, vars)?.floor()),
| Expr::Max(a, b) => Ok(eval_expr(a, vars)?.max(eval_expr(b, vars)?)),
| _ => Err(format!("Numerical evaluation of {expr:?} is not supported")),
}
}
pub fn eval_expr_single(
expr: &Expr,
x_name: &str,
x_val: f64,
) -> Result<f64, String> {
let mut vars = HashMap::new();
vars.insert(x_name.to_string(), x_val);
eval_expr(expr, &vars)
}
pub mod pure {
#[must_use]
pub fn sin(x: f64) -> f64 {
x.sin()
}
#[must_use]
pub fn cos(x: f64) -> f64 {
x.cos()
}
#[must_use]
pub fn tan(x: f64) -> f64 {
x.tan()
}
#[must_use]
pub fn asin(x: f64) -> f64 {
x.asin()
}
#[must_use]
pub fn acos(x: f64) -> f64 {
x.acos()
}
#[must_use]
pub fn atan(x: f64) -> f64 {
x.atan()
}
#[must_use]
pub fn atan2(
y: f64,
x: f64,
) -> f64 {
y.atan2(x)
}
#[must_use]
pub fn sinh(x: f64) -> f64 {
x.sinh()
}
#[must_use]
pub fn cosh(x: f64) -> f64 {
x.cosh()
}
#[must_use]
pub fn tanh(x: f64) -> f64 {
x.tanh()
}
#[must_use]
pub fn asinh(x: f64) -> f64 {
x.asinh()
}
#[must_use]
pub fn acosh(x: f64) -> f64 {
x.acosh()
}
#[must_use]
pub fn atanh(x: f64) -> f64 {
x.atanh()
}
#[must_use]
pub const fn abs(x: f64) -> f64 {
x.abs()
}
#[must_use]
pub fn sqrt(x: f64) -> f64 {
x.sqrt()
}
#[must_use]
pub fn ln(x: f64) -> f64 {
x.ln()
}
#[must_use]
pub fn log(
x: f64,
base: f64,
) -> f64 {
x.log(base)
}
#[must_use]
pub fn exp(x: f64) -> f64 {
x.exp()
}
#[must_use]
pub fn pow(
base: f64,
exp: f64,
) -> f64 {
base.powf(exp)
}
#[must_use]
pub const fn floor(x: f64) -> f64 {
x.floor()
}
#[must_use]
pub const fn ceil(x: f64) -> f64 {
x.ceil()
}
#[must_use]
pub const fn round(x: f64) -> f64 {
x.round()
}
#[must_use]
pub const fn signum(x: f64) -> f64 {
x.signum()
}
}