use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
use crate::kernel::{ExprData, ExprId, ExprPool};
use crate::simplify::engine::simplify;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IntegrationError {
NotImplemented(String),
DivisionByZero,
UnsupportedExtensionDegree(u32),
NonElementary(String),
}
impl fmt::Display for IntegrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IntegrationError::NotImplemented(msg) => write!(f, "integrate: not implemented: {msg}"),
IntegrationError::DivisionByZero => write!(f, "integrate: division by zero"),
IntegrationError::UnsupportedExtensionDegree(q) => write!(
f,
"integrate: algebraic extension of degree {q} is not supported \
(v1.1 supports only degree-2 / sqrt extensions)"
),
IntegrationError::NonElementary(msg) => {
write!(f, "integrate: no elementary antiderivative exists: {msg}")
}
}
}
}
impl std::error::Error for IntegrationError {}
impl IntegrationError {
pub fn remediation(&self) -> Option<&'static str> {
match self {
IntegrationError::NotImplemented(_) => Some(
"only power, linearity, sin/cos/exp rules and algebraic (sqrt) rules \
are implemented; use a numeric integrator for arbitrary functions",
),
IntegrationError::DivisionByZero => None,
IntegrationError::UnsupportedExtensionDegree(_) => Some(
"v1.1 supports sqrt(P(x)) only; higher-degree radicals (cbrt, nth-root) \
are planned for v2.0",
),
IntegrationError::NonElementary(_) => Some(
"this integrand has no closed-form antiderivative in terms of elementary \
functions; use a numeric integrator or elliptic-integral library",
),
}
}
pub fn span(&self) -> Option<(usize, usize)> {
None
}
}
impl crate::errors::AlkahestError for IntegrationError {
fn code(&self) -> &'static str {
match self {
IntegrationError::NotImplemented(_) => "E-INT-001",
IntegrationError::DivisionByZero => "E-INT-002",
IntegrationError::UnsupportedExtensionDegree(_) => "E-INT-003",
IntegrationError::NonElementary(_) => "E-INT-004",
}
}
fn remediation(&self) -> Option<&'static str> {
IntegrationError::remediation(self)
}
}
fn try_log_derivative(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<ExprId> {
use super::risch::poly_rde::{poly_mul, rational_to_expr, trim};
use super::risch::rational_rde::expr_to_qrational;
use super::risch::tower::find_generators;
let gens = find_generators(expr, var, pool);
if gens.len() != 1 || !gens[0].is_log() {
return None;
}
let theta = gens[0].generator; let h = gens[0].argument();
let (coeff, n) = extract_log_power(expr, theta, pool)?;
if n == 0 {
return None;
}
let (cn, cd) = expr_to_qrational(coeff, var, pool)?;
let hp = crate::diff::diff(h, var, pool).ok()?.value;
let (hpn, hpd) = expr_to_qrational(hp, var, pool)?;
let (hn, hd) = expr_to_qrational(h, var, pool)?;
let rn = poly_mul(&hpn, &hd);
let rd = poly_mul(&hpd, &hn);
if trim(poly_mul(&cn, &rd)) != trim(poly_mul(&rn, &cd)) {
return None;
}
if n == -1 {
Some(pool.func("log", vec![theta])) } else {
let np1 = n + 1;
let pow = pool.pow(theta, pool.integer(np1));
let coeff_expr = rational_to_expr(&rug::Rational::from((1_i64, np1)), pool);
Some(pool.mul(vec![coeff_expr, pow]))
}
}
fn extract_log_power(expr: ExprId, theta: ExprId, pool: &ExprPool) -> Option<(ExprId, i64)> {
if expr == theta {
return Some((pool.integer(1_i32), 1));
}
match pool.get(expr) {
ExprData::Pow { base, exp } if base == theta => match pool.get(exp) {
ExprData::Integer(m) => Some((pool.integer(1_i32), m.0.to_i64()?)),
_ => None,
},
ExprData::Mul(args) => {
let mut n: i64 = 0;
let mut rest: Vec<ExprId> = Vec::new();
for &a in &args {
if a == theta {
n += 1;
} else if let ExprData::Pow { base, exp } = pool.get(a) {
if base == theta {
match pool.get(exp) {
ExprData::Integer(m) => n += m.0.to_i64()?,
_ => rest.push(a),
}
} else {
rest.push(a);
}
} else {
rest.push(a);
}
}
if n == 0 {
return None;
}
let coeff = match rest.len() {
0 => pool.integer(1_i32),
1 => rest[0],
_ => pool.mul(rest),
};
Some((coeff, n))
}
_ => None,
}
}
fn as_integer(expr: ExprId, pool: &ExprPool) -> Option<i64> {
pool.with(expr, |data| match data {
ExprData::Integer(n) => n.0.to_i64(),
_ => None,
})
}
fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
let mut cache: HashMap<ExprId, bool> = HashMap::new();
is_free_of_inner(expr, var, pool, &mut cache)
}
fn is_free_of_inner(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
cache: &mut HashMap<ExprId, bool>,
) -> bool {
if expr == var {
return false;
}
if let Some(&cached) = cache.get(&expr) {
return cached;
}
let children: Vec<ExprId> = pool.with(expr, |data| match data {
ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
ExprData::Pow { base, exp } => vec![*base, *exp],
ExprData::Func { args, .. } => args.clone(),
_ => vec![],
});
let result = children
.into_iter()
.all(|c| is_free_of_inner(c, var, pool, cache));
cache.insert(expr, result);
result
}
fn is_linear_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
if expr == var {
return Some((pool.integer(1_i32), pool.integer(0_i32)));
}
match pool.get(expr) {
ExprData::Mul(args) => {
let var_pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != var_pos)
.map(|(_, &a)| a)
.collect();
let a = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
if is_free_of(a, var, pool) {
Some((a, pool.integer(0_i32)))
} else {
None
}
}
ExprData::Add(args) => {
let mut a_opt: Option<ExprId> = None;
let mut b_parts: Vec<ExprId> = vec![];
for &arg in &args {
if arg == var {
if a_opt.is_some() {
return None;
}
a_opt = Some(pool.integer(1_i32));
} else {
match pool.get(arg) {
ExprData::Mul(margs) => {
let vpos = margs.iter().position(|&m| m == var);
if let Some(vp) = vpos {
if a_opt.is_some() {
return None;
}
let others: Vec<ExprId> = margs
.iter()
.enumerate()
.filter(|&(i, _)| i != vp)
.map(|(_, &m)| m)
.collect();
let coeff = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
if is_free_of(coeff, var, pool) {
a_opt = Some(coeff);
} else {
b_parts.push(arg);
}
} else if is_free_of(arg, var, pool) {
b_parts.push(arg);
} else {
return None;
}
}
_ if is_free_of(arg, var, pool) => b_parts.push(arg),
_ => return None,
}
}
}
let a = a_opt?;
let b = match b_parts.len() {
0 => pool.integer(0_i32),
1 => b_parts[0],
_ => pool.add(b_parts),
};
Some((a, b))
}
_ => None,
}
}
fn try_x_times_func(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<ExprId> {
let args = match pool.get(expr) {
ExprData::Mul(v) => v,
_ => return None,
};
let exp_pos = args.iter().position(|&a| {
pool.with(a, |d| match d {
ExprData::Func { name, args } => name == "exp" && args.len() == 1 && args[0] == var,
_ => false,
})
})?;
let var_pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != exp_pos && i != var_pos)
.map(|(_, &a)| a)
.collect();
if !others.iter().all(|&a| is_free_of(a, var, pool)) {
return None;
}
let exp_x = args[exp_pos];
let x_minus_1 = pool.add(vec![var, pool.integer(-1_i32)]);
let mut factors = vec![exp_x, x_minus_1];
factors.extend_from_slice(&others);
let result = pool.mul(factors);
log.push(RewriteStep::simple("int_x_exp", expr, result));
Some(result)
}
fn special_integral_name(func: &str) -> Option<&'static str> {
match func {
"exp" => Some("Ei"),
"sin" => Some("Si"),
"cos" => Some("Ci"),
"sinh" => Some("Shi"),
"cosh" => Some("Chi"),
_ => None,
}
}
fn is_negative_integer(exp: ExprId, pool: &ExprPool) -> bool {
as_integer(exp, pool).is_some_and(|n| n < 0)
}
fn is_polynomial_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
if expr == var || is_free_of(expr, var, pool) {
return true;
}
match pool.get(expr) {
ExprData::Add(args) | ExprData::Mul(args) => {
args.iter().all(|&a| is_polynomial_in(a, var, pool))
}
ExprData::Pow { base, exp } => {
is_polynomial_in(base, var, pool) && as_integer(exp, pool).is_some_and(|n| n >= 0)
}
_ => false,
}
}
fn is_simple_denominator_base(base: ExprId, var: ExprId, pool: &ExprPool) -> bool {
!is_free_of(base, var, pool) && is_polynomial_in(base, var, pool)
}
fn known_nonelementary(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<String> {
if let Some(msg) = match_log_denominator(expr, var, pool) {
return Some(msg);
}
let args = match pool.get(expr) {
ExprData::Mul(args) => args,
_ => return None,
};
let mut special: Option<String> = None; let mut has_poly_denom = false; let mut log_denom: Option<String> = None;
for &a in &args {
if is_free_of(a, var, pool) {
continue;
}
if let ExprData::Func { ref name, ref args } = pool.get(a) {
if args.len() == 1
&& special_integral_name(name).is_some()
&& is_linear_in(args[0], var, pool).is_some()
{
if special.is_some() {
return None; }
special = Some(pool.display(a).to_string());
continue;
}
}
if let ExprData::Pow { base, exp } = pool.get(a) {
if is_negative_integer(exp, pool) {
if let Some(msg) = match_log_denominator(a, var, pool) {
if log_denom.is_some() {
return None;
}
log_denom = Some(msg);
continue;
}
if is_simple_denominator_base(base, var, pool) {
has_poly_denom = true;
continue;
}
}
}
return None;
}
if let (Some(f), true) = (&special, has_poly_denom) {
return Some(format!(
"{f} divided by a polynomial gives a special-function integral \
(Ei/Si/Ci/Shi/Chi), which is not elementary (Liouville's theorem)"
));
}
if let Some(msg) = log_denom {
return Some(msg);
}
None
}
fn match_log_denominator(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<String> {
let ExprData::Pow { base, exp } = pool.get(expr) else {
return None;
};
if !is_negative_integer(exp, pool) {
return None;
}
let ExprData::Func { ref name, ref args } = pool.get(base) else {
return None;
};
if name == "log" && args.len() == 1 && is_linear_in(args[0], var, pool).is_some() {
Some(format!(
"1/{} is the logarithmic integral li, which is not elementary \
(Liouville's theorem)",
pool.display(base)
))
} else {
None
}
}
pub(crate) fn integrate_raw(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
if let Some(result) = try_x_times_func(expr, var, pool, log) {
return Ok(result);
}
enum Node {
IsVar,
Constant,
Add(Vec<ExprId>),
Mul(Vec<ExprId>),
Pow { base: ExprId, exp: ExprId },
Func { name: String, arg: ExprId },
Unknown,
}
let node = pool.with(expr, |data| match data {
ExprData::Symbol { .. } if expr == var => Node::IsVar,
ExprData::Symbol { .. }
| ExprData::Integer(_)
| ExprData::Rational(_)
| ExprData::Float(_) => Node::Constant,
ExprData::Add(args) => Node::Add(args.clone()),
ExprData::Mul(args) => Node::Mul(args.clone()),
ExprData::Pow { base, exp } => Node::Pow {
base: *base,
exp: *exp,
},
ExprData::Func { name, args } if args.len() == 1 => Node::Func {
name: name.clone(),
arg: args[0],
},
_ => Node::Unknown,
});
match node {
Node::IsVar => {
let two = pool.integer(2_i32);
let inv_two = pool.pow(two, pool.integer(-1_i32));
let result = pool.mul(vec![pool.pow(var, two), inv_two]);
log.push(RewriteStep::simple("power_rule", expr, result));
Ok(result)
}
Node::Constant => {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
Ok(result)
}
Node::Add(args) => {
let mut int_args = Vec::with_capacity(args.len());
for a in &args {
let ia = integrate_raw(*a, var, pool, log)?;
int_args.push(ia);
}
let result = pool.add(int_args);
log.push(RewriteStep::simple("sum_rule", expr, result));
Ok(result)
}
Node::Mul(args) => {
let (consts, non_consts): (Vec<ExprId>, Vec<ExprId>) =
args.iter().partition(|&&a| is_free_of(a, var, pool));
if non_consts.is_empty() {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
let inner = match non_consts.len() {
1 => non_consts[0],
_ => pool.mul(non_consts.clone()),
};
let const_factor = match consts.len() {
0 => None,
1 => Some(consts[0]),
_ => Some(pool.mul(consts.clone())),
};
if inner == expr {
return Err(IntegrationError::NotImplemented(format!(
"∫ {} — irreducible product of var-dependent factors",
pool.display(expr)
)));
}
let int_inner = integrate_raw(inner, var, pool, log)?;
let result = match const_factor {
None => int_inner,
Some(c) => {
let r = pool.mul(vec![c, int_inner]);
log.push(RewriteStep::simple("constant_multiple_rule", expr, r));
r
}
};
Ok(result)
}
Node::Pow { base, exp } => {
let n_opt = as_integer(exp, pool);
if let Some(n) = n_opt {
if base == var {
if n == -1 {
let result = pool.func("log", vec![var]);
log.push(RewriteStep::simple("log_rule", expr, result));
return Ok(result);
}
let np1 = pool.integer(n + 1);
let inv_np1 = pool.pow(np1, pool.integer(-1_i32));
let result = pool.mul(vec![pool.pow(var, np1), inv_np1]);
log.push(RewriteStep::simple("power_rule", expr, result));
return Ok(result);
}
if n == -1 {
if let Some((a, _b)) = is_linear_in(base, var, pool) {
let log_base = pool.func("log", vec![base]);
let a_inv = pool.pow(a, pool.integer(-1_i32));
let result = pool.mul(vec![a_inv, log_base]);
log.push(RewriteStep::simple("int_linear_inv", expr, result));
return Ok(result);
}
}
if is_free_of(base, var, pool) {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
}
Err(IntegrationError::NotImplemented(
"∫ (expr)^(exp) where base or exp is non-trivial".to_string(),
))
}
Node::Func { name, arg } => {
if arg != var {
if is_free_of(arg, var, pool) {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
if name == "exp" {
if let Some((a, _b)) = is_linear_in(arg, var, pool) {
let exp_expr = pool.func("exp", vec![arg]);
let a_inv = pool.pow(a, pool.integer(-1_i32));
let result = pool.mul(vec![a_inv, exp_expr]);
log.push(RewriteStep::simple("int_exp_linear", expr, result));
return Ok(result);
}
}
return Err(IntegrationError::NotImplemented(format!(
"∫ {name}(non-trivial arg) — chain rule not implemented"
)));
}
match name.as_str() {
"sin" => {
let neg_one = pool.integer(-1_i32);
let result = pool.mul(vec![neg_one, pool.func("cos", vec![var])]);
log.push(RewriteStep::simple("int_sin", expr, result));
Ok(result)
}
"cos" => {
let result = pool.func("sin", vec![var]);
log.push(RewriteStep::simple("int_cos", expr, result));
Ok(result)
}
"exp" => {
let result = pool.func("exp", vec![var]);
log.push(RewriteStep::simple("int_exp", expr, result));
Ok(result)
}
"log" => {
let log_x = pool.func("log", vec![var]);
let x_log_x = pool.mul(vec![var, log_x]);
let neg_x = pool.mul(vec![pool.integer(-1_i32), var]);
let result = pool.add(vec![x_log_x, neg_x]);
log.push(RewriteStep::simple("int_log", expr, result));
Ok(result)
}
"sqrt" => Err(IntegrationError::NotImplemented(
"∫ sqrt(x) — not in the supported Risch subset".to_string(),
)),
other => Err(IntegrationError::NotImplemented(format!("∫ {other}(x)"))),
}
}
Node::Unknown => Err(IntegrationError::NotImplemented(
"unsupported expression node".to_string(),
)),
}
}
pub fn integrate(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<ExprId>, IntegrationError> {
let has_algebraic = super::algebraic::contains_algebraic_subterm(expr, pool)
|| super::algebraic::contains_algebraic_func_of_var(expr, var, pool);
let has_transcendental = super::risch::contains_risch_form(expr, var, pool);
if has_algebraic && !has_transcendental {
return super::algebraic::integrate_algebraic(expr, var, pool);
}
if has_transcendental {
return super::risch::integrate_risch(expr, var, pool);
}
if let Some(result) = try_log_derivative(expr, var, pool) {
let simplified = simplify(result, pool);
let mut rlog = DerivationLog::new();
rlog.push(RewriteStep::simple(
"log_derivative_rule",
expr,
simplified.value,
));
let final_log = rlog.merge(simplified.log);
return Ok(DerivedExpr::with_log(simplified.value, final_log));
}
if let Some(reason) = known_nonelementary(expr, var, pool) {
return Err(IntegrationError::NonElementary(reason));
}
integrate_inner(expr, var, pool, 0)
}
fn integrate_inner(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
depth: u32,
) -> Result<DerivedExpr<ExprId>, IntegrationError> {
let mut log = DerivationLog::new();
match integrate_raw(expr, var, pool, &mut log) {
Ok(raw) => {
let simplified = simplify(raw, pool);
let final_log = log.merge(simplified.log);
Ok(DerivedExpr::with_log(simplified.value, final_log))
}
Err(IntegrationError::NotImplemented(msg)) => {
if let Some(result) =
super::risch::rational_integrate::try_integrate_rational(expr, var, pool)
{
let simplified = simplify(result, pool);
let mut rlog = DerivationLog::new();
rlog.push(RewriteStep::simple(
"rothstein_trager",
expr,
simplified.value,
));
let final_log = rlog.merge(simplified.log);
return Ok(DerivedExpr::with_log(simplified.value, final_log));
}
if let Some(result) = try_u_substitution(expr, var, pool, depth) {
let simplified = simplify(result, pool);
let mut rlog = DerivationLog::new();
rlog.push(RewriteStep::simple(
"u_substitution",
expr,
simplified.value,
));
let final_log = rlog.merge(simplified.log);
return Ok(DerivedExpr::with_log(simplified.value, final_log));
}
Err(IntegrationError::NotImplemented(msg))
}
Err(other) => Err(other),
}
}
pub fn integrate_definite(
expr: ExprId,
var: ExprId,
lower: ExprId,
upper: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<ExprId>, IntegrationError> {
let antideriv = integrate(expr, var, pool)?;
let f = antideriv.value;
let f_upper = subs_var(f, var, upper, pool);
let f_lower = subs_var(f, var, lower, pool);
let neg_lower = pool.mul(vec![pool.integer(-1_i32), f_lower]);
let diff_expr = pool.add(vec![f_upper, neg_lower]);
let simplified = simplify(diff_expr, pool);
let mut log = DerivationLog::new();
log.push(RewriteStep::simple(
"fundamental_theorem_of_calculus",
expr,
simplified.value,
));
let final_log = antideriv.log.merge(log).merge(simplified.log);
Ok(DerivedExpr::with_log(simplified.value, final_log))
}
fn subs_var(expr: ExprId, var: ExprId, value: ExprId, pool: &ExprPool) -> ExprId {
let mut map = HashMap::new();
map.insert(var, value);
crate::kernel::subs(expr, &map, pool)
}
const U_SUBST_MAX_DEPTH: u32 = 3;
const U_SUBST_MAX_CANDIDATES: usize = 12;
fn try_u_substitution(expr: ExprId, var: ExprId, pool: &ExprPool, depth: u32) -> Option<ExprId> {
if depth >= U_SUBST_MAX_DEPTH {
return None;
}
let mut variants = vec![expr];
let expanded = trig_expand(expr, pool);
if expanded != expr {
variants.push(expanded);
}
for &form in &variants {
let candidates = collect_usub_candidates(form, var, pool);
for g in candidates.into_iter().take(U_SUBST_MAX_CANDIDATES) {
if g == var || is_free_of(g, var, pool) {
continue;
}
let Ok(dg_raw) = crate::diff::diff(g, var, pool) else {
continue;
};
let dg = simplify(dg_raw.value, pool).value;
if is_zero(dg, pool) {
continue;
}
let inv = reciprocal(dg, pool);
let quotient = simplify(pool.mul(vec![form, inv]), pool).value;
let u = pool.symbol("__usub_u", crate::kernel::Domain::Real);
let mut fwd = HashMap::new();
fwd.insert(g, u);
let replaced = crate::kernel::subs(quotient, &fwd, pool);
if !is_free_of(replaced, var, pool) {
continue;
}
let Ok(inner) = integrate_inner(replaced, u, pool, depth + 1) else {
continue;
};
let mut back = HashMap::new();
back.insert(u, g);
let result = simplify(crate::kernel::subs(inner.value, &back, pool), pool).value;
if verify_antiderivative(result, expr, var, pool) {
return Some(result);
}
}
}
None
}
fn trig_expand(expr: ExprId, pool: &ExprPool) -> ExprId {
use crate::simplify::engine::{simplify_with, SimplifyConfig};
use crate::simplify::rulesets::trig_rules;
let rules = trig_rules();
simplify_with(expr, pool, &rules, SimplifyConfig::default()).value
}
fn reciprocal(expr: ExprId, pool: &ExprPool) -> ExprId {
let neg_one = pool.integer(-1_i32);
match pool.get(expr) {
ExprData::Mul(args) => {
let inv_args: Vec<ExprId> = args.iter().map(|&a| reciprocal(a, pool)).collect();
pool.mul(inv_args)
}
ExprData::Pow { base, exp } => {
let neg_exp = pool.mul(vec![neg_one, exp]);
pool.pow(base, neg_exp)
}
_ => pool.pow(expr, neg_one),
}
}
fn collect_usub_candidates(expr: ExprId, var: ExprId, pool: &ExprPool) -> Vec<ExprId> {
let mut out: Vec<ExprId> = Vec::new();
let mut seen: std::collections::HashSet<ExprId> = std::collections::HashSet::new();
let mut factor_candidates: Vec<ExprId> = Vec::new();
if let ExprData::Mul(args) = pool.get(expr) {
for &a in &args {
if a != var && !is_free_of(a, var, pool) && seen.insert(a) {
factor_candidates.push(a);
}
}
}
collect_usub_inner(expr, var, pool, &mut out, &mut seen);
out.sort_by_key(|&c| std::cmp::Reverse(node_count(c, pool)));
out.extend(factor_candidates);
out
}
fn collect_usub_inner(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
out: &mut Vec<ExprId>,
seen: &mut std::collections::HashSet<ExprId>,
) {
match pool.get(expr) {
ExprData::Func { args, .. } => {
for a in args {
if a != var && !is_free_of(a, var, pool) && seen.insert(a) {
out.push(a);
}
collect_usub_inner(a, var, pool, out, seen);
}
}
ExprData::Pow { base, exp } => {
if base != var && !is_free_of(base, var, pool) && seen.insert(base) {
out.push(base);
}
collect_usub_inner(base, var, pool, out, seen);
collect_usub_inner(exp, var, pool, out, seen);
}
ExprData::Add(args) | ExprData::Mul(args) => {
for a in args {
collect_usub_inner(a, var, pool, out, seen);
}
}
_ => {}
}
}
fn node_count(expr: ExprId, pool: &ExprPool) -> usize {
1 + pool.with(expr, |data| match data {
ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
args.iter().map(|&a| node_count(a, pool)).sum::<usize>()
}
ExprData::Pow { base, exp } => node_count(*base, pool) + node_count(*exp, pool),
_ => 0,
})
}
fn is_zero(expr: ExprId, pool: &ExprPool) -> bool {
as_integer(expr, pool) == Some(0)
}
fn verify_antiderivative(
candidate: ExprId,
integrand: ExprId,
var: ExprId,
pool: &ExprPool,
) -> bool {
let Ok(d_raw) = crate::diff::diff(candidate, var, pool) else {
return false;
};
let d = simplify(d_raw.value, pool).value;
let neg = pool.mul(vec![pool.integer(-1_i32), integrand]);
let diff_expr = simplify(pool.add(vec![d, neg]), pool).value;
if is_zero(diff_expr, pool) {
return true;
}
let samples = [0.3719_f64, 0.9137, 1.4231, 2.1719, 2.8123, 3.6411];
let mut checked = 0_usize;
for &xv in &samples {
let mut env = HashMap::new();
env.insert(var, xv);
let (Some(dv), Some(fv)) = (
crate::jit::eval_interp(d, &env, pool),
crate::jit::eval_interp(integrand, &env, pool),
) else {
return false;
};
if !dv.is_finite() || !fv.is_finite() {
continue; }
let tol = 1e-7 * (1.0 + dv.abs().max(fv.abs()));
if (dv - fv).abs() > tol {
return false;
}
checked += 1;
}
checked >= 2
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diff::diff;
use crate::kernel::{Domain, ExprPool};
use crate::poly::UniPoly;
fn p() -> ExprPool {
ExprPool::new()
}
fn coeffs_equal(a: ExprId, b: ExprId, x: ExprId, pool: &ExprPool) -> bool {
let ap = UniPoly::from_symbolic(a, x, pool);
let bp = UniPoly::from_symbolic(b, x, pool);
match (ap, bp) {
(Ok(a), Ok(b)) => a.coefficients_i64() == b.coefficients_i64(),
_ => a == b,
}
}
fn verify(expr: ExprId, x: ExprId, pool: &ExprPool) {
let integral = integrate(expr, x, pool).unwrap();
let deriv = diff(integral.value, x, pool).unwrap();
assert!(
coeffs_equal(deriv.value, expr, x, pool),
"diff(integrate(f)) ≠ f for f = {}",
pool.display(expr)
);
}
#[test]
fn integrate_constant() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.integer(5_i32), x, &pool).unwrap();
let expected = pool.mul(vec![pool.integer(5_i32), x]);
assert!(coeffs_equal(r.value, expected, x, &pool));
}
#[test]
fn integrate_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
verify(x, x, &pool);
}
#[test]
fn integrate_x_squared() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
verify(x2, x, &pool);
}
#[test]
fn integrate_polynomial() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.pow(x, pool.integer(2_i32)),
pool.mul(vec![pool.integer(2_i32), x]),
]);
let r = integrate(expr, x, &pool).unwrap();
let d = diff(r.value, x, &pool).unwrap();
assert!(
coeffs_equal(d.value, expr, x, &pool),
"diff(∫(x²+2x)) ≠ x²+2x; got {}",
pool.display(d.value)
);
}
#[test]
fn integrate_one_over_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x_inv = pool.pow(x, pool.integer(-1_i32));
let r = integrate(x_inv, x, &pool).unwrap();
assert_eq!(r.value, pool.func("log", vec![x]));
assert!(r.log.steps().iter().any(|s| s.rule_name == "log_rule"));
}
#[test]
fn integrate_sin() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sin_x = pool.func("sin", vec![x]);
let r = integrate(sin_x, x, &pool).unwrap();
let neg_one = pool.integer(-1_i32);
let expected = pool.mul(vec![neg_one, pool.func("cos", vec![x])]);
assert_eq!(r.value, expected);
assert!(r.log.steps().iter().any(|s| s.rule_name == "int_sin"));
}
#[test]
fn integrate_cos() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.func("cos", vec![x]), x, &pool).unwrap();
assert_eq!(r.value, pool.func("sin", vec![x]));
}
#[test]
fn integrate_exp() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.func("exp", vec![x]), x, &pool).unwrap();
assert_eq!(r.value, pool.func("exp", vec![x]));
}
#[test]
fn integrate_constant_multiple() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![pool.integer(3_i32), pool.pow(x, pool.integer(2_i32))]);
verify(expr, x, &pool);
}
#[test]
fn integrate_not_implemented() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let err = integrate(pool.func("sin", vec![x2]), x, &pool);
assert!(matches!(err, Err(IntegrationError::NotImplemented(_))));
}
#[test]
fn integrate_log_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let r = integrate(log_x, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_log"),
"should have logged int_log step"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("log"),
"result should contain log: {result_str}"
);
}
#[test]
fn integrate_exp_linear_arg() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let two = pool.integer(2_i32);
let two_x = pool.mul(vec![two, x]);
let expr = pool.func("exp", vec![two_x]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log
.steps()
.iter()
.any(|s| s.rule_name == "int_exp_linear"),
"should fire int_exp_linear"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("exp"),
"result should contain exp: {result_str}"
);
}
#[test]
fn integrate_x_times_exp_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![x, pool.func("exp", vec![x])]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
"should fire int_x_exp"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("exp"),
"result should contain exp: {result_str}"
);
}
#[test]
fn integrate_const_times_x_times_exp_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let three = pool.integer(3_i32);
let expr = pool.mul(vec![three, x, pool.func("exp", vec![x])]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
"should fire int_x_exp for 3*x*exp(x)"
);
}
#[test]
fn integrate_one_over_linear() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let two = pool.integer(2_i32);
let three = pool.integer(3_i32);
let linear = pool.add(vec![pool.mul(vec![two, x]), three]);
let expr = pool.pow(linear, pool.integer(-1_i32));
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log
.steps()
.iter()
.any(|s| s.rule_name == "int_linear_inv"),
"should fire int_linear_inv"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("log"),
"result should contain log: {result_str}"
);
}
#[test]
fn integrate_x_cubed_plus_2x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.pow(x, pool.integer(3_i32)),
pool.mul(vec![pool.integer(2_i32), x]),
]);
verify(expr, x, &pool);
}
#[test]
fn integrate_derivation_log_nonempty() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
assert!(
!r.log.is_empty(),
"integration should produce a derivation log"
);
assert!(r.log.steps().iter().any(|s| s.rule_name == "power_rule"));
}
#[test]
fn integrate_sqrt_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sqrt_x = pool.func("sqrt", vec![x]);
let result = integrate(sqrt_x, x, &pool);
match &result {
Ok(r) => println!("sqrt(x) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ sqrt(x) dx failed: {:?}", result);
}
#[test]
fn integrate_inv_sqrt_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sqrt_x = pool.func("sqrt", vec![x]);
let inv_sqrt_x = pool.pow(sqrt_x, pool.integer(-1_i32));
let result = integrate(inv_sqrt_x, x, &pool);
match &result {
Ok(r) => println!("1/sqrt(x) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ 1/sqrt(x) dx failed: {:?}", result);
}
#[test]
fn integrate_sqrt_x2_plus_1() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let p_expr = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let sqrt_p = pool.func("sqrt", vec![p_expr]);
let result = integrate(sqrt_p, x, &pool);
match &result {
Ok(r) => println!("sqrt(x^2+1) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ sqrt(x²+1) dx failed: {:?}", result);
}
fn over(pool: &ExprPool, num: ExprId, denom: ExprId) -> ExprId {
let inv = pool.pow(denom, pool.integer(-1_i32));
pool.mul(vec![num, inv])
}
#[test]
fn sin_over_x_is_nonelementary_not_crash() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = over(&pool, pool.func("sin", vec![x]), x);
let r = integrate(f, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ sin(x)/x dx should be NonElementary; got {r:?}"
);
}
#[test]
fn exp_over_x_is_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = over(&pool, pool.func("exp", vec![x]), x);
let r = integrate(f, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ exp(x)/x dx should be NonElementary; got {r:?}"
);
}
#[test]
fn cos_over_linear_is_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let denom = pool.add(vec![
pool.mul(vec![pool.integer(2_i32), x]),
pool.integer(1_i32),
]);
let f = over(&pool, pool.func("cos", vec![x]), denom);
let r = integrate(f, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ cos(x)/(2x+1) dx should be NonElementary; got {r:?}"
);
}
#[test]
fn one_over_log_is_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.pow(pool.func("log", vec![x]), pool.integer(-1_i32));
let r = integrate(f, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ 1/log(x) dx should be NonElementary; got {r:?}"
);
}
#[test]
fn exp_over_x_squared_is_nonelementary() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = over(&pool, pool.func("exp", vec![x]), x2);
let r = integrate(f, x, &pool);
assert!(
matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ exp(x)/x² dx should be NonElementary; got {r:?}"
);
}
#[test]
fn log_over_x_is_elementary_not_misclassified() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = over(&pool, pool.func("log", vec![x]), x);
let r = integrate(f, x, &pool);
assert!(
!matches!(r, Err(IntegrationError::NonElementary(_))),
"∫ log(x)/x dx must not be flagged NonElementary; got {r:?}"
);
}
#[test]
fn x_times_sin_over_x_not_flagged() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let num = pool.mul(vec![x, pool.func("sin", vec![x])]);
let f = over(&pool, num, x);
assert!(
known_nonelementary(f, x, &pool).is_none(),
"x·sin(x)/x must not be certified NonElementary"
);
}
#[test]
fn rational_integration_via_fallback() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let den = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(-1_i32)]);
let f = pool.pow(den, pool.integer(-1_i32));
let r = integrate(f, x, &pool);
assert!(
r.is_ok(),
"∫ 1/(x²−1) dx should integrate via fallback; got {r:?}"
);
assert!(
pool.display(r.unwrap().value).to_string().contains("log"),
"expected log terms in the antiderivative"
);
}
#[test]
fn power_rule_not_regressed_by_fallback() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.pow(x, pool.integer(-2_i32));
let r = integrate(f, x, &pool).unwrap();
let d = diff(r.value, x, &pool).unwrap();
for &xv in &[1.5_f64, 2.5] {
let lhs = eval_simple(d.value, x, xv, &pool);
assert!(
(lhs - xv.powi(-2)).abs() < 1e-9,
"power rule regressed at {xv}"
);
}
}
#[test]
fn arctan_case_via_fallback() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let den = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let f = pool.pow(den, pool.integer(-1_i32));
let r = integrate(f, x, &pool);
assert!(r.is_ok(), "∫ 1/(x²+1) dx should integrate; got {r:?}");
assert!(pool.display(r.unwrap().value).to_string().contains("atan"));
}
fn eval_simple(expr: ExprId, x: ExprId, xv: f64, pool: &ExprPool) -> f64 {
if expr == x {
return xv;
}
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_simple(a, x, xv, pool)).sum(),
ExprData::Mul(args) => args.iter().map(|&a| eval_simple(a, x, xv, pool)).product(),
ExprData::Pow { base, exp } => {
eval_simple(base, x, xv, pool).powf(eval_simple(exp, x, xv, pool))
}
other => panic!("eval_simple: unsupported {other:?}"),
}
}
#[test]
fn plain_sin_not_flagged() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.func("sin", vec![x]);
assert!(integrate(f, x, &pool).is_ok());
assert!(known_nonelementary(f, x, &pool).is_none());
}
fn eval_log(expr: ExprId, x: ExprId, xv: f64, pool: &ExprPool) -> f64 {
if expr == x {
return xv;
}
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_log(a, x, xv, pool)).sum(),
ExprData::Mul(args) => args.iter().map(|&a| eval_log(a, x, xv, pool)).product(),
ExprData::Pow { base, exp } => {
eval_log(base, x, xv, pool).powf(eval_log(exp, x, xv, pool))
}
ExprData::Func { ref name, ref args } if args.len() == 1 => {
let a = eval_log(args[0], x, xv, pool);
match name.as_str() {
"log" => a.ln(),
other => panic!("eval_log: unsupported func {other}"),
}
}
other => panic!("eval_log: unsupported node {other:?}"),
}
}
fn verify_log(f: ExprId, x: ExprId, pool: &ExprPool) {
let r = integrate(f, x, pool).unwrap_or_else(|e| panic!("expected elementary: {e:?}"));
let d = diff(r.value, x, pool).unwrap();
let ds = simplify(d.value, pool).value;
for &xv in &[1.3_f64, 2.1, 3.4] {
let lhs = eval_log(ds, x, xv, pool);
let rhs = eval_log(f, x, xv, pool);
assert!(
(lhs - rhs).abs() < 1e-7,
"d/dx F ≠ f at x={xv}: {lhs} vs {rhs}\n F = {}",
pool.display(r.value)
);
}
}
#[test]
fn log_derivative_one_over_x_log_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let logx = pool.func("log", vec![x]);
let f = pool.mul(vec![
pool.pow(x, pool.integer(-1)),
pool.pow(logx, pool.integer(-1)),
]);
verify_log(f, x, &pool);
let r = integrate(f, x, &pool).unwrap();
assert!(
pool.display(r.value).to_string().contains("log(log"),
"expected log(log(x)); got {}",
pool.display(r.value)
);
}
#[test]
fn log_derivative_negative_powers() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let logx = pool.func("log", vec![x]);
for m in [2_i32, 3] {
let f = pool.mul(vec![
pool.pow(x, pool.integer(-1)),
pool.pow(logx, pool.integer(-m)),
]);
verify_log(f, x, &pool);
}
}
#[test]
fn log_derivative_polynomial_argument() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let h = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let logh = pool.func("log", vec![h]);
let dh_over_h = pool.mul(vec![pool.integer(2_i32), x, pool.pow(h, pool.integer(-1))]);
let f = pool.mul(vec![dh_over_h, pool.pow(logh, pool.integer(-1))]);
verify_log(f, x, &pool);
}
#[test]
fn log_derivative_does_not_misfire() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let logx = pool.func("log", vec![x]);
let f = pool.pow(logx, pool.integer(-1));
assert!(
matches!(
integrate(f, x, &pool),
Err(IntegrationError::NonElementary(_))
),
"∫ 1/log(x) dx must remain NonElementary"
);
let f = pool.mul(vec![x, pool.pow(logx, pool.integer(-1))]);
assert!(
integrate(f, x, &pool).is_err(),
"∫ x/log(x) dx must not be (mis)integrated by the log-derivative rule"
);
}
fn eval_num(expr: ExprId, pool: &ExprPool) -> f64 {
match pool.get(expr) {
ExprData::Integer(n) => n.0.to_f64(),
ExprData::Rational(r) => r.0.to_f64(),
ExprData::Add(args) => args.iter().map(|&a| eval_num(a, pool)).sum(),
ExprData::Mul(args) => args.iter().map(|&a| eval_num(a, pool)).product(),
ExprData::Pow { base, exp } => {
let b = eval_num(base, pool);
if let ExprData::Integer(n) = pool.get(exp) {
if let Some(k) = n.0.to_i32() {
return b.powi(k);
}
}
b.powf(eval_num(exp, pool))
}
ExprData::Func { ref name, ref args } if args.len() == 1 => {
let a = eval_num(args[0], pool);
match name.as_str() {
"log" => a.ln(),
"atan" => a.atan(),
"sqrt" => a.sqrt(),
other => panic!("eval_num: unsupported func {other}"),
}
}
other => panic!("eval_num: unsupported {other:?}"),
}
}
fn assert_num(result: ExprId, expected: f64, pool: &ExprPool) {
let got = eval_num(result, pool);
assert!(
(got - expected).abs() < 1e-9,
"definite integral = {got}, expected {expected}"
);
}
#[test]
fn definite_x_squared_0_1() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.pow(x, pool.integer(2_i32));
let r = integrate_definite(f, x, pool.integer(0_i32), pool.integer(1_i32), &pool).unwrap();
assert_num(r.value, 1.0 / 3.0, &pool);
}
#[test]
fn definite_two_x_0_1() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.mul(vec![pool.integer(2_i32), x]);
let r = integrate_definite(f, x, pool.integer(0_i32), pool.integer(1_i32), &pool).unwrap();
assert_num(r.value, 1.0, &pool);
}
#[test]
fn definite_one_over_x_1_2() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.pow(x, pool.integer(-1_i32));
let r = integrate_definite(f, x, pool.integer(1_i32), pool.integer(2_i32), &pool).unwrap();
assert_num(r.value, 2.0_f64.ln(), &pool);
}
#[test]
fn definite_sin_arctan_bounds() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let den = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let f = pool.pow(den, pool.integer(-1_i32));
let r = integrate_definite(f, x, pool.integer(0_i32), pool.integer(1_i32), &pool).unwrap();
assert_num(r.value, std::f64::consts::FRAC_PI_4, &pool);
}
#[test]
fn definite_nonelementary_propagates() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.func("exp", vec![pool.pow(x, pool.integer(2_i32))]);
let r = integrate_definite(f, x, pool.integer(0_i32), pool.integer(1_i32), &pool);
assert!(
r.is_err(),
"∫_0^1 exp(x²) dx must propagate the integration error, got {r:?}"
);
}
#[test]
fn definite_unsupported_propagates() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.mul(vec![
pool.func("sin", vec![x]),
pool.pow(x, pool.integer(-1_i32)),
]);
let r = integrate_definite(f, x, pool.integer(1_i32), pool.integer(2_i32), &pool);
assert!(r.is_err(), "∫ sin(x)/x dx must error in definite form");
}
fn verify_numeric(integrand: ExprId, x: ExprId, pool: &ExprPool) {
let integral = integrate(integrand, x, pool)
.unwrap_or_else(|e| panic!("integrate failed for {}: {e}", pool.display(integrand)));
let deriv = diff(integral.value, x, pool).unwrap();
let d = simplify(deriv.value, pool).value;
let samples = [0.41_f64, 0.93, 1.37, 2.11, 2.83];
let mut checked = 0;
for &xv in &samples {
let mut env = std::collections::HashMap::new();
env.insert(x, xv);
let (Some(dv), Some(fv)) = (
crate::jit::eval_interp(d, &env, pool),
crate::jit::eval_interp(integrand, &env, pool),
) else {
continue;
};
if !dv.is_finite() || !fv.is_finite() {
continue;
}
assert!(
(dv - fv).abs() <= 1e-7 * (1.0 + dv.abs().max(fv.abs())),
"diff(∫f) ≠ f at x={xv}: got {dv}, want {fv}, for f = {}, F = {}",
pool.display(integrand),
pool.display(integral.value),
);
checked += 1;
}
assert!(checked >= 2, "no usable samples to verify antiderivative");
}
#[test]
fn usub_x_sin_x2() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.mul(vec![x, pool.func("sin", vec![x2])]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_2x_exp_x2() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.mul(vec![pool.integer(2_i32), x, pool.func("exp", vec![x2])]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_x_exp_x2() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.mul(vec![x, pool.func("exp", vec![x2])]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_lnx_over_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.mul(vec![
pool.func("log", vec![x]),
pool.pow(x, pool.integer(-1_i32)),
]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_tan_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let f = pool.func("tan", vec![x]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_exp_cos_exp() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let ex = pool.func("exp", vec![x]);
let f = pool.mul(vec![ex, pool.func("cos", vec![ex])]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_x_cos_x2_plus_1() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let inner = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let f = pool.mul(vec![x, pool.func("cos", vec![inner])]);
verify_numeric(f, x, &pool);
}
#[test]
fn usub_nonelementary_still_errors() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let f = pool.func("exp", vec![x2]);
let r = integrate(f, x, &pool);
assert!(
r.is_err(),
"∫ e^(x²) dx must error, got {:?}",
r.map(|d| pool.display(d.value))
);
}
#[test]
fn usub_does_not_disturb_basic_rules() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sinx = pool.func("sin", vec![x]);
verify_numeric(sinx, x, &pool);
let x2 = pool.pow(x, pool.integer(2_i32));
verify(x2, x, &pool);
let ex = pool.func("exp", vec![x]);
verify_numeric(ex, x, &pool);
let inv = pool.pow(x, pool.integer(-1_i32));
verify_numeric(inv, x, &pool);
}
}