use crate::pattern::apply_rules_to_fixpoint;
use crate::simplification_rules::all_simplification_rules;
use num_complex::Complex64;
use num_rational::Rational64;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct Equation {
pub id: String,
pub left: Expression,
pub right: Expression,
}
impl Equation {
pub fn new(id: impl Into<String>, left: Expression, right: Expression) -> Self {
Self {
id: id.into(),
left,
right,
}
}
}
impl fmt::Display for Equation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} = {}", self.left, self.right)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum SymbolicConstant {
Pi,
E,
I,
}
impl fmt::Display for SymbolicConstant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SymbolicConstant::Pi => write!(f, "π"),
SymbolicConstant::E => write!(f, "e"),
SymbolicConstant::I => write!(f, "i"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expression {
Integer(i64),
Rational(Rational64),
Float(f64),
Complex(Complex64),
Constant(SymbolicConstant),
Variable(Variable),
Unary(UnaryOp, Box<Expression>),
Binary(BinaryOp, Box<Expression>, Box<Expression>),
Function(Function, Vec<Expression>),
Power(Box<Expression>, Box<Expression>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Variable {
pub name: String,
pub dimension: Option<String>,
}
impl Variable {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
dimension: None,
}
}
pub fn with_dimension(name: impl Into<String>, dimension: impl Into<String>) -> Self {
Self {
name: name.into(),
dimension: Some(dimension.into()),
}
}
}
impl fmt::Display for Variable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum UnaryOp {
Neg,
Not,
Abs,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Mod,
}
impl BinaryOp {
pub fn precedence(self) -> u8 {
match self {
BinaryOp::Add | BinaryOp::Sub => 1,
BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => 2,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Function {
Sin,
Cos,
Tan,
Asin,
Acos,
Atan,
Atan2,
Sinh,
Cosh,
Tanh,
Exp,
Ln,
Log,
Log2,
Log10,
Sqrt,
Cbrt,
Pow,
Floor,
Ceil,
Round,
Abs,
Sign,
Min,
Max,
Custom(String),
}
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Function::Sin => write!(f, "sin"),
Function::Cos => write!(f, "cos"),
Function::Tan => write!(f, "tan"),
Function::Asin => write!(f, "asin"),
Function::Acos => write!(f, "acos"),
Function::Atan => write!(f, "atan"),
Function::Atan2 => write!(f, "atan2"),
Function::Sinh => write!(f, "sinh"),
Function::Cosh => write!(f, "cosh"),
Function::Tanh => write!(f, "tanh"),
Function::Exp => write!(f, "exp"),
Function::Ln => write!(f, "ln"),
Function::Log => write!(f, "log"),
Function::Log2 => write!(f, "log2"),
Function::Log10 => write!(f, "log10"),
Function::Sqrt => write!(f, "sqrt"),
Function::Cbrt => write!(f, "cbrt"),
Function::Pow => write!(f, "pow"),
Function::Floor => write!(f, "floor"),
Function::Ceil => write!(f, "ceil"),
Function::Round => write!(f, "round"),
Function::Abs => write!(f, "abs"),
Function::Sign => write!(f, "sign"),
Function::Min => write!(f, "min"),
Function::Max => write!(f, "max"),
Function::Custom(name) => write!(f, "{}", name),
}
}
}
impl fmt::Display for Expression {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with_precedence(f, 0)
}
}
impl Expression {
fn fmt_with_precedence(&self, f: &mut fmt::Formatter<'_>, parent_prec: u8) -> fmt::Result {
match self {
Expression::Integer(n) => write!(f, "{}", n),
Expression::Rational(r) => write!(f, "{}/{}", r.numer(), r.denom()),
Expression::Float(x) => write!(f, "{}", x),
Expression::Complex(c) => {
if c.im >= 0.0 {
write!(f, "{}+{}i", c.re, c.im)
} else {
write!(f, "{}{}i", c.re, c.im)
}
}
Expression::Constant(c) => write!(f, "{}", c),
Expression::Variable(v) => write!(f, "{}", v),
Expression::Unary(UnaryOp::Neg, expr) => {
write!(f, "-")?;
expr.fmt_with_precedence(f, 3)
}
Expression::Unary(UnaryOp::Not, expr) => {
write!(f, "!")?;
expr.fmt_with_precedence(f, 3)
}
Expression::Unary(UnaryOp::Abs, expr) => {
write!(f, "|")?;
expr.fmt_with_precedence(f, 0)?;
write!(f, "|")
}
Expression::Binary(op, left, right) => {
let prec = op.precedence();
let needs_parens = prec < parent_prec;
if needs_parens {
write!(f, "(")?;
}
left.fmt_with_precedence(f, prec)?;
match op {
BinaryOp::Add => write!(f, " + ")?,
BinaryOp::Sub => write!(f, " - ")?,
BinaryOp::Mul => write!(f, " * ")?,
BinaryOp::Div => write!(f, " / ")?,
BinaryOp::Mod => write!(f, " % ")?,
}
right.fmt_with_precedence(f, prec + 1)?;
if needs_parens {
write!(f, ")")?;
}
Ok(())
}
Expression::Function(func, args) => {
write!(f, "{}", func)?;
write!(f, "(")?;
for (i, arg) in args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
arg.fmt_with_precedence(f, 0)?;
}
write!(f, ")")
}
Expression::Power(base, exp) => {
let needs_parens = parent_prec > 3;
if needs_parens {
write!(f, "(")?;
}
base.fmt_with_precedence(f, 4)?;
write!(f, "^")?;
exp.fmt_with_precedence(f, 4)?;
if needs_parens {
write!(f, ")")?;
}
Ok(())
}
}
}
#[inline]
pub fn pi() -> Self {
Expression::Constant(SymbolicConstant::Pi)
}
#[inline]
pub fn euler() -> Self {
Expression::Constant(SymbolicConstant::E)
}
#[inline]
pub fn i() -> Self {
Expression::Constant(SymbolicConstant::I)
}
pub fn to_latex(&self) -> String {
self.to_latex_inner(0)
}
pub fn to_latex_display(&self) -> String {
format!(r"\[{}\]", self.to_latex())
}
pub fn to_latex_inline(&self) -> String {
format!("${}$", self.to_latex())
}
fn to_latex_inner(&self, parent_prec: u8) -> String {
match self {
Expression::Integer(n) => {
if *n < 0 {
format!("{{{}}}", n)
} else {
n.to_string()
}
}
Expression::Rational(r) => {
format!(r"\frac{{{}}}{{{}}}", r.numer(), r.denom())
}
Expression::Float(x) => {
if *x < 0.0 {
format!("{{{}}}", x)
} else {
x.to_string()
}
}
Expression::Complex(c) => {
if c.im >= 0.0 {
format!("{}+{}i", c.re, c.im)
} else {
format!("{}{}i", c.re, c.im)
}
}
Expression::Constant(c) => match c {
SymbolicConstant::Pi => r"\pi".to_string(),
SymbolicConstant::E => "e".to_string(),
SymbolicConstant::I => "i".to_string(),
},
Expression::Variable(v) => {
match v.name.as_str() {
"alpha" => r"\alpha".to_string(),
"beta" => r"\beta".to_string(),
"gamma" => r"\gamma".to_string(),
"delta" => r"\delta".to_string(),
"epsilon" => r"\epsilon".to_string(),
"zeta" => r"\zeta".to_string(),
"eta" => r"\eta".to_string(),
"theta" => r"\theta".to_string(),
"iota" => r"\iota".to_string(),
"kappa" => r"\kappa".to_string(),
"lambda" => r"\lambda".to_string(),
"mu" => r"\mu".to_string(),
"nu" => r"\nu".to_string(),
"xi" => r"\xi".to_string(),
"omicron" => r"\omicron".to_string(),
"rho" => r"\rho".to_string(),
"sigma" => r"\sigma".to_string(),
"tau" => r"\tau".to_string(),
"upsilon" => r"\upsilon".to_string(),
"phi" => r"\phi".to_string(),
"chi" => r"\chi".to_string(),
"psi" => r"\psi".to_string(),
"omega" => r"\omega".to_string(),
"Gamma" => r"\Gamma".to_string(),
"Delta" => r"\Delta".to_string(),
"Theta" => r"\Theta".to_string(),
"Lambda" => r"\Lambda".to_string(),
"Xi" => r"\Xi".to_string(),
"Pi" => r"\Pi".to_string(),
"Sigma" => r"\Sigma".to_string(),
"Phi" => r"\Phi".to_string(),
"Psi" => r"\Psi".to_string(),
"Omega" => r"\Omega".to_string(),
name if name.contains('_') => {
let parts: Vec<&str> = name.splitn(2, '_').collect();
if parts.len() == 2 {
format!("{}_{{{}}}", parts[0], parts[1])
} else {
name.to_string()
}
}
name => name.to_string(),
}
}
Expression::Unary(UnaryOp::Neg, expr) => {
format!("-{}", expr.to_latex_inner(3))
}
Expression::Unary(UnaryOp::Not, expr) => {
format!(r"\lnot {}", expr.to_latex_inner(3))
}
Expression::Unary(UnaryOp::Abs, expr) => {
format!(r"\left|{}\right|", expr.to_latex_inner(0))
}
Expression::Binary(BinaryOp::Div, num, denom) => {
format!(
r"\frac{{{}}}{{{}}}",
num.to_latex_inner(0),
denom.to_latex_inner(0)
)
}
Expression::Binary(op, left, right) => {
let prec = op.precedence();
let needs_parens = prec < parent_prec;
let left_str = left.to_latex_inner(prec);
let right_str = right.to_latex_inner(prec + 1);
let op_str = match op {
BinaryOp::Add => " + ",
BinaryOp::Sub => " - ",
BinaryOp::Mul => r" \cdot ",
BinaryOp::Div => unreachable!(), BinaryOp::Mod => r" \bmod ",
};
if needs_parens {
format!(r"\left({}{}{}right)", left_str, op_str, right_str)
} else {
format!("{}{}{}", left_str, op_str, right_str)
}
}
Expression::Function(func, args) => {
let func_name = match func {
Function::Sin => r"\sin",
Function::Cos => r"\cos",
Function::Tan => r"\tan",
Function::Asin => r"\arcsin",
Function::Acos => r"\arccos",
Function::Atan => r"\arctan",
Function::Atan2 => r"\arctan",
Function::Sinh => r"\sinh",
Function::Cosh => r"\cosh",
Function::Tanh => r"\tanh",
Function::Exp => r"\exp",
Function::Ln => r"\ln",
Function::Log2 => r"\log_2",
Function::Log10 => r"\log_{10}",
Function::Sqrt => {
if args.len() == 1 {
return format!(r"\sqrt{{{}}}", args[0].to_latex_inner(0));
}
r"\sqrt"
}
Function::Cbrt => {
if args.len() == 1 {
return format!(r"\sqrt[3]{{{}}}", args[0].to_latex_inner(0));
}
r"\sqrt[3]"
}
Function::Abs => {
if args.len() == 1 {
return format!(r"\left|{}\right|", args[0].to_latex_inner(0));
}
r"\text{abs}"
}
Function::Sign => r"\text{sgn}",
Function::Floor => r"\lfloor",
Function::Ceil => r"\lceil",
Function::Round => r"\text{round}",
Function::Min => r"\min",
Function::Max => r"\max",
Function::Pow => {
if args.len() == 2 {
return format!(
"{}^{{{}}}",
args[0].to_latex_inner(4),
args[1].to_latex_inner(0)
);
}
r"\text{pow}"
}
Function::Log => {
if args.len() == 2 {
return format!(
r"\log_{{{}}}{{{}}}",
args[1].to_latex_inner(0),
args[0].to_latex_inner(0)
);
}
r"\log"
}
Function::Custom(name) => {
let args_str: Vec<String> =
args.iter().map(|a| a.to_latex_inner(0)).collect();
return format!(r"\text{{{}}}\left({}\right)", name, args_str.join(", "));
}
};
if matches!(func, Function::Floor) && args.len() == 1 {
return format!(r"\lfloor {} \rfloor", args[0].to_latex_inner(0));
}
if matches!(func, Function::Ceil) && args.len() == 1 {
return format!(r"\lceil {} \rceil", args[0].to_latex_inner(0));
}
let args_str: Vec<String> = args.iter().map(|a| a.to_latex_inner(0)).collect();
format!(r"{}\left({}\right)", func_name, args_str.join(", "))
}
Expression::Power(base, exp) => {
let base_str = base.to_latex_inner(4);
let exp_str = exp.to_latex_inner(0);
let base_needs_braces = matches!(
**base,
Expression::Binary(_, _, _) | Expression::Unary(_, _)
);
if base_needs_braces {
format!(r"\left({}\right)^{{{}}}", base_str, exp_str)
} else {
format!("{}^{{{}}}", base_str, exp_str)
}
}
}
}
pub fn variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
self.collect_variables(&mut vars);
vars
}
fn collect_variables(&self, vars: &mut HashSet<String>) {
match self {
Expression::Variable(v) => {
vars.insert(v.name.clone());
}
Expression::Unary(_, expr) => {
expr.collect_variables(vars);
}
Expression::Binary(_, left, right) => {
left.collect_variables(vars);
right.collect_variables(vars);
}
Expression::Function(_, args) => {
for arg in args {
arg.collect_variables(vars);
}
}
Expression::Power(base, exp) => {
base.collect_variables(vars);
exp.collect_variables(vars);
}
_ => {}
}
}
pub fn contains_variable(&self, name: &str) -> bool {
match self {
Expression::Variable(v) => v.name == name,
Expression::Unary(_, expr) => expr.contains_variable(name),
Expression::Binary(_, left, right) => {
left.contains_variable(name) || right.contains_variable(name)
}
Expression::Function(_, args) => args.iter().any(|arg| arg.contains_variable(name)),
Expression::Power(base, exp) => {
base.contains_variable(name) || exp.contains_variable(name)
}
_ => false,
}
}
pub fn map<F>(&self, f: &F) -> Expression
where
F: Fn(&Expression) -> Expression,
{
let mapped = match self {
Expression::Unary(op, expr) => Expression::Unary(*op, Box::new(expr.map(f))),
Expression::Binary(op, left, right) => {
Expression::Binary(*op, Box::new(left.map(f)), Box::new(right.map(f)))
}
Expression::Function(func, args) => {
Expression::Function(func.clone(), args.iter().map(|arg| arg.map(f)).collect())
}
Expression::Power(base, exp) => {
Expression::Power(Box::new(base.map(f)), Box::new(exp.map(f)))
}
_ => self.clone(),
};
f(&mapped)
}
pub fn fold<T, F>(&self, init: T, f: &F) -> T
where
F: Fn(T, &Expression) -> T,
{
let acc = f(init, self);
match self {
Expression::Unary(_, expr) => expr.fold(acc, f),
Expression::Binary(_, left, right) => {
let acc = left.fold(acc, f);
right.fold(acc, f)
}
Expression::Function(_, args) => args.iter().fold(acc, |acc, arg| arg.fold(acc, f)),
Expression::Power(base, exp) => {
let acc = base.fold(acc, f);
exp.fold(acc, f)
}
_ => acc,
}
}
pub fn simplify(&self) -> Expression {
let simplified = match self {
Expression::Unary(op, expr) => {
let simplified_expr = expr.simplify();
match op {
UnaryOp::Neg => {
if let Expression::Unary(UnaryOp::Neg, inner) = &simplified_expr {
inner.as_ref().clone()
} else {
Expression::Unary(*op, Box::new(simplified_expr))
}
}
_ => Expression::Unary(*op, Box::new(simplified_expr)),
}
}
Expression::Binary(op, left, right) => {
let left_simplified = left.simplify();
let right_simplified = right.simplify();
let after_identity = match op {
BinaryOp::Add => {
if Self::is_zero(&left_simplified) {
return right_simplified;
}
if Self::is_zero(&right_simplified) {
return left_simplified;
}
let (coef1, base1) = Self::extract_coefficient_and_base(&left_simplified);
let (coef2, base2) = Self::extract_coefficient_and_base(&right_simplified);
if Self::bases_equal(&base1, &base2) && !Self::is_one(&base1) {
let new_coef = coef1 + coef2;
if new_coef.abs() < 1e-10 {
return Expression::Integer(0);
}
let coef_expr = Self::from_numeric_value(new_coef);
if Self::is_one(&coef_expr) {
return base1;
}
return Expression::Binary(
BinaryOp::Mul,
Box::new(coef_expr),
Box::new(base1),
);
}
None
}
BinaryOp::Sub => {
if Self::is_zero(&right_simplified) {
return left_simplified;
}
if left_simplified == right_simplified {
return Expression::Integer(0);
}
let (coef1, base1) = Self::extract_coefficient_and_base(&left_simplified);
let (coef2, base2) = Self::extract_coefficient_and_base(&right_simplified);
if Self::bases_equal(&base1, &base2) && !Self::is_one(&base1) {
let new_coef = coef1 - coef2;
if new_coef.abs() < 1e-10 {
return Expression::Integer(0);
}
let coef_expr = Self::from_numeric_value(new_coef);
if Self::is_one(&coef_expr) {
return base1;
}
if new_coef < 0.0 {
return Expression::Unary(
UnaryOp::Neg,
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Self::from_numeric_value(-new_coef)),
Box::new(base1),
)),
);
}
return Expression::Binary(
BinaryOp::Mul,
Box::new(coef_expr),
Box::new(base1),
);
}
None
}
BinaryOp::Mul => {
if Self::is_zero(&left_simplified) {
return Expression::Integer(0);
}
if Self::is_zero(&right_simplified) {
return Expression::Integer(0);
}
if Self::is_one(&left_simplified) {
return right_simplified;
}
if Self::is_one(&right_simplified) {
return left_simplified;
}
if let (Expression::Power(base1, exp1), Expression::Power(base2, exp2)) =
(&left_simplified, &right_simplified)
{
if base1 == base2 {
let new_exp =
Expression::Binary(BinaryOp::Add, exp1.clone(), exp2.clone())
.simplify();
return Expression::Power(base1.clone(), Box::new(new_exp));
}
}
if left_simplified == right_simplified {
return Expression::Power(
Box::new(left_simplified),
Box::new(Expression::Integer(2)),
);
}
if let Expression::Power(base, exp) = &right_simplified {
if **base == left_simplified {
let new_exp = Expression::Binary(
BinaryOp::Add,
exp.clone(),
Box::new(Expression::Integer(1)),
)
.simplify();
return Expression::Power(base.clone(), Box::new(new_exp));
}
}
if let Expression::Power(base, exp) = &left_simplified {
if **base == right_simplified {
let new_exp = Expression::Binary(
BinaryOp::Add,
exp.clone(),
Box::new(Expression::Integer(1)),
)
.simplify();
return Expression::Power(base.clone(), Box::new(new_exp));
}
}
None
}
BinaryOp::Div => {
if Self::is_one(&right_simplified) {
return left_simplified;
}
None
}
_ => None,
};
if after_identity.is_some() {
return after_identity.unwrap();
}
if Self::is_numeric_constant(&left_simplified)
&& Self::is_numeric_constant(&right_simplified)
{
if let (Some(left_val), Some(right_val)) = (
Self::extract_numeric_value(&left_simplified),
Self::extract_numeric_value(&right_simplified),
) {
let result = match op {
BinaryOp::Add => Some(left_val + right_val),
BinaryOp::Sub => Some(left_val - right_val),
BinaryOp::Mul => Some(left_val * right_val),
BinaryOp::Div => {
if right_val.abs() > 1e-10 {
Some(left_val / right_val)
} else {
None }
}
BinaryOp::Mod => Some(left_val % right_val),
};
if let Some(value) = result {
return Self::from_numeric_value(value);
}
}
}
Expression::Binary(*op, Box::new(left_simplified), Box::new(right_simplified))
}
Expression::Function(func, args) => {
let simplified_args: Vec<Expression> =
args.iter().map(|arg| arg.simplify()).collect();
if simplified_args.iter().all(Self::is_numeric_constant) {
let temp_expr = Expression::Function(func.clone(), simplified_args.clone());
if let Some(value) = temp_expr.evaluate(&HashMap::new()) {
return Self::from_numeric_value(value);
}
}
Expression::Function(func.clone(), simplified_args)
}
Expression::Power(base, exp) => {
let base_simplified = base.simplify();
let exp_simplified = exp.simplify();
if Self::is_zero(&exp_simplified) && !Self::is_zero(&base_simplified) {
return Expression::Integer(1);
}
if Self::is_one(&exp_simplified) {
return base_simplified;
}
if let Expression::Power(inner_base, inner_exp) = &base_simplified {
let new_exp = Expression::Binary(
BinaryOp::Mul,
inner_exp.clone(),
Box::new(exp_simplified.clone()),
)
.simplify();
return Expression::Power(inner_base.clone(), Box::new(new_exp));
}
if Self::is_numeric_constant(&base_simplified)
&& Self::is_numeric_constant(&exp_simplified)
{
if let (Some(base_val), Some(exp_val)) = (
Self::extract_numeric_value(&base_simplified),
Self::extract_numeric_value(&exp_simplified),
) {
let result = base_val.powf(exp_val);
if result.is_finite() {
return Self::from_numeric_value(result);
}
}
}
Expression::Power(Box::new(base_simplified), Box::new(exp_simplified))
}
_ => self.clone(),
};
let rules = all_simplification_rules();
apply_rules_to_fixpoint(&simplified, &rules, 20)
}
fn is_zero(expr: &Expression) -> bool {
match expr {
Expression::Integer(0) => true,
Expression::Float(x) if *x == 0.0 => true,
_ => false,
}
}
fn is_one(expr: &Expression) -> bool {
match expr {
Expression::Integer(1) => true,
Expression::Float(x) if *x == 1.0 => true,
_ => false,
}
}
fn is_numeric_constant(expr: &Expression) -> bool {
matches!(
expr,
Expression::Integer(_) | Expression::Float(_) | Expression::Rational(_)
)
}
fn extract_numeric_value(expr: &Expression) -> Option<f64> {
match expr {
Expression::Integer(n) => Some(*n as f64),
Expression::Float(x) => Some(*x),
Expression::Rational(r) => Some(*r.numer() as f64 / *r.denom() as f64),
_ => None,
}
}
fn from_numeric_value(value: f64) -> Expression {
if value.is_finite() && value.fract().abs() < 1e-10 {
Expression::Integer(value.round() as i64)
} else {
Expression::Float(value)
}
}
fn extract_coefficient_and_base(expr: &Expression) -> (f64, Expression) {
match expr {
Expression::Integer(n) => (*n as f64, Expression::Integer(1)),
Expression::Float(x) => (*x, Expression::Integer(1)),
Expression::Rational(r) => (
*r.numer() as f64 / *r.denom() as f64,
Expression::Integer(1),
),
Expression::Unary(UnaryOp::Neg, inner) => {
let (coef, base) = Self::extract_coefficient_and_base(inner);
(-coef, base)
}
Expression::Binary(BinaryOp::Mul, left, right) => {
if let Some(coef) = Self::extract_numeric_value(left) {
let (inner_coef, base) = Self::extract_coefficient_and_base(right);
(coef * inner_coef, base)
} else if let Some(coef) = Self::extract_numeric_value(right) {
let (inner_coef, base) = Self::extract_coefficient_and_base(left);
(coef * inner_coef, base)
} else {
(1.0, expr.clone())
}
}
Expression::Binary(BinaryOp::Div, left, right) => {
if let Some(divisor) = Self::extract_numeric_value(right) {
if divisor.abs() > 1e-10 {
let (coef, base) = Self::extract_coefficient_and_base(left);
(coef / divisor, base)
} else {
(1.0, expr.clone())
}
} else {
(1.0, expr.clone())
}
}
_ => (1.0, expr.clone()),
}
}
fn bases_equal(base1: &Expression, base2: &Expression) -> bool {
base1 == base2
}
pub fn differentiate(&self, with_respect_to: &str) -> Expression {
match self {
Expression::Integer(_)
| Expression::Rational(_)
| Expression::Float(_)
| Expression::Complex(_)
| Expression::Constant(_) => Expression::Integer(0),
Expression::Variable(v) => {
if v.name == with_respect_to {
Expression::Integer(1)
} else {
Expression::Integer(0)
}
}
Expression::Unary(op, expr) => {
let inner_derivative = expr.differentiate(with_respect_to);
match op {
UnaryOp::Neg => Expression::Unary(UnaryOp::Neg, Box::new(inner_derivative)),
UnaryOp::Abs => {
let sign =
Expression::Function(Function::Sign, vec![expr.as_ref().clone()]);
Expression::Binary(
BinaryOp::Mul,
Box::new(sign),
Box::new(inner_derivative),
)
}
UnaryOp::Not => Expression::Integer(0),
}
}
Expression::Binary(op, left, right) => {
let left_deriv = left.differentiate(with_respect_to);
let right_deriv = right.differentiate(with_respect_to);
match op {
BinaryOp::Add => Expression::Binary(
BinaryOp::Add,
Box::new(left_deriv),
Box::new(right_deriv),
),
BinaryOp::Sub => Expression::Binary(
BinaryOp::Sub,
Box::new(left_deriv),
Box::new(right_deriv),
),
BinaryOp::Mul => {
let term1 =
Expression::Binary(BinaryOp::Mul, left.clone(), Box::new(right_deriv));
let term2 =
Expression::Binary(BinaryOp::Mul, right.clone(), Box::new(left_deriv));
Expression::Binary(BinaryOp::Add, Box::new(term1), Box::new(term2))
}
BinaryOp::Div => {
let numerator_term1 =
Expression::Binary(BinaryOp::Mul, right.clone(), Box::new(left_deriv));
let numerator_term2 =
Expression::Binary(BinaryOp::Mul, left.clone(), Box::new(right_deriv));
let numerator = Expression::Binary(
BinaryOp::Sub,
Box::new(numerator_term1),
Box::new(numerator_term2),
);
let denominator =
Expression::Power(right.clone(), Box::new(Expression::Integer(2)));
Expression::Binary(
BinaryOp::Div,
Box::new(numerator),
Box::new(denominator),
)
}
BinaryOp::Mod => Expression::Integer(0),
}
}
Expression::Power(base, exponent) => {
let base_has_var = base.contains_variable(with_respect_to);
let exp_has_var = exponent.contains_variable(with_respect_to);
if !base_has_var && !exp_has_var {
Expression::Integer(0)
} else if base_has_var && !exp_has_var {
let base_deriv = base.differentiate(with_respect_to);
let n_minus_1 = Expression::Binary(
BinaryOp::Sub,
exponent.clone(),
Box::new(Expression::Integer(1)),
);
let power_term = Expression::Power(base.clone(), Box::new(n_minus_1));
let scaled =
Expression::Binary(BinaryOp::Mul, exponent.clone(), Box::new(power_term));
Expression::Binary(BinaryOp::Mul, Box::new(scaled), Box::new(base_deriv))
} else if !base_has_var && exp_has_var {
let exp_deriv = exponent.differentiate(with_respect_to);
let ln_base = Expression::Function(Function::Ln, vec![base.as_ref().clone()]);
let power_term = Expression::Power(base.clone(), exponent.clone());
let scaled =
Expression::Binary(BinaryOp::Mul, Box::new(power_term), Box::new(ln_base));
Expression::Binary(BinaryOp::Mul, Box::new(scaled), Box::new(exp_deriv))
} else {
let base_deriv = base.differentiate(with_respect_to);
let exp_deriv = exponent.differentiate(with_respect_to);
let ln_base = Expression::Function(Function::Ln, vec![base.as_ref().clone()]);
let term1 =
Expression::Binary(BinaryOp::Mul, Box::new(exp_deriv), Box::new(ln_base));
let u_prime_over_u =
Expression::Binary(BinaryOp::Div, Box::new(base_deriv), base.clone());
let term2 = Expression::Binary(
BinaryOp::Mul,
exponent.clone(),
Box::new(u_prime_over_u),
);
let sum = Expression::Binary(BinaryOp::Add, Box::new(term1), Box::new(term2));
let power = Expression::Power(base.clone(), exponent.clone());
Expression::Binary(BinaryOp::Mul, Box::new(power), Box::new(sum))
}
}
Expression::Function(func, args) => {
if args.is_empty() {
return Expression::Integer(0);
}
match func {
Function::Sin => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let cos_u = Expression::Function(Function::Cos, vec![arg.clone()]);
Expression::Binary(BinaryOp::Mul, Box::new(cos_u), Box::new(arg_deriv))
}
Function::Cos => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let sin_u = Expression::Function(Function::Sin, vec![arg.clone()]);
let neg_sin = Expression::Unary(UnaryOp::Neg, Box::new(sin_u));
Expression::Binary(BinaryOp::Mul, Box::new(neg_sin), Box::new(arg_deriv))
}
Function::Tan => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let cos_u = Expression::Function(Function::Cos, vec![arg.clone()]);
let cos_squared =
Expression::Power(Box::new(cos_u), Box::new(Expression::Integer(2)));
let sec_squared = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(cos_squared),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(sec_squared),
Box::new(arg_deriv),
)
}
Function::Asin => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let u_squared = Expression::Power(
Box::new(arg.clone()),
Box::new(Expression::Integer(2)),
);
let one_minus_u_sq = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Integer(1)),
Box::new(u_squared),
);
let sqrt_term = Expression::Function(Function::Sqrt, vec![one_minus_u_sq]);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(sqrt_term),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Acos => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let u_squared = Expression::Power(
Box::new(arg.clone()),
Box::new(Expression::Integer(2)),
);
let one_minus_u_sq = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Integer(1)),
Box::new(u_squared),
);
let sqrt_term = Expression::Function(Function::Sqrt, vec![one_minus_u_sq]);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(sqrt_term),
);
let neg_deriv = Expression::Unary(UnaryOp::Neg, Box::new(deriv_factor));
Expression::Binary(BinaryOp::Mul, Box::new(neg_deriv), Box::new(arg_deriv))
}
Function::Atan => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let u_squared = Expression::Power(
Box::new(arg.clone()),
Box::new(Expression::Integer(2)),
);
let one_plus_u_sq = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(1)),
Box::new(u_squared),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(one_plus_u_sq),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Atan2 => {
Expression::Integer(0)
}
Function::Sinh => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let cosh_u = Expression::Function(Function::Cosh, vec![arg.clone()]);
Expression::Binary(BinaryOp::Mul, Box::new(cosh_u), Box::new(arg_deriv))
}
Function::Cosh => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let sinh_u = Expression::Function(Function::Sinh, vec![arg.clone()]);
Expression::Binary(BinaryOp::Mul, Box::new(sinh_u), Box::new(arg_deriv))
}
Function::Tanh => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let cosh_u = Expression::Function(Function::Cosh, vec![arg.clone()]);
let cosh_squared =
Expression::Power(Box::new(cosh_u), Box::new(Expression::Integer(2)));
let sech_squared = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(cosh_squared),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(sech_squared),
Box::new(arg_deriv),
)
}
Function::Exp => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let exp_u = Expression::Function(Function::Exp, vec![arg.clone()]);
Expression::Binary(BinaryOp::Mul, Box::new(exp_u), Box::new(arg_deriv))
}
Function::Ln => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let one_over_u = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(arg.clone()),
);
Expression::Binary(BinaryOp::Mul, Box::new(one_over_u), Box::new(arg_deriv))
}
Function::Log10 => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let ln_10 =
Expression::Function(Function::Ln, vec![Expression::Integer(10)]);
let u_times_ln10 = Expression::Binary(
BinaryOp::Mul,
Box::new(arg.clone()),
Box::new(ln_10),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(u_times_ln10),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Log2 => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let ln_2 = Expression::Function(Function::Ln, vec![Expression::Integer(2)]);
let u_times_ln2 = Expression::Binary(
BinaryOp::Mul,
Box::new(arg.clone()),
Box::new(ln_2),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(u_times_ln2),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Log => {
if args.len() >= 2 {
let arg = &args[0];
let base = &args[1];
let arg_deriv = arg.differentiate(with_respect_to);
let ln_base = Expression::Function(Function::Ln, vec![base.clone()]);
let u_times_lnb = Expression::Binary(
BinaryOp::Mul,
Box::new(arg.clone()),
Box::new(ln_base),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(u_times_lnb),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
} else {
Expression::Integer(0)
}
}
Function::Sqrt => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let sqrt_u = Expression::Function(Function::Sqrt, vec![arg.clone()]);
let two_sqrt_u = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(2)),
Box::new(sqrt_u),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(two_sqrt_u),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Cbrt => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let two_thirds = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(2)),
Box::new(Expression::Integer(3)),
);
let u_to_2_3 =
Expression::Power(Box::new(arg.clone()), Box::new(two_thirds));
let three_u_2_3 = Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Integer(3)),
Box::new(u_to_2_3),
);
let deriv_factor = Expression::Binary(
BinaryOp::Div,
Box::new(Expression::Integer(1)),
Box::new(three_u_2_3),
);
Expression::Binary(
BinaryOp::Mul,
Box::new(deriv_factor),
Box::new(arg_deriv),
)
}
Function::Pow => {
if args.len() >= 2 {
let power_expr = Expression::Power(
Box::new(args[0].clone()),
Box::new(args[1].clone()),
);
power_expr.differentiate(with_respect_to)
} else {
Expression::Integer(0)
}
}
Function::Floor | Function::Ceil | Function::Round => Expression::Integer(0),
Function::Abs => {
let arg = &args[0];
let arg_deriv = arg.differentiate(with_respect_to);
let sign_u = Expression::Function(Function::Sign, vec![arg.clone()]);
Expression::Binary(BinaryOp::Mul, Box::new(sign_u), Box::new(arg_deriv))
}
Function::Sign => {
Expression::Integer(0)
}
Function::Min | Function::Max => Expression::Integer(0),
Function::Custom(_) => Expression::Integer(0),
}
}
}
}
pub fn evaluate(&self, vars: &HashMap<String, f64>) -> Option<f64> {
match self {
Expression::Integer(n) => Some(*n as f64),
Expression::Rational(r) => Some(*r.numer() as f64 / *r.denom() as f64),
Expression::Float(x) => Some(*x),
Expression::Complex(c) => {
if c.im.abs() < 1e-10 {
Some(c.re)
} else {
None
}
}
Expression::Constant(c) => match c {
SymbolicConstant::Pi => Some(std::f64::consts::PI),
SymbolicConstant::E => Some(std::f64::consts::E),
SymbolicConstant::I => None, },
Expression::Variable(v) => vars.get(&v.name).copied(),
Expression::Unary(op, expr) => {
let val = expr.evaluate(vars)?;
match op {
UnaryOp::Neg => Some(-val),
UnaryOp::Not => Some(if val == 0.0 { 1.0 } else { 0.0 }),
UnaryOp::Abs => Some(val.abs()),
}
}
Expression::Binary(op, left, right) => {
let left_val = left.evaluate(vars)?;
let right_val = right.evaluate(vars)?;
match op {
BinaryOp::Add => Some(left_val + right_val),
BinaryOp::Sub => Some(left_val - right_val),
BinaryOp::Mul => Some(left_val * right_val),
BinaryOp::Div => {
if right_val.abs() < 1e-10 {
None
} else {
Some(left_val / right_val)
}
}
BinaryOp::Mod => Some(left_val % right_val),
}
}
Expression::Function(func, args) => {
let arg_vals: Option<Vec<f64>> =
args.iter().map(|arg| arg.evaluate(vars)).collect();
let arg_vals = arg_vals?;
match func {
Function::Sin => Some(arg_vals.get(0)?.sin()),
Function::Cos => Some(arg_vals.get(0)?.cos()),
Function::Tan => Some(arg_vals.get(0)?.tan()),
Function::Asin => Some(arg_vals.get(0)?.asin()),
Function::Acos => Some(arg_vals.get(0)?.acos()),
Function::Atan => Some(arg_vals.get(0)?.atan()),
Function::Atan2 => Some(arg_vals.get(0)?.atan2(*arg_vals.get(1)?)),
Function::Sinh => Some(arg_vals.get(0)?.sinh()),
Function::Cosh => Some(arg_vals.get(0)?.cosh()),
Function::Tanh => Some(arg_vals.get(0)?.tanh()),
Function::Exp => Some(arg_vals.get(0)?.exp()),
Function::Ln => {
let x = *arg_vals.get(0)?;
if x > 0.0 {
Some(x.ln())
} else {
None
}
}
Function::Log => {
let value = *arg_vals.get(0)?;
if arg_vals.len() >= 2 {
let base = *arg_vals.get(1)?;
if value > 0.0 && base > 0.0 {
Some(value.log(base))
} else {
None
}
} else if value > 0.0 {
Some(value.log10())
} else {
None
}
}
Function::Log2 => Some(arg_vals.get(0)?.log2()),
Function::Log10 => Some(arg_vals.get(0)?.log10()),
Function::Sqrt => Some(arg_vals.get(0)?.sqrt()),
Function::Cbrt => Some(arg_vals.get(0)?.cbrt()),
Function::Pow => Some(arg_vals.get(0)?.powf(*arg_vals.get(1)?)),
Function::Floor => Some(arg_vals.get(0)?.floor()),
Function::Ceil => Some(arg_vals.get(0)?.ceil()),
Function::Round => Some(arg_vals.get(0)?.round()),
Function::Abs => Some(arg_vals.get(0)?.abs()),
Function::Sign => Some(arg_vals.get(0)?.signum()),
Function::Min => arg_vals.iter().copied().reduce(f64::min),
Function::Max => arg_vals.iter().copied().reduce(f64::max),
Function::Custom(_) => None,
}
}
Expression::Power(base, exp) => {
let base_val = base.evaluate(vars)?;
let exp_val = exp.evaluate(vars)?;
Some(base_val.powf(exp_val))
}
}
}
}