use crate::core::number::Number;
use crate::core::Expression;
use crate::error::MathError;
use num_bigint::BigInt;
use num_rational::BigRational;
use std::collections::HashMap;
pub trait EvalNumeric {
fn eval_numeric(&self, precision: u32) -> Result<Expression, MathError>;
}
#[derive(Debug, Clone)]
pub struct EvalContext {
pub variables: HashMap<String, Expression>,
pub numeric: bool,
pub precision: u32,
pub simplify_first: bool,
}
impl EvalContext {
pub fn symbolic() -> Self {
Self {
variables: HashMap::new(),
numeric: false,
precision: 53,
simplify_first: false,
}
}
pub fn numeric(variables: HashMap<String, Expression>) -> Self {
Self {
variables,
numeric: true,
precision: 53,
simplify_first: true,
}
}
pub fn with_precision(mut self, precision: u32) -> Self {
self.precision = precision;
self
}
pub fn with_simplify(mut self, simplify: bool) -> Self {
self.simplify_first = simplify;
self
}
}
impl Default for EvalContext {
fn default() -> Self {
Self::symbolic()
}
}
fn is_number_negative(n: &Number) -> bool {
match n {
Number::Integer(i) => *i < 0,
Number::Float(f) => *f < 0.0,
Number::BigInteger(bi) => **bi < BigInt::from(0),
Number::Rational(r) => **r < BigRational::new(BigInt::from(0), BigInt::from(1)),
}
}
impl EvalNumeric for Expression {
fn eval_numeric(&self, _precision: u32) -> Result<Expression, MathError> {
match self {
Expression::Number(_) => Ok(self.clone()),
Expression::Symbol(_) => Ok(self.clone()),
Expression::Constant(c) => {
use crate::core::MathConstant;
match c {
MathConstant::Pi => Ok(Expression::float(std::f64::consts::PI)),
MathConstant::E => Ok(Expression::float(std::f64::consts::E)),
MathConstant::I => Ok(self.clone()),
MathConstant::Infinity => Ok(self.clone()),
MathConstant::NegativeInfinity => Ok(self.clone()),
MathConstant::Undefined => Ok(self.clone()),
MathConstant::GoldenRatio => {
Ok(Expression::float(MathConstant::GoldenRatio.to_f64()))
}
MathConstant::EulerGamma => {
Ok(Expression::float(MathConstant::EulerGamma.to_f64()))
}
MathConstant::TribonacciConstant => {
Ok(Expression::float(MathConstant::TribonacciConstant.to_f64()))
}
}
}
Expression::Add(terms) => {
let evaluated: Result<Vec<_>, _> =
terms.iter().map(|t| t.eval_numeric(_precision)).collect();
Ok(Expression::add(evaluated?))
}
Expression::Mul(factors) => {
let evaluated: Result<Vec<_>, _> =
factors.iter().map(|f| f.eval_numeric(_precision)).collect();
Ok(Expression::mul(evaluated?))
}
Expression::Pow(base, exp) => {
let base_eval = base.eval_numeric(_precision)?;
let exp_eval = exp.eval_numeric(_precision)?;
if base_eval.is_zero() {
if let Expression::Number(n) = &exp_eval {
if is_number_negative(n) {
return Err(MathError::DivisionByZero);
}
}
}
Ok(Expression::pow(base_eval, exp_eval))
}
Expression::Function { name, args } => {
let eval_args = args
.iter()
.map(|arg| arg.eval_numeric(_precision))
.collect::<Result<Vec<_>, _>>()?;
if let Some(result) =
super::evaluation::evaluate_function_dispatch(name, &eval_args)
{
return Ok(result);
}
Ok(Expression::function(name.clone(), eval_args))
}
Expression::Matrix(matrix) => {
let (rows, cols) = matrix.dimensions();
let mut new_rows = Vec::with_capacity(rows);
for i in 0..rows {
let mut row = Vec::with_capacity(cols);
for j in 0..cols {
let element = matrix.get_element(i, j);
row.push(element.eval_numeric(_precision)?);
}
new_rows.push(row);
}
Ok(Expression::matrix(new_rows))
}
Expression::Set(elements) => {
let evaluated: Result<Vec<_>, _> = elements
.iter()
.map(|e| e.eval_numeric(_precision))
.collect();
Ok(Expression::set(evaluated?))
}
Expression::Complex(data) => {
let real_eval = data.real.eval_numeric(_precision)?;
let imag_eval = data.imag.eval_numeric(_precision)?;
Ok(Expression::complex(real_eval, imag_eval))
}
Expression::Interval(interval) => {
let start_eval = interval.start.eval_numeric(_precision)?;
let end_eval = interval.end.eval_numeric(_precision)?;
Ok(Expression::interval(
start_eval,
end_eval,
interval.start_inclusive,
interval.end_inclusive,
))
}
Expression::Piecewise(data) => {
let mut new_pieces = Vec::with_capacity(data.pieces.len());
for (expr, cond) in &data.pieces {
let expr_eval = expr.eval_numeric(_precision)?;
new_pieces.push((expr_eval, cond.clone()));
}
let default_eval = if let Some(ref default) = data.default {
Some(default.eval_numeric(_precision)?)
} else {
None
};
Ok(Expression::piecewise(new_pieces, default_eval))
}
Expression::Relation(rel) => {
let lhs_eval = rel.left.eval_numeric(_precision)?;
let rhs_eval = rel.right.eval_numeric(_precision)?;
Ok(Expression::relation(lhs_eval, rhs_eval, rel.relation_type))
}
Expression::Calculus(_) => Ok(self.clone()),
Expression::MethodCall(_) => Ok(self.clone()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eval_context_symbolic() {
let ctx = EvalContext::symbolic();
assert!(!ctx.numeric);
assert!(ctx.variables.is_empty());
assert_eq!(ctx.precision, 53);
assert!(!ctx.simplify_first);
}
#[test]
fn test_eval_context_numeric() {
let mut vars = HashMap::new();
vars.insert("x".to_string(), Expression::integer(5));
let ctx = EvalContext::numeric(vars);
assert!(ctx.numeric);
assert_eq!(ctx.variables.len(), 1);
assert_eq!(ctx.precision, 53);
assert!(ctx.simplify_first);
}
#[test]
fn test_eval_context_with_precision() {
let ctx = EvalContext::symbolic().with_precision(128);
assert_eq!(ctx.precision, 128);
}
#[test]
fn test_eval_context_with_simplify() {
let ctx = EvalContext::symbolic().with_simplify(true);
assert!(ctx.simplify_first);
let ctx = EvalContext::symbolic().with_simplify(false);
assert!(!ctx.simplify_first);
}
#[test]
fn test_eval_context_default() {
let ctx = EvalContext::default();
assert!(!ctx.numeric);
assert!(ctx.variables.is_empty());
}
#[test]
fn test_eval_context_chaining() {
let mut vars = HashMap::new();
vars.insert("x".to_string(), Expression::integer(3));
let ctx = EvalContext::numeric(vars)
.with_precision(128)
.with_simplify(false);
assert!(ctx.numeric);
assert_eq!(ctx.variables.len(), 1);
assert_eq!(ctx.precision, 128);
assert!(!ctx.simplify_first);
}
}