use crate::calculus::integrals::{
basic::BasicIntegrals, by_parts::IntegrationByParts, function_integrals::FunctionIntegrals,
rational, risch, substitution, table, trigonometric,
};
use crate::core::{Expression, Number, Symbol};
use std::collections::HashSet;
const MAX_DEPTH: usize = 10;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IntegrationStrategy {
TableLookup,
RationalFunction,
FunctionRegistry,
IntegrationByParts,
Substitution,
Trigonometric,
Risch,
BasicRules,
}
#[derive(Debug, Clone)]
pub struct StrategyContext {
active_strategies: HashSet<IntegrationStrategy>,
depth: usize,
}
impl StrategyContext {
pub fn new() -> Self {
Self {
active_strategies: HashSet::new(),
depth: 0,
}
}
pub fn is_active(&self, strategy: IntegrationStrategy) -> bool {
self.active_strategies.contains(&strategy)
}
pub fn with_strategy<F>(&self, strategy: IntegrationStrategy, f: F) -> Option<Expression>
where
F: FnOnce(&Self) -> Option<Expression>,
{
if self.is_active(strategy) {
return None;
}
let mut child_context = self.clone();
child_context.active_strategies.insert(strategy);
child_context.depth += 1;
f(&child_context)
}
pub fn depth(&self) -> usize {
self.depth
}
}
impl Default for StrategyContext {
fn default() -> Self {
Self::new()
}
}
pub fn integrate_with_strategy(expr: &Expression, var: Symbol, depth: usize) -> Expression {
if depth >= MAX_DEPTH {
return Expression::integral(expr.clone(), var);
}
let context = StrategyContext {
active_strategies: HashSet::new(),
depth,
};
integrate_with_context(expr, var, &context)
}
fn integrate_with_context(expr: &Expression, var: Symbol, ctx: &StrategyContext) -> Expression {
if let Some(result) = try_table_lookup_with_context(expr, &var, ctx) {
return result;
}
if is_rational_function(expr, &var) {
if let Some(result) = try_rational_function(expr, &var) {
return result;
}
}
if let Some(result) = try_registry_integration_with_context(expr, &var, ctx) {
return result;
}
if let Some(result) = ctx.with_strategy(IntegrationStrategy::IntegrationByParts, |child_ctx| {
try_by_parts_with_context(expr, &var, child_ctx, child_ctx.depth())
}) {
return result;
}
if let Some(result) = ctx.with_strategy(IntegrationStrategy::Substitution, |child_ctx| {
try_substitution_with_context(expr, &var, child_ctx)
}) {
return result;
}
if let Some(result) = ctx.with_strategy(IntegrationStrategy::Trigonometric, |child_ctx| {
try_trigonometric_with_context(expr, &var, child_ctx)
}) {
return result;
}
if let Some(result) = ctx.with_strategy(IntegrationStrategy::Risch, |child_ctx| {
try_risch_with_context(expr, &var, child_ctx)
}) {
return result;
}
if let Some(result) = try_basic_rules_with_context(expr, &var, ctx) {
return result;
}
Expression::integral(expr.clone(), var)
}
fn try_table_lookup_with_context(
expr: &Expression,
var: &Symbol,
_context: &StrategyContext,
) -> Option<Expression> {
table::try_table_lookup(expr, var)
}
fn try_rational_function(expr: &Expression, var: &Symbol) -> Option<Expression> {
rational::integrate_rational(expr, var)
}
pub fn try_registry_integration(expr: &Expression, var: &Symbol) -> Option<Expression> {
match expr {
Expression::Function { name, args } => {
let result = FunctionIntegrals::integrate(name, args, var.clone());
if is_symbolic_integral(&result) {
None
} else {
Some(result)
}
}
_ => None,
}
}
fn try_registry_integration_with_context(
expr: &Expression,
var: &Symbol,
_context: &StrategyContext,
) -> Option<Expression> {
try_registry_integration(expr, var)
}
pub fn try_by_parts(expr: &Expression, var: &Symbol, depth: usize) -> Option<Expression> {
IntegrationByParts::integrate(expr, var.clone(), depth)
}
fn try_by_parts_with_context(
expr: &Expression,
var: &Symbol,
context: &StrategyContext,
depth: usize,
) -> Option<Expression> {
IntegrationByParts::integrate_with_context(expr, var.clone(), context, depth)
}
fn try_substitution_with_context(
expr: &Expression,
var: &Symbol,
context: &StrategyContext,
) -> Option<Expression> {
substitution::try_substitution(expr, var, context.depth())
}
fn try_trigonometric_with_context(
expr: &Expression,
var: &Symbol,
_context: &StrategyContext,
) -> Option<Expression> {
trigonometric::try_trigonometric_integration(expr, var)
}
fn try_risch_with_context(
expr: &Expression,
var: &Symbol,
_context: &StrategyContext,
) -> Option<Expression> {
risch::try_risch_integration(expr, var)
}
pub fn is_polynomial(expr: &Expression, _var: &Symbol) -> bool {
match expr {
Expression::Number(_) | Expression::Constant(_) => true,
Expression::Symbol(_sym) => true,
Expression::Add(terms) => terms.iter().all(|t| is_polynomial(t, _var)),
Expression::Mul(factors) => factors.iter().all(|f| is_polynomial(f, _var)),
Expression::Pow(base, exp) => {
if !is_polynomial(base, _var) {
return false;
}
matches!(exp.as_ref(), Expression::Number(Number::Integer(n)) if * n >= 0)
}
_ => false,
}
}
fn is_rational_function(expr: &Expression, var: &Symbol) -> bool {
let result = match expr {
_ if is_polynomial(expr, var) => true,
Expression::Pow(base, exp) => {
if let Expression::Number(Number::Integer(_)) = exp.as_ref() {
let poly_check = is_polynomial(base.as_ref(), var);
poly_check
} else {
false
}
}
Expression::Mul(factors) => {
for factor in factors.iter() {
if let Expression::Pow(base, exp) = factor {
if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
if *n < 0 && !is_polynomial(base.as_ref(), var) {
return false;
}
} else {
return false;
}
}
}
factors.iter().all(|f| match f {
Expression::Pow(base, _) => is_polynomial(base.as_ref(), var),
_ => is_polynomial(f, var),
})
}
_ => false,
};
result
}
fn try_basic_rules_with_context(
expr: &Expression,
var: &Symbol,
context: &StrategyContext,
) -> Option<Expression> {
match expr {
Expression::Number(_) => Some(BasicIntegrals::handle_constant(expr, var.clone())),
Expression::Symbol(sym) => Some(BasicIntegrals::handle_symbol(sym, var)),
Expression::Add(terms) => Some(BasicIntegrals::handle_sum(terms, var, context.depth())),
Expression::Mul(factors) => Some(BasicIntegrals::handle_product(
factors,
var.clone(),
context.depth(),
)),
Expression::Pow(base, exp) => Some(BasicIntegrals::handle_power(base, exp, var.clone())),
Expression::Calculus(data) => {
Some(BasicIntegrals::handle_calculus(expr, data, var.clone()))
}
_ => None,
}
}
fn is_symbolic_integral(expr: &Expression) -> bool {
matches!(expr, Expression::Calculus(_))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
#[test]
fn test_is_polynomial_constant() {
let x = symbol!(x);
assert!(is_polynomial(&Expression::integer(5), &x));
}
#[test]
fn test_is_polynomial_variable() {
let x = symbol!(x);
assert!(is_polynomial(&Expression::symbol(x.clone()), &x));
}
#[test]
fn test_is_polynomial_sum() {
let x = symbol!(x);
let poly = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::symbol(x.clone()),
Expression::integer(1),
]);
assert!(is_polynomial(&poly, &x));
}
#[test]
fn test_is_polynomial_product() {
let x = symbol!(x);
let poly = Expression::mul(vec![
Expression::integer(3),
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
]);
assert!(is_polynomial(&poly, &x));
}
#[test]
fn test_is_not_polynomial_negative_power() {
let x = symbol!(x);
let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(-1));
assert!(!is_polynomial(&expr, &x));
}
#[test]
fn test_is_not_polynomial_function() {
let x = symbol!(x);
let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
assert!(!is_polynomial(&expr, &x));
}
}