use crate::core::{Expression, Number, BinaryOperator, UnaryOperator, MathConstant};
use crate::engine::error::ComputeError;
use crate::engine::simplify::Simplifier;
use std::collections::HashMap;
use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::{Zero, One, ToPrimitive};
pub struct EnhancedSimplifier {
base_simplifier: Simplifier,
auto_simplify: bool,
rule_cache: HashMap<Expression, Expression>,
}
impl EnhancedSimplifier {
pub fn new() -> Self {
Self {
base_simplifier: Simplifier::new(),
auto_simplify: true,
rule_cache: HashMap::new(),
}
}
pub fn set_auto_simplify(&mut self, enabled: bool) {
self.auto_simplify = enabled;
}
pub fn enhanced_simplify(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
let mut simplified = self.base_simplifier.simplify(expr)?;
simplified = self.apply_enhanced_rules(&simplified)?;
if self.auto_simplify {
simplified = self.apply_auto_simplify_rules(&simplified)?;
}
Ok(simplified)
}
fn apply_enhanced_rules(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
if let Some(cached) = self.rule_cache.get(expr) {
return Ok(cached.clone());
}
let mut result = expr.clone();
result = self.apply_constant_folding(&result)?;
result = self.simplify_radicals(&result)?;
result = self.simplify_trigonometric(&result)?;
result = self.apply_advanced_algebraic_rules(&result)?;
self.rule_cache.insert(expr.clone(), result.clone());
Ok(result)
}
fn apply_auto_simplify_rules(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
let mut current = expr.clone();
let mut previous;
let max_iterations = 10; let mut iteration = 0;
loop {
previous = current.clone();
current = self.base_simplifier.simplify(¤t)?;
current = self.combine_like_terms(¤t)?;
current = self.apply_constant_folding(¤t)?;
current = self.simplify_radicals(¤t)?;
current = self.apply_advanced_algebraic_rules(¤t)?;
iteration += 1;
if current == previous || iteration >= max_iterations {
break;
}
}
Ok(current)
}
fn combine_like_terms(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
match expr {
Expression::BinaryOp { op, left, right } => {
let left_simplified = self.combine_like_terms(left)?;
let right_simplified = self.combine_like_terms(right)?;
match op {
BinaryOperator::Add => {
self.combine_addition_terms(&left_simplified, &right_simplified)
}
BinaryOperator::Subtract => {
self.combine_subtraction_terms(&left_simplified, &right_simplified)
}
_ => Ok(Expression::binary_op(op.clone(), left_simplified, right_simplified))
}
}
Expression::UnaryOp { op, operand } => {
let operand_simplified = self.combine_like_terms(operand)?;
Ok(Expression::unary_op(op.clone(), operand_simplified))
}
Expression::Function { name, args } => {
let args_simplified: Result<Vec<_>, _> = args.iter()
.map(|arg| self.combine_like_terms(arg))
.collect();
Ok(Expression::function(name, args_simplified?))
}
_ => Ok(expr.clone())
}
}
fn combine_addition_terms(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
let mut terms = Vec::new();
self.collect_addition_terms(left, &mut terms);
self.collect_addition_terms(right, &mut terms);
let combined_terms = self.merge_like_terms(terms)?;
self.build_addition_expression(combined_terms)
}
fn combine_subtraction_terms(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
let negated_right = Expression::negate(right.clone());
self.combine_addition_terms(left, &negated_right)
}
fn collect_addition_terms(&self, expr: &Expression, terms: &mut Vec<Expression>) {
match expr {
Expression::BinaryOp { op: BinaryOperator::Add, left, right } => {
self.collect_addition_terms(left, terms);
self.collect_addition_terms(right, terms);
}
_ => {
terms.push(expr.clone());
}
}
}
fn merge_like_terms(&mut self, terms: Vec<Expression>) -> Result<Vec<Expression>, ComputeError> {
let mut merged = Vec::new();
let mut used = vec![false; terms.len()];
for i in 0..terms.len() {
if used[i] {
continue;
}
let mut coefficient = Expression::Number(Number::one());
let mut base_term = None;
if let Some((coeff, base)) = self.extract_coefficient_and_base(&terms[i]) {
coefficient = coeff;
base_term = Some(base);
} else {
merged.push(terms[i].clone());
used[i] = true;
continue;
}
used[i] = true;
for j in (i + 1)..terms.len() {
if used[j] {
continue;
}
if let Some((other_coeff, other_base)) = self.extract_coefficient_and_base(&terms[j]) {
if let Some(ref base) = base_term {
if base == &other_base {
coefficient = Expression::add(coefficient, other_coeff);
used[j] = true;
}
}
}
}
coefficient = self.base_simplifier.simplify(&coefficient)?;
if let Some(base) = base_term {
if self.is_zero(&coefficient) {
continue;
} else if self.is_one(&coefficient) {
merged.push(base);
} else {
merged.push(Expression::multiply(coefficient, base));
}
}
}
Ok(merged)
}
fn extract_coefficient_and_base(&self, expr: &Expression) -> Option<(Expression, Expression)> {
match expr {
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
if matches!(left.as_ref(), Expression::Number(_)) {
Some((left.as_ref().clone(), right.as_ref().clone()))
} else if matches!(right.as_ref(), Expression::Number(_)) {
Some((right.as_ref().clone(), left.as_ref().clone()))
} else {
Some((Expression::Number(Number::one()), expr.clone()))
}
}
Expression::Number(_) => {
Some((expr.clone(), Expression::Number(Number::one())))
}
_ => {
Some((Expression::Number(Number::one()), expr.clone()))
}
}
}
fn build_addition_expression(&self, terms: Vec<Expression>) -> Result<Expression, ComputeError> {
if terms.is_empty() {
Ok(Expression::Number(Number::zero()))
} else if terms.len() == 1 {
Ok(terms[0].clone())
} else {
let mut result = terms[0].clone();
for term in terms.iter().skip(1) {
result = Expression::add(result, term.clone());
}
Ok(result)
}
}
fn apply_constant_folding(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
match expr {
Expression::BinaryOp { op, left, right } => {
let left_folded = self.apply_constant_folding(left)?;
let right_folded = self.apply_constant_folding(right)?;
if let (Expression::Number(left_num), Expression::Number(right_num)) = (&left_folded, &right_folded) {
match op {
BinaryOperator::Add => Ok(Expression::Number(left_num.clone() + right_num.clone())),
BinaryOperator::Subtract => Ok(Expression::Number(left_num.clone() - right_num.clone())),
BinaryOperator::Multiply => Ok(Expression::Number(left_num.clone() * right_num.clone())),
BinaryOperator::Divide => {
if right_num.is_zero() {
Err(ComputeError::DivisionByZero)
} else {
Ok(Expression::Number(left_num.clone() / right_num.clone()))
}
}
BinaryOperator::Power => {
if let Number::Integer(exp) = right_num {
if let Some(exp_u32) = exp.to_u32() {
if exp_u32 <= 100 { let result = self.compute_integer_power(left_num, exp_u32)?;
return Ok(Expression::Number(result));
}
}
}
Ok(Expression::binary_op(op.clone(), left_folded, right_folded))
}
_ => Ok(Expression::binary_op(op.clone(), left_folded, right_folded))
}
} else {
Ok(Expression::binary_op(op.clone(), left_folded, right_folded))
}
}
Expression::UnaryOp { op, operand } => {
let operand_folded = self.apply_constant_folding(operand)?;
if let Expression::Number(num) = &operand_folded {
match op {
UnaryOperator::Negate => Ok(Expression::Number(-num.clone())),
UnaryOperator::Sqrt => self.simplify_numeric_square_root(num),
_ => Ok(Expression::unary_op(op.clone(), operand_folded))
}
} else {
Ok(Expression::unary_op(op.clone(), operand_folded))
}
}
Expression::Function { name, args } => {
let args_folded: Result<Vec<_>, _> = args.iter()
.map(|arg| self.apply_constant_folding(arg))
.collect();
Ok(Expression::function(name, args_folded?))
}
_ => Ok(expr.clone())
}
}
fn compute_integer_power(&self, base: &Number, exp: u32) -> Result<Number, ComputeError> {
if exp == 0 {
Ok(Number::one())
} else if exp == 1 {
Ok(base.clone())
} else {
let mut result = base.clone();
for _ in 1..exp {
result = result * base.clone();
}
Ok(result)
}
}
fn simplify_radicals(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
match expr {
Expression::BinaryOp { op, left, right } => {
let left_simplified = self.simplify_radicals(left)?;
let right_simplified = self.simplify_radicals(right)?;
match op {
BinaryOperator::Add => {
self.simplify_radical_addition(&left_simplified, &right_simplified)
}
BinaryOperator::Subtract => {
self.simplify_radical_subtraction(&left_simplified, &right_simplified)
}
BinaryOperator::Multiply => {
self.simplify_radical_multiplication(&left_simplified, &right_simplified)
}
BinaryOperator::Divide => {
self.simplify_radical_division(&left_simplified, &right_simplified)
}
_ => Ok(Expression::binary_op(op.clone(), left_simplified, right_simplified))
}
}
Expression::UnaryOp { op, operand } => {
let operand_simplified = self.simplify_radicals(operand)?;
match op {
UnaryOperator::Sqrt => {
self.simplify_square_root(&operand_simplified)
}
_ => Ok(Expression::unary_op(op.clone(), operand_simplified))
}
}
Expression::Function { name, args } => {
if name == "sqrt" && args.len() == 1 {
let arg_simplified = self.simplify_radicals(&args[0])?;
self.simplify_square_root(&arg_simplified)
} else {
let args_simplified: Result<Vec<_>, _> = args.iter()
.map(|arg| self.simplify_radicals(arg))
.collect();
Ok(Expression::function(name, args_simplified?))
}
}
_ => Ok(expr.clone())
}
}
fn simplify_radical_addition(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if let (Some((coeff_a, radical_a)), Some((coeff_b, radical_b))) = (
self.extract_radical_coefficient(left),
self.extract_radical_coefficient(right)
) {
if radical_a == radical_b {
let new_coeff = Expression::add(coeff_a, coeff_b);
let simplified_coeff = self.base_simplifier.simplify(&new_coeff)?;
if self.is_zero(&simplified_coeff) {
return Ok(Expression::Number(Number::zero()));
}
if self.is_one(&simplified_coeff) {
return Ok(self.create_sqrt_expression(&radical_a));
}
return Ok(Expression::multiply(
simplified_coeff,
self.create_sqrt_expression(&radical_a)
));
}
}
Ok(Expression::add(left.clone(), right.clone()))
}
fn simplify_radical_subtraction(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if let (Some((coeff_a, radical_a)), Some((coeff_b, radical_b))) = (
self.extract_radical_coefficient(left),
self.extract_radical_coefficient(right)
) {
if radical_a == radical_b {
let new_coeff = Expression::subtract(coeff_a, coeff_b);
let simplified_coeff = self.base_simplifier.simplify(&new_coeff)?;
if self.is_zero(&simplified_coeff) {
return Ok(Expression::Number(Number::zero()));
}
if self.is_one(&simplified_coeff) {
return Ok(self.create_sqrt_expression(&radical_a));
}
return Ok(Expression::multiply(
simplified_coeff,
self.create_sqrt_expression(&radical_a)
));
}
}
Ok(Expression::subtract(left.clone(), right.clone()))
}
fn simplify_radical_multiplication(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
match (left, right) {
(Expression::Function { name: name1, args: args1 },
Expression::Function { name: name2, args: args2 })
if name1 == "sqrt" && name2 == "sqrt" && args1.len() == 1 && args2.len() == 1 => {
let product = Expression::multiply(args1[0].clone(), args2[0].clone());
let simplified_product = self.base_simplifier.simplify(&product)?;
self.simplify_square_root(&simplified_product)
}
_ => Ok(Expression::multiply(left.clone(), right.clone()))
}
}
fn simplify_radical_division(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
match (left, right) {
(Expression::Function { name: name1, args: args1 },
Expression::Function { name: name2, args: args2 })
if name1 == "sqrt" && name2 == "sqrt" && args1.len() == 1 && args2.len() == 1 => {
let quotient = Expression::divide(args1[0].clone(), args2[0].clone());
let simplified_quotient = self.base_simplifier.simplify("ient)?;
Ok(Expression::function("sqrt", vec![simplified_quotient]))
}
_ => Ok(Expression::divide(left.clone(), right.clone()))
}
}
fn simplify_square_root(&mut self, arg: &Expression) -> Result<Expression, ComputeError> {
match arg {
Expression::BinaryOp { op: BinaryOperator::Power, left, right }
if matches!(right.as_ref(), Expression::Number(n) if n.is_two()) => {
Ok(Expression::function("abs", vec![left.as_ref().clone()]))
}
Expression::Number(n) => {
self.simplify_numeric_square_root(n)
}
_ => {
if let Some(simplified) = self.try_denest_radical(arg)? {
Ok(simplified)
} else {
Ok(Expression::function("sqrt", vec![arg.clone()]))
}
}
}
}
fn simplify_numeric_square_root(&self, n: &Number) -> Result<Expression, ComputeError> {
match n {
Number::Integer(i) => {
if i < &BigInt::zero() {
Ok(Expression::function("sqrt", vec![Expression::Number(n.clone())]))
} else {
if let Some(sqrt_int) = self.integer_sqrt(i) {
Ok(Expression::Number(Number::Integer(sqrt_int)))
} else {
self.extract_square_factors(i)
}
}
}
_ => Ok(Expression::function("sqrt", vec![Expression::Number(n.clone())]))
}
}
fn extract_radical_coefficient(&self, expr: &Expression) -> Option<(Expression, Expression)> {
match expr {
Expression::Function { name, args } if name == "sqrt" && args.len() == 1 => {
Some((Expression::Number(Number::one()), args[0].clone()))
}
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
if let Expression::Function { name, args } = left.as_ref() {
if name == "sqrt" && args.len() == 1 {
return Some((right.as_ref().clone(), args[0].clone()));
}
}
if let Expression::Function { name, args } = right.as_ref() {
if name == "sqrt" && args.len() == 1 {
return Some((left.as_ref().clone(), args[0].clone()));
}
}
None
}
_ => None
}
}
fn create_sqrt_expression(&self, arg: &Expression) -> Expression {
Expression::function("sqrt", vec![arg.clone()])
}
fn integer_sqrt(&self, n: &BigInt) -> Option<BigInt> {
if n < &BigInt::zero() {
return None;
}
if n == &BigInt::zero() || n == &BigInt::one() {
return Some(n.clone());
}
if let Some(n_f64) = n.to_f64() {
if n_f64 <= (u64::MAX as f64) {
let sqrt_f64 = n_f64.sqrt();
if sqrt_f64.fract() == 0.0 {
let sqrt_int = sqrt_f64 as u64;
let sqrt_bigint = BigInt::from(sqrt_int);
if &(&sqrt_bigint * &sqrt_bigint) == n {
return Some(sqrt_bigint);
}
}
return None;
}
}
None
}
fn extract_square_factors(&self, n: &BigInt) -> Result<Expression, ComputeError> {
if n <= &BigInt::zero() {
return Ok(Expression::function("sqrt", vec![Expression::Number(Number::Integer(n.clone()))]));
}
let mut remaining = n.clone();
let mut extracted = BigInt::one();
let small_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47];
for &p in &small_primes {
let p_big = BigInt::from(p);
let p_squared = &p_big * &p_big;
while &remaining % &p_squared == BigInt::zero() {
remaining /= &p_squared;
extracted *= &p_big;
}
}
if extracted == BigInt::one() {
if remaining == BigInt::one() {
Ok(Expression::Number(Number::Integer(BigInt::one())))
} else {
Ok(Expression::function("sqrt", vec![Expression::Number(Number::Integer(remaining))]))
}
} else {
if remaining == BigInt::one() {
Ok(Expression::Number(Number::Integer(extracted)))
} else {
Ok(Expression::multiply(
Expression::Number(Number::Integer(extracted)),
Expression::function("sqrt", vec![Expression::Number(Number::Integer(remaining))])
))
}
}
}
fn try_denest_radical(&mut self, arg: &Expression) -> Result<Option<Expression>, ComputeError> {
match arg {
Expression::BinaryOp { op, left, right } => {
let is_subtract = matches!(op, BinaryOperator::Subtract);
if let Some((a, b, c)) = self.match_nested_radical_pattern(left, right, is_subtract) {
return self.try_special_denesting(&a, &b, &c, is_subtract);
}
}
_ => {}
}
Ok(None)
}
fn match_nested_radical_pattern(&self, left: &Expression, right: &Expression, is_subtract: bool) -> Option<(Expression, Expression, Expression)> {
if let Expression::Number(_) = left {
if let Some((b, c)) = self.extract_coefficient_sqrt(right) {
return Some((left.clone(), b, c));
}
}
if !is_subtract {
if let Expression::Number(_) = right {
if let Some((b, c)) = self.extract_coefficient_sqrt(left) {
return Some((right.clone(), b, c));
}
}
}
None
}
fn extract_coefficient_sqrt(&self, expr: &Expression) -> Option<(Expression, Expression)> {
match expr {
Expression::Function { name, args } if name == "sqrt" && args.len() == 1 => {
Some((Expression::Number(Number::one()), args[0].clone()))
}
Expression::BinaryOp { op: BinaryOperator::Multiply, left, right } => {
if let Expression::Function { name, args } = left.as_ref() {
if name == "sqrt" && args.len() == 1 {
return Some((right.as_ref().clone(), args[0].clone()));
}
}
if let Expression::Function { name, args } = right.as_ref() {
if name == "sqrt" && args.len() == 1 {
return Some((left.as_ref().clone(), args[0].clone()));
}
}
None
}
_ => None
}
}
fn try_special_denesting(&mut self, a: &Expression, b: &Expression, c: &Expression, is_subtract: bool) -> Result<Option<Expression>, ComputeError> {
if let (Expression::Number(a_num), Expression::Number(b_num), Expression::Number(c_num)) = (a, b, c) {
if a_num == &Number::integer(3) && b_num == &Number::integer(2) && c_num == &Number::integer(2) && is_subtract {
let sqrt2 = Expression::function("sqrt", vec![Expression::Number(Number::integer(2))]);
let result = Expression::subtract(sqrt2, Expression::Number(Number::integer(1)));
return Ok(Some(result));
}
if a_num == &Number::integer(3) && b_num == &Number::integer(2) && c_num == &Number::integer(2) && !is_subtract {
let sqrt2 = Expression::function("sqrt", vec![Expression::Number(Number::integer(2))]);
let result = Expression::add(sqrt2, Expression::Number(Number::integer(1)));
return Ok(Some(result));
}
}
Ok(None)
}
fn apply_advanced_algebraic_rules(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
match expr {
Expression::BinaryOp { op, left, right } => {
let left_simplified = self.apply_advanced_algebraic_rules(left)?;
let right_simplified = self.apply_advanced_algebraic_rules(right)?;
match op {
BinaryOperator::Add => {
self.apply_addition_rules(&left_simplified, &right_simplified)
}
BinaryOperator::Subtract => {
self.apply_subtraction_rules(&left_simplified, &right_simplified)
}
BinaryOperator::Multiply => {
self.apply_multiplication_rules(&left_simplified, &right_simplified)
}
BinaryOperator::Divide => {
self.apply_division_rules(&left_simplified, &right_simplified)
}
_ => Ok(Expression::binary_op(op.clone(), left_simplified, right_simplified))
}
}
Expression::UnaryOp { op, operand } => {
let operand_simplified = self.apply_advanced_algebraic_rules(operand)?;
Ok(Expression::unary_op(op.clone(), operand_simplified))
}
Expression::Function { name, args } => {
let args_simplified: Result<Vec<_>, _> = args.iter()
.map(|arg| self.apply_advanced_algebraic_rules(arg))
.collect();
Ok(Expression::function(name, args_simplified?))
}
_ => Ok(expr.clone())
}
}
fn apply_addition_rules(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if self.is_zero(left) {
return Ok(right.clone());
}
if self.is_zero(right) {
return Ok(left.clone());
}
if let Expression::UnaryOp { op: UnaryOperator::Negate, operand } = right {
if operand.as_ref() == left {
return Ok(Expression::Number(Number::zero()));
}
}
if let Expression::UnaryOp { op: UnaryOperator::Negate, operand } = left {
if operand.as_ref() == right {
return Ok(Expression::Number(Number::zero()));
}
}
Ok(Expression::add(left.clone(), right.clone()))
}
fn apply_subtraction_rules(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if self.is_zero(right) {
return Ok(left.clone());
}
if self.is_zero(left) {
return Ok(Expression::negate(right.clone()));
}
if left == right {
return Ok(Expression::Number(Number::zero()));
}
Ok(Expression::subtract(left.clone(), right.clone()))
}
fn apply_multiplication_rules(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if self.is_zero(left) || self.is_zero(right) {
return Ok(Expression::Number(Number::zero()));
}
if self.is_one(left) {
return Ok(right.clone());
}
if self.is_one(right) {
return Ok(left.clone());
}
if self.is_negative_one(left) {
return Ok(Expression::negate(right.clone()));
}
if self.is_negative_one(right) {
return Ok(Expression::negate(left.clone()));
}
Ok(Expression::multiply(left.clone(), right.clone()))
}
fn apply_division_rules(&mut self, left: &Expression, right: &Expression) -> Result<Expression, ComputeError> {
if self.is_zero(right) {
return Err(ComputeError::DivisionByZero);
}
if self.is_zero(left) {
return Ok(Expression::Number(Number::zero()));
}
if self.is_one(right) {
return Ok(left.clone());
}
if left == right {
return Ok(Expression::Number(Number::one()));
}
Ok(Expression::divide(left.clone(), right.clone()))
}
fn simplify_trigonometric(&mut self, expr: &Expression) -> Result<Expression, ComputeError> {
match expr {
Expression::BinaryOp { op, left, right } => {
let left_simplified = self.simplify_trigonometric(left)?;
let right_simplified = self.simplify_trigonometric(right)?;
Ok(Expression::binary_op(op.clone(), left_simplified, right_simplified))
}
Expression::UnaryOp { op, operand } => {
let operand_simplified = self.simplify_trigonometric(operand)?;
Ok(Expression::unary_op(op.clone(), operand_simplified))
}
Expression::Function { name, args } => {
let args_simplified: Result<Vec<_>, _> = args.iter()
.map(|arg| self.simplify_trigonometric(arg))
.collect();
Ok(Expression::function(name, args_simplified?))
}
_ => Ok(expr.clone())
}
}
fn is_zero(&self, expr: &Expression) -> bool {
matches!(expr, Expression::Number(n) if n.is_zero())
}
fn is_one(&self, expr: &Expression) -> bool {
matches!(expr, Expression::Number(n) if n.is_one())
}
fn is_negative_one(&self, expr: &Expression) -> bool {
match expr {
Expression::Number(n) => n == &Number::integer(-1),
Expression::UnaryOp { op: UnaryOperator::Negate, operand } => {
matches!(operand.as_ref(), Expression::Number(n) if n.is_one())
}
_ => false
}
}
}
impl Default for EnhancedSimplifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "enhanced_simplify_tests.rs"]
mod enhanced_simplify_tests;