use crate::deriv::log::{DerivationLog, RewriteStep};
use crate::integrate::engine::IntegrationError;
use crate::kernel::{ExprData, ExprId, ExprPool};
use crate::simplify::engine::simplify;
use super::poly_rde::is_free_of_var;
use super::tower::{decompose_as_log_poly, ExtensionKind, TowerLevel};
pub fn integrate_log_tower(
expr: ExprId,
level: &TowerLevel,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let log_gen = level.generator; let h = match level.kind {
ExtensionKind::Log { h } => h,
_ => {
return Err(IntegrationError::NotImplemented(
"integrate_log_tower called with non-Log level".to_string(),
))
}
};
let coeffs = decompose_as_log_poly(expr, log_gen, pool).ok_or_else(|| {
IntegrationError::NotImplemented(format!(
"could not decompose {} as a polynomial in log({})",
pool.display(expr),
pool.display(h)
))
})?;
let coeffs = trim_zero_coeffs(coeffs, pool);
integrate_log_poly(&coeffs, log_gen, h, var, pool, log)
}
fn integrate_log_poly(
coeffs: &[ExprId],
log_gen: ExprId, h: ExprId, var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let zero = pool.integer(0_i32);
if coeffs.is_empty() {
return Ok(zero);
}
let n = coeffs.len() - 1;
if n == 0 {
let c0 = coeffs[0];
return integrate_base(c0, var, pool, log);
}
integrate_log_poly_recursive(coeffs, log_gen, h, var, pool, log)
}
fn integrate_log_poly_recursive(
coeffs: &[ExprId],
log_gen: ExprId,
h: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let zero = pool.integer(0_i32);
if coeffs.is_empty() {
return Ok(zero);
}
let n = find_top_degree(coeffs, pool);
if n == 0 {
return integrate_base(coeffs[0], var, pool, log);
}
let c_n = simplify(coeffs[n], pool).value;
if is_zero(c_n, pool) {
return integrate_log_poly_recursive(&coeffs[..n], log_gen, h, var, pool, log);
}
let p_n_raw = integrate_base(c_n, var, pool, log)?;
let p_n = simplify(p_n_raw, pool).value;
let log_n = log_power(log_gen, n as i64, pool);
let term_top = if is_one(p_n, pool) {
log_n
} else {
pool.mul(vec![p_n, log_n])
};
let h_prime = differentiate_sym(h, var, pool)?;
let h_prime_expr = simplify(h_prime, pool).value;
let h_prime_over_h = if is_one(h, pool) {
pool.integer(0_i32) } else {
let raw = pool.mul(vec![h_prime_expr, pool.pow(h, pool.integer(-1_i32))]);
simplify(raw, pool).value
};
let neg_n = pool.integer(-(n as i64));
let correction_raw = pool.mul(vec![neg_n, p_n, h_prime_over_h]);
let correction = simplify(correction_raw, pool).value;
let mut new_coeffs: Vec<ExprId> = if n > 0 { coeffs[..n].to_vec() } else { vec![] };
if new_coeffs.is_empty() {
new_coeffs.push(zero);
}
let old_cn1 = new_coeffs[n - 1];
let combined = pool.add(vec![old_cn1, correction]);
new_coeffs[n - 1] = simplify(combined, pool).value;
let rest = integrate_log_poly_recursive(&new_coeffs, log_gen, h, var, pool, log)?;
let result = if is_zero(rest, pool) {
term_top
} else {
pool.add(vec![term_top, rest])
};
let simplified = simplify(result, pool);
*log = log.clone().merge(simplified.log);
log.push(RewriteStep::simple("risch_log_ibp", c_n, simplified.value));
Ok(simplified.value)
}
fn integrate_base(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let expr = crate::simplify::engine::simplify(expr, pool).value;
if is_zero(expr, pool) {
return Ok(pool.integer(0_i32));
}
let mut inner_log = DerivationLog::new();
let result = crate::integrate::engine::integrate_raw(expr, var, pool, &mut inner_log)?;
let result = crate::simplify::engine::simplify(result, pool).value;
*log = log.clone().merge(inner_log);
Ok(result)
}
fn differentiate_sym(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<ExprId, IntegrationError> {
use crate::diff::diff;
match diff(expr, var, pool) {
Ok(d) => Ok(d.value),
Err(e) => Err(IntegrationError::NotImplemented(format!(
"could not differentiate {}: {e}",
pool.display(expr)
))),
}
}
fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 0)
}
fn is_one(expr: ExprId, pool: &ExprPool) -> bool {
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 1)
}
fn log_power(log_gen: ExprId, n: i64, pool: &ExprPool) -> ExprId {
match n {
0 => pool.integer(1_i32),
1 => log_gen,
_ => pool.pow(log_gen, pool.integer(n)),
}
}
fn find_top_degree(coeffs: &[ExprId], pool: &ExprPool) -> usize {
for k in (0..coeffs.len()).rev() {
if !is_zero(coeffs[k], pool) {
return k;
}
}
0
}
fn trim_zero_coeffs(mut coeffs: Vec<ExprId>, pool: &ExprPool) -> Vec<ExprId> {
while coeffs.last().is_some_and(|&c| is_zero(c, pool)) {
coeffs.pop();
}
if coeffs.is_empty() {
coeffs.push(pool.integer(0_i32));
}
coeffs
}
pub fn needs_log_risch(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
needs_log_risch_inner(expr, var, pool)
}
fn needs_log_risch_inner(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
match pool.get(expr) {
ExprData::Pow { base, exp } => {
if let ExprData::Func { ref name, ref args } = pool.get(base) {
if name == "log" && args.len() == 1 {
if let ExprData::Integer(n) = pool.get(exp) {
if n.0 >= 2 {
return true;
}
}
}
}
needs_log_risch_inner(base, var, pool) || needs_log_risch_inner(exp, var, pool)
}
ExprData::Mul(args) => {
let has_log = args.iter().any(|&a| is_log_expr(a, pool));
let has_nonconstant = args
.iter()
.any(|&a| !is_free_of_var(a, var, pool) && !is_log_expr(a, pool));
if has_log && has_nonconstant {
return true;
}
args.iter().any(|&a| needs_log_risch_inner(a, var, pool))
}
ExprData::Add(args) => args.iter().any(|&a| needs_log_risch_inner(a, var, pool)),
_ => false,
}
}
fn is_log_expr(expr: ExprId, pool: &ExprPool) -> bool {
matches!(pool.get(expr), ExprData::Func { ref name, ref args } if name == "log" && args.len() == 1)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool() -> ExprPool {
ExprPool::new()
}
#[test]
fn log_x_squared() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let integrand = pool.pow(log_x, pool.integer(2_i32));
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1, "should find exactly one log generator");
let level = &gens[0];
let mut inner_log = DerivationLog::new();
let result = integrate_log_tower(integrand, level, x, &pool, &mut inner_log);
assert!(
result.is_ok(),
"∫ log(x)² dx should be elementary: {:?}",
result
);
let antideriv = result.unwrap();
let s = pool.display(antideriv).to_string();
assert!(s.contains("log"), "result should contain log: {}", s);
}
#[test]
fn x_times_log_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let integrand = pool.mul(vec![x, log_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut inner_log = DerivationLog::new();
let result = integrate_log_tower(integrand, level, x, &pool, &mut inner_log);
assert!(
result.is_ok(),
"∫ x·log(x) dx should be elementary: {:?}",
result
);
}
#[test]
fn log_x_alone() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
use super::super::tower::find_generators;
let gens = find_generators(log_x, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut inner_log = DerivationLog::new();
let result = integrate_log_tower(log_x, level, x, &pool, &mut inner_log);
assert!(
result.is_ok(),
"∫ log(x) dx should be elementary: {:?}",
result
);
let s = pool.display(result.unwrap()).to_string();
assert!(s.contains("log"), "result should contain log: {}", s);
}
#[test]
fn needs_log_risch_detection() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
assert!(!needs_log_risch(log_x, x, &pool));
let log2 = pool.pow(log_x, pool.integer(2_i32));
assert!(needs_log_risch(log2, x, &pool));
let x_log_x = pool.mul(vec![x, log_x]);
assert!(needs_log_risch(x_log_x, x, &pool));
}
}