use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
use crate::kernel::{ExprData, ExprId, ExprPool};
use crate::simplify::engine::simplify;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IntegrationError {
NotImplemented(String),
DivisionByZero,
UnsupportedExtensionDegree(u32),
NonElementary(String),
}
impl fmt::Display for IntegrationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IntegrationError::NotImplemented(msg) => write!(f, "integrate: not implemented: {msg}"),
IntegrationError::DivisionByZero => write!(f, "integrate: division by zero"),
IntegrationError::UnsupportedExtensionDegree(q) => write!(
f,
"integrate: algebraic extension of degree {q} is not supported \
(v1.1 supports only degree-2 / sqrt extensions)"
),
IntegrationError::NonElementary(msg) => {
write!(f, "integrate: no elementary antiderivative exists: {msg}")
}
}
}
}
impl std::error::Error for IntegrationError {}
impl IntegrationError {
pub fn remediation(&self) -> Option<&'static str> {
match self {
IntegrationError::NotImplemented(_) => Some(
"only power, linearity, sin/cos/exp rules and algebraic (sqrt) rules \
are implemented; use a numeric integrator for arbitrary functions",
),
IntegrationError::DivisionByZero => None,
IntegrationError::UnsupportedExtensionDegree(_) => Some(
"v1.1 supports sqrt(P(x)) only; higher-degree radicals (cbrt, nth-root) \
are planned for v2.0",
),
IntegrationError::NonElementary(_) => Some(
"this integrand has no closed-form antiderivative in terms of elementary \
functions; use a numeric integrator or elliptic-integral library",
),
}
}
pub fn span(&self) -> Option<(usize, usize)> {
None
}
}
impl crate::errors::AlkahestError for IntegrationError {
fn code(&self) -> &'static str {
match self {
IntegrationError::NotImplemented(_) => "E-INT-001",
IntegrationError::DivisionByZero => "E-INT-002",
IntegrationError::UnsupportedExtensionDegree(_) => "E-INT-003",
IntegrationError::NonElementary(_) => "E-INT-004",
}
}
fn remediation(&self) -> Option<&'static str> {
IntegrationError::remediation(self)
}
}
fn as_integer(expr: ExprId, pool: &ExprPool) -> Option<i64> {
pool.with(expr, |data| match data {
ExprData::Integer(n) => n.0.to_i64(),
_ => None,
})
}
fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
if expr == var {
return false;
}
let children: Vec<ExprId> = pool.with(expr, |data| match data {
ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
ExprData::Pow { base, exp } => vec![*base, *exp],
ExprData::Func { args, .. } => args.clone(),
_ => vec![],
});
children.into_iter().all(|c| is_free_of(c, var, pool))
}
fn is_linear_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
if expr == var {
return Some((pool.integer(1_i32), pool.integer(0_i32)));
}
match pool.get(expr) {
ExprData::Mul(args) => {
let var_pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != var_pos)
.map(|(_, &a)| a)
.collect();
let a = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
if is_free_of(a, var, pool) {
Some((a, pool.integer(0_i32)))
} else {
None
}
}
ExprData::Add(args) => {
let mut a_opt: Option<ExprId> = None;
let mut b_parts: Vec<ExprId> = vec![];
for &arg in &args {
if arg == var {
if a_opt.is_some() {
return None;
}
a_opt = Some(pool.integer(1_i32));
} else {
match pool.get(arg) {
ExprData::Mul(margs) => {
let vpos = margs.iter().position(|&m| m == var);
if let Some(vp) = vpos {
if a_opt.is_some() {
return None;
}
let others: Vec<ExprId> = margs
.iter()
.enumerate()
.filter(|&(i, _)| i != vp)
.map(|(_, &m)| m)
.collect();
let coeff = match others.len() {
0 => pool.integer(1_i32),
1 => others[0],
_ => pool.mul(others),
};
if is_free_of(coeff, var, pool) {
a_opt = Some(coeff);
} else {
b_parts.push(arg);
}
} else if is_free_of(arg, var, pool) {
b_parts.push(arg);
} else {
return None;
}
}
_ if is_free_of(arg, var, pool) => b_parts.push(arg),
_ => return None,
}
}
}
let a = a_opt?;
let b = match b_parts.len() {
0 => pool.integer(0_i32),
1 => b_parts[0],
_ => pool.add(b_parts),
};
Some((a, b))
}
_ => None,
}
}
fn try_x_times_func(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Option<ExprId> {
let args = match pool.get(expr) {
ExprData::Mul(v) => v,
_ => return None,
};
let exp_pos = args.iter().position(|&a| {
pool.with(a, |d| match d {
ExprData::Func { name, args } => name == "exp" && args.len() == 1 && args[0] == var,
_ => false,
})
})?;
let var_pos = args.iter().position(|&a| a == var)?;
let others: Vec<ExprId> = args
.iter()
.enumerate()
.filter(|&(i, _)| i != exp_pos && i != var_pos)
.map(|(_, &a)| a)
.collect();
if !others.iter().all(|&a| is_free_of(a, var, pool)) {
return None;
}
let exp_x = args[exp_pos];
let x_minus_1 = pool.add(vec![var, pool.integer(-1_i32)]);
let mut factors = vec![exp_x, x_minus_1];
factors.extend_from_slice(&others);
let result = pool.mul(factors);
log.push(RewriteStep::simple("int_x_exp", expr, result));
Some(result)
}
pub(crate) fn integrate_raw(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
log: &mut DerivationLog,
) -> Result<ExprId, IntegrationError> {
if let Some(result) = try_x_times_func(expr, var, pool, log) {
return Ok(result);
}
enum Node {
IsVar,
Constant,
Add(Vec<ExprId>),
Mul(Vec<ExprId>),
Pow { base: ExprId, exp: ExprId },
Func { name: String, arg: ExprId },
Unknown,
}
let node = pool.with(expr, |data| match data {
ExprData::Symbol { .. } if expr == var => Node::IsVar,
ExprData::Symbol { .. }
| ExprData::Integer(_)
| ExprData::Rational(_)
| ExprData::Float(_) => Node::Constant,
ExprData::Add(args) => Node::Add(args.clone()),
ExprData::Mul(args) => Node::Mul(args.clone()),
ExprData::Pow { base, exp } => Node::Pow {
base: *base,
exp: *exp,
},
ExprData::Func { name, args } if args.len() == 1 => Node::Func {
name: name.clone(),
arg: args[0],
},
_ => Node::Unknown,
});
match node {
Node::IsVar => {
let two = pool.integer(2_i32);
let inv_two = pool.pow(two, pool.integer(-1_i32));
let result = pool.mul(vec![pool.pow(var, two), inv_two]);
log.push(RewriteStep::simple("power_rule", expr, result));
Ok(result)
}
Node::Constant => {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
Ok(result)
}
Node::Add(args) => {
let mut int_args = Vec::with_capacity(args.len());
for a in &args {
let ia = integrate_raw(*a, var, pool, log)?;
int_args.push(ia);
}
let result = pool.add(int_args);
log.push(RewriteStep::simple("sum_rule", expr, result));
Ok(result)
}
Node::Mul(args) => {
let (consts, non_consts): (Vec<ExprId>, Vec<ExprId>) =
args.iter().partition(|&&a| is_free_of(a, var, pool));
if non_consts.is_empty() {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
let inner = match non_consts.len() {
1 => non_consts[0],
_ => pool.mul(non_consts.clone()),
};
let const_factor = match consts.len() {
0 => None,
1 => Some(consts[0]),
_ => Some(pool.mul(consts.clone())),
};
let int_inner = integrate_raw(inner, var, pool, log)?;
let result = match const_factor {
None => int_inner,
Some(c) => {
let r = pool.mul(vec![c, int_inner]);
log.push(RewriteStep::simple("constant_multiple_rule", expr, r));
r
}
};
Ok(result)
}
Node::Pow { base, exp } => {
let n_opt = as_integer(exp, pool);
if let Some(n) = n_opt {
if base == var {
if n == -1 {
let result = pool.func("log", vec![var]);
log.push(RewriteStep::simple("log_rule", expr, result));
return Ok(result);
}
let np1 = pool.integer(n + 1);
let inv_np1 = pool.pow(np1, pool.integer(-1_i32));
let result = pool.mul(vec![pool.pow(var, np1), inv_np1]);
log.push(RewriteStep::simple("power_rule", expr, result));
return Ok(result);
}
if n == -1 {
if let Some((a, _b)) = is_linear_in(base, var, pool) {
let log_base = pool.func("log", vec![base]);
let a_inv = pool.pow(a, pool.integer(-1_i32));
let result = pool.mul(vec![a_inv, log_base]);
log.push(RewriteStep::simple("int_linear_inv", expr, result));
return Ok(result);
}
}
if is_free_of(base, var, pool) {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
}
Err(IntegrationError::NotImplemented(
"∫ (expr)^(exp) where base or exp is non-trivial".to_string(),
))
}
Node::Func { name, arg } => {
if arg != var {
if is_free_of(arg, var, pool) {
let result = pool.mul(vec![expr, var]);
log.push(RewriteStep::simple("constant_rule", expr, result));
return Ok(result);
}
if name == "exp" {
if let Some((a, _b)) = is_linear_in(arg, var, pool) {
let exp_expr = pool.func("exp", vec![arg]);
let a_inv = pool.pow(a, pool.integer(-1_i32));
let result = pool.mul(vec![a_inv, exp_expr]);
log.push(RewriteStep::simple("int_exp_linear", expr, result));
return Ok(result);
}
}
return Err(IntegrationError::NotImplemented(format!(
"∫ {name}(non-trivial arg) — chain rule not implemented"
)));
}
match name.as_str() {
"sin" => {
let neg_one = pool.integer(-1_i32);
let result = pool.mul(vec![neg_one, pool.func("cos", vec![var])]);
log.push(RewriteStep::simple("int_sin", expr, result));
Ok(result)
}
"cos" => {
let result = pool.func("sin", vec![var]);
log.push(RewriteStep::simple("int_cos", expr, result));
Ok(result)
}
"exp" => {
let result = pool.func("exp", vec![var]);
log.push(RewriteStep::simple("int_exp", expr, result));
Ok(result)
}
"log" => {
let log_x = pool.func("log", vec![var]);
let x_log_x = pool.mul(vec![var, log_x]);
let neg_x = pool.mul(vec![pool.integer(-1_i32), var]);
let result = pool.add(vec![x_log_x, neg_x]);
log.push(RewriteStep::simple("int_log", expr, result));
Ok(result)
}
"sqrt" => Err(IntegrationError::NotImplemented(
"∫ sqrt(x) — not in the supported Risch subset".to_string(),
)),
other => Err(IntegrationError::NotImplemented(format!("∫ {other}(x)"))),
}
}
Node::Unknown => Err(IntegrationError::NotImplemented(
"unsupported expression node".to_string(),
)),
}
}
pub fn integrate(
expr: ExprId,
var: ExprId,
pool: &ExprPool,
) -> Result<DerivedExpr<ExprId>, IntegrationError> {
if super::algebraic::contains_algebraic_subterm(expr, pool) {
return super::algebraic::integrate_algebraic(expr, var, pool);
}
let mut log = DerivationLog::new();
let raw = integrate_raw(expr, var, pool, &mut log)?;
let simplified = simplify(raw, pool);
let final_log = log.merge(simplified.log);
Ok(DerivedExpr::with_log(simplified.value, final_log))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diff::diff;
use crate::kernel::{Domain, ExprPool};
use crate::poly::UniPoly;
fn p() -> ExprPool {
ExprPool::new()
}
fn coeffs_equal(a: ExprId, b: ExprId, x: ExprId, pool: &ExprPool) -> bool {
let ap = UniPoly::from_symbolic(a, x, pool);
let bp = UniPoly::from_symbolic(b, x, pool);
match (ap, bp) {
(Ok(a), Ok(b)) => a.coefficients_i64() == b.coefficients_i64(),
_ => a == b,
}
}
fn verify(expr: ExprId, x: ExprId, pool: &ExprPool) {
let integral = integrate(expr, x, pool).unwrap();
let deriv = diff(integral.value, x, pool).unwrap();
assert!(
coeffs_equal(deriv.value, expr, x, pool),
"diff(integrate(f)) ≠ f for f = {}",
pool.display(expr)
);
}
#[test]
fn integrate_constant() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.integer(5_i32), x, &pool).unwrap();
let expected = pool.mul(vec![pool.integer(5_i32), x]);
assert!(coeffs_equal(r.value, expected, x, &pool));
}
#[test]
fn integrate_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
verify(x, x, &pool);
}
#[test]
fn integrate_x_squared() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
verify(x2, x, &pool);
}
#[test]
fn integrate_polynomial() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.pow(x, pool.integer(2_i32)),
pool.mul(vec![pool.integer(2_i32), x]),
]);
let r = integrate(expr, x, &pool).unwrap();
let d = diff(r.value, x, &pool).unwrap();
assert!(
coeffs_equal(d.value, expr, x, &pool),
"diff(∫(x²+2x)) ≠ x²+2x; got {}",
pool.display(d.value)
);
}
#[test]
fn integrate_one_over_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x_inv = pool.pow(x, pool.integer(-1_i32));
let r = integrate(x_inv, x, &pool).unwrap();
assert_eq!(r.value, pool.func("log", vec![x]));
assert!(r.log.steps().iter().any(|s| s.rule_name == "log_rule"));
}
#[test]
fn integrate_sin() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sin_x = pool.func("sin", vec![x]);
let r = integrate(sin_x, x, &pool).unwrap();
let neg_one = pool.integer(-1_i32);
let expected = pool.mul(vec![neg_one, pool.func("cos", vec![x])]);
assert_eq!(r.value, expected);
assert!(r.log.steps().iter().any(|s| s.rule_name == "int_sin"));
}
#[test]
fn integrate_cos() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.func("cos", vec![x]), x, &pool).unwrap();
assert_eq!(r.value, pool.func("sin", vec![x]));
}
#[test]
fn integrate_exp() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.func("exp", vec![x]), x, &pool).unwrap();
assert_eq!(r.value, pool.func("exp", vec![x]));
}
#[test]
fn integrate_constant_multiple() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![pool.integer(3_i32), pool.pow(x, pool.integer(2_i32))]);
verify(expr, x, &pool);
}
#[test]
fn integrate_not_implemented() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let err = integrate(pool.func("sin", vec![x2]), x, &pool);
assert!(matches!(err, Err(IntegrationError::NotImplemented(_))));
}
#[test]
fn integrate_log_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let log_x = pool.func("log", vec![x]);
let r = integrate(log_x, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_log"),
"should have logged int_log step"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("log"),
"result should contain log: {result_str}"
);
}
#[test]
fn integrate_exp_linear_arg() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let two = pool.integer(2_i32);
let two_x = pool.mul(vec![two, x]);
let expr = pool.func("exp", vec![two_x]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log
.steps()
.iter()
.any(|s| s.rule_name == "int_exp_linear"),
"should fire int_exp_linear"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("exp"),
"result should contain exp: {result_str}"
);
}
#[test]
fn integrate_x_times_exp_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.mul(vec![x, pool.func("exp", vec![x])]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
"should fire int_x_exp"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("exp"),
"result should contain exp: {result_str}"
);
}
#[test]
fn integrate_const_times_x_times_exp_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let three = pool.integer(3_i32);
let expr = pool.mul(vec![three, x, pool.func("exp", vec![x])]);
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
"should fire int_x_exp for 3*x*exp(x)"
);
}
#[test]
fn integrate_one_over_linear() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let two = pool.integer(2_i32);
let three = pool.integer(3_i32);
let linear = pool.add(vec![pool.mul(vec![two, x]), three]);
let expr = pool.pow(linear, pool.integer(-1_i32));
let r = integrate(expr, x, &pool).unwrap();
assert!(
r.log
.steps()
.iter()
.any(|s| s.rule_name == "int_linear_inv"),
"should fire int_linear_inv"
);
let result_str = pool.display(r.value).to_string();
assert!(
result_str.contains("log"),
"result should contain log: {result_str}"
);
}
#[test]
fn integrate_x_cubed_plus_2x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.pow(x, pool.integer(3_i32)),
pool.mul(vec![pool.integer(2_i32), x]),
]);
verify(expr, x, &pool);
}
#[test]
fn integrate_derivation_log_nonempty() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let r = integrate(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
assert!(
!r.log.is_empty(),
"integration should produce a derivation log"
);
assert!(r.log.steps().iter().any(|s| s.rule_name == "power_rule"));
}
#[test]
fn integrate_sqrt_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sqrt_x = pool.func("sqrt", vec![x]);
let result = integrate(sqrt_x, x, &pool);
match &result {
Ok(r) => println!("sqrt(x) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ sqrt(x) dx failed: {:?}", result);
}
#[test]
fn integrate_inv_sqrt_x() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let sqrt_x = pool.func("sqrt", vec![x]);
let inv_sqrt_x = pool.pow(sqrt_x, pool.integer(-1_i32));
let result = integrate(inv_sqrt_x, x, &pool);
match &result {
Ok(r) => println!("1/sqrt(x) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ 1/sqrt(x) dx failed: {:?}", result);
}
#[test]
fn integrate_sqrt_x2_plus_1() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let p_expr = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
let sqrt_p = pool.func("sqrt", vec![p_expr]);
let result = integrate(sqrt_p, x, &pool);
match &result {
Ok(r) => println!("sqrt(x^2+1) integral = {}", pool.display(r.value)),
Err(e) => println!("ERROR: {e}"),
}
assert!(result.is_ok(), "∫ sqrt(x²+1) dx failed: {:?}", result);
}
}