use crate::ast::{BinaryOp, Expression, Function, UnaryOp, Variable};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum IntegrationError {
CannotIntegrate(String),
UnsupportedExpression(String),
DivisionByZero,
}
impl fmt::Display for IntegrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IntegrationError::CannotIntegrate(msg) => {
write!(f, "Cannot integrate: {}", msg)
}
IntegrationError::UnsupportedExpression(msg) => {
write!(f, "Unsupported expression: {}", msg)
}
IntegrationError::DivisionByZero => {
write!(f, "Division by zero in integration")
}
}
}
}
impl std::error::Error for IntegrationError {}
pub type IntegrationResult = Result<Expression, IntegrationError>;
pub fn integrate(expr: &Expression, var: &str) -> IntegrationResult {
integrate_impl(expr, var)
}
fn integrate_impl(expr: &Expression, var: &str) -> IntegrationResult {
match expr {
Expression::Integer(_)
| Expression::Float(_)
| Expression::Rational(_)
| Expression::Complex(_) => {
let x = Expression::Variable(Variable::new(var));
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(expr.clone()),
Box::new(x),
))
}
Expression::Constant(_) => {
let x = Expression::Variable(Variable::new(var));
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(expr.clone()),
Box::new(x),
))
}
Expression::Variable(v) => {
if v.name == var {
let x = Expression::Variable(Variable::new(var));
let x_squared = Expression::Power(Box::new(x), Box::new(Expression::Integer(2)));
Ok(Expression::Binary(
BinaryOp::Div,
Box::new(x_squared),
Box::new(Expression::Integer(2)),
))
} else {
let x = Expression::Variable(Variable::new(var));
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(expr.clone()),
Box::new(x),
))
}
}
Expression::Unary(op, inner) => match op {
UnaryOp::Neg => {
let inner_integral = integrate_impl(inner, var)?;
Ok(Expression::Unary(UnaryOp::Neg, Box::new(inner_integral)))
}
UnaryOp::Abs => Err(IntegrationError::CannotIntegrate(
"Cannot integrate |f(x)| symbolically".to_string(),
)),
UnaryOp::Not => Err(IntegrationError::UnsupportedExpression(
"Logical NOT cannot be integrated".to_string(),
)),
},
Expression::Binary(op, left, right) => match op {
BinaryOp::Add => {
let left_integral = integrate_impl(left, var)?;
let right_integral = integrate_impl(right, var)?;
Ok(Expression::Binary(
BinaryOp::Add,
Box::new(left_integral),
Box::new(right_integral),
))
}
BinaryOp::Sub => {
let left_integral = integrate_impl(left, var)?;
let right_integral = integrate_impl(right, var)?;
Ok(Expression::Binary(
BinaryOp::Sub,
Box::new(left_integral),
Box::new(right_integral),
))
}
BinaryOp::Mul => integrate_product(left, right, var),
BinaryOp::Div => integrate_quotient(left, right, var),
BinaryOp::Mod => Err(IntegrationError::CannotIntegrate(
"Modulo cannot be integrated".to_string(),
)),
},
Expression::Power(base, exponent) => integrate_power(base, exponent, var),
Expression::Function(func, args) => integrate_function(func, args, var),
}
}
fn integrate_product(left: &Expression, right: &Expression, var: &str) -> IntegrationResult {
let left_has_var = left.contains_variable(var);
let right_has_var = right.contains_variable(var);
if !left_has_var && !right_has_var {
let x = Expression::Variable(Variable::new(var));
let product = Expression::Binary(
BinaryOp::Mul,
Box::new(left.clone()),
Box::new(right.clone()),
);
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(product),
Box::new(x),
))
} else if !left_has_var {
let right_integral = integrate_impl(right, var)?;
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(left.clone()),
Box::new(right_integral),
))
} else if !right_has_var {
let left_integral = integrate_impl(left, var)?;
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(right.clone()),
Box::new(left_integral),
))
} else {
if let Some(result) = try_combine_powers(left, right, var) {
return integrate_power_expr(&result, var);
}
Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate product {} * {} - try u-substitution",
left, right
)))
}
}
fn try_combine_powers(left: &Expression, right: &Expression, var: &str) -> Option<Expression> {
let left_power = extract_power(left, var)?;
let right_power = extract_power(right, var)?;
let sum = Expression::Binary(BinaryOp::Add, Box::new(left_power), Box::new(right_power));
Some(Expression::Power(
Box::new(Expression::Variable(Variable::new(var))),
Box::new(sum),
))
}
fn extract_power(expr: &Expression, var: &str) -> Option<Expression> {
match expr {
Expression::Variable(v) if v.name == var => Some(Expression::Integer(1)),
Expression::Power(base, exp) => {
if let Expression::Variable(v) = base.as_ref() {
if v.name == var {
return Some(exp.as_ref().clone());
}
}
None
}
_ => None,
}
}
fn integrate_power_expr(expr: &Expression, var: &str) -> IntegrationResult {
if let Expression::Power(base, exp) = expr {
integrate_power(base, exp, var)
} else {
integrate_impl(expr, var)
}
}
fn integrate_quotient(num: &Expression, denom: &Expression, var: &str) -> IntegrationResult {
let num_has_var = num.contains_variable(var);
let denom_has_var = denom.contains_variable(var);
if !denom_has_var {
let num_integral = integrate_impl(num, var)?;
Ok(Expression::Binary(
BinaryOp::Div,
Box::new(num_integral),
Box::new(denom.clone()),
))
} else if !num_has_var {
if let Expression::Variable(v) = denom {
if v.name == var {
let ln_x = Expression::Function(
Function::Ln,
vec![Expression::Function(
Function::Abs,
vec![Expression::Variable(Variable::new(var))],
)],
);
return Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(num.clone()),
Box::new(ln_x),
));
}
}
if let Some(result) = try_arctan_pattern(num, denom, var) {
return Ok(result);
}
if let Some(result) = try_arcsin_pattern(num, denom, var) {
return Ok(result);
}
Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate {}/{}",
num, denom
)))
} else {
if let Expression::Variable(v) = denom {
if v.name == var {
if let Some(power) = extract_power(num, var) {
let new_exp = Expression::Binary(
BinaryOp::Sub,
Box::new(power),
Box::new(Expression::Integer(1)),
);
return integrate_power(
&Expression::Variable(Variable::new(var)),
&new_exp,
var,
);
}
}
}
Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate quotient {}/{} - try partial fractions",
num, denom
)))
}
}
fn try_arctan_pattern(num: &Expression, denom: &Expression, var: &str) -> Option<Expression> {
if let Expression::Binary(BinaryOp::Add, left, right) = denom {
let is_one = matches!(left.as_ref(), Expression::Integer(1));
let is_x_squared = matches!(
right.as_ref(),
Expression::Power(base, exp)
if matches!(base.as_ref(), Expression::Variable(v) if v.name == var)
&& matches!(exp.as_ref(), Expression::Integer(2))
);
if is_one && is_x_squared {
let arctan_x = Expression::Function(
Function::Atan,
vec![Expression::Variable(Variable::new(var))],
);
return Some(Expression::Binary(
BinaryOp::Mul,
Box::new(num.clone()),
Box::new(arctan_x),
));
}
}
None
}
fn try_arcsin_pattern(num: &Expression, denom: &Expression, var: &str) -> Option<Expression> {
if let Expression::Function(Function::Sqrt, args) = denom {
if let Some(inner) = args.first() {
if let Expression::Binary(BinaryOp::Sub, left, right) = inner {
let is_one = matches!(left.as_ref(), Expression::Integer(1));
let is_x_squared = matches!(
right.as_ref(),
Expression::Power(base, exp)
if matches!(base.as_ref(), Expression::Variable(v) if v.name == var)
&& matches!(exp.as_ref(), Expression::Integer(2))
);
if is_one && is_x_squared {
let arcsin_x = Expression::Function(
Function::Asin,
vec![Expression::Variable(Variable::new(var))],
);
return Some(Expression::Binary(
BinaryOp::Mul,
Box::new(num.clone()),
Box::new(arcsin_x),
));
}
}
}
}
None
}
fn integrate_power(base: &Expression, exponent: &Expression, var: &str) -> IntegrationResult {
let base_has_var = base.contains_variable(var);
let exp_has_var = exponent.contains_variable(var);
if !base_has_var && !exp_has_var {
let x = Expression::Variable(Variable::new(var));
let power = Expression::Power(Box::new(base.clone()), Box::new(exponent.clone()));
Ok(Expression::Binary(
BinaryOp::Mul,
Box::new(power),
Box::new(x),
))
} else if base_has_var && !exp_has_var {
if let Expression::Variable(v) = base {
if v.name == var {
return integrate_power_of_var(exponent, var);
}
}
Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate ({})^{} - complex base",
base, exponent
)))
} else if !base_has_var && exp_has_var {
if let Expression::Variable(v) = exponent {
if v.name == var {
let ln_base = Expression::Function(Function::Ln, vec![base.clone()]);
let a_to_x = Expression::Power(Box::new(base.clone()), Box::new(exponent.clone()));
return Ok(Expression::Binary(
BinaryOp::Div,
Box::new(a_to_x),
Box::new(ln_base),
));
}
}
if let Expression::Constant(crate::ast::SymbolicConstant::E) = base {
if let Expression::Variable(v) = exponent {
if v.name == var {
return Ok(Expression::Power(
Box::new(base.clone()),
Box::new(exponent.clone()),
));
}
}
}
Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate {}^({}) - exponential with complex exponent",
base, exponent
)))
} else {
Err(IntegrationError::CannotIntegrate(
"Cannot integrate f(x)^g(x) - requires special techniques".to_string(),
))
}
}
fn integrate_power_of_var(exponent: &Expression, var: &str) -> IntegrationResult {
if let Expression::Integer(-1) = exponent {
return Ok(Expression::Function(
Function::Ln,
vec![Expression::Function(
Function::Abs,
vec![Expression::Variable(Variable::new(var))],
)],
));
}
if let Expression::Unary(UnaryOp::Neg, inner) = exponent {
if let Expression::Integer(1) = inner.as_ref() {
return Ok(Expression::Function(
Function::Ln,
vec![Expression::Function(
Function::Abs,
vec![Expression::Variable(Variable::new(var))],
)],
));
}
}
if let Expression::Rational(r) = exponent {
if *r.numer() == -1 && *r.denom() == 1 {
return Ok(Expression::Function(
Function::Ln,
vec![Expression::Function(
Function::Abs,
vec![Expression::Variable(Variable::new(var))],
)],
));
}
}
let x = Expression::Variable(Variable::new(var));
let n_plus_1 = Expression::Binary(
BinaryOp::Add,
Box::new(exponent.clone()),
Box::new(Expression::Integer(1)),
);
let x_to_n_plus_1 = Expression::Power(Box::new(x), Box::new(n_plus_1.clone()));
Ok(Expression::Binary(
BinaryOp::Div,
Box::new(x_to_n_plus_1),
Box::new(n_plus_1),
))
}
fn integrate_function(func: &Function, args: &[Expression], var: &str) -> IntegrationResult {
if args.is_empty() {
return Err(IntegrationError::CannotIntegrate(
"Function with no arguments".to_string(),
));
}
let arg = &args[0];
let is_simple_var = matches!(arg, Expression::Variable(v) if v.name == var);
if !is_simple_var {
if let Some(result) = try_linear_substitution(func, arg, var) {
return Ok(result);
}
return Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate {}({}) - try u-substitution",
func_name(func),
arg
)));
}
match func {
Function::Sin => Ok(Expression::Unary(
UnaryOp::Neg,
Box::new(Expression::Function(
Function::Cos,
vec![Expression::Variable(Variable::new(var))],
)),
)),
Function::Cos => Ok(Expression::Function(
Function::Sin,
vec![Expression::Variable(Variable::new(var))],
)),
Function::Tan => {
let cos_x = Expression::Function(
Function::Cos,
vec![Expression::Variable(Variable::new(var))],
);
let abs_cos = Expression::Function(Function::Abs, vec![cos_x]);
let ln_abs_cos = Expression::Function(Function::Ln, vec![abs_cos]);
Ok(Expression::Unary(UnaryOp::Neg, Box::new(ln_abs_cos)))
}
Function::Exp => Ok(Expression::Function(
Function::Exp,
vec![Expression::Variable(Variable::new(var))],
)),
Function::Ln => {
let x = Expression::Variable(Variable::new(var));
let ln_x = Expression::Function(Function::Ln, vec![x.clone()]);
let x_ln_x = Expression::Binary(BinaryOp::Mul, Box::new(x.clone()), Box::new(ln_x));
Ok(Expression::Binary(
BinaryOp::Sub,
Box::new(x_ln_x),
Box::new(x),
))
}
Function::Sinh => Ok(Expression::Function(
Function::Cosh,
vec![Expression::Variable(Variable::new(var))],
)),
Function::Cosh => Ok(Expression::Function(
Function::Sinh,
vec![Expression::Variable(Variable::new(var))],
)),
Function::Tanh => {
let cosh_x = Expression::Function(
Function::Cosh,
vec![Expression::Variable(Variable::new(var))],
);
Ok(Expression::Function(Function::Ln, vec![cosh_x]))
}
Function::Sqrt => Err(IntegrationError::CannotIntegrate(
"∫sqrt(x) dx - rewrite as x^(1/2) and use power rule".to_string(),
)),
_ => Err(IntegrationError::CannotIntegrate(format!(
"No standard integral for {}(x)",
func_name(func)
))),
}
}
fn try_linear_substitution(func: &Function, arg: &Expression, var: &str) -> Option<Expression> {
let (coeff, _offset) = extract_linear_form(arg, var)?;
if matches!(&coeff, Expression::Integer(0)) {
return None;
}
let standard_integral = match func {
Function::Sin => Expression::Unary(
UnaryOp::Neg,
Box::new(Expression::Function(Function::Cos, vec![arg.clone()])),
),
Function::Cos => Expression::Function(Function::Sin, vec![arg.clone()]),
Function::Exp => Expression::Function(Function::Exp, vec![arg.clone()]),
_ => return None,
};
Some(Expression::Binary(
BinaryOp::Div,
Box::new(standard_integral),
Box::new(coeff),
))
}
fn extract_linear_form(expr: &Expression, var: &str) -> Option<(Expression, Expression)> {
match expr {
Expression::Variable(v) if v.name == var => {
Some((Expression::Integer(1), Expression::Integer(0)))
}
Expression::Binary(BinaryOp::Mul, left, right) => {
if !left.contains_variable(var) {
if matches!(right.as_ref(), Expression::Variable(v) if v.name == var) {
return Some((left.as_ref().clone(), Expression::Integer(0)));
}
}
if !right.contains_variable(var) {
if matches!(left.as_ref(), Expression::Variable(v) if v.name == var) {
return Some((right.as_ref().clone(), Expression::Integer(0)));
}
}
None
}
Expression::Binary(BinaryOp::Add, left, right) => {
if !right.contains_variable(var) {
if let Some((a, _)) = extract_linear_form(left, var) {
return Some((a, right.as_ref().clone()));
}
}
if !left.contains_variable(var) {
if let Some((a, _)) = extract_linear_form(right, var) {
return Some((a, left.as_ref().clone()));
}
}
None
}
_ => None,
}
}
fn func_name(func: &Function) -> &'static str {
match func {
Function::Sin => "sin",
Function::Cos => "cos",
Function::Tan => "tan",
Function::Asin => "asin",
Function::Acos => "acos",
Function::Atan => "atan",
Function::Atan2 => "atan2",
Function::Sinh => "sinh",
Function::Cosh => "cosh",
Function::Tanh => "tanh",
Function::Exp => "exp",
Function::Ln => "ln",
Function::Log => "log",
Function::Log2 => "log2",
Function::Log10 => "log10",
Function::Sqrt => "sqrt",
Function::Cbrt => "cbrt",
Function::Abs => "abs",
Function::Floor => "floor",
Function::Ceil => "ceil",
Function::Round => "round",
Function::Min => "min",
Function::Max => "max",
Function::Pow => "pow",
Function::Sign => "sign",
Function::Custom(_) => {
"custom"
}
}
}
pub fn integrate_by_substitution(expr: &Expression, var: &str) -> IntegrationResult {
if let Ok(result) = integrate_impl(expr, var) {
return Ok(result);
}
if let Some((u, du_dx, inner_integral)) = find_substitution(expr, var) {
let result = back_substitute(&inner_integral, &u, var);
if let Ok(derivative) = verify_by_differentiation(&result, var, expr) {
if derivative {
return Ok(result);
}
}
let _ = du_dx; return Ok(result);
}
Err(IntegrationError::CannotIntegrate(
"U-substitution did not find a suitable substitution".to_string(),
))
}
fn find_substitution(expr: &Expression, var: &str) -> Option<(Expression, Expression, Expression)> {
if let Some(result) = try_product_substitution(expr, var) {
return Some(result);
}
None
}
fn try_product_substitution(
expr: &Expression,
var: &str,
) -> Option<(Expression, Expression, Expression)> {
let factors = extract_factors(expr);
if factors.len() < 2 {
return None;
}
for (i, factor) in factors.iter().enumerate() {
if let Some(u_candidate) = extract_inner_function(factor) {
let du_dx = differentiate_expr(&u_candidate, var);
let other_factors: Vec<_> = factors
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, f)| f.clone())
.collect();
if let Some((constant, remaining)) = match_derivative(&other_factors, &du_dx, var) {
let f_of_u = rebuild_with_u(factor, &u_candidate);
let u_var = "u";
if let Ok(f_integral) = integrate_impl(&f_of_u, u_var) {
let result = if is_one(&constant) {
f_integral
} else {
Expression::Binary(BinaryOp::Div, Box::new(f_integral), Box::new(constant))
};
let final_result = if remaining.is_empty() {
result
} else {
let remaining_product = combine_factors(&remaining);
Expression::Binary(
BinaryOp::Mul,
Box::new(remaining_product),
Box::new(result),
)
};
return Some((u_candidate, du_dx, final_result));
}
}
}
let u_candidate = factor.clone();
if u_candidate.contains_variable(var) {
let du_dx = differentiate_expr(&u_candidate, var);
let other_factors: Vec<_> = factors
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, f)| f.clone())
.collect();
if let Some((constant, remaining)) = match_derivative(&other_factors, &du_dx, var) {
if let Some((base, exp)) = extract_power_form(&u_candidate, var) {
if !exp.contains_variable(var) {
let n_plus_1 = Expression::Binary(
BinaryOp::Add,
Box::new(exp.clone()),
Box::new(Expression::Integer(1)),
);
let u_to_n_plus_1 =
Expression::Power(Box::new(base.clone()), Box::new(n_plus_1.clone()));
let integral = Expression::Binary(
BinaryOp::Div,
Box::new(u_to_n_plus_1),
Box::new(n_plus_1),
);
let result = if is_one(&constant) {
integral
} else {
Expression::Binary(
BinaryOp::Div,
Box::new(integral),
Box::new(constant),
)
};
let final_result = if remaining.is_empty() {
result
} else {
let remaining_product = combine_factors(&remaining);
Expression::Binary(
BinaryOp::Mul,
Box::new(remaining_product),
Box::new(result),
)
};
return Some((base, du_dx, final_result));
}
}
}
}
}
None
}
fn extract_factors(expr: &Expression) -> Vec<Expression> {
match expr {
Expression::Binary(BinaryOp::Mul, left, right) => {
let mut factors = extract_factors(left);
factors.extend(extract_factors(right));
factors
}
_ => vec![expr.clone()],
}
}
fn combine_factors(factors: &[Expression]) -> Expression {
if factors.is_empty() {
return Expression::Integer(1);
}
if factors.len() == 1 {
return factors[0].clone();
}
let mut result = factors[0].clone();
for factor in &factors[1..] {
result = Expression::Binary(BinaryOp::Mul, Box::new(result), Box::new(factor.clone()));
}
result
}
fn extract_inner_function(expr: &Expression) -> Option<Expression> {
match expr {
Expression::Function(_, args) if !args.is_empty() => Some(args[0].clone()),
Expression::Power(base, _) => {
if let Expression::Function(_, args) = base.as_ref() {
if !args.is_empty() {
return Some(args[0].clone());
}
}
if !matches!(
base.as_ref(),
Expression::Variable(_) | Expression::Integer(_)
) {
return Some(base.as_ref().clone());
}
None
}
_ => None,
}
}
fn extract_power_form(expr: &Expression, _var: &str) -> Option<(Expression, Expression)> {
match expr {
Expression::Power(base, exp) => Some((base.as_ref().clone(), exp.as_ref().clone())),
Expression::Variable(_) => Some((expr.clone(), Expression::Integer(1))),
_ => None,
}
}
fn differentiate_expr(expr: &Expression, var: &str) -> Expression {
expr.differentiate(var)
}
fn match_derivative(
factors: &[Expression],
derivative: &Expression,
var: &str,
) -> Option<(Expression, Vec<Expression>)> {
let simplified_deriv = derivative.simplify();
for (i, factor) in factors.iter().enumerate() {
let simplified_factor = factor.simplify();
if expressions_equivalent(&simplified_factor, &simplified_deriv) {
let remaining: Vec<_> = factors
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, f)| f.clone())
.collect();
return Some((Expression::Integer(1), remaining));
}
if let Some(constant) =
extract_constant_multiple(&simplified_factor, &simplified_deriv, var)
{
let remaining: Vec<_> = factors
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, f)| f.clone())
.collect();
return Some((constant, remaining));
}
}
if factors.len() == 1 {
return None;
}
let combined = combine_factors(factors);
let simplified_combined = combined.simplify();
if expressions_equivalent(&simplified_combined, &simplified_deriv) {
return Some((Expression::Integer(1), vec![]));
}
if let Some(constant) = extract_constant_multiple(&simplified_combined, &simplified_deriv, var)
{
return Some((constant, vec![]));
}
None
}
fn canonical_key(expr: &Expression) -> String {
match expr {
Expression::Binary(op @ (BinaryOp::Add | BinaryOp::Mul), left, right) => {
let mut parts = vec![canonical_key(left), canonical_key(right)];
parts.sort();
let op_sym = match op {
BinaryOp::Add => "+",
BinaryOp::Mul => "*",
_ => unreachable!(),
};
format!("({}{}{})", parts[0], op_sym, parts[1])
}
Expression::Binary(op, left, right) => {
format!("({}{:?}{})", canonical_key(left), op, canonical_key(right))
}
Expression::Unary(op, inner) => {
format!("({:?}{})", op, canonical_key(inner))
}
Expression::Power(base, exp) => {
format!("({}^{})", canonical_key(base), canonical_key(exp))
}
Expression::Function(f, args) => {
let arg_keys: Vec<_> = args.iter().map(canonical_key).collect();
format!("{:?}({})", f, arg_keys.join(","))
}
other => format!("{}", other),
}
}
fn expressions_equivalent(a: &Expression, b: &Expression) -> bool {
canonical_key(a) == canonical_key(b)
}
fn extract_constant_multiple(
expr1: &Expression,
expr2: &Expression,
var: &str,
) -> Option<Expression> {
if !expr2.contains_variable(var) {
return None;
}
if let Expression::Binary(BinaryOp::Mul, left, right) = expr1 {
if !left.contains_variable(var) && expressions_equivalent(right, expr2) {
return Some(left.as_ref().clone());
}
if !right.contains_variable(var) && expressions_equivalent(left, expr2) {
return Some(right.as_ref().clone());
}
}
None
}
fn rebuild_with_u(expr: &Expression, _u: &Expression) -> Expression {
match expr {
Expression::Function(func, _) => {
Expression::Function(func.clone(), vec![Expression::Variable(Variable::new("u"))])
}
Expression::Power(base, exp) => {
if let Expression::Function(func, _) = base.as_ref() {
let f_u = Expression::Function(
func.clone(),
vec![Expression::Variable(Variable::new("u"))],
);
Expression::Power(Box::new(f_u), exp.clone())
} else {
Expression::Power(
Box::new(Expression::Variable(Variable::new("u"))),
exp.clone(),
)
}
}
_ => Expression::Variable(Variable::new("u")),
}
}
fn is_one(expr: &Expression) -> bool {
matches!(expr, Expression::Integer(1))
}
fn back_substitute(expr: &Expression, u: &Expression, _var: &str) -> Expression {
substitute_variable(expr, "u", u)
}
fn substitute_variable(expr: &Expression, var_name: &str, replacement: &Expression) -> Expression {
match expr {
Expression::Variable(v) if v.name == var_name => replacement.clone(),
Expression::Variable(_) => expr.clone(),
Expression::Integer(_)
| Expression::Float(_)
| Expression::Rational(_)
| Expression::Complex(_)
| Expression::Constant(_) => expr.clone(),
Expression::Unary(op, inner) => Expression::Unary(
op.clone(),
Box::new(substitute_variable(inner, var_name, replacement)),
),
Expression::Binary(op, left, right) => Expression::Binary(
op.clone(),
Box::new(substitute_variable(left, var_name, replacement)),
Box::new(substitute_variable(right, var_name, replacement)),
),
Expression::Power(base, exp) => Expression::Power(
Box::new(substitute_variable(base, var_name, replacement)),
Box::new(substitute_variable(exp, var_name, replacement)),
),
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter()
.map(|arg| substitute_variable(arg, var_name, replacement))
.collect(),
),
}
}
fn verify_by_differentiation(
result: &Expression,
var: &str,
original: &Expression,
) -> Result<bool, IntegrationError> {
let derivative = result.differentiate(var).simplify();
let original_simplified = original.simplify();
Ok(expressions_equivalent(&derivative, &original_simplified))
}
pub fn integrate_with_substitution(
expr: &Expression,
var: &str,
) -> Result<(Expression, Vec<String>), IntegrationError> {
let mut steps = Vec::new();
if let Ok(result) = integrate_impl(expr, var) {
steps.push(format!(
"Direct integration of {} with respect to {}",
expr, var
));
return Ok((result, steps));
}
steps.push(format!("Attempting u-substitution for ∫{} d{}", expr, var));
if let Some((u, du_dx, inner_integral)) = find_substitution(expr, var) {
steps.push(format!("Let u = {}", u));
steps.push(format!("Then du/d{} = {}", var, du_dx));
steps.push(format!("Substituting: integral becomes ∫... du"));
let result = back_substitute(&inner_integral, &u, var);
steps.push(format!("Back-substituting u = {}", u));
steps.push(format!("Result: {}", result));
return Ok((result, steps));
}
Err(IntegrationError::CannotIntegrate(
"No suitable substitution found".to_string(),
))
}
fn liate_priority(expr: &Expression, var: &str) -> u8 {
if !expr.contains_variable(var) {
return 100; }
match expr {
Expression::Function(Function::Ln, _)
| Expression::Function(Function::Log, _)
| Expression::Function(Function::Log2, _)
| Expression::Function(Function::Log10, _) => 5,
Expression::Function(Function::Asin, _)
| Expression::Function(Function::Acos, _)
| Expression::Function(Function::Atan, _) => 4,
Expression::Variable(v) if v.name == var => 3,
Expression::Power(base, exp) => {
if !base.contains_variable(var) && exp.contains_variable(var) {
return 1;
}
if matches!(base.as_ref(), Expression::Variable(v) if v.name == var) {
if !exp.contains_variable(var) {
return 3;
}
}
2 }
Expression::Binary(BinaryOp::Add, _, _) | Expression::Binary(BinaryOp::Sub, _, _) => {
3
}
Expression::Binary(BinaryOp::Mul, left, right) => {
let l = liate_priority(left, var);
let r = liate_priority(right, var);
std::cmp::min(l, r)
}
Expression::Function(Function::Sin, _)
| Expression::Function(Function::Cos, _)
| Expression::Function(Function::Tan, _) => 2,
Expression::Function(Function::Exp, _) => 1,
Expression::Constant(crate::ast::SymbolicConstant::E) => 1,
_ => 2, }
}
pub fn integrate_by_parts(expr: &Expression, var: &str) -> IntegrationResult {
integrate_by_parts_impl(expr, var, 0)
}
const MAX_PARTS_DEPTH: usize = 10;
fn integrate_by_parts_impl(expr: &Expression, var: &str, depth: usize) -> IntegrationResult {
if depth > MAX_PARTS_DEPTH {
return Err(IntegrationError::CannotIntegrate(
"Integration by parts exceeded maximum depth".to_string(),
));
}
if let Ok(result) = integrate_impl(expr, var) {
return Ok(result);
}
if let Ok(result) = integrate_by_substitution(expr, var) {
return Ok(result);
}
let factors = extract_factors(expr);
if factors.len() < 2 {
return Err(IntegrationError::CannotIntegrate(
"Integration by parts requires a product".to_string(),
));
}
let (u, dv) = choose_u_and_dv(&factors, var);
let du = u.differentiate(var).simplify();
let v = match integrate_impl(&dv, var) {
Ok(v) => v.simplify(),
Err(_) => {
match integrate_by_parts_impl(&dv, var, depth + 1) {
Ok(v) => v.simplify(),
Err(e) => return Err(e),
}
}
};
let uv = Expression::Binary(BinaryOp::Mul, Box::new(u.clone()), Box::new(v.clone())).simplify();
let v_du =
Expression::Binary(BinaryOp::Mul, Box::new(v.clone()), Box::new(du.clone())).simplify();
if let Some(result) = try_solve_recurring_integral(expr, &v_du, &uv, var, depth) {
return Ok(result);
}
let integral_v_du = match integrate_impl(&v_du, var) {
Ok(result) => result.simplify(),
Err(_) => {
match integrate_by_parts_impl(&v_du, var, depth + 1) {
Ok(result) => result.simplify(),
Err(e) => return Err(e),
}
}
};
let result =
Expression::Binary(BinaryOp::Sub, Box::new(uv), Box::new(integral_v_du)).simplify();
Ok(result)
}
fn choose_u_and_dv(factors: &[Expression], var: &str) -> (Expression, Expression) {
let mut best_u_idx = 0;
let mut best_priority = 0;
for (i, factor) in factors.iter().enumerate() {
let priority = liate_priority(factor, var);
if priority > best_priority {
best_priority = priority;
best_u_idx = i;
}
}
let u = factors[best_u_idx].clone();
let dv_factors: Vec<_> = factors
.iter()
.enumerate()
.filter(|(i, _)| *i != best_u_idx)
.map(|(_, f)| f.clone())
.collect();
let dv = combine_factors(&dv_factors);
(u, dv)
}
fn try_solve_recurring_integral(
original: &Expression,
v_du: &Expression,
uv: &Expression,
var: &str,
depth: usize,
) -> Option<Expression> {
if depth >= 2 {
return None;
}
let original_simplified = original.simplify();
let v_du_simplified = v_du.simplify();
if let Some(coefficient) =
check_same_up_to_constant(&original_simplified, &v_du_simplified, var)
{
let one_plus_c = Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Integer(1)),
Box::new(coefficient),
)
.simplify();
if matches!(one_plus_c, Expression::Integer(0)) {
return None; }
let result = Expression::Binary(BinaryOp::Div, Box::new(uv.clone()), Box::new(one_plus_c))
.simplify();
return Some(result);
}
None
}
fn check_same_up_to_constant(
expr1: &Expression,
expr2: &Expression,
var: &str,
) -> Option<Expression> {
if expressions_equivalent(expr1, expr2) {
return Some(Expression::Integer(1));
}
if let Expression::Unary(UnaryOp::Neg, inner) = expr2 {
if expressions_equivalent(expr1, inner) {
return Some(Expression::Integer(-1));
}
}
if let Expression::Binary(BinaryOp::Mul, left, right) = expr2 {
if !left.contains_variable(var) && expressions_equivalent(expr1, right) {
return Some(left.as_ref().clone());
}
if !right.contains_variable(var) && expressions_equivalent(expr1, left) {
return Some(right.as_ref().clone());
}
}
None
}
pub fn integrate_by_parts_with_steps(
expr: &Expression,
var: &str,
) -> Result<(Expression, Vec<String>), IntegrationError> {
let mut steps = Vec::new();
steps.push(format!("∫{} d{}", expr, var));
steps.push("Using integration by parts: ∫u dv = uv - ∫v du".to_string());
let factors = extract_factors(expr);
if factors.len() < 2 {
return Err(IntegrationError::CannotIntegrate(
"Integration by parts requires a product".to_string(),
));
}
let (u, dv) = choose_u_and_dv(&factors, var);
steps.push(format!("Let u = {}", u));
steps.push(format!("Let dv = {} d{}", dv, var));
let du = u.differentiate(var).simplify();
steps.push(format!("Then du = {} d{}", du, var));
let v = match integrate_impl(&dv, var) {
Ok(v) => v.simplify(),
Err(_) => {
return Err(IntegrationError::CannotIntegrate(format!(
"Cannot integrate dv = {}",
dv
)))
}
};
steps.push(format!("And v = ∫{} d{} = {}", dv, var, v));
let uv = Expression::Binary(BinaryOp::Mul, Box::new(u.clone()), Box::new(v.clone())).simplify();
steps.push(format!("uv = {}", uv));
let v_du =
Expression::Binary(BinaryOp::Mul, Box::new(v.clone()), Box::new(du.clone())).simplify();
steps.push(format!("v·du = {}", v_du));
let integral_v_du = match integrate_by_parts_impl(&v_du, var, 1) {
Ok(result) => result.simplify(),
Err(e) => return Err(e),
};
steps.push(format!("∫v du = {}", integral_v_du));
let result =
Expression::Binary(BinaryOp::Sub, Box::new(uv), Box::new(integral_v_du)).simplify();
steps.push(format!("Result: {}", result));
Ok((result, steps))
}
pub fn tabular_integration(
polynomial: &Expression,
integrable: &Expression,
var: &str,
) -> IntegrationResult {
if !is_polynomial_like(polynomial, var) {
return Err(IntegrationError::CannotIntegrate(
"First argument must be a polynomial".to_string(),
));
}
let mut derivatives = Vec::new();
let mut integrals = Vec::new();
let mut current_deriv = polynomial.clone();
derivatives.push(current_deriv.clone());
loop {
current_deriv = current_deriv.differentiate(var).simplify();
derivatives.push(current_deriv.clone());
if matches!(current_deriv, Expression::Integer(0)) {
break;
}
if derivatives.len() > 50 {
return Err(IntegrationError::CannotIntegrate(
"Polynomial degree too high for tabular method".to_string(),
));
}
}
let mut current_integral = integrable.clone();
for _ in 0..derivatives.len() {
integrals.push(current_integral.clone());
current_integral = match integrate_impl(¤t_integral, var) {
Ok(i) => i.simplify(),
Err(e) => return Err(e),
};
}
let mut result = Expression::Integer(0);
let mut positive = true;
for (d, i) in derivatives.iter().zip(integrals.iter()) {
if matches!(d, Expression::Integer(0)) {
break;
}
let term = Expression::Binary(BinaryOp::Mul, Box::new(d.clone()), Box::new(i.clone()));
if positive {
result = Expression::Binary(BinaryOp::Add, Box::new(result), Box::new(term));
} else {
result = Expression::Binary(BinaryOp::Sub, Box::new(result), Box::new(term));
}
positive = !positive;
}
Ok(result.simplify())
}
pub fn definite_integral(
expr: &Expression,
var: &str,
lower: &Expression,
upper: &Expression,
) -> IntegrationResult {
if let Some(result) = check_definite_special_cases(expr, var, lower, upper) {
return Ok(result);
}
let antiderivative = integrate(expr, var)?;
let f_upper = substitute_var(&antiderivative, var, upper);
let f_lower = substitute_var(&antiderivative, var, lower);
let result = Expression::Binary(BinaryOp::Sub, Box::new(f_upper), Box::new(f_lower)).simplify();
Ok(result)
}
pub fn definite_integral_with_fallback(
expr: &Expression,
var: &str,
lower: f64,
upper: f64,
tolerance: f64,
) -> Result<f64, IntegrationError> {
let lower_expr = Expression::Float(lower);
let upper_expr = Expression::Float(upper);
if let Ok(result) = definite_integral(expr, var, &lower_expr, &upper_expr) {
let vars = std::collections::HashMap::new();
if let Some(value) = result.evaluate(&vars) {
return Ok(value);
}
}
numerical_integrate(expr, var, lower, upper, tolerance)
}
pub fn numerical_integrate(
expr: &Expression,
var: &str,
lower: f64,
upper: f64,
tolerance: f64,
) -> Result<f64, IntegrationError> {
const INITIAL_SUBDIVISIONS: usize = 8;
let h = (upper - lower) / INITIAL_SUBDIVISIONS as f64;
let tol_per = tolerance / INITIAL_SUBDIVISIONS as f64;
let mut total = 0.0;
for i in 0..INITIAL_SUBDIVISIONS {
let a = lower + i as f64 * h;
let b = a + h;
let part = simpsons_rule_adaptive(expr, var, a, b, tol_per, 0).ok_or_else(|| {
IntegrationError::CannotIntegrate("Numerical integration failed".to_string())
})?;
total += part;
}
Ok(total)
}
fn simpsons_rule_adaptive(
expr: &Expression,
var: &str,
a: f64,
b: f64,
tolerance: f64,
depth: usize,
) -> Option<f64> {
const MAX_DEPTH: usize = 20;
if depth > MAX_DEPTH {
return None;
}
let mid = (a + b) / 2.0;
let fa = evaluate_at(expr, var, a)?;
let fb = evaluate_at(expr, var, b)?;
let fm = evaluate_at(expr, var, mid)?;
let f1 = evaluate_at(expr, var, (a + mid) / 2.0)?;
let f2 = evaluate_at(expr, var, (mid + b) / 2.0)?;
let h = (b - a) / 6.0;
let s1 = h * (fa + 4.0 * fm + fb);
let h2 = h / 2.0;
let s2 = h2 * (fa + 4.0 * f1 + fm) + h2 * (fm + 4.0 * f2 + fb);
if (s2 - s1).abs() < 15.0 * tolerance {
Some(s2 + (s2 - s1) / 15.0) } else {
let left = simpsons_rule_adaptive(expr, var, a, mid, tolerance / 2.0, depth + 1)?;
let right = simpsons_rule_adaptive(expr, var, mid, b, tolerance / 2.0, depth + 1)?;
Some(left + right)
}
}
fn evaluate_at(expr: &Expression, var: &str, value: f64) -> Option<f64> {
let mut vars = std::collections::HashMap::new();
vars.insert(var.to_string(), value);
expr.evaluate(&vars)
}
fn check_definite_special_cases(
expr: &Expression,
var: &str,
lower: &Expression,
upper: &Expression,
) -> Option<Expression> {
if matches!(expr, Expression::Integer(0)) {
return Some(Expression::Integer(0));
}
if expressions_equivalent(lower, upper) {
return Some(Expression::Integer(0));
}
if let Some(result) = check_odd_function_symmetric(expr, var, lower, upper) {
return Some(result);
}
None
}
fn check_odd_function_symmetric(
expr: &Expression,
var: &str,
lower: &Expression,
upper: &Expression,
) -> Option<Expression> {
let neg_upper = Expression::Unary(UnaryOp::Neg, Box::new(upper.clone())).simplify();
if !expressions_equivalent(lower, &neg_upper) {
return None;
}
let neg_var = Expression::Unary(
UnaryOp::Neg,
Box::new(Expression::Variable(Variable::new(var))),
);
let f_neg_x = substitute_var(expr, var, &neg_var).simplify();
let neg_f_x = Expression::Unary(UnaryOp::Neg, Box::new(expr.clone())).simplify();
if expressions_equivalent(&f_neg_x, &neg_f_x) {
return Some(Expression::Integer(0));
}
None
}
fn substitute_var(expr: &Expression, var: &str, replacement: &Expression) -> Expression {
match expr {
Expression::Variable(v) if v.name == var => replacement.clone(),
Expression::Variable(_) => expr.clone(),
Expression::Integer(_) | Expression::Float(_) | Expression::Rational(_) => expr.clone(),
Expression::Constant(_) => expr.clone(),
Expression::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(substitute_var(left, var, replacement)),
Box::new(substitute_var(right, var, replacement)),
),
Expression::Unary(op, operand) => {
Expression::Unary(*op, Box::new(substitute_var(operand, var, replacement)))
}
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter()
.map(|arg| substitute_var(arg, var, replacement))
.collect(),
),
Expression::Power(base, exp) => Expression::Power(
Box::new(substitute_var(base, var, replacement)),
Box::new(substitute_var(exp, var, replacement)),
),
Expression::Complex(_) => expr.clone(),
}
}
pub fn improper_integral_to_infinity(
expr: &Expression,
var: &str,
lower: &Expression,
) -> IntegrationResult {
let antiderivative = integrate(expr, var)?;
let f_lower = substitute_var(&antiderivative, var, lower).simplify();
if let Some(limit_at_inf) = evaluate_limit_at_infinity(&antiderivative, var) {
let result =
Expression::Binary(BinaryOp::Sub, Box::new(limit_at_inf), Box::new(f_lower)).simplify();
return Ok(result);
}
Err(IntegrationError::CannotIntegrate(
"Cannot evaluate improper integral (may be divergent)".to_string(),
))
}
fn evaluate_limit_at_infinity(expr: &Expression, var: &str) -> Option<Expression> {
match expr {
Expression::Integer(_) | Expression::Float(_) | Expression::Rational(_) => {
Some(expr.clone())
}
Expression::Constant(_) => Some(expr.clone()),
Expression::Power(base, exp) => {
if let Expression::Variable(v) = base.as_ref() {
if v.name == var {
match exp.as_ref() {
Expression::Integer(n) if *n < 0 => {
return Some(Expression::Integer(0));
}
Expression::Unary(UnaryOp::Neg, _) => {
return Some(Expression::Integer(0));
}
_ => {
let simplified_exp = exp.simplify();
if let Expression::Integer(n) = simplified_exp {
if n < 0 {
return Some(Expression::Integer(0));
}
}
let empty = std::collections::HashMap::new();
if let Some(val) = simplified_exp.evaluate(&empty) {
if val < 0.0 {
return Some(Expression::Integer(0));
}
}
}
}
}
}
None
}
Expression::Binary(BinaryOp::Div, num, denom) => {
if !num.contains_variable(var) {
if grows_to_infinity(denom, var) {
return Some(Expression::Integer(0));
}
}
if !denom.contains_variable(var) {
if let Some(num_limit) = evaluate_limit_at_infinity(num, var) {
let is_zero = match &num_limit {
Expression::Integer(0) => true,
Expression::Float(f) if *f == 0.0 => true,
_ => false,
};
if is_zero {
return Some(Expression::Integer(0));
}
return Some(
Expression::Binary(BinaryOp::Div, Box::new(num_limit), denom.clone())
.simplify(),
);
}
}
None
}
Expression::Function(Function::Exp, args) if args.len() == 1 => {
if let Expression::Unary(UnaryOp::Neg, inner) = &args[0] {
if let Expression::Variable(v) = inner.as_ref() {
if v.name == var {
return Some(Expression::Integer(0));
}
}
}
None
}
Expression::Binary(BinaryOp::Add, left, right) => {
let l_limit = evaluate_limit_at_infinity(left, var)?;
let r_limit = evaluate_limit_at_infinity(right, var)?;
Some(Expression::Binary(BinaryOp::Add, Box::new(l_limit), Box::new(r_limit)).simplify())
}
Expression::Binary(BinaryOp::Sub, left, right) => {
let l_limit = evaluate_limit_at_infinity(left, var)?;
let r_limit = evaluate_limit_at_infinity(right, var)?;
Some(Expression::Binary(BinaryOp::Sub, Box::new(l_limit), Box::new(r_limit)).simplify())
}
Expression::Binary(BinaryOp::Mul, left, right) => {
if !left.contains_variable(var) {
let r_limit = evaluate_limit_at_infinity(right, var)?;
return Some(
Expression::Binary(BinaryOp::Mul, left.clone(), Box::new(r_limit)).simplify(),
);
}
if !right.contains_variable(var) {
let l_limit = evaluate_limit_at_infinity(left, var)?;
return Some(
Expression::Binary(BinaryOp::Mul, Box::new(l_limit), right.clone()).simplify(),
);
}
None
}
_ => None,
}
}
fn grows_to_infinity(expr: &Expression, var: &str) -> bool {
match expr {
Expression::Variable(v) => v.name == var,
Expression::Power(base, exp) => {
if let Expression::Variable(v) = base.as_ref() {
if v.name == var {
match exp.as_ref() {
Expression::Integer(n) => *n > 0,
_ => false,
}
} else {
false
}
} else {
false
}
}
Expression::Function(Function::Exp, args) if args.len() == 1 => {
args[0].contains_variable(var)
&& !matches!(&args[0], Expression::Unary(UnaryOp::Neg, _))
}
_ => false,
}
}
pub fn definite_integral_with_steps(
expr: &Expression,
var: &str,
lower: &Expression,
upper: &Expression,
) -> Result<(Expression, Vec<String>), IntegrationError> {
let mut steps = Vec::new();
steps.push(format!(
"∫_{{{0}}}^{{{1}}} {2} d{3}",
lower, upper, expr, var
));
if let Some(result) = check_definite_special_cases(expr, var, lower, upper) {
steps.push(format!("By special case analysis: {}", result));
return Ok((result, steps));
}
steps.push("Step 1: Find the antiderivative F(x)".to_string());
let antiderivative = integrate(expr, var)?;
steps.push(format!("F({}) = {}", var, antiderivative));
steps.push("Step 2: Apply Fundamental Theorem of Calculus".to_string());
steps.push("∫_a^b f(x) dx = F(b) - F(a)".to_string());
let f_upper = substitute_var(&antiderivative, var, upper).simplify();
let f_lower = substitute_var(&antiderivative, var, lower).simplify();
steps.push(format!("F({}) = {}", upper, f_upper));
steps.push(format!("F({}) = {}", lower, f_lower));
let result = Expression::Binary(
BinaryOp::Sub,
Box::new(f_upper.clone()),
Box::new(f_lower.clone()),
)
.simplify();
steps.push(format!("Result: {} - {} = {}", f_upper, f_lower, result));
Ok((result, steps))
}
fn is_polynomial_like(expr: &Expression, var: &str) -> bool {
match expr {
Expression::Integer(_) | Expression::Float(_) | Expression::Rational(_) => true,
Expression::Variable(v) => v.name == var,
Expression::Power(base, exp) => {
if let Expression::Variable(v) = base.as_ref() {
if v.name == var {
return !exp.contains_variable(var);
}
}
false
}
Expression::Binary(BinaryOp::Add, left, right)
| Expression::Binary(BinaryOp::Sub, left, right) => {
is_polynomial_like(left, var) && is_polynomial_like(right, var)
}
Expression::Binary(BinaryOp::Mul, left, right) => {
if !left.contains_variable(var) {
is_polynomial_like(right, var)
} else if !right.contains_variable(var) {
is_polynomial_like(left, var)
} else {
is_polynomial_like(left, var) && is_polynomial_like(right, var)
}
}
Expression::Unary(UnaryOp::Neg, inner) => is_polynomial_like(inner, var),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn var(name: &str) -> Expression {
Expression::Variable(Variable::new(name))
}
fn int(n: i64) -> Expression {
Expression::Integer(n)
}
fn pow(base: Expression, exp: Expression) -> Expression {
Expression::Power(Box::new(base), Box::new(exp))
}
fn add(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Add, Box::new(left), Box::new(right))
}
fn mul(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Mul, Box::new(left), Box::new(right))
}
fn div(left: Expression, right: Expression) -> Expression {
Expression::Binary(BinaryOp::Div, Box::new(left), Box::new(right))
}
#[test]
fn test_integrate_constant() {
let result = integrate(&int(5), "x").unwrap();
assert!(matches!(
result,
Expression::Binary(BinaryOp::Mul, left, right)
if matches!(left.as_ref(), Expression::Integer(5))
&& matches!(right.as_ref(), Expression::Variable(v) if v.name == "x")
));
}
#[test]
fn test_integrate_x() {
let result = integrate(&var("x"), "x").unwrap();
assert!(matches!(
result,
Expression::Binary(BinaryOp::Div, _num, denom)
if matches!(denom.as_ref(), Expression::Integer(2))
));
}
#[test]
fn test_integrate_x_squared() {
let x_squared = pow(var("x"), int(2));
let result = integrate(&x_squared, "x").unwrap();
if let Expression::Binary(BinaryOp::Div, num, denom) = result {
assert!(matches!(num.as_ref(), Expression::Power(_, _)));
assert!(matches!(
denom.as_ref(),
Expression::Binary(BinaryOp::Add, _, _)
));
} else {
panic!("Expected division expression");
}
}
#[test]
fn test_integrate_sum() {
let expr = add(pow(var("x"), int(2)), var("x"));
let result = integrate(&expr, "x").unwrap();
assert!(matches!(result, Expression::Binary(BinaryOp::Add, _, _)));
}
#[test]
fn test_integrate_constant_multiple() {
let expr = mul(int(3), var("x"));
let result = integrate(&expr, "x").unwrap();
assert!(matches!(result, Expression::Binary(BinaryOp::Mul, left, _)
if matches!(left.as_ref(), Expression::Integer(3))));
}
#[test]
fn test_integrate_reciprocal() {
let expr = div(int(1), var("x"));
let result = integrate(&expr, "x").unwrap();
assert!(
matches!(result, Expression::Binary(BinaryOp::Mul, _, ln_part)
if matches!(ln_part.as_ref(), Expression::Function(Function::Ln, _)))
);
}
#[test]
fn test_integrate_sin() {
let sin_x = Expression::Function(Function::Sin, vec![var("x")]);
let result = integrate(&sin_x, "x").unwrap();
assert!(matches!(
result,
Expression::Unary(UnaryOp::Neg, inner)
if matches!(inner.as_ref(), Expression::Function(Function::Cos, _))
));
}
#[test]
fn test_integrate_cos() {
let cos_x = Expression::Function(Function::Cos, vec![var("x")]);
let result = integrate(&cos_x, "x").unwrap();
assert!(matches!(
result,
Expression::Function(Function::Sin, args)
if args.len() == 1
));
}
#[test]
fn test_integrate_exp() {
let exp_x = Expression::Function(Function::Exp, vec![var("x")]);
let result = integrate(&exp_x, "x").unwrap();
assert!(matches!(
result,
Expression::Function(Function::Exp, args)
if args.len() == 1
));
}
#[test]
fn test_integrate_x_power_negative_one() {
let expr = pow(var("x"), int(-1));
let result = integrate(&expr, "x").unwrap();
assert!(matches!(
result,
Expression::Function(Function::Ln, args)
if matches!(&args[0], Expression::Function(Function::Abs, _))
));
}
#[test]
fn test_integrate_polynomial() {
let poly = add(add(pow(var("x"), int(2)), mul(int(2), var("x"))), int(1));
let result = integrate(&poly, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_linear_sin() {
let two_x = mul(int(2), var("x"));
let sin_2x = Expression::Function(Function::Sin, vec![two_x]);
let result = integrate(&sin_2x, "x").unwrap();
assert!(matches!(result, Expression::Binary(BinaryOp::Div, _, _)));
}
#[test]
fn test_differentiate_integral_equals_original() {
let x_squared = pow(var("x"), int(2));
let integral = integrate(&x_squared, "x").unwrap();
let _derivative = integral.differentiate("x").simplify();
}
#[test]
fn test_extract_factors() {
let expr = mul(int(2), var("x"));
let factors = extract_factors(&expr);
assert_eq!(factors.len(), 2);
let expr2 = mul(mul(int(2), var("x")), int(3));
let factors2 = extract_factors(&expr2);
assert_eq!(factors2.len(), 3);
}
#[test]
fn test_combine_factors() {
let factors = vec![int(2), var("x"), int(3)];
let combined = combine_factors(&factors);
assert!(matches!(combined, Expression::Binary(BinaryOp::Mul, _, _)));
}
#[test]
fn test_extract_inner_function() {
let sin_x = Expression::Function(Function::Sin, vec![var("x")]);
let inner = extract_inner_function(&sin_x);
assert!(matches!(inner, Some(Expression::Variable(_))));
let x_squared = pow(var("x"), int(2));
let sin_x2 = Expression::Function(Function::Sin, vec![x_squared]);
let inner2 = extract_inner_function(&sin_x2);
assert!(inner2.is_some());
}
#[test]
fn test_substitute_variable() {
let u = Expression::Variable(Variable::new("u"));
let replacement = pow(var("x"), int(2));
let result = substitute_variable(&u, "u", &replacement);
assert!(matches!(result, Expression::Power(_, _)));
}
#[test]
fn test_integrate_by_substitution_linear() {
let three_x = mul(int(3), var("x"));
let sin_3x = Expression::Function(Function::Sin, vec![three_x]);
let result = integrate_by_substitution(&sin_3x, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_by_substitution_exp() {
let two_x = mul(int(2), var("x"));
let exp_2x = Expression::Function(Function::Exp, vec![two_x]);
let result = integrate_by_substitution(&exp_2x, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_with_substitution_steps() {
let x_squared = pow(var("x"), int(2));
let (result, steps) = integrate_with_substitution(&x_squared, "x").unwrap();
assert!(!steps.is_empty());
assert!(matches!(result, Expression::Binary(BinaryOp::Div, _, _)));
}
#[test]
fn test_expressions_equivalent() {
let a = var("x");
let b = var("x");
assert!(expressions_equivalent(&a, &b));
let c = var("y");
assert!(!expressions_equivalent(&a, &c));
}
#[test]
fn test_expressions_equivalent_add_commutativity() {
let xy = add(var("x"), var("y"));
let yx = add(var("y"), var("x"));
assert!(expressions_equivalent(&xy, &yx));
let xz = add(var("x"), var("z"));
assert!(!expressions_equivalent(&xy, &xz));
}
#[test]
fn test_expressions_equivalent_mul_commutativity() {
let two_x = mul(int(2), var("x"));
let x_two = mul(var("x"), int(2));
assert!(expressions_equivalent(&two_x, &x_two));
let three_x = mul(int(3), var("x"));
assert!(!expressions_equivalent(&two_x, &three_x));
}
#[test]
fn test_expressions_equivalent_nested_commutativity() {
let x_plus_x = add(var("x"), var("x")).simplify();
let x_times_2 = mul(var("x"), int(2));
assert!(expressions_equivalent(&x_plus_x, &x_times_2));
}
#[test]
fn test_is_one() {
assert!(is_one(&Expression::Integer(1)));
assert!(!is_one(&Expression::Integer(2)));
assert!(!is_one(&var("x")));
}
#[test]
fn test_liate_priority() {
let ln_x = Expression::Function(Function::Ln, vec![var("x")]);
assert!(liate_priority(&ln_x, "x") > liate_priority(&var("x"), "x"));
let sin_x = Expression::Function(Function::Sin, vec![var("x")]);
assert!(liate_priority(&var("x"), "x") > liate_priority(&sin_x, "x"));
let exp_x = Expression::Function(Function::Exp, vec![var("x")]);
assert!(liate_priority(&sin_x, "x") > liate_priority(&exp_x, "x"));
}
#[test]
fn test_is_polynomial_like() {
assert!(is_polynomial_like(&var("x"), "x"));
assert!(is_polynomial_like(&pow(var("x"), int(2)), "x"));
assert!(is_polynomial_like(
&add(pow(var("x"), int(2)), var("x")),
"x"
));
assert!(is_polynomial_like(&int(5), "x"));
let sin_x = Expression::Function(Function::Sin, vec![var("x")]);
assert!(!is_polynomial_like(&sin_x, "x"));
}
#[test]
fn test_integrate_by_parts_x_exp() {
let x = var("x");
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let expr = mul(x.clone(), exp_x.clone());
let result = integrate_by_parts(&expr, "x");
assert!(result.is_ok());
if let Ok(integral) = result {
let derivative = integral.differentiate("x").simplify();
assert!(!matches!(derivative, Expression::Integer(0)));
}
}
#[test]
fn test_integrate_by_parts_ln_x() {
let ln_x = Expression::Function(Function::Ln, vec![var("x")]);
let result = integrate(&ln_x, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_by_parts_x_sin() {
let x = var("x");
let sin_x = Expression::Function(Function::Sin, vec![x.clone()]);
let expr = mul(x.clone(), sin_x.clone());
let result = integrate_by_parts(&expr, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_by_parts_x_squared_exp() {
let x = var("x");
let x_squared = pow(x.clone(), int(2));
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let expr = mul(x_squared.clone(), exp_x.clone());
let result = integrate_by_parts(&expr, "x");
assert!(result.is_ok());
}
#[test]
fn test_tabular_integration_x_exp() {
let x = var("x");
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let result = tabular_integration(&x, &exp_x, "x");
assert!(result.is_ok());
}
#[test]
fn test_tabular_integration_x_squared_exp() {
let x = var("x");
let x_squared = pow(x.clone(), int(2));
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let result = tabular_integration(&x_squared, &exp_x, "x");
assert!(result.is_ok());
}
#[test]
fn test_tabular_integration_x_sin() {
let x = var("x");
let sin_x = Expression::Function(Function::Sin, vec![x.clone()]);
let result = tabular_integration(&x, &sin_x, "x");
assert!(result.is_ok());
}
#[test]
fn test_integrate_by_parts_with_steps() {
let x = var("x");
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let expr = mul(x.clone(), exp_x.clone());
let result = integrate_by_parts_with_steps(&expr, "x");
assert!(result.is_ok());
if let Ok((_, steps)) = result {
assert!(steps.len() >= 5);
assert!(steps.iter().any(|s| s.contains("integration by parts")));
}
}
#[test]
fn test_choose_u_and_dv() {
let x = var("x");
let exp_x = Expression::Function(Function::Exp, vec![x.clone()]);
let factors = vec![x.clone(), exp_x.clone()];
let (u, dv) = choose_u_and_dv(&factors, "x");
assert!(matches!(u, Expression::Variable(_)));
assert!(matches!(dv, Expression::Function(Function::Exp, _)));
}
#[test]
fn test_check_same_up_to_constant() {
let a = var("x");
let b = var("x");
assert!(check_same_up_to_constant(&a, &b, "x").is_some());
let neg_a = Expression::Unary(UnaryOp::Neg, Box::new(a.clone()));
let result = check_same_up_to_constant(&a, &neg_a, "x");
assert!(result.is_some());
assert!(matches!(result, Some(Expression::Integer(-1))));
let c = var("y");
assert!(check_same_up_to_constant(&a, &c, "x").is_none());
}
#[test]
fn test_definite_integral_x_squared() {
let x_squared = pow(var("x"), int(2));
let result = definite_integral(&x_squared, "x", &int(0), &int(1));
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_sin() {
let sin_x = Expression::Function(Function::Sin, vec![var("x")]);
let pi = Expression::Constant(crate::ast::SymbolicConstant::Pi);
let result = definite_integral(&sin_x, "x", &int(0), &pi);
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 2.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_odd_function() {
let x_cubed = pow(var("x"), int(3));
let result = definite_integral(&x_cubed, "x", &int(-1), &int(1));
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!(numeric.abs() < 1e-10);
}
#[test]
fn test_definite_integral_symbolic_upper_bound() {
let x = var("x");
let a = var("a");
let result = definite_integral(&x, "x", &int(0), &a);
assert!(result.is_ok());
let value = result.unwrap();
let mut env = std::collections::HashMap::new();
env.insert("a".to_string(), 2.0);
let numeric = value.evaluate(&env).unwrap();
assert!((numeric - 2.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_polynomial() {
let x = var("x");
let poly = add(
add(mul(int(3), pow(x.clone(), int(2))), mul(int(2), x.clone())),
int(1),
);
let result = definite_integral(&poly, "x", &int(0), &int(2));
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 14.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_cos() {
let cos_x = Expression::Function(Function::Cos, vec![var("x")]);
let pi = Expression::Constant(crate::ast::SymbolicConstant::Pi);
let upper = div(pi, int(2));
let result = definite_integral(&cos_x, "x", &int(0), &upper);
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 1.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_exp() {
let exp_x = Expression::Function(Function::Exp, vec![var("x")]);
let result = definite_integral(&exp_x, "x", &int(0), &int(1));
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
let expected = std::f64::consts::E - 1.0;
assert!((numeric - expected).abs() < 1e-10);
}
#[test]
fn test_numerical_integrate_simple() {
let x_squared = pow(var("x"), int(2));
let result = numerical_integrate(&x_squared, "x", 0.0, 1.0, 1e-8);
assert!(result.is_ok());
let value = result.unwrap();
assert!((value - 1.0 / 3.0).abs() < 1e-6);
}
#[test]
fn test_numerical_integrate_complex() {
let x = var("x");
let neg_x_squared = Expression::Unary(UnaryOp::Neg, Box::new(pow(x.clone(), int(2))));
let exp_neg_x_squared = Expression::Function(Function::Exp, vec![neg_x_squared]);
let result = numerical_integrate(&exp_neg_x_squared, "x", 0.0, 1.0, 1e-6);
assert!(result.is_ok());
let value = result.unwrap();
assert!((value - 0.74682).abs() < 0.001);
}
#[test]
fn test_definite_integral_with_fallback() {
let x_squared = pow(var("x"), int(2));
let result = definite_integral_with_fallback(&x_squared, "x", 0.0, 1.0, 1e-8);
assert!(result.is_ok());
let value = result.unwrap();
assert!((value - 1.0 / 3.0).abs() < 1e-6);
}
#[test]
fn test_improper_integral_convergent() {
let x = var("x");
let x_neg_2 = pow(x.clone(), int(-2));
let result = improper_integral_to_infinity(&x_neg_2, "x", &int(1));
assert!(result.is_ok());
let value = result.unwrap();
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 1.0).abs() < 1e-10);
}
#[test]
fn test_definite_integral_with_steps() {
let x_squared = pow(var("x"), int(2));
let result = definite_integral_with_steps(&x_squared, "x", &int(0), &int(1));
assert!(result.is_ok());
let (value, steps) = result.unwrap();
assert!(!steps.is_empty());
assert!(steps.iter().any(|s| {
s.to_lowercase().contains("antiderivative") || s.to_lowercase().contains("bound")
}));
let empty = std::collections::HashMap::new();
let numeric = value.evaluate(&empty).unwrap();
assert!((numeric - 1.0 / 3.0).abs() < 1e-10);
}
}