use crate::deriv::log::{DerivationLog, RewriteStep};
use crate::integrate::engine::IntegrationError;
use crate::kernel::{ExprId, ExprPool};
use crate::simplify::engine::simplify;
use super::poly_rde::{expr_to_qpoly, is_free_of_var, qpoly_to_expr, solve_poly_rde};
use super::tower::{decompose_wrt_exp, poly_degree, TowerLevel};
pub fn integrate_exp_tower(
expr: ExprId,
level: &TowerLevel,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let exp_gen = level.generator; let eta = match level.kind {
super::tower::ExtensionKind::Exp { eta } => eta,
_ => {
return Err(IntegrationError::NotImplemented(
"integrate_exp_tower called with non-Exp level".to_string(),
))
}
};
let (rational_part, exp_terms) = decompose_wrt_exp(expr, exp_gen, pool);
let zero = pool.integer(0_i32);
let int_rational = if is_zero(rational_part, pool) {
zero
} else {
let mut inner_log = DerivationLog::new();
match crate::integrate::engine::integrate_raw(rational_part, var, pool, &mut inner_log) {
Ok(r) => {
*log = log.clone().merge(inner_log);
r
}
Err(e) => return Err(e),
}
};
if exp_terms.is_empty() {
return Ok(int_rational);
}
let deta_expr = differentiate_poly(eta, var, pool)?;
let deta = expr_to_qpoly(deta_expr, var, pool).ok_or_else(|| {
IntegrationError::NotImplemented(format!(
"exponent derivative η'(x) = {} is not a polynomial in the integration variable; \
only polynomial exponents are supported",
pool.display(deta_expr)
))
})?;
let mut result_terms: Vec<ExprId> = Vec::new();
if !is_zero(int_rational, pool) {
result_terms.push(int_rational);
}
for (c_expr, k) in &exp_terms {
let k = *k;
let term_result =
integrate_single_exp_term(*c_expr, k, &deta, deta_expr, eta, exp_gen, var, pool, log)?;
result_terms.push(term_result);
}
let raw = match result_terms.len() {
0 => zero,
1 => result_terms[0],
_ => pool.add(result_terms),
};
let simplified = simplify(raw, pool);
*log = log.clone().merge(simplified.log);
log.push(RewriteStep::simple("risch_exp", expr, simplified.value));
Ok(simplified.value)
}
#[allow(clippy::too_many_arguments)]
fn integrate_single_exp_term(
c_expr: ExprId,
k: i64,
deta: &[rug::Rational], deta_expr: ExprId, eta: ExprId, exp_gen: ExprId, var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let c_poly = expr_to_qpoly(c_expr, var, pool).ok_or_else(|| {
IntegrationError::NotImplemented(format!(
"coefficient {} of exp(η)^{} is not a polynomial in the integration variable; \
only polynomial coefficients are supported",
pool.display(c_expr),
k
))
})?;
match solve_poly_rde(k, deta, &c_poly) {
Some(v_poly) => {
let v_expr = qpoly_to_expr(&v_poly, var, pool);
let exp_k_eta = build_exp_k_eta(k, eta, exp_gen, pool);
let result = if is_one(v_expr, pool) {
exp_k_eta
} else {
pool.mul(vec![v_expr, exp_k_eta])
};
log.push(RewriteStep::simple("risch_exp_rde", c_expr, result));
Ok(result)
}
None => {
Err(IntegrationError::NonElementary(format!(
"the Risch DE v'(x) + {}·({}(x))·v(x) = {}(x) has no polynomial solution;\n\
the integrand ∫ {} · exp(η)^{} dx is not an elementary function\n\
(η = {})",
k,
pool.display(deta_expr),
pool.display(c_expr),
pool.display(c_expr),
k,
pool.display(eta),
)))
}
}
}
fn differentiate_poly(
poly_expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<ExprId, IntegrationError> {
use crate::diff::diff;
match diff(poly_expr, var, pool) {
Ok(derived) => Ok(derived.value),
Err(e) => Err(IntegrationError::NotImplemented(format!(
"could not differentiate exponent: {e}"
))),
}
}
fn build_exp_k_eta(k: i64, eta: ExprId, exp_gen: ExprId, pool: &ExprPool) -> ExprId {
match k {
0 => pool.integer(1_i32),
1 => exp_gen,
_ => {
let k_expr = pool.integer(k);
let k_eta = pool.mul(vec![k_expr, eta]);
pool.func("exp", vec![k_eta])
}
}
}
fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 0)
}
fn is_one(expr: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 1)
}
pub fn needs_exp_risch(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
needs_exp_risch_inner(expr, var, pool)
}
fn needs_exp_risch_inner(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "exp" && args.len() == 1 => {
let eta = args[0];
if is_free_of_var(eta, var, pool) {
return false;
}
if let Some(d) = poly_degree(eta, var, pool) {
if d >= 2 {
return true;
}
}
false
}
ExprData::Mul(args) => {
let mut has_linear_exp = false;
let mut max_poly_deg: u32 = 0;
let mut has_nonlinear_exp = false;
for &a in &args {
match pool.get(a) {
ExprData::Func { ref name, ref args } if name == "exp" && args.len() == 1 => {
let eta = args[0];
if is_free_of_var(eta, var, pool) {
} else if let Some(d) = poly_degree(eta, var, pool) {
if d >= 2 {
has_nonlinear_exp = true;
} else {
has_linear_exp = true;
}
} else {
has_linear_exp = true;
}
}
_ => {
if let Some(d) = poly_degree(a, var, pool) {
max_poly_deg = max_poly_deg.max(d);
}
}
}
}
if has_nonlinear_exp {
return true;
}
if has_linear_exp && max_poly_deg >= 2 {
return true;
}
args.iter().any(|&a| needs_exp_risch_inner(a, var, pool))
}
ExprData::Add(args) => args.iter().any(|&a| needs_exp_risch_inner(a, var, pool)),
ExprData::Pow { base, exp } => {
needs_exp_risch_inner(base, var, pool) || needs_exp_risch_inner(exp, var, pool)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool() -> ExprPool {
ExprPool::new()
}
#[test]
fn exp_x2_is_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
use super::super::tower::find_generators;
let gens = find_generators(exp_x2, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(exp_x2, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ exp(x²) dx should be NonElementary, got: {:?}",
result
);
}
#[test]
fn x_times_exp_x2_is_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let integrand = pool.mul(vec![x, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ x·exp(x²) dx should be elementary, got: {:?}",
result
);
let antideriv = result.unwrap();
let s = pool.display(antideriv).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn two_x_times_exp_x2_equals_exp_x2() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let two = pool.integer(2_i32);
let integrand = pool.mul(vec![two, x, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log).unwrap();
let s = pool.display(result).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn x2_times_exp_x_is_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x = pool.func("exp", vec![x]);
let integrand = pool.mul(vec![x2, exp_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ x²·exp(x) dx should be elementary, got: {:?}",
result
);
let s = pool.display(result.unwrap()).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn needs_exp_risch_detection() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
assert!(needs_exp_risch(exp_x2, x, &pool));
let exp_x = pool.func("exp", vec![x]);
assert!(!needs_exp_risch(exp_x, x, &pool));
let x2_times_exp_x = pool.mul(vec![pool.pow(x, pool.integer(2_i32)), exp_x]);
assert!(needs_exp_risch(x2_times_exp_x, x, &pool));
let x_times_exp_x = pool.mul(vec![x, exp_x]);
assert!(!needs_exp_risch(x_times_exp_x, x, &pool));
let x_times_exp_x2 = pool.mul(vec![x, exp_x2]);
assert!(needs_exp_risch(x_times_exp_x2, x, &pool));
}
}