use super::Simplify;
use crate::core::expression::evaluation::evaluate_function_dispatch;
use crate::core::{Expression, Number};
pub fn simplify_function(name: &str, args: &[Expression]) -> Expression {
if args.is_empty() {
return Expression::function(name, vec![]);
}
let simplified_args: Vec<Expression> = args.iter().map(|arg| arg.simplify()).collect();
if name == "exp" && simplified_args.len() == 1 {
if let Expression::Function {
name: inner_name,
args: inner_args,
} = &simplified_args[0]
{
if (inner_name.as_ref() == "log" || inner_name.as_ref() == "ln")
&& inner_args.len() == 1
{
return inner_args[0].clone();
}
}
}
if (name == "log" || name == "ln") && simplified_args.len() == 1 {
if let Expression::Function {
name: inner_name,
args: inner_args,
} = &simplified_args[0]
{
if inner_name.as_ref() == "exp" && inner_args.len() == 1 {
return inner_args[0].clone();
}
}
}
let should_keep_symbolic = matches!(
name,
"sin"
| "cos"
| "tan"
| "cot"
| "sec"
| "csc"
| "asin"
| "acos"
| "atan"
| "asinh"
| "acosh"
| "atanh"
) && simplified_args.len() == 1
&& matches!(&simplified_args[0], Expression::Number(Number::Integer(n)) if *n != 0);
if let Some(result) = evaluate_function_dispatch(name, &simplified_args) {
match &result {
Expression::Function { .. } => Expression::function(name, simplified_args),
Expression::Number(_) if should_keep_symbolic => {
Expression::function(name, simplified_args)
}
_ => result,
}
} else {
Expression::function(name, simplified_args)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{expr, symbol};
#[test]
fn test_trigonometric_simplification() {
let result = simplify_function("sin", &[expr!(0)]);
assert_eq!(result, expr!(0));
let result = simplify_function("cos", &[expr!(0)]);
assert_eq!(result, expr!(1));
let result = simplify_function("tan", &[expr!(0)]);
assert_eq!(result, expr!(0));
}
#[test]
fn test_exponential_simplification() {
let result = simplify_function("exp", &[expr!(0)]);
assert_eq!(result, expr!(1));
let result = simplify_function("ln", &[expr!(1)]);
assert_eq!(result, expr!(0));
}
#[test]
fn test_sqrt_simplification() {
let result = simplify_function("sqrt", &[expr!(0)]);
assert_eq!(result, expr!(0));
let result = simplify_function("sqrt", &[expr!(4)]);
assert_eq!(result, expr!(2));
}
#[test]
fn test_factorial_simplification() {
let result = simplify_function("factorial", &[expr!(0)]);
assert_eq!(result, expr!(1));
let result = simplify_function("factorial", &[expr!(5)]);
assert_eq!(result, expr!(120));
}
#[test]
fn test_universal_evaluation_gamma() {
let result = simplify_function("gamma", &[expr!(5)]);
assert_eq!(result, expr!(24));
let result = simplify_function("gamma", &[expr!(1)]);
assert_eq!(result, expr!(1));
}
#[test]
fn test_universal_evaluation_preserves_symbolic() {
let result = simplify_function("gamma", &[expr!(x)]);
assert!(matches!(result, Expression::Function { .. }));
}
#[test]
fn test_exp_log_identity() {
let x = symbol!(x);
let result = simplify_function(
"exp",
&[Expression::function(
"log",
vec![Expression::symbol(x.clone())],
)],
);
assert_eq!(result, Expression::symbol(x));
}
#[test]
fn test_log_exp_identity() {
let x = symbol!(x);
let result = simplify_function(
"log",
&[Expression::function(
"exp",
vec![Expression::symbol(x.clone())],
)],
);
assert_eq!(result, Expression::symbol(x));
}
#[test]
fn test_exp_ln_identity() {
let x = symbol!(x);
let result = simplify_function(
"exp",
&[Expression::function(
"ln",
vec![Expression::symbol(x.clone())],
)],
);
assert_eq!(result, Expression::symbol(x));
}
#[test]
fn test_ln_exp_identity() {
let x = symbol!(x);
let result = simplify_function(
"ln",
&[Expression::function(
"exp",
vec![Expression::symbol(x.clone())],
)],
);
assert_eq!(result, Expression::symbol(x));
}
#[test]
fn test_function_composition_stays_symbolic() {
let result = simplify_function("sin", &[simplify_function("cos", &[expr!(0)])]);
match result {
Expression::Function { name, args } => {
assert_eq!(name.as_ref(), "sin");
assert_eq!(args.len(), 1);
assert_eq!(args[0], expr!(1));
}
_ => panic!("Expected Function(sin, [1]), got {:?}", result),
}
}
}