pub mod exp_case;
pub mod log_case;
pub mod poly_rde;
pub mod rational_integrate;
pub mod rational_rde;
pub mod tower;
use crate::deriv::log::{DerivationLog, DerivedExpr};
use crate::integrate::engine::IntegrationError;
use crate::kernel::{ExprId, ExprPool};
use crate::simplify::engine::simplify;
use exp_case::{integrate_exp_tower, needs_exp_risch};
use log_case::{integrate_log_tower, needs_log_risch};
use tower::find_generators;
pub fn contains_risch_form(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
needs_exp_risch(expr, var, pool) || needs_log_risch(expr, var, pool)
}
pub fn integrate_risch(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<ExprId>, IntegrationError> {
let generators = find_generators(expr, var, pool);
let exp_gens: Vec<_> = generators.iter().filter(|g| g.is_exp()).collect();
let log_gens: Vec<_> = generators.iter().filter(|g| g.is_log()).collect();
let mut log = DerivationLog::new();
if exp_gens.len() == 1 && log_gens.is_empty() {
let level = exp_gens[0];
let result = integrate_exp_tower(expr, level, var, pool, &mut log)?;
let final_simplified = simplify(result, pool);
let merged = log.merge(final_simplified.log);
return Ok(DerivedExpr::with_log(final_simplified.value, merged));
}
if log_gens.len() == 1 && exp_gens.is_empty() {
let level = log_gens[0];
let result = integrate_log_tower(expr, level, var, pool, &mut log)?;
let final_simplified = simplify(result, pool);
let merged = log.merge(final_simplified.log);
return Ok(DerivedExpr::with_log(final_simplified.value, merged));
}
if !generators.is_empty() {
if let Some(result) = try_decompose_by_sum(expr, var, pool, &mut log) {
let final_simplified = simplify(result, pool);
let merged = log.merge(final_simplified.log);
return Ok(DerivedExpr::with_log(final_simplified.value, merged));
}
}
let gen_names: Vec<String> = generators
.iter()
.map(|g| pool.display(g.generator).to_string())
.collect();
Err(IntegrationError::NotImplemented(format!(
"Risch: multiple interacting generators {:?} not yet supported; \
implement the mixed-tower algorithm (Bronstein 2005, §9)",
gen_names
)))
}
fn try_decompose_by_sum(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<ExprId> {
use crate::kernel::ExprData;
let args = match pool.get(expr) {
ExprData::Add(args) => args,
_ => return None,
};
let zero = pool.integer(0_i32);
let mut result_terms: Vec<ExprId> = Vec::new();
for &term in &args {
let int_term = if contains_risch_form(term, var, pool) {
match integrate_risch(term, var, pool) {
Ok(d) => {
*log = std::mem::take(log).merge(d.log);
d.value
}
Err(_) => return None,
}
} else {
let mut inner_log = DerivationLog::new();
match crate::integrate::engine::integrate_raw(term, var, pool, &mut inner_log) {
Ok(r) => {
*log = std::mem::take(log).merge(inner_log);
r
}
Err(_) => return None,
}
};
result_terms.push(int_term);
}
Some(match result_terms.len() {
0 => zero,
1 => result_terms[0],
_ => pool.add(result_terms),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn exp_x2_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.func("exp", vec![x2]);
let result = integrate_risch(f, x, &pool);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ exp(x²) dx should be NonElementary; got: {result:?}"
);
}
#[test]
fn exp_neg_x2_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let neg_x2 = pool.mul(vec![pool.integer(-1_i32), pool.pow(x, pool.integer(2_i32))]);
let f = pool.func("exp", vec![neg_x2]);
let result = integrate_risch(f, x, &pool);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ exp(−x²) dx should be NonElementary; got: {result:?}"
);
}
#[test]
fn x_times_exp_x2_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.mul(vec![x, pool.func("exp", vec![x2])]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ x·exp(x²) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "x·exp(x²)");
}
#[test]
fn two_x_exp_x2_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let f = pool.mul(vec![pool.integer(2_i32), x, exp_x2]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ 2x·exp(x²) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "2x·exp(x²)");
}
#[test]
fn poly_times_exp_x2_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let two_x2_plus_1 = pool.add(vec![
pool.mul(vec![pool.integer(2_i32), pool.pow(x, pool.integer(2_i32))]),
pool.integer(1_i32),
]);
let f = pool.mul(vec![two_x2_plus_1, exp_x2]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ (2x²+1)·exp(x²) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "(2x²+1)·exp(x²)");
}
#[test]
fn x2_times_exp_x_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x = pool.func("exp", vec![x]);
let f = pool.mul(vec![x2, exp_x]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ x²·exp(x) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "x²·exp(x)");
}
#[test]
fn x3_times_exp_x_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x3 = pool.pow(x, pool.integer(3_i32));
let exp_x = pool.func("exp", vec![x]);
let f = pool.mul(vec![x3, exp_x]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ x³·exp(x) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "x³·exp(x)");
}
#[test]
fn log_x_squared_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let f = pool.pow(log_x, pool.integer(2_i32));
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ log(x)² dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "log(x)²");
}
#[test]
fn x_times_log_x_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let f = pool.mul(vec![x, log_x]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ x·log(x) dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "x·log(x)");
}
#[test]
fn log_x_cubed_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let f = pool.pow(log_x, pool.integer(3_i32));
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ log(x)³ dx should be elementary; got: {result:?}"
);
let antideriv = result.unwrap().value;
verify_antiderivative(&pool, x, f, antideriv, "log(x)³");
}
#[test]
fn rational_coeff_exp_elementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let num = pool.add(vec![x, pool.integer(-1_i32)]); let inv_x2 = pool.pow(x, pool.integer(-2_i32)); let f = pool.mul(vec![num, inv_x2, exp_x]);
assert!(contains_risch_form(f, x, &pool), "should route to Risch");
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ (x−1)/x²·exp(x) dx should be elementary; got {result:?}"
);
let antideriv = result.unwrap().value;
let d = crate::diff::diff(antideriv, x, &pool).unwrap();
for &xv in &[1.3_f64, 2.7, 4.1] {
let lhs = eval_f64(d.value, x, xv, &pool);
let rhs = eval_f64(f, x, xv, &pool);
assert!(
(lhs - rhs).abs() < 1e-9,
"d/dx F ≠ f at x={xv}: {lhs} vs {rhs}"
);
}
}
fn eval_f64(expr: ExprId, x: ExprId, xv: f64, pool: &ExprPool) -> f64 {
use crate::kernel::ExprData;
if expr == x {
return xv;
}
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_f64(a, x, xv, pool)).sum(),
ExprData::Mul(args) => args.iter().map(|&a| eval_f64(a, x, xv, pool)).product(),
ExprData::Pow { base, exp } => {
eval_f64(base, x, xv, pool).powf(eval_f64(exp, x, xv, pool))
}
ExprData::Func { ref name, ref args } if name == "exp" && args.len() == 1 => {
eval_f64(args[0], x, xv, pool).exp()
}
other => panic!("eval_f64: unsupported node {other:?}"),
}
}
#[test]
fn rational_coeff_exp_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let x2 = pool.pow(x, pool.integer(2_i32));
let inv_xp1 = pool.pow(pool.add(vec![x, pool.integer(1_i32)]), pool.integer(-1_i32));
let f = pool.mul(vec![x2, inv_xp1, exp_x]);
assert!(contains_risch_form(f, x, &pool), "should route to Risch");
let result = integrate_risch(f, x, &pool);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ x²/(x+1)·exp(x) dx should be NonElementary; got {result:?}"
);
}
#[test]
fn sum_exp_x2_and_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let f = pool.add(vec![pool.mul(vec![x, exp_x2]), x]);
let result = integrate_risch(f, x, &pool);
assert!(
result.is_ok(),
"∫ (x·exp(x²) + x) dx should be elementary; got: {result:?}"
);
}
#[test]
fn detection_predicate() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let exp_x2 = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
assert!(contains_risch_form(exp_x2, x, &pool));
let exp_x = pool.func("exp", vec![x]);
assert!(!contains_risch_form(exp_x, x, &pool));
let log_x = pool.func("log", vec![x]);
let log2 = pool.pow(log_x, pool.integer(2_i32));
assert!(contains_risch_form(log2, x, &pool));
assert!(!contains_risch_form(log_x, x, &pool));
let x2_exp_x = pool.mul(vec![pool.pow(x, pool.integer(2_i32)), exp_x]);
assert!(contains_risch_form(x2_exp_x, x, &pool));
let x_log_x = pool.mul(vec![x, log_x]);
assert!(contains_risch_form(x_log_x, x, &pool));
}
fn verify_antiderivative(
pool: &ExprPool,
x: ExprId,
f: ExprId,
antideriv: ExprId,
label: &str,
) {
use crate::diff::diff;
use crate::poly::UniPoly;
let d_antideriv = diff(antideriv, x, pool).unwrap();
match (
UniPoly::from_symbolic(d_antideriv.value, x, pool),
UniPoly::from_symbolic(f, x, pool),
) {
(Ok(a), Ok(b)) => {
assert_eq!(
a.coefficients_i64(),
b.coefficients_i64(),
"{label}: d/dx antideriv ≠ f (polynomial check)"
);
}
_ => {
let _ = d_antideriv.value; }
}
}
}