use crate::deriv::log::{DerivationLog, RewriteStep};
use crate::integrate::engine::IntegrationError;
use crate::kernel::{ExprId, ExprPool};
use crate::simplify::engine::simplify;
use super::alg_field::{AlgElem, AlgExtension, RatFn};
use super::alg_rde::solve_alg_rde;
use super::number_field::{KElem, KPoly, NumberField};
use super::poly_rde::{
apply_const, contains_subexpr, degree, expr_to_qpoly, is_free_of_var, poly_add, poly_deriv,
poly_mul, poly_one, poly_scale, poly_zero, qpoly_to_expr, rational_to_expr, solve_poly_rde,
solve_poly_rde_k, split_const_factor, trim, QPoly,
};
use super::rational_rde::{
expr_to_qrational, poly_gcd, poly_sub, solve_rational_rde, solve_rational_rde_generalized,
solve_rational_rde_k,
};
use super::tower::{decompose_wrt_exp, poly_degree, TowerLevel};
pub fn integrate_exp_tower(
expr: ExprId,
level: &TowerLevel,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let exp_gen = level.generator; let eta = match level.kind {
super::tower::ExtensionKind::Exp { eta } => eta,
_ => {
return Err(IntegrationError::NotImplemented(
"integrate_exp_tower called with non-Exp level".to_string(),
))
}
};
let (rational_part, exp_terms) = decompose_wrt_exp(expr, exp_gen, pool);
let zero = pool.integer(0_i32);
let int_rational = if is_zero(rational_part, pool) {
zero
} else {
let mut inner_log = DerivationLog::new();
match crate::integrate::engine::integrate_raw(rational_part, var, pool, &mut inner_log) {
Ok(r) => {
*log = log.clone().merge(inner_log);
r
}
Err(e) => return Err(e),
}
};
if exp_terms.is_empty() {
return Ok(int_rational);
}
let deta_expr = differentiate_poly(eta, var, pool)?;
let deta = match expr_to_qpoly(deta_expr, var, pool) {
Some(p) => p,
None => {
if let Some((deta_num, deta_den)) = expr_to_qrational(deta_expr, var, pool) {
return integrate_exp_tower_rational_eta(
int_rational,
&exp_terms,
eta,
exp_gen,
deta_num,
deta_den,
var,
pool,
log,
);
}
return try_transcendental_eta_v1(
int_rational,
&exp_terms,
eta,
exp_gen,
deta_expr,
var,
pool,
log,
);
}
};
let mut result_terms: Vec<ExprId> = Vec::new();
if !is_zero(int_rational, pool) {
result_terms.push(int_rational);
}
for (c_expr, k) in &exp_terms {
let k = *k;
let term_result =
integrate_single_exp_term(*c_expr, k, &deta, deta_expr, eta, exp_gen, var, pool, log)?;
result_terms.push(term_result);
}
let raw = match result_terms.len() {
0 => zero,
1 => result_terms[0],
_ => pool.add(result_terms),
};
let simplified = simplify(raw, pool);
*log = log.clone().merge(simplified.log);
log.push(RewriteStep::simple("risch_exp", expr, simplified.value));
Ok(simplified.value)
}
#[allow(clippy::too_many_arguments)]
fn integrate_exp_tower_rational_eta(
int_rational: ExprId,
exp_terms: &[(ExprId, i64)],
eta: ExprId,
exp_gen: ExprId,
deta_num: QPoly, deta_den: QPoly, var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let zero = pool.integer(0_i32);
let mut result_terms: Vec<ExprId> = Vec::new();
if !is_zero(int_rational, pool) {
result_terms.push(int_rational);
}
for (c_expr, k) in exp_terms {
let k = *k;
let exp_k_eta = build_exp_k_eta(k, eta, exp_gen, pool);
let (k_const, c_rest) = split_const_factor(*c_expr, var, pool);
let k_rat = rug::Rational::from(k);
let f_num = poly_scale(&deta_num, &k_rat);
let f_den = deta_den.clone();
let non_elementary = || {
IntegrationError::NonElementary(format!(
"the Risch DE v'(x) + {k}·η'(x)·v(x) = {}(x) has no rational solution; \
the integrand ∫ {} · exp(η)^{k} dx is not an elementary function \
(η = {})",
pool.display(*c_expr),
pool.display(*c_expr),
pool.display(eta),
))
};
if let Some(c_poly) = expr_to_qpoly(c_rest, var, pool) {
match solve_rational_rde_generalized(&f_num, &f_den, &c_poly, &poly_one()) {
Some((v_num, v_den)) => {
let v_expr = build_rational(&v_num, &v_den, var, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_rational_eta",
*c_expr,
result,
));
result_terms.push(result);
continue;
}
None => return Err(non_elementary()),
}
}
if let Some((c_num, c_den)) = expr_to_qrational(c_rest, var, pool) {
match solve_rational_rde_generalized(&f_num, &f_den, &c_num, &c_den) {
Some((v_num, v_den)) => {
let v_expr = build_rational(&v_num, &v_den, var, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_rational_eta_rational_coeff",
*c_expr,
result,
));
result_terms.push(result);
continue;
}
None => return Err(non_elementary()),
}
}
return Err(IntegrationError::NotImplemented(format!(
"coefficient {} of exp(η)^{k} is not a rational function in {}; \
algebraic/mixed coefficients with rational exponents are not yet supported",
pool.display(*c_expr),
pool.display(var),
)));
}
let raw = match result_terms.len() {
0 => zero,
1 => result_terms[0],
_ => pool.add(result_terms),
};
let simplified = simplify(raw, pool);
*log = log.clone().merge(simplified.log);
log.push(RewriteStep::simple(
"risch_exp_rational_eta",
pool.integer(0_i32),
simplified.value,
));
Ok(simplified.value)
}
#[allow(clippy::too_many_arguments)]
fn try_transcendental_eta_v1(
int_rational: ExprId,
exp_terms: &[(ExprId, i64)],
eta: ExprId,
exp_gen: ExprId,
deta_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let zero = pool.integer(0_i32);
let deta_simplified = simplify(deta_expr, pool).value;
let mut result_terms: Vec<ExprId> = Vec::new();
if !is_zero(int_rational, pool) {
result_terms.push(int_rational);
}
for (c_expr, k) in exp_terms {
let k = *k;
let (k_const, c_rest) = split_const_factor(*c_expr, var, pool);
let c_rest_simplified = simplify(c_rest, pool).value;
let k_deta_simplified = if k == 1 {
deta_simplified
} else {
simplify(pool.mul(vec![pool.integer(k as i32), deta_expr]), pool).value
};
if c_rest_simplified != k_deta_simplified {
if c_is_rational_in_theta(c_rest, deta_simplified, pool) {
return Err(IntegrationError::NonElementary(format!(
"∫ {} · exp(kη) dx: coefficient is rational (not polynomial) \
in the inner exp generator; non-elementary by Hermite \
reduction / pole-order argument (Bronstein §6.2)",
pool.display(*c_expr),
)));
}
let exp_k_eta = build_exp_k_eta(k, eta, exp_gen, pool);
match lower_tower_poly_cascade(
c_rest,
k,
deta_simplified,
exp_k_eta,
k_const,
*c_expr,
var,
pool,
log,
) {
Some(Ok(r)) => {
result_terms.push(r);
continue;
}
Some(Err(e)) => return Err(e),
None => {
return Err(IntegrationError::NotImplemented(format!(
"exponent derivative η'(x) = {} is transcendental and \
the lower-tower polynomial cascade did not apply for {}",
pool.display(deta_expr),
pool.display(*c_expr),
)))
}
}
}
let exp_k_eta = build_exp_k_eta(k, eta, exp_gen, pool);
let result = apply_const(k_const, exp_k_eta, pool);
log.push(RewriteStep::simple("risch_exp_nested_v1", *c_expr, result));
result_terms.push(result);
}
let raw = match result_terms.len() {
0 => zero,
1 => result_terms[0],
_ => pool.add(result_terms),
};
let simplified = simplify(raw, pool);
*log = log.clone().merge(simplified.log);
log.push(RewriteStep::simple(
"risch_exp_transcendental_eta",
pool.integer(0_i32),
simplified.value,
));
Ok(simplified.value)
}
fn c_is_rational_in_theta(c_rest: ExprId, theta_inner: ExprId, pool: &ExprPool) -> bool {
use super::tower::decompose_wrt_exp;
let (c0, exp_terms) = decompose_wrt_exp(c_rest, theta_inner, pool);
if contains_subexpr(c0, theta_inner, pool) {
return true;
}
for (coeff, j) in &exp_terms {
if *j < 0 {
return true;
}
if contains_subexpr(*coeff, theta_inner, pool) {
return true;
}
}
false
}
#[allow(clippy::too_many_arguments)]
fn lower_tower_poly_cascade(
c_rest: ExprId,
k: i64,
theta_inner: ExprId, exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
use super::tower::decompose_wrt_exp;
let inner_eta = match pool.get(theta_inner) {
crate::kernel::ExprData::Func { ref name, ref args }
if name == "exp" && args.len() == 1 =>
{
args[0]
}
_ => return None, };
let d_inner = match crate::diff::diff(theta_inner, var, pool) {
Ok(d) => simplify(d.value, pool).value,
Err(_) => return None,
};
if d_inner != theta_inner {
return None;
}
let _ = inner_eta;
let (c0, exp_terms) = decompose_wrt_exp(c_rest, theta_inner, pool);
if exp_terms.is_empty() {
return Some(Err(IntegrationError::NonElementary(format!(
"∫ {} · exp(kη) dx: coefficient has no inner-tower exp factor; \
non-elementary by degree bound (Bronstein §5)",
pool.display(c_expr),
))));
}
let cap_n = exp_terms.iter().map(|(_, j)| *j).max().unwrap_or(0);
if cap_n <= 0 {
return None; }
let cap_n = cap_n as usize;
let zero = pool.integer(0_i32);
let mut c_coeffs: Vec<ExprId> = vec![zero; cap_n + 1];
c_coeffs[0] = c0;
for (coeff, j) in &exp_terms {
let j = *j;
if j >= 1 && (j as usize) <= cap_n {
let old = c_coeffs[j as usize];
let combined = if is_zero(old, pool) {
*coeff
} else {
pool.add(vec![old, *coeff])
};
c_coeffs[j as usize] = simplify(combined, pool).value;
}
}
let mut v_coeffs: Vec<ExprId> = vec![zero; cap_n];
let c_top = simplify(c_coeffs[cap_n], pool).value;
v_coeffs[cap_n - 1] = if k == 1 {
c_top
} else {
let k_inv = pool.pow(pool.integer(k as i32), pool.integer(-1_i32));
simplify(pool.mul(vec![c_top, k_inv]), pool).value
};
for j in (1..cap_n).rev() {
let vj = v_coeffs[j];
let dvj = match crate::diff::diff(vj, var, pool) {
Ok(d) => simplify(d.value, pool).value,
Err(_) => return None,
};
let j_vj = simplify(pool.mul(vec![pool.integer(j as i32), vj]), pool).value;
let cj = simplify(c_coeffs[j], pool).value;
let neg1 = pool.integer(-1_i32);
let num = simplify(
pool.add(vec![
cj,
pool.mul(vec![neg1, dvj]),
pool.mul(vec![neg1, j_vj]),
]),
pool,
)
.value;
v_coeffs[j - 1] = if k == 1 {
num
} else {
let k_inv = pool.pow(pool.integer(k as i32), pool.integer(-1_i32));
simplify(pool.mul(vec![num, k_inv]), pool).value
};
}
let dv0 = match crate::diff::diff(v_coeffs[0], var, pool) {
Ok(d) => simplify(d.value, pool).value,
Err(_) => return None,
};
let residual = simplify(
pool.add(vec![dv0, pool.mul(vec![pool.integer(-1_i32), c_coeffs[0]])]),
pool,
)
.value;
if !is_zero(residual, pool) {
return Some(Err(IntegrationError::NonElementary(format!(
"∫ {} · exp(kη) dx: lower-tower cascade consistency check failed; \
non-elementary by denominator bound (Bronstein §6.2)",
pool.display(c_expr),
))));
}
let mut v_terms: Vec<ExprId> = Vec::new();
for (j, &vj) in v_coeffs.iter().enumerate() {
let vj_s = simplify(vj, pool).value;
if is_zero(vj_s, pool) {
continue;
}
let theta_j = match j {
0 => vj_s,
1 => {
if is_one(vj_s, pool) {
theta_inner
} else {
pool.mul(vec![vj_s, theta_inner])
}
}
_ => {
let theta_pow = pool.pow(theta_inner, pool.integer(j as i32));
if is_one(vj_s, pool) {
theta_pow
} else {
pool.mul(vec![vj_s, theta_pow])
}
}
};
v_terms.push(theta_j);
}
let v_expr = match v_terms.len() {
0 => zero,
1 => v_terms[0],
_ => pool.add(v_terms),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple("risch_exp_lower_tower", c_expr, result));
Some(Ok(result))
}
#[allow(clippy::too_many_arguments)]
fn integrate_single_exp_term(
c_expr: ExprId,
k: i64,
deta: &[rug::Rational], deta_expr: ExprId, eta: ExprId, exp_gen: ExprId, var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
let exp_k_eta = build_exp_k_eta(k, eta, exp_gen, pool);
let (k_const, c_rest) = split_const_factor(c_expr, var, pool);
let non_elementary = || {
IntegrationError::NonElementary(format!(
"the Risch DE v'(x) + {}·({}(x))·v(x) = {}(x) has no rational solution;\n\
the integrand ∫ {} · exp(η)^{} dx is not an elementary function\n\
(η = {})",
k,
pool.display(deta_expr),
pool.display(c_expr),
pool.display(c_expr),
k,
pool.display(eta),
))
};
if let Some(c_poly) = expr_to_qpoly(c_rest, var, pool) {
return match solve_poly_rde(k, deta, &c_poly) {
Some(v_poly) => {
let v_expr = qpoly_to_expr(&v_poly, var, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple("risch_exp_rde", c_expr, result));
Ok(result)
}
None => Err(non_elementary()),
};
}
if let Some((c_num, c_den)) = expr_to_qrational(c_rest, var, pool) {
let f = poly_scale(&deta.to_vec(), &rug::Rational::from(k));
return match solve_rational_rde(&f, &c_num, &c_den) {
Some((v_num, v_den)) => {
let v_expr = build_rational(&v_num, &v_den, var, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_rational",
c_expr,
result,
));
Ok(result)
}
None => Err(non_elementary()),
};
}
if let Some((d, sqrt_expr)) = detect_sqrt_field(c_rest, pool) {
let field = NumberField::new(vec![
rug::Rational::from(-d),
rug::Rational::from(0),
rug::Rational::from(1),
]);
let deta_k: KPoly = deta.iter().map(|r| field.from_rational(r)).collect();
if let Some(c_kpoly) = expr_to_kpoly(c_rest, var, sqrt_expr, &field, pool) {
return match solve_poly_rde_k(&field, k, &deta_k, &c_kpoly) {
Some(v) => {
let v_expr = kpoly_to_expr_alg(&v, var, sqrt_expr, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_algebraic",
c_expr,
result,
));
Ok(result)
}
None => Err(non_elementary()),
};
}
if let Some((c_num, c_den)) = expr_to_krational(c_rest, var, sqrt_expr, &field, pool) {
let f_k = field.kpoly_scale(&deta_k, &field.from_int(k));
return match solve_rational_rde_k(&field, &f_k, &c_num, &c_den) {
Some((v_num, v_den)) => {
let v_expr = build_krational(&v_num, &v_den, var, sqrt_expr, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_algebraic_rational",
c_expr,
result,
));
Ok(result)
}
None => Err(non_elementary()),
};
}
}
if let Some(ext) = detect_algebraic_extension(c_rest, pool) {
let (field, gens) = build_field_and_gens(&ext);
let deta_k: KPoly = deta.iter().map(|r| field.from_rational(r)).collect();
if let Some(c_kpoly) = expr_to_kpoly_general(c_rest, var, &gens, &field, pool) {
return match solve_poly_rde_k(&field, k, &deta_k, &c_kpoly) {
Some(v) => {
let v_expr = kpoly_to_expr_ext(&v, var, &ext, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_algebraic_ext",
c_expr,
result,
));
Ok(result)
}
None => Err(non_elementary()),
};
}
if let Some((c_num, c_den)) = expr_to_krational_general(c_rest, var, &gens, &field, pool) {
let f_k = field.kpoly_scale(&deta_k, &field.from_int(k));
return match solve_rational_rde_k(&field, &f_k, &c_num, &c_den) {
Some((v_num, v_den)) => {
let v_expr = build_krational_ext(&v_num, &v_den, var, &ext, pool);
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_rde_algebraic_rational_ext",
c_expr,
result,
));
Ok(result)
}
None => Err(non_elementary()),
};
}
}
if let Some(result) =
try_poly_in_log_rde(c_rest, k, deta, exp_k_eta, k_const, c_expr, var, pool, log)
{
return result;
}
if let Some(result) =
try_sqrt_poly_rde(c_rest, k, deta, exp_k_eta, k_const, c_expr, var, pool, log)
{
return result;
}
if let Some(result) =
try_radical_poly_rde(c_rest, k, deta, exp_k_eta, k_const, c_expr, var, pool, log)
{
return result;
}
if let Some(result) =
try_compositum_poly_rde(c_rest, k, deta, exp_k_eta, k_const, c_expr, var, pool, log)
{
return result;
}
if let Some(result) =
try_nested_radical_poly_rde(c_rest, k, deta, exp_k_eta, k_const, c_expr, var, pool, log)
{
return result;
}
Err(IntegrationError::NotImplemented(format!(
"coefficient {} of exp(η)^{} is not a polynomial or rational function over \
a supported algebraic extension; mixed/nested generators are not yet supported",
pool.display(c_expr),
k
)))
}
#[allow(clippy::too_many_arguments)]
fn try_poly_in_log_rde(
c_rest: ExprId,
k: i64,
deta: &[rug::Rational], exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId, var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
use super::rational_rde::{expr_to_qrational, solve_rational_rde};
use super::tower::{decompose_as_log_poly, find_generators};
let gens = find_generators(c_rest, var, pool);
let log_gens: Vec<_> = gens.iter().filter(|g| g.is_log()).collect();
if log_gens.len() != 1 || gens.iter().any(|g| g.is_exp()) {
return None; }
let log_level = log_gens[0];
let theta = log_level.generator; let h = log_level.argument();
let c_coeffs = decompose_as_log_poly(c_rest, theta, pool)?;
let n = c_coeffs.len().saturating_sub(1);
let h_prime = match crate::diff::diff(h, var, pool) {
Ok(d) => crate::simplify::engine::simplify(d.value, pool).value,
Err(_) => return None,
};
let h_prime_over_h = {
let raw = pool.mul(vec![h_prime, pool.pow(h, pool.integer(-1_i32))]);
crate::simplify::engine::simplify(raw, pool).value
};
let (hp_num, hp_den) = expr_to_qrational(h_prime_over_h, var, pool)?;
let f: Vec<rug::Rational> = deta
.iter()
.map(|r| r.clone() * rug::Rational::from(k))
.collect();
let mut a: Vec<Option<(Vec<rug::Rational>, Vec<rug::Rational>)>> = vec![None; n + 1];
let mut rhs_coeffs: Vec<ExprId> = c_coeffs;
for j in (0..=n).rev() {
let cj_expr = crate::simplify::engine::simplify(rhs_coeffs[j], pool).value;
let (cj_num, cj_den) = match expr_to_qrational(cj_expr, var, pool) {
Some(p) => p,
None => {
return Some(Err(IntegrationError::NotImplemented(format!(
"poly-in-log RDE: coefficient of θ^{j} is not rational in {}",
pool.display(var)
))))
}
};
match solve_rational_rde(&f, &cj_num, &cj_den) {
Some(sol) => {
a[j] = Some(sol);
if j > 0 {
if let Some((aj_num, aj_den)) = &a[j] {
use super::poly_rde::{poly_mul, poly_scale};
let j_rat = rug::Rational::from(j as i64);
let corr_num = poly_scale(&poly_mul(&hp_num, aj_num), &j_rat);
let corr_den = poly_mul(&hp_den, aj_den);
let corr_expr = {
let cn_expr = qpoly_to_expr(&corr_num, var, pool);
let cd_expr = qpoly_to_expr(&corr_den, var, pool);
pool.mul(vec![cn_expr, pool.pow(cd_expr, pool.integer(-1_i32))])
};
let old = rhs_coeffs[j - 1];
let neg_corr = pool.mul(vec![pool.integer(-1_i32), corr_expr]);
rhs_coeffs[j - 1] = pool.add(vec![old, neg_corr]);
}
}
}
None => {
return Some(Err(IntegrationError::NonElementary(format!(
"poly-in-log RDE: no rational solution at degree {j} for \
∫ {} · exp(η)^{k} dx",
pool.display(c_expr)
))));
}
}
}
let mut v_terms: Vec<ExprId> = Vec::new();
for (j, sol) in a.iter().enumerate() {
if let Some((vn, vd)) = sol {
let vn_t = trim(vn.clone());
let vd_t = trim(vd.clone());
if vn_t.is_empty() {
continue; }
let vn_expr = qpoly_to_expr(&vn_t, var, pool);
let coeff_expr = if vd_t == poly_one() {
vn_expr
} else {
let vd_expr = qpoly_to_expr(&vd_t, var, pool);
pool.mul(vec![vn_expr, pool.pow(vd_expr, pool.integer(-1_i32))])
};
let theta_j = match j {
0 => coeff_expr,
1 => pool.mul(vec![coeff_expr, theta]),
_ => pool.mul(vec![coeff_expr, pool.pow(theta, pool.integer(j as i32))]),
};
v_terms.push(theta_j);
}
}
let v_expr = match v_terms.len() {
0 => pool.integer(0_i32),
1 => v_terms[0],
_ => pool.add(v_terms),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple("risch_exp_poly_in_log", c_expr, result));
Some(Ok(result))
}
fn build_v_times_exp(v_expr: ExprId, exp_k_eta: ExprId, pool: &ExprPool) -> ExprId {
if is_one(v_expr, pool) {
exp_k_eta
} else {
pool.mul(vec![v_expr, exp_k_eta])
}
}
#[derive(Debug, Clone)]
pub(super) enum AlgebraicExtension {
SingleSqrt { d: i64, sqrt_expr: ExprId },
CompositumTwoSqrts {
a: i64,
b: i64,
sqrt_a: ExprId,
sqrt_b: ExprId,
},
NthRoot { n: i64, m: u32, root_expr: ExprId },
}
pub(super) fn detect_algebraic_extension(
expr: ExprId,
pool: &ExprPool,
) -> Option<AlgebraicExtension> {
let mut sqrts: Vec<(i64, ExprId)> = Vec::new();
let mut nth_roots: Vec<(i64, u32, ExprId)> = Vec::new();
scan_algebraic_gens(expr, pool, &mut sqrts, &mut nth_roots);
let mut dsqrts: Vec<(i64, ExprId)> = Vec::new();
for (d, e) in sqrts {
if !dsqrts.iter().any(|(dd, _)| *dd == d) {
dsqrts.push((d, e));
}
}
let mut droots: Vec<(i64, u32, ExprId)> = Vec::new();
for (n, m, e) in nth_roots {
if !droots.iter().any(|(nn, mm, _)| *nn == n && *mm == m) {
droots.push((n, m, e));
}
}
match (dsqrts.len(), droots.len()) {
(0, 0) => None,
(1, 0) => {
let (d, sqrt_expr) = dsqrts[0];
Some(AlgebraicExtension::SingleSqrt { d, sqrt_expr })
}
(2, 0) => {
let (mut a, mut sqrt_a) = dsqrts[0];
let (mut b, mut sqrt_b) = dsqrts[1];
if a > b {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut sqrt_a, &mut sqrt_b);
}
Some(AlgebraicExtension::CompositumTwoSqrts {
a,
b,
sqrt_a,
sqrt_b,
})
}
(0, 1) => {
let (n, m, root_expr) = droots[0];
Some(AlgebraicExtension::NthRoot { n, m, root_expr })
}
_ => None, }
}
fn scan_algebraic_gens(
expr: ExprId,
pool: &ExprPool,
sqrts: &mut Vec<(i64, ExprId)>,
nth_roots: &mut Vec<(i64, u32, ExprId)>,
) {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if args.len() == 1 => {
let arg = args[0];
match name.as_str() {
"sqrt" => {
if let ExprData::Integer(n) = pool.get(arg) {
if let Some(d) = n.0.to_i64() {
if d > 1 && !is_perfect_square(d) {
sqrts.push((d, expr));
}
}
}
}
"cbrt" => {
if let ExprData::Integer(n) = pool.get(arg) {
if let Some(d) = n.0.to_i64() {
if d > 1 && !is_perfect_mth_power(d, 3) {
nth_roots.push((d, 3, expr));
}
}
}
}
_ => {
scan_algebraic_gens(arg, pool, sqrts, nth_roots);
}
}
}
ExprData::Pow { base, exp } => {
if let (ExprData::Integer(n_int), ExprData::Rational(r)) =
(pool.get(base), pool.get(exp))
{
if let Some(d) = n_int.0.to_i64() {
if *r.0.numer() == 1 {
if let Some(m) = r.0.denom().to_u32() {
if m >= 2 && d > 1 && !is_perfect_mth_power(d, m) {
nth_roots.push((d, m, expr));
return; }
}
}
}
}
scan_algebraic_gens(base, pool, sqrts, nth_roots);
scan_algebraic_gens(exp, pool, sqrts, nth_roots);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_algebraic_gens(a, pool, sqrts, nth_roots);
}
}
_ => {}
}
}
fn is_perfect_mth_power(d: i64, m: u32) -> bool {
if d <= 0 || m == 0 {
return false;
}
if m == 1 {
return true;
}
let root = (d as f64).powf(1.0 / m as f64).round() as i64;
(root - 1..=root + 1).any(|k| k > 0 && k.pow(m) == d)
}
pub(super) fn build_field_and_gens(
ext: &AlgebraicExtension,
) -> (NumberField, Vec<(ExprId, KElem)>) {
match ext {
AlgebraicExtension::SingleSqrt { d, sqrt_expr } => {
let field = NumberField::new(vec![
rug::Rational::from(-d),
rug::Rational::from(0),
rug::Rational::from(1),
]);
let kelem = vec![rug::Rational::from(0), rug::Rational::from(1)];
(field, vec![(*sqrt_expr, kelem)])
}
AlgebraicExtension::CompositumTwoSqrts {
a,
b,
sqrt_a,
sqrt_b,
} => {
let a = *a;
let b = *b;
let field = NumberField::new(vec![
rug::Rational::from((a - b) * (a - b)),
rug::Rational::from(0),
rug::Rational::from(-2 * (a + b)),
rug::Rational::from(0),
rug::Rational::from(1),
]);
let two_ab = rug::Rational::from(2 * (a - b));
let kelem_a = vec![
rug::Rational::from(0),
rug::Rational::from(3 * a + b) / two_ab.clone(),
rug::Rational::from(0),
rug::Rational::from(-1) / two_ab.clone(),
];
let kelem_b = vec![
rug::Rational::from(0),
rug::Rational::from(-(a + 3 * b)) / two_ab.clone(),
rug::Rational::from(0),
rug::Rational::from(1) / two_ab,
];
(field, vec![(*sqrt_a, kelem_a), (*sqrt_b, kelem_b)])
}
AlgebraicExtension::NthRoot { n, m, root_expr } => {
let n = *n;
let m = *m;
let mut min_poly = vec![rug::Rational::from(0); m as usize + 1];
min_poly[0] = rug::Rational::from(-n);
min_poly[m as usize] = rug::Rational::from(1);
let field = NumberField::new(min_poly);
let mut kelem = vec![rug::Rational::from(0); m as usize];
kelem[1] = rug::Rational::from(1);
(field, vec![(*root_expr, kelem)])
}
}
}
fn expr_to_kpoly_general(
expr: ExprId,
var: ExprId,
gens: &[(ExprId, KElem)],
field: &NumberField,
pool: &ExprPool,
) -> Option<KPoly> {
use crate::kernel::ExprData;
if expr == var {
return Some(vec![NumberField::k_zero(), field.from_int(1)]);
}
for (gen_expr, kelem) in gens {
if expr == *gen_expr {
return Some(vec![kelem.clone()]);
}
}
match pool.get(expr) {
ExprData::Integer(n) => Some(vec![
field.from_rational(&rug::Rational::from(n.0.to_i64()?))
]),
ExprData::Rational(r) => Some(vec![field.from_rational(&r.0)]),
ExprData::Add(args) => {
let mut acc: KPoly = Vec::new();
for a in &args {
let p = expr_to_kpoly_general(*a, var, gens, field, pool)?;
acc = field.kpoly_add(&acc, &p);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc: KPoly = vec![field.from_int(1)];
for a in &args {
let p = expr_to_kpoly_general(*a, var, gens, field, pool)?;
acc = field.kpoly_mul(&acc, &p);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
let n = match pool.get(exp) {
ExprData::Integer(n) => n.0.to_i64()?,
_ => return None,
};
let b = expr_to_kpoly_general(base, var, gens, field, pool)?;
if n >= 0 {
let mut acc: KPoly = vec![field.from_int(1)];
for _ in 0..n {
acc = field.kpoly_mul(&acc, &b);
}
Some(acc)
} else {
if NumberField::kdeg(&b) != 0 {
return None;
}
let inv = field.inv(&b[0])?;
let mut acc = field.from_int(1);
for _ in 0..(-n) {
acc = field.mul(&acc, &inv);
}
Some(vec![acc])
}
}
_ => None,
}
}
pub(super) fn expr_to_krational_general(
expr: ExprId,
var: ExprId,
gens: &[(ExprId, KElem)],
field: &NumberField,
pool: &ExprPool,
) -> Option<(KPoly, KPoly)> {
use crate::kernel::ExprData;
let one: KPoly = vec![field.from_int(1)];
if expr == var {
return Some((vec![NumberField::k_zero(), field.from_int(1)], one));
}
for (gen_expr, kelem) in gens {
if expr == *gen_expr {
return Some((vec![kelem.clone()], one.clone()));
}
}
match pool.get(expr) {
ExprData::Integer(n) => Some((
vec![field.from_rational(&rug::Rational::from(n.0.to_i64()?))],
one,
)),
ExprData::Rational(r) => Some((vec![field.from_rational(&r.0)], one)),
ExprData::Add(args) => {
let mut acc: (KPoly, KPoly) = (Vec::new(), one.clone());
for a in &args {
let term = expr_to_krational_general(*a, var, gens, field, pool)?;
acc = krat_add(field, &acc, &term);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc: (KPoly, KPoly) = (one.clone(), one.clone());
for a in &args {
let factor = expr_to_krational_general(*a, var, gens, field, pool)?;
acc = krat_mul(field, &acc, &factor);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
let n = match pool.get(exp) {
ExprData::Integer(n) => n.0.to_i64()?,
_ => return None,
};
let (bn, bd) = expr_to_krational_general(base, var, gens, field, pool)?;
if n >= 0 {
Some((
field.kpoly_pow(&bn, n as u32),
field.kpoly_pow(&bd, n as u32),
))
} else {
let m = (-n) as u32;
Some((field.kpoly_pow(&bd, m), field.kpoly_pow(&bn, m)))
}
}
_ => None,
}
}
pub(super) fn kelem_to_expr_ext(e: &KElem, ext: &AlgebraicExtension, pool: &ExprPool) -> ExprId {
match ext {
AlgebraicExtension::SingleSqrt { sqrt_expr, .. } => kelem_to_expr(e, *sqrt_expr, pool),
AlgebraicExtension::NthRoot { root_expr, .. } => {
let mut terms: Vec<ExprId> = Vec::new();
for (i, c) in e.iter().enumerate() {
if *c == 0 {
continue;
}
let c_expr = rational_to_expr(c, pool);
let term = match i {
0 => c_expr,
1 => {
if *c == 1 {
*root_expr
} else {
pool.mul(vec![c_expr, *root_expr])
}
}
_ => {
let xp = pool.pow(*root_expr, pool.integer(i as i32));
if *c == 1 {
xp
} else {
pool.mul(vec![c_expr, xp])
}
}
};
terms.push(term);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
AlgebraicExtension::CompositumTwoSqrts {
a,
b,
sqrt_a,
sqrt_b,
} => {
let a = *a;
let b = *b;
let c = |i: usize| -> rug::Rational {
e.get(i).cloned().unwrap_or_else(|| rug::Rational::from(0))
};
let coeff_1 = c(0) + c(2).clone() * rug::Rational::from(a + b);
let coeff_sa = c(1) + c(3).clone() * rug::Rational::from(a + 3 * b);
let coeff_sb = c(1) + c(3).clone() * rug::Rational::from(3 * a + b);
let coeff_sab = c(2) * rug::Rational::from(2);
let sqrt_ab = pool.mul(vec![*sqrt_a, *sqrt_b]);
let mut terms: Vec<ExprId> = Vec::new();
if coeff_1 != 0 {
terms.push(rational_to_expr(&coeff_1, pool));
}
if coeff_sa != 0 {
let t = if coeff_sa == 1 {
*sqrt_a
} else {
pool.mul(vec![rational_to_expr(&coeff_sa, pool), *sqrt_a])
};
terms.push(t);
}
if coeff_sb != 0 {
let t = if coeff_sb == 1 {
*sqrt_b
} else {
pool.mul(vec![rational_to_expr(&coeff_sb, pool), *sqrt_b])
};
terms.push(t);
}
if coeff_sab != 0 {
let t = if coeff_sab == 1 {
sqrt_ab
} else {
pool.mul(vec![rational_to_expr(&coeff_sab, pool), sqrt_ab])
};
terms.push(t);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
}
}
pub(super) fn kpoly_to_expr_ext(
p: &KPoly,
var: ExprId,
ext: &AlgebraicExtension,
pool: &ExprPool,
) -> ExprId {
let mut terms: Vec<ExprId> = Vec::new();
for (i, c) in p.iter().enumerate() {
if NumberField::is_zero(c) {
continue;
}
let c_expr = kelem_to_expr_ext(c, ext, pool);
let term = match i {
0 => c_expr,
1 => {
if is_one(c_expr, pool) {
var
} else {
pool.mul(vec![c_expr, var])
}
}
_ => {
let xp = pool.pow(var, pool.integer(i as i32));
if is_one(c_expr, pool) {
xp
} else {
pool.mul(vec![c_expr, xp])
}
}
};
terms.push(term);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
pub(super) fn build_krational_ext(
num: &KPoly,
den: &KPoly,
var: ExprId,
ext: &AlgebraicExtension,
pool: &ExprPool,
) -> ExprId {
let num_expr = kpoly_to_expr_ext(num, var, ext, pool);
let den_is_one = NumberField::kdeg(den) <= 0
&& den
.first()
.map(|c| trim(c.clone()) == vec![rug::Rational::from(1)])
.unwrap_or(true);
if den_is_one {
return num_expr;
}
let den_expr = kpoly_to_expr_ext(den, var, ext, pool);
pool.mul(vec![num_expr, pool.pow(den_expr, pool.integer(-1_i32))])
}
fn detect_sqrt_field(expr: ExprId, pool: &ExprPool) -> Option<(i64, ExprId)> {
let mut found: Vec<(i64, ExprId)> = Vec::new();
scan_sqrt(expr, pool, &mut found);
let mut distinct: Vec<(i64, ExprId)> = Vec::new();
for (d, e) in found {
if !distinct.iter().any(|(dd, _)| *dd == d) {
distinct.push((d, e));
}
}
match distinct.len() {
1 => Some(distinct[0]),
_ => None,
}
}
fn scan_sqrt(expr: ExprId, pool: &ExprPool, out: &mut Vec<(i64, ExprId)>) {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
if let ExprData::Integer(n) = pool.get(args[0]) {
if let Some(d) = n.0.to_i64() {
if d > 1 && !is_perfect_square(d) {
out.push((d, expr));
}
}
}
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_sqrt(a, pool, out);
}
}
ExprData::Pow { base, exp } => {
scan_sqrt(base, pool, out);
scan_sqrt(exp, pool, out);
}
ExprData::Func { ref args, .. } => {
for &a in args {
scan_sqrt(a, pool, out);
}
}
_ => {}
}
}
fn is_perfect_square(d: i64) -> bool {
if d < 0 {
return false;
}
let r = (d as f64).sqrt() as i64;
(r - 1..=r + 1).any(|c| c >= 0 && c * c == d)
}
fn expr_to_kpoly(
expr: ExprId,
var: ExprId,
sqrt_expr: ExprId,
field: &NumberField,
pool: &ExprPool,
) -> Option<KPoly> {
use crate::kernel::ExprData;
if expr == var {
return Some(vec![NumberField::k_zero(), field.from_int(1)]);
}
if expr == sqrt_expr {
return Some(vec![vec![rug::Rational::from(0), rug::Rational::from(1)]]);
}
match pool.get(expr) {
ExprData::Integer(n) => Some(vec![
field.from_rational(&rug::Rational::from(n.0.to_i64()?))
]),
ExprData::Rational(r) => Some(vec![field.from_rational(&r.0)]),
ExprData::Add(args) => {
let mut acc: KPoly = Vec::new();
for a in &args {
let p = expr_to_kpoly(*a, var, sqrt_expr, field, pool)?;
acc = field.kpoly_add(&acc, &p);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc: KPoly = vec![field.from_int(1)];
for a in &args {
let p = expr_to_kpoly(*a, var, sqrt_expr, field, pool)?;
acc = field.kpoly_mul(&acc, &p);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
let n = match pool.get(exp) {
ExprData::Integer(n) => n.0.to_i64()?,
_ => return None,
};
let b = expr_to_kpoly(base, var, sqrt_expr, field, pool)?;
if n >= 0 {
let mut acc: KPoly = vec![field.from_int(1)];
for _ in 0..n {
acc = field.kpoly_mul(&acc, &b);
}
Some(acc)
} else {
if NumberField::kdeg(&b) != 0 {
return None;
}
let inv = field.inv(&b[0])?;
let mut acc = field.from_int(1);
for _ in 0..(-n) {
acc = field.mul(&acc, &inv);
}
Some(vec![acc])
}
}
_ => None,
}
}
fn kelem_to_expr(e: &KElem, sqrt_expr: ExprId, pool: &ExprPool) -> ExprId {
let a = e.first().cloned().unwrap_or_else(|| rug::Rational::from(0));
let b = e.get(1).cloned().unwrap_or_else(|| rug::Rational::from(0));
let mut terms: Vec<ExprId> = Vec::new();
if a != 0 {
terms.push(rational_to_expr(&a, pool));
}
if b != 0 {
let bt = if b == 1 {
sqrt_expr
} else {
pool.mul(vec![rational_to_expr(&b, pool), sqrt_expr])
};
terms.push(bt);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
fn kpoly_to_expr_alg(p: &KPoly, var: ExprId, sqrt_expr: ExprId, pool: &ExprPool) -> ExprId {
let mut terms: Vec<ExprId> = Vec::new();
for (i, c) in p.iter().enumerate() {
if NumberField::is_zero(c) {
continue;
}
let ce = kelem_to_expr(c, sqrt_expr, pool);
let term = match i {
0 => ce,
1 => {
if is_one(ce, pool) {
var
} else {
pool.mul(vec![ce, var])
}
}
_ => {
let xp = pool.pow(var, pool.integer(i as i32));
if is_one(ce, pool) {
xp
} else {
pool.mul(vec![ce, xp])
}
}
};
terms.push(term);
}
match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
}
}
fn expr_to_krational(
expr: ExprId,
var: ExprId,
sqrt_expr: ExprId,
field: &NumberField,
pool: &ExprPool,
) -> Option<(KPoly, KPoly)> {
use crate::kernel::ExprData;
let one: KPoly = vec![field.from_int(1)];
if expr == var {
return Some((vec![NumberField::k_zero(), field.from_int(1)], one));
}
if expr == sqrt_expr {
return Some((
vec![vec![rug::Rational::from(0), rug::Rational::from(1)]],
one,
));
}
match pool.get(expr) {
ExprData::Integer(n) => Some((
vec![field.from_rational(&rug::Rational::from(n.0.to_i64()?))],
one,
)),
ExprData::Rational(r) => Some((vec![field.from_rational(&r.0)], one)),
ExprData::Add(args) => {
let mut acc: (KPoly, KPoly) = (Vec::new(), one);
for a in &args {
let term = expr_to_krational(*a, var, sqrt_expr, field, pool)?;
acc = krat_add(field, &acc, &term);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc: (KPoly, KPoly) = (one.clone(), one);
for a in &args {
let factor = expr_to_krational(*a, var, sqrt_expr, field, pool)?;
acc = krat_mul(field, &acc, &factor);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
let n = match pool.get(exp) {
ExprData::Integer(n) => n.0.to_i64()?,
_ => return None,
};
let (bn, bd) = expr_to_krational(base, var, sqrt_expr, field, pool)?;
if n >= 0 {
Some((
field.kpoly_pow(&bn, n as u32),
field.kpoly_pow(&bd, n as u32),
))
} else {
let m = (-n) as u32;
if NumberField::kdeg(&bn) < 0 {
return None; }
Some((field.kpoly_pow(&bd, m), field.kpoly_pow(&bn, m)))
}
}
_ => None,
}
}
fn krat_add(field: &NumberField, a: &(KPoly, KPoly), b: &(KPoly, KPoly)) -> (KPoly, KPoly) {
let num = field.kpoly_add(&field.kpoly_mul(&a.0, &b.1), &field.kpoly_mul(&b.0, &a.1));
let den = field.kpoly_mul(&a.1, &b.1);
(num, den)
}
fn krat_mul(field: &NumberField, a: &(KPoly, KPoly), b: &(KPoly, KPoly)) -> (KPoly, KPoly) {
(field.kpoly_mul(&a.0, &b.0), field.kpoly_mul(&a.1, &b.1))
}
fn build_krational(
num: &KPoly,
den: &KPoly,
var: ExprId,
sqrt_expr: ExprId,
pool: &ExprPool,
) -> ExprId {
let num_expr = kpoly_to_expr_alg(num, var, sqrt_expr, pool);
let den_is_one = NumberField::kdeg(den) <= 0
&& den
.first()
.map(|c| trim(c.clone()) == vec![rug::Rational::from(1)])
.unwrap_or(true);
if den_is_one {
return num_expr;
}
let den_expr = kpoly_to_expr_alg(den, var, sqrt_expr, pool);
let den_inv = pool.pow(den_expr, pool.integer(-1_i32));
pool.mul(vec![num_expr, den_inv])
}
pub(super) fn build_rational(
num: &[rug::Rational],
den: &[rug::Rational],
var: ExprId,
pool: &ExprPool,
) -> ExprId {
let num_expr = qpoly_to_expr(&num.to_vec(), var, pool);
if super::poly_rde::degree(&den.to_vec()) <= 0 && den.first().map(|c| *c == 1).unwrap_or(true) {
return num_expr;
}
let den_expr = qpoly_to_expr(&den.to_vec(), var, pool);
let den_inv = pool.pow(den_expr, pool.integer(-1_i32));
pool.mul(vec![num_expr, den_inv])
}
fn differentiate_poly(
poly_expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<ExprId, IntegrationError> {
use crate::diff::diff;
match diff(poly_expr, var, pool) {
Ok(derived) => Ok(derived.value),
Err(e) => Err(IntegrationError::NotImplemented(format!(
"could not differentiate exponent: {e}"
))),
}
}
fn build_exp_k_eta(k: i64, eta: ExprId, exp_gen: ExprId, pool: &ExprPool) -> ExprId {
match k {
0 => pool.integer(1_i32),
1 => exp_gen,
_ => {
let k_expr = pool.integer(k);
let k_eta = pool.mul(vec![k_expr, eta]);
pool.func("exp", vec![k_eta])
}
}
}
fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 0)
}
fn is_one(expr: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
matches!(pool.get(expr), ExprData::Integer(n) if n.0 == 1)
}
fn contains_var_algebraic(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
!is_free_of_var(args[0], var, pool)
}
ExprData::Pow { base, exp } => {
if matches!(pool.get(exp), ExprData::Rational(_)) {
!is_free_of_var(base, var, pool)
} else {
contains_var_algebraic(base, var, pool)
}
}
ExprData::Add(args) | ExprData::Mul(args) => {
args.iter().any(|&a| contains_var_algebraic(a, var, pool))
}
_ => false,
}
}
fn detect_sqrt_of_poly(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(QPoly, ExprId)> {
let mut found: Vec<(QPoly, ExprId)> = Vec::new();
scan_sqrt_of_poly(expr, var, pool, &mut found);
let mut distinct: Vec<(QPoly, ExprId)> = Vec::new();
for (p, e) in found {
if !distinct
.iter()
.any(|(q, _)| trim(q.clone()) == trim(p.clone()))
{
distinct.push((p, e));
}
}
if distinct.len() == 1 {
Some(distinct.remove(0))
} else {
None
}
}
fn scan_sqrt_of_poly(expr: ExprId, var: ExprId, pool: &ExprPool, out: &mut Vec<(QPoly, ExprId)>) {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
if let Some(p) = expr_to_qpoly(args[0], var, pool) {
if degree(&p) >= 1 {
out.push((p, expr));
}
}
}
ExprData::Pow { base, exp } => {
if let ExprData::Rational(r) = pool.get(exp) {
if *r.0.denom() == 2 {
if let Some(p) = expr_to_qpoly(base, var, pool) {
if degree(&p) >= 1 {
out.push((p, expr));
return; }
}
}
}
scan_sqrt_of_poly(base, var, pool, out);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_sqrt_of_poly(a, var, pool, out);
}
}
_ => {}
}
}
fn get_radicand_expr(alpha: ExprId, pool: &ExprPool) -> Option<ExprId> {
use crate::kernel::ExprData;
match pool.get(alpha) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => Some(args[0]),
ExprData::Pow { base, exp } => {
if let ExprData::Rational(r) = pool.get(exp) {
if *r.0.numer() == 1 && *r.0.denom() == 2 {
return Some(base);
}
}
None
}
_ => None,
}
}
type QRat = (QPoly, QPoly);
fn qr_zero() -> QRat {
(poly_zero(), poly_one())
}
fn qr_one() -> QRat {
(poly_one(), poly_one())
}
fn qr_add(a: &QRat, b: &QRat) -> QRat {
(
poly_add(&poly_mul(&a.0, &b.1), &poly_mul(&b.0, &a.1)),
poly_mul(&a.1, &b.1),
)
}
fn qr_mul(a: &QRat, b: &QRat) -> QRat {
(poly_mul(&a.0, &b.0), poly_mul(&a.1, &b.1))
}
fn qr_scale_poly(a: &QRat, p: &QPoly) -> QRat {
(poly_mul(&a.0, p), a.1.clone())
}
type KPair = (QRat, QRat);
fn kp_zero() -> KPair {
(qr_zero(), qr_zero())
}
fn kp_one() -> KPair {
(qr_one(), qr_zero())
}
fn kp_alpha() -> KPair {
(qr_zero(), qr_one())
}
fn kp_from_qr(r: QRat) -> KPair {
(r, qr_zero())
}
fn kp_add(a: &KPair, b: &KPair) -> KPair {
(qr_add(&a.0, &b.0), qr_add(&a.1, &b.1))
}
fn kp_mul(a: &KPair, b: &KPair, p: &QPoly) -> KPair {
let c0 = qr_add(&qr_mul(&a.0, &b.0), &qr_scale_poly(&qr_mul(&a.1, &b.1), p));
let c1 = qr_add(&qr_mul(&a.0, &b.1), &qr_mul(&a.1, &b.0));
(c0, c1)
}
fn kp_inv(a: &KPair, p: &QPoly) -> Option<KPair> {
let a0sq = qr_mul(&a.0, &a.0);
let a1sq_p = qr_scale_poly(&qr_mul(&a.1, &a.1), p);
let norm_num = poly_sub(&poly_mul(&a0sq.0, &a1sq_p.1), &poly_mul(&a1sq_p.0, &a0sq.1));
let norm_den = poly_mul(&a0sq.1, &a1sq_p.1);
if trim(norm_num.clone()).is_empty() {
return None; }
let inv_c0 = (poly_mul(&a.0 .0, &norm_den), poly_mul(&a.0 .1, &norm_num));
let neg_a1_num = poly_scale(&a.1 .0, &rug::Rational::from(-1));
let inv_c1 = (
poly_mul(&neg_a1_num, &norm_den),
poly_mul(&a.1 .1, &norm_num),
);
Some((inv_c0, inv_c1))
}
fn kp_pow(a: &KPair, n: i64, p: &QPoly) -> Option<KPair> {
if n == 0 {
return Some(kp_one());
}
if n < 0 {
let inv = kp_inv(a, p)?;
return kp_pow(&inv, -n, p);
}
let mut acc = kp_one();
for _ in 0..n {
acc = kp_mul(&acc, a, p);
}
Some(acc)
}
fn decompose_over_alpha(
expr: ExprId,
alpha: ExprId,
p_poly: &QPoly,
var: ExprId,
pool: &ExprPool,
) -> Option<KPair> {
use crate::kernel::ExprData;
if !contains_subexpr(expr, alpha, pool) {
let r = expr_to_qrational(expr, var, pool)?;
return Some(kp_from_qr(r));
}
if expr == alpha {
return Some(kp_alpha());
}
match pool.get(expr) {
ExprData::Add(args) => {
let mut acc = kp_zero();
for &a in &args {
let t = decompose_over_alpha(a, alpha, p_poly, var, pool)?;
acc = kp_add(&acc, &t);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = kp_one();
for &a in &args {
let t = decompose_over_alpha(a, alpha, p_poly, var, pool)?;
acc = kp_mul(&acc, &t, p_poly);
}
Some(acc)
}
ExprData::Pow { base, exp } => {
match pool.get(exp) {
ExprData::Integer(n) => {
let n = n.0.to_i64()?;
let b = decompose_over_alpha(base, alpha, p_poly, var, pool)?;
kp_pow(&b, n, p_poly)
}
ExprData::Rational(r) => {
if get_radicand_expr(alpha, pool) == Some(base) && *r.0.denom() == 2 {
let m = r.0.numer().to_i64()?;
return kp_pow(&kp_alpha(), m, p_poly);
}
None
}
_ => None,
}
}
_ => None,
}
}
pub(super) fn detect_radical_generator(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Option<(usize, QPoly)> {
let mut found: Vec<(usize, QPoly)> = Vec::new();
scan_radical_generator(expr, var, pool, &mut found);
let mut distinct: Vec<(usize, QPoly)> = Vec::new();
for (n, p) in found {
if !distinct
.iter()
.any(|(m, q)| *m == n && trim(q.clone()) == trim(p.clone()))
{
distinct.push((n, p));
}
}
if distinct.len() == 1 {
Some(distinct.remove(0))
} else {
None
}
}
fn scan_radical_generator(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
out: &mut Vec<(usize, QPoly)>,
) {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
if let Some(p) = expr_to_qpoly(args[0], var, pool) {
if degree(&p) >= 1 {
out.push((2, p));
}
}
}
ExprData::Func { ref name, ref args } if name == "cbrt" && args.len() == 1 => {
if let Some(p) = expr_to_qpoly(args[0], var, pool) {
if degree(&p) >= 1 {
out.push((3, p));
}
}
}
ExprData::Pow { base, exp } => {
if let ExprData::Rational(r) = pool.get(exp) {
if let Some(den) = r.0.denom().to_i64() {
if den >= 2 {
if let Some(p) = expr_to_qpoly(base, var, pool) {
if degree(&p) >= 1 {
out.push((den as usize, p));
return; }
}
}
}
}
scan_radical_generator(base, var, pool, out);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_radical_generator(a, var, pool, out);
}
}
_ => {}
}
}
pub(super) fn decompose_over_alg_generator(
expr: ExprId,
n: usize,
p_radicand: &QPoly,
e: &AlgExtension,
var: ExprId,
pool: &ExprPool,
) -> Option<AlgElem> {
use crate::kernel::ExprData;
if let Some((num, den)) = expr_to_qrational(expr, var, pool) {
return Some(e.constant(RatFn::new(num, den)));
}
let as_generator_power = |base: ExprId, numr: i64, den: i64| -> Option<AlgElem> {
if den >= 2 && (n as i64) % den == 0 {
let bp = expr_to_qpoly(base, var, pool)?;
if trim(bp) == trim(p_radicand.clone()) {
return e.pow(&e.generator(), numr * (n as i64 / den));
}
}
None
};
match pool.get(expr) {
ExprData::Add(args) => {
let mut acc = e.from_int(0);
for &a in &args {
let t = decompose_over_alg_generator(a, n, p_radicand, e, var, pool)?;
acc = e.add(&acc, &t);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = e.from_int(1);
for &a in &args {
let t = decompose_over_alg_generator(a, n, p_radicand, e, var, pool)?;
acc = e.mul(&acc, &t);
}
Some(acc)
}
ExprData::Pow { base, exp } => match pool.get(exp) {
ExprData::Integer(m) => {
let m = m.0.to_i64()?;
let b = decompose_over_alg_generator(base, n, p_radicand, e, var, pool)?;
e.pow(&b, m)
}
ExprData::Rational(r) => {
as_generator_power(base, r.0.numer().to_i64()?, r.0.denom().to_i64()?)
}
_ => None,
},
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
as_generator_power(args[0], 1, 2)
}
ExprData::Func { ref name, ref args } if name == "cbrt" && args.len() == 1 => {
as_generator_power(args[0], 1, 3)
}
_ => None,
}
}
#[allow(dead_code)] fn decompose_radical(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Option<(AlgExtension, AlgElem)> {
let (n, p) = detect_radical_generator(expr, var, pool)?;
let e = AlgExtension::radical(n, &p);
let elem = decompose_over_alg_generator(expr, n, &p, &e, var, pool)?;
Some((e, elem))
}
#[allow(clippy::too_many_arguments)]
fn try_sqrt_poly_rde(
c_rest: ExprId,
k: i64,
deta: &[rug::Rational],
exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
let (p_poly, alpha) = detect_sqrt_of_poly(c_rest, var, pool)?;
let (c0, c1) = decompose_over_alpha(c_rest, alpha, &p_poly, var, pool)?;
let f: QPoly = deta
.iter()
.map(|r| r.clone() * rug::Rational::from(k))
.collect();
let ne = || {
IntegrationError::NonElementary(format!(
"the Risch DE over ℚ(x)(√({})) for ∫ {} · exp(η)^{k} dx \
has no rational solution",
pool.display(pool.func("placeholder", vec![])), pool.display(c_expr),
))
};
let a_sol = match solve_rational_rde(&f, &c0.0, &c0.1) {
Some(s) => s,
None => return Some(Err(ne())),
};
let p_prime = poly_deriv(&p_poly);
let f_eff_num = poly_add(
&poly_scale(&poly_mul(&f, &p_poly), &rug::Rational::from(2)),
&p_prime,
);
let f_eff_den = poly_scale(&p_poly, &rug::Rational::from(2));
let b_sol = match solve_rational_rde_generalized(&f_eff_num, &f_eff_den, &c1.0, &c1.1) {
Some(s) => s,
None => return Some(Err(ne())),
};
let a_expr = build_rational(&a_sol.0, &a_sol.1, var, pool);
let b_expr = build_rational(&b_sol.0, &b_sol.1, var, pool);
let a_zero = trim(a_sol.0.clone()).is_empty();
let b_zero = trim(b_sol.0.clone()).is_empty();
let v_expr = match (a_zero, b_zero) {
(true, true) => pool.integer(0_i32),
(true, false) => pool.mul(vec![b_expr, alpha]),
(false, true) => a_expr,
(false, false) => pool.add(vec![a_expr, pool.mul(vec![b_expr, alpha])]),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple("risch_exp_sqrt_poly", c_expr, result));
Some(Ok(result))
}
#[allow(clippy::too_many_arguments)]
fn try_radical_poly_rde(
c_rest: ExprId,
k: i64,
deta: &[rug::Rational],
exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
let (n, a) = detect_radical_generator(c_rest, var, pool)?;
if n < 3 {
return None; }
let a = trim(a);
if degree(&a) < 1 {
return None;
}
let a_prime = poly_deriv(&a);
if degree(&poly_gcd(&a, &a_prime)) >= 1 {
return None;
}
let e = AlgExtension::radical(n, &a);
let elem = decompose_over_alg_generator(c_rest, n, &a, &e, var, pool)?;
let f: QPoly = deta
.iter()
.map(|r| r.clone() * rug::Rational::from(k))
.collect();
let a_expr = qpoly_to_expr(&a, var, pool);
let ne = || {
IntegrationError::NonElementary(format!(
"the Risch DE over ℚ(x)({}^(1/{n})) for ∫ {} · exp(η)^{k} dx \
has no rational solution",
pool.display(a_expr),
pool.display(c_expr),
))
};
let mut terms: Vec<ExprId> = Vec::new();
for i in 0..n {
let (c_num, c_den) = match elem.get(i) {
Some(r) => (r.numer().clone(), r.denom().clone()),
None => (QPoly::new(), poly_one()),
};
if trim(c_num.clone()).is_empty() {
continue;
}
let (vn, vd) = if i == 0 {
match solve_rational_rde(&f, &c_num, &c_den) {
Some(s) => s,
None => return Some(Err(ne())),
}
} else {
let f_eff_num = poly_add(
&poly_scale(&poly_mul(&f, &a), &rug::Rational::from(n as i64)),
&poly_scale(&a_prime, &rug::Rational::from(i as i64)),
);
let f_eff_den = poly_scale(&a, &rug::Rational::from(n as i64));
match solve_rational_rde_generalized(&f_eff_num, &f_eff_den, &c_num, &c_den) {
Some(s) => s,
None => return Some(Err(ne())),
}
};
if trim(vn.clone()).is_empty() {
continue;
}
let v_expr = build_rational(&vn, &vd, var, pool);
if i == 0 {
terms.push(v_expr);
} else {
let yi = pool.pow(a_expr, pool.rational(i as i32, n as i32));
terms.push(pool.mul(vec![v_expr, yi]));
}
}
let v_expr = match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_radical_poly",
c_expr,
result,
));
Some(Ok(result))
}
#[allow(clippy::too_many_arguments)]
fn try_nested_radical_poly_rde(
c_rest: ExprId,
k: i64,
deta: &[rug::Rational],
exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
let (a, b) = detect_nested_radical(c_rest, var, pool)?;
let a2_minus_b = poly_sub(&poly_mul(&a, &a), &b);
let q_min = vec![
a2_minus_b,
poly_zero(),
poly_scale(&a, &rug::Rational::from(-2)),
poly_zero(),
poly_one(),
];
let e = AlgExtension::new(&q_min);
let sqrt_outer: AlgElem = vec![RatFn::int(0), RatFn::int(1)];
let sqrt_inner: AlgElem = e.reduce(&[
RatFn::new(poly_scale(&a, &rug::Rational::from(-1)), poly_one()),
RatFn::int(0),
RatFn::int(1),
]);
let g = decompose_over_nested_radical(c_rest, &a, &b, &sqrt_outer, &sqrt_inner, &e, var, pool)?;
let f_poly: QPoly = deta
.iter()
.map(|r| r.clone() * rug::Rational::from(k))
.collect();
let f = RatFn::from_poly(&f_poly);
let ne = || {
IntegrationError::NonElementary(format!(
"the coupled Risch DE over ℚ(x)(√(a+√b)) for ∫ {} · exp(η)^{k} dx \
has no rational solution",
pool.display(c_expr),
))
};
let y = match solve_alg_rde(&e, &f, &g) {
Some(y) => y,
None => return Some(Err(ne())),
};
let a_expr = qpoly_to_expr(&a, var, pool);
let b_expr = qpoly_to_expr(&b, var, pool);
let inner = pool.add(vec![a_expr, pool.func("sqrt", vec![b_expr])]);
let alpha = pool.func("sqrt", vec![inner]);
let mut terms: Vec<ExprId> = Vec::new();
for (j, yj) in y.iter().enumerate() {
if yj.numer().is_empty() {
continue;
}
let coeff = build_rational(yj.numer(), yj.denom(), var, pool);
let term = if j == 0 {
coeff
} else {
pool.mul(vec![coeff, pool.pow(alpha, pool.integer(j as i32))])
};
terms.push(term);
}
let v_expr = match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_nested_radical_poly",
c_expr,
result,
));
Some(Ok(result))
}
fn detect_nested_radical(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(QPoly, QPoly)> {
let mut found: Vec<(QPoly, QPoly)> = Vec::new();
scan_nested_radical(expr, var, pool, &mut found);
let mut distinct: Vec<(QPoly, QPoly)> = Vec::new();
for (a, b) in found {
if !distinct.iter().any(|(a2, b2)| {
trim(a2.clone()) == trim(a.clone()) && trim(b2.clone()) == trim(b.clone())
}) {
distinct.push((a, b));
}
}
if distinct.len() == 1 {
Some(distinct.remove(0))
} else {
None
}
}
fn scan_nested_radical(expr: ExprId, var: ExprId, pool: &ExprPool, out: &mut Vec<(QPoly, QPoly)>) {
use crate::kernel::ExprData;
let try_radicand = |radicand: ExprId, out: &mut Vec<(QPoly, QPoly)>| {
if let Some((a, b)) = match_a_plus_sqrt_b(radicand, var, pool) {
out.push((a, b));
}
};
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
try_radicand(args[0], out);
scan_nested_radical(args[0], var, pool, out);
}
ExprData::Pow { base, exp } => {
if let ExprData::Rational(r) = pool.get(exp) {
if r.0.denom().to_i64() == Some(2) {
try_radicand(base, out);
scan_nested_radical(base, var, pool, out);
return;
}
}
scan_nested_radical(base, var, pool, out);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_nested_radical(a, var, pool, out);
}
}
_ => {}
}
}
fn match_a_plus_sqrt_b(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(QPoly, QPoly)> {
use crate::kernel::ExprData;
let ExprData::Add(args) = pool.get(expr) else {
return None;
};
let mut a = poly_zero();
let mut inner_b: Option<QPoly> = None;
for &term in &args {
if let Some(p) = expr_to_qpoly(term, var, pool) {
a = poly_add(&a, &p);
continue;
}
let b = match pool.get(term) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
expr_to_qpoly(args[0], var, pool)
}
ExprData::Pow { base, exp } => match pool.get(exp) {
ExprData::Rational(r)
if r.0.denom().to_i64() == Some(2) && r.0.numer().to_i64() == Some(1) =>
{
expr_to_qpoly(base, var, pool)
}
_ => return None,
},
_ => return None,
}?;
if inner_b.is_some() {
return None; }
inner_b = Some(b);
}
let b = inner_b?;
if degree(&b) < 1 || is_perfect_square_poly(&b) {
return None;
}
Some((a, trim(b)))
}
fn is_perfect_square_poly(b: &QPoly) -> bool {
let b = trim(b.clone());
let d = degree(&b);
if d < 0 {
return true; }
if d % 2 != 0 {
return false; }
if let Some(s) = poly_sqrt_exact(&b) {
return trim(poly_mul(&s, &s)) == b;
}
false
}
fn poly_sqrt_exact(b: &QPoly) -> Option<QPoly> {
let b = trim(b.clone());
let d = degree(&b);
if d < 0 {
return Some(poly_zero());
}
if d % 2 != 0 {
return None;
}
let n = (d / 2) as usize; let lead_b = b[d as usize].clone();
let lead_s = rational_sqrt(&lead_b)?;
let mut s = vec![rug::Rational::from(0); n + 1];
s[n] = lead_s;
for k in (0..n).rev() {
let target = b
.get(n + k)
.cloned()
.unwrap_or_else(|| rug::Rational::from(0));
let mut rest = rug::Rational::from(0);
for i in 0..s.len() {
let jj = (n + k) as i64 - i as i64;
if jj < 0 || jj as usize >= s.len() {
continue;
}
let j = jj as usize;
if (i == n && j == k) || (i == k && j == n) {
continue; }
rest += s[i].clone() * s[j].clone();
}
let two_sn = rug::Rational::from(2) * s[n].clone();
s[k] = (target - rest) / two_sn;
}
Some(trim(s))
}
pub(super) fn rational_sqrt(r: &rug::Rational) -> Option<rug::Rational> {
if *r.numer() < 0 {
return None;
}
let num = r.numer().clone();
let den = r.denom().clone();
let sn = num.clone().sqrt(); let sd = den.clone().sqrt();
if sn.clone() * sn.clone() == num && sd.clone() * sd.clone() == den {
Some(rug::Rational::from((sn, sd)))
} else {
None
}
}
#[allow(clippy::too_many_arguments)]
fn decompose_over_nested_radical(
expr: ExprId,
a: &QPoly,
b: &QPoly,
sqrt_outer: &AlgElem,
sqrt_inner: &AlgElem,
e: &AlgExtension,
var: ExprId,
pool: &ExprPool,
) -> Option<AlgElem> {
use crate::kernel::ExprData;
if let Some((num, den)) = expr_to_qrational(expr, var, pool) {
return Some(e.constant(RatFn::new(num, den)));
}
let radical_in_basis = |radicand: ExprId| -> Option<AlgElem> {
if let Some((ra, rb)) = match_a_plus_sqrt_b(radicand, var, pool) {
if trim(ra) == trim(a.clone()) && trim(rb) == trim(b.clone()) {
return Some(sqrt_outer.clone());
}
}
if let Some(rb) = expr_to_qpoly(radicand, var, pool) {
if trim(rb) == trim(b.clone()) {
return Some(sqrt_inner.clone());
}
}
None
};
match pool.get(expr) {
ExprData::Add(args) => {
let mut acc = e.from_int(0);
for &arg in &args {
acc = e.add(
&acc,
&decompose_over_nested_radical(
arg, a, b, sqrt_outer, sqrt_inner, e, var, pool,
)?,
);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = e.from_int(1);
for &arg in &args {
acc = e.mul(
&acc,
&decompose_over_nested_radical(
arg, a, b, sqrt_outer, sqrt_inner, e, var, pool,
)?,
);
}
Some(acc)
}
ExprData::Pow { base, exp } => match pool.get(exp) {
ExprData::Integer(m) => {
let m = m.0.to_i64()?;
let bb = decompose_over_nested_radical(
base, a, b, sqrt_outer, sqrt_inner, e, var, pool,
)?;
e.pow(&bb, m)
}
ExprData::Rational(r) if r.0.denom().to_i64() == Some(2) => {
let m = r.0.numer().to_i64()?;
let g = radical_in_basis(base)?;
e.pow(&g, m)
}
_ => None,
},
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
radical_in_basis(args[0])
}
_ => None,
}
}
#[allow(clippy::too_many_arguments)]
fn try_compositum_poly_rde(
c_rest: ExprId,
k: i64,
deta: &[rug::Rational],
exp_k_eta: ExprId,
k_const: ExprId,
c_expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<Result<ExprId, IntegrationError>> {
let (p, q) = detect_two_sqrt_compositum(c_rest, var, pool)?;
if trim(poly_sub(&q, &p)).is_empty() {
return None; }
let pq_sum = poly_add(&p, &q);
let pq_diff = poly_sub(&p, &q);
let q_min = vec![
poly_mul(&pq_diff, &pq_diff),
poly_zero(),
poly_scale(&pq_sum, &rug::Rational::from(-2)),
poly_zero(),
poly_one(),
];
let e = AlgExtension::new(&q_min);
let sqrt_p = sqrt_in_alpha_basis(&p, &q);
let sqrt_q = sqrt_in_alpha_basis(&q, &p);
let g = decompose_over_compositum(c_rest, &p, &q, &sqrt_p, &sqrt_q, &e, var, pool)?;
let f_poly: QPoly = deta
.iter()
.map(|r| r.clone() * rug::Rational::from(k))
.collect();
let f = RatFn::from_poly(&f_poly);
let ne = || {
IntegrationError::NonElementary(format!(
"the coupled Risch DE over ℚ(x)(√p+√q) for ∫ {} · exp(η)^{k} dx \
has no rational solution",
pool.display(c_expr),
))
};
let y = match solve_alg_rde(&e, &f, &g) {
Some(y) => y,
None => return Some(Err(ne())),
};
let p_expr = qpoly_to_expr(&p, var, pool);
let q_expr = qpoly_to_expr(&q, var, pool);
let alpha = pool.add(vec![
pool.func("sqrt", vec![p_expr]),
pool.func("sqrt", vec![q_expr]),
]);
let mut terms: Vec<ExprId> = Vec::new();
for (j, yj) in y.iter().enumerate() {
if yj.numer().is_empty() {
continue;
}
let coeff = build_rational(yj.numer(), yj.denom(), var, pool);
let term = if j == 0 {
coeff
} else {
pool.mul(vec![coeff, pool.pow(alpha, pool.integer(j as i32))])
};
terms.push(term);
}
let v_expr = match terms.len() {
0 => pool.integer(0_i32),
1 => terms[0],
_ => pool.add(terms),
};
let core = build_v_times_exp(v_expr, exp_k_eta, pool);
let result = apply_const(k_const, core, pool);
log.push(RewriteStep::simple(
"risch_exp_compositum_poly",
c_expr,
result,
));
Some(Ok(result))
}
fn sqrt_in_alpha_basis(self_rad: &QPoly, other_rad: &QPoly) -> AlgElem {
let den = poly_scale(&poly_sub(other_rad, self_rad), &rug::Rational::from(2)); let lin = poly_add(&poly_scale(self_rad, &rug::Rational::from(3)), other_rad); let c1 = RatFn::new(poly_scale(&lin, &rug::Rational::from(-1)), den.clone());
let c3 = RatFn::new(poly_one(), den);
vec![RatFn::int(0), c1, RatFn::int(0), c3]
}
fn detect_two_sqrt_compositum(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Option<(QPoly, QPoly)> {
let mut found: Vec<QPoly> = Vec::new();
scan_sqrt_radicands(expr, var, pool, &mut found);
let mut distinct: Vec<QPoly> = Vec::new();
for p in found {
if !distinct.iter().any(|q| trim(q.clone()) == trim(p.clone())) {
distinct.push(p);
}
}
if distinct.len() == 2 {
let q = distinct.pop().unwrap();
let p = distinct.pop().unwrap();
Some((p, q))
} else {
None
}
}
fn scan_sqrt_radicands(expr: ExprId, var: ExprId, pool: &ExprPool, out: &mut Vec<QPoly>) {
use crate::kernel::ExprData;
let push_if_poly = |arg: ExprId, out: &mut Vec<QPoly>| {
if let Some(p) = expr_to_qpoly(arg, var, pool) {
if degree(&p) >= 1 {
out.push(p);
}
}
};
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
push_if_poly(args[0], out);
}
ExprData::Pow { base, exp } => {
if let ExprData::Rational(r) = pool.get(exp) {
if r.0.denom().to_i64() == Some(2) {
push_if_poly(base, out);
return;
}
}
scan_sqrt_radicands(base, var, pool, out);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for &a in &args {
scan_sqrt_radicands(a, var, pool, out);
}
}
_ => {}
}
}
#[allow(clippy::too_many_arguments)]
fn decompose_over_compositum(
expr: ExprId,
p: &QPoly,
q: &QPoly,
sqrt_p: &AlgElem,
sqrt_q: &AlgElem,
e: &AlgExtension,
var: ExprId,
pool: &ExprPool,
) -> Option<AlgElem> {
use crate::kernel::ExprData;
if let Some((num, den)) = expr_to_qrational(expr, var, pool) {
return Some(e.constant(RatFn::new(num, den)));
}
let half_power = |base: ExprId, a: i64| -> Option<AlgElem> {
let r = trim(expr_to_qpoly(base, var, pool)?);
if r == trim(p.clone()) {
e.pow(sqrt_p, a)
} else if r == trim(q.clone()) {
e.pow(sqrt_q, a)
} else {
None
}
};
match pool.get(expr) {
ExprData::Add(args) => {
let mut acc = e.from_int(0);
for &a in &args {
acc = e.add(
&acc,
&decompose_over_compositum(a, p, q, sqrt_p, sqrt_q, e, var, pool)?,
);
}
Some(acc)
}
ExprData::Mul(args) => {
let mut acc = e.from_int(1);
for &a in &args {
acc = e.mul(
&acc,
&decompose_over_compositum(a, p, q, sqrt_p, sqrt_q, e, var, pool)?,
);
}
Some(acc)
}
ExprData::Pow { base, exp } => match pool.get(exp) {
ExprData::Integer(m) => {
let m = m.0.to_i64()?;
let b = decompose_over_compositum(base, p, q, sqrt_p, sqrt_q, e, var, pool)?;
e.pow(&b, m)
}
ExprData::Rational(r) if r.0.denom().to_i64() == Some(2) => {
half_power(base, r.0.numer().to_i64()?)
}
_ => None,
},
ExprData::Func { ref name, ref args } if name == "sqrt" && args.len() == 1 => {
half_power(args[0], 1)
}
_ => None,
}
}
pub fn needs_exp_risch(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
needs_exp_risch_inner(expr, var, pool)
}
fn is_var_dependent_denominator(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
if let ExprData::Pow { base, exp } = pool.get(expr) {
if let ExprData::Integer(n) = pool.get(exp) {
if n.0.to_i64().is_some_and(|v| v < 0) {
return !is_free_of_var(base, var, pool);
}
}
}
false
}
fn needs_exp_risch_inner(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Func { ref name, ref args } if name == "exp" && args.len() == 1 => {
let eta = args[0];
if is_free_of_var(eta, var, pool) {
return false;
}
if let Some(d) = poly_degree(eta, var, pool) {
if d >= 2 {
return true;
}
return false;
}
true
}
ExprData::Mul(args) => {
let mut has_linear_exp = false;
let mut max_poly_deg: u32 = 0;
let mut has_nonlinear_exp = false;
let mut has_rational_coeff = false;
for &a in &args {
match pool.get(a) {
ExprData::Func { ref name, ref args } if name == "exp" && args.len() == 1 => {
let eta = args[0];
if is_free_of_var(eta, var, pool) {
} else if let Some(d) = poly_degree(eta, var, pool) {
if d >= 2 {
has_nonlinear_exp = true;
} else {
has_linear_exp = true;
}
} else {
has_nonlinear_exp = true;
}
}
_ => {
if let Some(d) = poly_degree(a, var, pool) {
max_poly_deg = max_poly_deg.max(d);
} else if is_var_dependent_denominator(a, var, pool) {
has_rational_coeff = true;
}
}
}
}
if has_nonlinear_exp {
return true;
}
if has_linear_exp && max_poly_deg >= 2 {
return true;
}
if (has_linear_exp || has_nonlinear_exp) && has_rational_coeff {
return true;
}
let has_var_alg = args.iter().any(|&a| contains_var_algebraic(a, var, pool));
if (has_linear_exp || has_nonlinear_exp) && has_var_alg {
return true;
}
args.iter().any(|&a| needs_exp_risch_inner(a, var, pool))
}
ExprData::Add(args) => args.iter().any(|&a| needs_exp_risch_inner(a, var, pool)),
ExprData::Pow { base, exp } => {
needs_exp_risch_inner(base, var, pool) || needs_exp_risch_inner(exp, var, pool)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool() -> ExprPool {
ExprPool::new()
}
fn qp(coeffs: &[i64]) -> QPoly {
coeffs.iter().map(|&c| rug::Rational::from(c)).collect()
}
#[test]
fn detect_radical_generator_forms() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let cbrt_x = pool.func("cbrt", vec![x]);
assert_eq!(
detect_radical_generator(cbrt_x, x, &pool),
Some((3, qp(&[0, 1])))
);
let x_pow = pool.pow(x, pool.rational(1_i32, 3_i32));
assert_eq!(
detect_radical_generator(x_pow, x, &pool),
Some((3, qp(&[0, 1])))
);
let sqrt = pool.func("sqrt", vec![pool.add(vec![x, pool.integer(1_i32)])]);
assert_eq!(
detect_radical_generator(sqrt, x, &pool),
Some((2, qp(&[1, 1])))
);
assert_eq!(
detect_radical_generator(pool.add(vec![x, pool.integer(1_i32)]), x, &pool),
None
);
let mixed = pool.add(vec![pool.func("sqrt", vec![x]), pool.func("cbrt", vec![x])]);
assert_eq!(detect_radical_generator(mixed, x, &pool), None);
}
#[test]
fn decompose_x_plus_cbrt_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![x, pool.func("cbrt", vec![x])]);
let (e, elem) = decompose_radical(expr, x, &pool).expect("decomposes");
assert_eq!(e.degree(), 3);
let expected = vec![RatFn::from_poly(&qp(&[0, 1])), RatFn::int(1)];
assert!(e.elem_eq(&elem, &expected), "got {elem:?}");
}
#[test]
fn decompose_cbrt_x_squared() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sq = pool.pow(pool.func("cbrt", vec![x]), pool.integer(2_i32));
let (e, elem) = decompose_radical(sq, x, &pool).expect("decomposes");
let expected = vec![RatFn::int(0), RatFn::int(0), RatFn::int(1)];
assert!(e.elem_eq(&elem, &expected), "got {elem:?}");
let frac = pool.pow(x, pool.rational(2_i32, 3_i32));
let (e2, elem2) = decompose_radical(frac, x, &pool).expect("decomposes");
assert!(e2.elem_eq(&elem2, &expected), "x^(2/3): got {elem2:?}");
}
#[test]
fn decompose_inverse_of_one_plus_cbrt_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let base = pool.add(vec![pool.integer(1_i32), pool.func("cbrt", vec![x])]);
let inv_expr = pool.pow(base, pool.integer(-1_i32));
let (e, inv_elem) = decompose_radical(inv_expr, x, &pool).expect("decomposes");
let base_elem = vec![RatFn::int(1), RatFn::int(1)]; let product = e.mul(&inv_elem, &base_elem);
assert!(
e.elem_eq(&product, &e.from_int(1)),
"inv·(1+y) = {product:?}"
);
}
#[test]
fn decompose_degree2_matches_kpair_semantics() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![x, pool.func("sqrt", vec![x])]);
let (e, elem) = decompose_radical(expr, x, &pool).expect("decomposes");
assert_eq!(e.degree(), 2);
let expected = vec![RatFn::from_poly(&qp(&[0, 1])), RatFn::int(1)];
assert!(e.elem_eq(&elem, &expected), "got {elem:?}");
}
#[test]
fn exp_x2_is_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
use super::super::tower::find_generators;
let gens = find_generators(exp_x2, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(exp_x2, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ exp(x²) dx should be NonElementary, got: {:?}",
result
);
}
#[test]
fn mixed_cbrt_times_exp_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x13 = pool.pow(x, pool.rational(1_i32, 3_i32));
let inv_3x = pool.mul(vec![
pool.rational(1_i32, 3_i32),
pool.pow(x, pool.integer(-1_i32)),
]);
let coeff = pool.add(vec![pool.integer(1_i32), inv_3x]);
let integrand = pool.mul(vec![coeff, x13, pool.func("exp", vec![x])]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn cbrt_times_exp_is_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x13 = pool.pow(x, pool.rational(1_i32, 3_i32));
let integrand = pool.mul(vec![x13, pool.func("exp", vec![x])]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = gens.iter().find(|g| g.is_exp()).expect("an exp generator");
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ x^(1/3)·exp(x) dx should be NonElementary, got: {result:?}"
);
}
#[test]
fn x_times_exp_x2_is_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let integrand = pool.mul(vec![x, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ x·exp(x²) dx should be elementary, got: {:?}",
result
);
let antideriv = result.unwrap();
let s = pool.display(antideriv).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn two_x_times_exp_x2_equals_exp_x2() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let two = pool.integer(2_i32);
let integrand = pool.mul(vec![two, x, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log).unwrap();
let s = pool.display(result).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn x2_times_exp_x_is_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x = pool.func("exp", vec![x]);
let integrand = pool.mul(vec![x2, exp_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ x²·exp(x) dx should be elementary, got: {:?}",
result
);
let s = pool.display(result.unwrap()).to_string();
assert!(s.contains("exp"), "result should contain exp: {}", s);
}
#[test]
fn needs_exp_risch_detection() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
assert!(needs_exp_risch(exp_x2, x, &pool));
let exp_x = pool.func("exp", vec![x]);
assert!(!needs_exp_risch(exp_x, x, &pool));
let x2_times_exp_x = pool.mul(vec![pool.pow(x, pool.integer(2_i32)), exp_x]);
assert!(needs_exp_risch(x2_times_exp_x, x, &pool));
let x_times_exp_x = pool.mul(vec![x, exp_x]);
assert!(!needs_exp_risch(x_times_exp_x, x, &pool));
let x_times_exp_x2 = pool.mul(vec![x, exp_x2]);
assert!(needs_exp_risch(x_times_exp_x2, x, &pool));
}
fn eval_f64(expr: ExprId, x: ExprId, xv: f64, pool: &ExprPool) -> f64 {
use crate::kernel::ExprData;
if expr == x {
return xv;
}
match pool.get(expr) {
ExprData::Symbol { ref name, .. } if name == "pi" => std::f64::consts::PI,
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_f64(a, x, xv, pool)).sum(),
ExprData::Mul(args) => args.iter().map(|&a| eval_f64(a, x, xv, pool)).product(),
ExprData::Pow { base, exp } => {
eval_f64(base, x, xv, pool).powf(eval_f64(exp, x, xv, pool))
}
ExprData::Func { ref name, ref args } if args.len() == 1 => {
let a = eval_f64(args[0], x, xv, pool);
match name.as_str() {
"exp" => a.exp(),
"log" => a.ln(),
"sqrt" => a.sqrt(),
other => panic!("eval_f64: unsupported func {other}"),
}
}
other => panic!("eval_f64: unsupported node {other:?}"),
}
}
fn verify_exp_tower(integrand: ExprId, x: ExprId, pool: &ExprPool) {
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, pool);
let level = gens.iter().find(|g| g.is_exp()).expect("an exp generator");
let mut log = DerivationLog::new();
let f =
integrate_exp_tower(integrand, level, x, pool, &mut log).expect("should be elementary");
let d = crate::diff::diff(f, x, pool).unwrap();
let ds = crate::simplify::engine::simplify(d.value, pool).value;
for &xv in &[0.7_f64, 1.3, 2.1] {
let lhs = eval_f64(ds, x, xv, pool);
let rhs = eval_f64(integrand, x, xv, pool);
assert!(
(lhs - rhs).abs() < 1e-7,
"d/dx F ≠ f at x={xv}: {lhs} vs {rhs}\n F = {}",
pool.display(f)
);
}
}
#[test]
fn compositum_two_sqrts_times_exp() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let xp1 = pool.add(vec![x, pool.integer(1_i32)]);
let sx = pool.func("sqrt", vec![x]);
let sx1 = pool.func("sqrt", vec![xp1]);
let half = pool.rational(1_i32, 2_i32);
let inv2sx = pool.mul(vec![half, pool.pow(x, pool.rational(-1_i32, 2_i32))]);
let inv2sx1 = pool.mul(vec![half, pool.pow(xp1, pool.rational(-1_i32, 2_i32))]);
let c_rest = pool.add(vec![sx, sx1, inv2sx, inv2sx1]);
let integrand = pool.mul(vec![c_rest, exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn nested_radical_sqrt_x_plus_sqrt_x_times_exp() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let sx = pool.func("sqrt", vec![x]); let inner = pool.add(vec![x, sx]); let alpha = pool.func("sqrt", vec![inner]);
let x_inv_half = pool.pow(x, pool.rational(-1_i32, 2_i32)); let num = pool.add(vec![pool.integer(2_i32), x_inv_half]); let alpha_inv = pool.pow(alpha, pool.integer(-1_i32)); let quarter = pool.rational(1_i32, 4_i32);
let d_alpha = pool.mul(vec![quarter, num, alpha_inv]);
let c_rest = pool.add(vec![d_alpha, alpha]);
let integrand = pool.mul(vec![c_rest, exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn detect_nested_radical_cases() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sx = pool.func("sqrt", vec![x]);
let nested = pool.func("sqrt", vec![pool.add(vec![x, sx])]);
assert_eq!(
detect_nested_radical(nested, x, &pool),
Some((qp(&[0, 1]), qp(&[0, 1])))
);
let x2p1 = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let inner = pool.add(vec![pool.integer(1_i32), pool.func("sqrt", vec![x2p1])]);
let nested2 = pool.func("sqrt", vec![inner]);
assert_eq!(
detect_nested_radical(nested2, x, &pool),
Some((qp(&[1]), qp(&[1, 0, 1])))
);
let sqrt_x2 = pool.func("sqrt", vec![pool.pow(x, pool.integer(2_i32))]);
let spurious = pool.func("sqrt", vec![pool.add(vec![x, sqrt_x2])]);
assert_eq!(detect_nested_radical(spurious, x, &pool), None);
}
#[test]
fn nested_radical_times_exp_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let sx = pool.func("sqrt", vec![x]);
let alpha = pool.func("sqrt", vec![pool.add(vec![x, sx])]); let integrand = pool.mul(vec![alpha, exp_x]);
let r = crate::integrate::engine::integrate(integrand, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"expected NonElementary; got {r:?}"
);
}
#[test]
fn compositum_two_sqrts_times_exp_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let xp1 = pool.add(vec![x, pool.integer(1_i32)]);
let c_rest = pool.add(vec![
pool.func("sqrt", vec![x]),
pool.func("sqrt", vec![xp1]),
]);
let integrand = pool.mul(vec![c_rest, exp_x]);
let r = crate::integrate::engine::integrate(integrand, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"expected NonElementary; got {r:?}"
);
}
#[test]
fn rational_const_factor_exp_x2() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x2 = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
let half = pool.rational(1_i32, 2_i32);
let integrand = pool.mul(vec![half, x, exp_x2]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn algebraic_const_factor_exp_x2() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x2 = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let integrand = pool.mul(vec![sqrt2, x, exp_x2]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn symbolic_const_factor_poly_exp_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let pi = pool.symbol("pi", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let integrand = pool.mul(vec![pi, pool.pow(x, pool.integer(2_i32)), exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn algebraic_const_factor_rational_coeff() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x = pool.func("exp", vec![x]);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let num = pool.add(vec![x, pool.integer(-1_i32)]);
let inv_x2 = pool.pow(x, pool.integer(-2_i32));
let integrand = pool.mul(vec![sqrt2, num, inv_x2, exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn algebraic_const_factor_nonelementary_preserved() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let exp_x2 = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let integrand = pool.mul(vec![sqrt2, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = gens.iter().find(|g| g.is_exp()).unwrap();
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ √2·exp(x²) dx must remain NonElementary; got {result:?}"
);
}
#[test]
fn algebraic_coeff_linear_exp_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let exp_x = pool.func("exp", vec![x]);
let coeff = pool.add(vec![x, sqrt2]);
let integrand = pool.mul(vec![coeff, exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn algebraic_coeff_quadratic_exp_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt3 = pool.func("sqrt", vec![pool.integer(3_i32)]);
let exp_x = pool.func("exp", vec![x]);
let coeff = pool.add(vec![
pool.mul(vec![sqrt3, pool.pow(x, pool.integer(2_i32))]),
x,
]);
let integrand = pool.mul(vec![coeff, exp_x]);
verify_exp_tower(integrand, x, &pool);
}
#[test]
fn algebraic_coeff_nonelementary_preserved() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let exp_x2 = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
let coeff = pool.add(vec![x, sqrt2]);
let integrand = pool.mul(vec![coeff, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = gens.iter().find(|g| g.is_exp()).unwrap();
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ (x+√2)·exp(x²) dx must be NonElementary; got {result:?}"
);
}
#[test]
fn algebraic_rational_coeff_exp_x() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let neg_sqrt2 = pool.mul(vec![pool.integer(-1_i32), sqrt2]);
let base = pool.add(vec![x, neg_sqrt2]); let num = pool.add(vec![x, neg_sqrt2, pool.integer(-1_i32)]); let exp_x = pool.func("exp", vec![x]);
let integrand = pool.mul(vec![num, pool.pow(base, pool.integer(-2_i32)), exp_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = gens.iter().find(|g| g.is_exp()).expect("exp generator");
let mut log = DerivationLog::new();
let f = integrate_exp_tower(integrand, level, x, &pool, &mut log)
.expect("∫ (x−√2−1)/(x−√2)²·exp(x) dx should be elementary");
let d = crate::diff::diff(f, x, &pool).unwrap();
let ds = crate::simplify::engine::simplify(d.value, &pool).value;
for &xv in &[2.5_f64, 3.3, 4.1] {
let lhs = eval_f64(ds, x, xv, &pool);
let rhs = eval_f64(integrand, x, xv, &pool);
assert!(
(lhs - rhs).abs() < 1e-7,
"d/dx F ≠ f at x={xv}: {lhs} vs {rhs}\n F = {}",
pool.display(f)
);
}
}
#[test]
fn algebraic_rational_coeff_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let neg_sqrt2 = pool.mul(vec![pool.integer(-1_i32), sqrt2]);
let base = pool.add(vec![x, neg_sqrt2]); let exp_x = pool.func("exp", vec![x]);
let integrand = pool.mul(vec![
pool.pow(x, pool.integer(2_i32)),
pool.pow(base, pool.integer(-1_i32)),
exp_x,
]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = gens.iter().find(|g| g.is_exp()).unwrap();
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ x²/(x−√2)·exp(x) dx must be NonElementary; got {result:?}"
);
}
#[test]
fn detect_sqrt_field_cases() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let e = pool.add(vec![x, sqrt2]);
assert_eq!(detect_sqrt_field(e, &pool).map(|(d, _)| d), Some(2));
let sqrt4 = pool.func("sqrt", vec![pool.integer(4_i32)]);
let e = pool.add(vec![x, sqrt4]);
assert_eq!(detect_sqrt_field(e, &pool), None);
let sqrt3 = pool.func("sqrt", vec![pool.integer(3_i32)]);
let e = pool.add(vec![sqrt2, sqrt3]);
assert_eq!(detect_sqrt_field(e, &pool), None);
let e = pool.add(vec![x, pool.integer(1_i32)]);
assert_eq!(detect_sqrt_field(e, &pool), None);
}
#[test]
fn split_const_factor_cases() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let c = pool.mul(vec![sqrt2, x]);
let (k, rest) = split_const_factor(c, x, &pool);
assert_eq!(k, sqrt2);
assert_eq!(rest, x);
let (k, rest) = split_const_factor(sqrt2, x, &pool);
assert_eq!(k, sqrt2);
assert_eq!(rest, pool.integer(1_i32));
let (k, rest) = split_const_factor(x, x, &pool);
assert_eq!(k, pool.integer(1_i32));
assert_eq!(rest, x);
}
fn eval_f64_gapf(expr: ExprId, x: ExprId, xv: f64, pool: &ExprPool) -> f64 {
use crate::kernel::ExprData;
if expr == x {
return xv;
}
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_f64_gapf(a, x, xv, pool)).sum(),
ExprData::Mul(args) => args
.iter()
.map(|&a| eval_f64_gapf(a, x, xv, pool))
.product(),
ExprData::Pow { base, exp } => {
eval_f64_gapf(base, x, xv, pool).powf(eval_f64_gapf(exp, x, xv, pool))
}
ExprData::Func { ref name, ref args } if args.len() == 1 => {
let a = eval_f64_gapf(args[0], x, xv, pool);
match name.as_str() {
"exp" => a.exp(),
"log" => a.ln(),
"sqrt" => a.sqrt(),
"cbrt" => a.cbrt(),
other => panic!("eval_f64_gapf: unsupported func {other}"),
}
}
other => panic!("eval_f64_gapf: unsupported {other:?}"),
}
}
fn verify_numeric_gapf(integrand: ExprId, antideriv: ExprId, x: ExprId, pool: &ExprPool) {
let d = crate::diff::diff(antideriv, x, pool).unwrap();
let ds = crate::simplify::engine::simplify(d.value, pool).value;
for &xv in &[0.5_f64, 1.0, 2.0] {
let lhs = eval_f64_gapf(ds, x, xv, pool);
let rhs = eval_f64_gapf(integrand, x, xv, pool);
assert!(
(lhs - rhs).abs() < 1e-8,
"d/dx F ≠ f at x={xv}: got {lhs}, expected {rhs}\n F = {}",
pool.display(antideriv)
);
}
}
#[test]
fn exp_inv_x_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let inv_x = pool.pow(x, pool.integer(-1_i32)); let f = pool.func("exp", vec![inv_x]);
assert!(
needs_exp_risch(f, x, &pool),
"exp(1/x) should be routed to Risch"
);
use super::super::tower::find_generators;
let gens = find_generators(f, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(f, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ exp(1/x) dx must be NonElementary; got {result:?}"
);
}
#[test]
fn inv_x2_times_exp_inv_x_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let inv_x = pool.pow(x, pool.integer(-1_i32));
let exp_inv_x = pool.func("exp", vec![inv_x]);
let inv_x2 = pool.pow(x, pool.integer(-2_i32));
let integrand = pool.mul(vec![inv_x2, exp_inv_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ (1/x²)·exp(1/x) dx must be elementary; got {result:?}"
);
verify_numeric_gapf(integrand, result.unwrap(), x, &pool);
}
#[test]
fn two_inv_x3_times_exp_neg_inv_x2_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let neg_inv_x2 = pool.mul(vec![
pool.integer(-1_i32),
pool.pow(x, pool.integer(-2_i32)),
]);
let exp_neg_inv_x2 = pool.func("exp", vec![neg_inv_x2]);
let two_inv_x3 = pool.mul(vec![pool.integer(2_i32), pool.pow(x, pool.integer(-3_i32))]);
let integrand = pool.mul(vec![two_inv_x3, exp_neg_inv_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ (2/x³)·exp(−1/x²) dx must be elementary; got {result:?}"
);
verify_numeric_gapf(integrand, result.unwrap(), x, &pool);
}
#[test]
fn detection_rational_eta() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let inv_x = pool.pow(x, pool.integer(-1_i32));
let exp_inv_x = pool.func("exp", vec![inv_x]);
assert!(
needs_exp_risch(exp_inv_x, x, &pool),
"exp(1/x) should need Risch"
);
let exp_x = pool.func("exp", vec![x]);
assert!(
!needs_exp_risch(exp_x, x, &pool),
"exp(x) alone should NOT route to Risch"
);
let x_exp_inv_x = pool.mul(vec![x, exp_inv_x]);
assert!(
needs_exp_risch(x_exp_inv_x, x, &pool),
"x·exp(1/x) should need Risch"
);
}
#[test]
fn compositum_two_sqrts_exp_x2_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let sqrt3 = pool.func("sqrt", vec![pool.integer(3_i32)]);
let x2 = pool.pow(x, pool.integer(2_i32));
let exp_x2 = pool.func("exp", vec![x2]);
let coeff = pool.add(vec![x, sqrt2, sqrt3]);
let integrand = pool.mul(vec![coeff, exp_x2]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ (x+√2+√3)·exp(x²) dx: const-offset term is non-elementary; got {result:?}"
);
}
#[test]
fn compositum_two_sqrts_exp_x_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let sqrt3 = pool.func("sqrt", vec![pool.integer(3_i32)]);
let exp_x = pool.func("exp", vec![x]);
let coeff = pool.add(vec![x, sqrt2, sqrt3]);
let integrand = pool.mul(vec![coeff, exp_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ (x+√2+√3)·exp(x) dx must be elementary; got {result:?}"
);
let antideriv = result.unwrap();
verify_numeric_gapf(integrand, antideriv, x, &pool);
}
#[test]
fn nth_root_cbrt3_exp_x_elementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let cbrt3 = pool.func("cbrt", vec![pool.integer(3_i32)]);
let exp_x = pool.func("exp", vec![x]);
let coeff = pool.add(vec![x, cbrt3]);
let integrand = pool.mul(vec![coeff, exp_x]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
assert_eq!(gens.len(), 1);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
result.is_ok(),
"∫ (x+cbrt(3))·exp(x) dx must be elementary; got {result:?}"
);
let cbrt3_v: f64 = 3.0f64.powf(1.0 / 3.0);
let f = result.unwrap();
let eval = |expr: ExprId, xv: f64| -> f64 { eval_f64_gapf(expr, x, xv, &pool) };
let h = 1e-6_f64;
for &xv in &[0.5_f64, 1.2, 2.7] {
let fd = (eval(f, xv + h) - eval(f, xv - h)) / (2.0 * h);
let exact = (xv + cbrt3_v) * xv.exp();
assert!(
(fd - exact).abs() < 1e-5,
"finite-diff check at x={xv}: fd={fd}, exact={exact}"
);
}
}
#[test]
fn nth_root_pow_1_3_exp_x2_nonelementary() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let cbrt2 = pool.pow(
pool.integer(2_i32),
pool.rational(rug::Integer::from(1), rug::Integer::from(3)),
);
let x2 = pool.pow(x, pool.integer(2_i32));
let integrand = pool.mul(vec![cbrt2, pool.func("exp", vec![x2])]);
use super::super::tower::find_generators;
let gens = find_generators(integrand, x, &pool);
let level = &gens[0];
let mut log = DerivationLog::new();
let result = integrate_exp_tower(integrand, level, x, &pool, &mut log);
assert!(
matches!(result, Err(IntegrationError::NonElementary(_))),
"∫ 2^(1/3)·exp(x²) dx must be NonElementary; got {result:?}"
);
}
#[test]
fn detect_algebraic_extension_cases() {
let pool = pool();
let x = pool.symbol("x", Domain::Real);
let sqrt2 = pool.func("sqrt", vec![pool.integer(2_i32)]);
let sqrt3 = pool.func("sqrt", vec![pool.integer(3_i32)]);
let cbrt5 = pool.func("cbrt", vec![pool.integer(5_i32)]);
let e1 = pool.add(vec![x, sqrt2]);
assert!(
matches!(
detect_algebraic_extension(e1, &pool),
Some(AlgebraicExtension::SingleSqrt { d: 2, .. })
),
"x+√2 should give SingleSqrt(2)"
);
let e2 = pool.add(vec![x, sqrt2, sqrt3]);
assert!(
matches!(
detect_algebraic_extension(e2, &pool),
Some(AlgebraicExtension::CompositumTwoSqrts { a: 2, b: 3, .. })
),
"x+√2+√3 should give CompositumTwoSqrts(2,3)"
);
let e3 = pool.add(vec![x, cbrt5]);
assert!(
matches!(
detect_algebraic_extension(e3, &pool),
Some(AlgebraicExtension::NthRoot { n: 5, m: 3, .. })
),
"x+cbrt(5) should give NthRoot(5,3)"
);
assert!(detect_algebraic_extension(x, &pool).is_none());
}
}