Skip to main content

alkahest_cas/integrate/
engine.rs

1/// Symbolic integration — rule-based Risch subset.
2///
3/// Handles:
4/// - Constants: `∫ c dx = c·x`
5/// - Power rule: `∫ x^n dx = x^(n+1)/(n+1)` (`n ≠ -1`)
6/// - Logarithm: `∫ x^(-1) dx = ln(x)`  (`∫ 1/x dx`)
7/// - Sum rule: `∫ (f + g) dx = ∫f dx + ∫g dx`
8/// - Constant-multiple rule: `∫ c·f dx = c · ∫f dx`
9/// - Known functions: sin, cos, exp, 1/x
10///
11/// Everything else returns `Err(IntegrationError::NotImplemented)`.
12///
13/// The result is simplified with the rule-based simplifier before returning.
14use crate::deriv::log::{DerivationLog, DerivedExpr, RewriteStep};
15use crate::kernel::{ExprData, ExprId, ExprPool};
16use crate::simplify::engine::simplify;
17use std::fmt;
18
19// ---------------------------------------------------------------------------
20// Error type
21// ---------------------------------------------------------------------------
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum IntegrationError {
25    /// The expression is outside the supported Risch subset.
26    NotImplemented(String),
27    /// Division by zero would occur (e.g. power-rule with n=-1 on a non-x base).
28    DivisionByZero,
29    /// The algebraic extension has degree > 2 (v1.1 supports only sqrt / degree-2).
30    UnsupportedExtensionDegree(u32),
31    /// The integrand provably has no elementary antiderivative (e.g. elliptic integrals).
32    NonElementary(String),
33}
34
35impl fmt::Display for IntegrationError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            IntegrationError::NotImplemented(msg) => write!(f, "integrate: not implemented: {msg}"),
39            IntegrationError::DivisionByZero => write!(f, "integrate: division by zero"),
40            IntegrationError::UnsupportedExtensionDegree(q) => write!(
41                f,
42                "integrate: algebraic extension of degree {q} is not supported \
43                 (v1.1 supports only degree-2 / sqrt extensions)"
44            ),
45            IntegrationError::NonElementary(msg) => {
46                write!(f, "integrate: no elementary antiderivative exists: {msg}")
47            }
48        }
49    }
50}
51
52impl std::error::Error for IntegrationError {}
53
54impl IntegrationError {
55    /// A human-readable remediation hint for the user.
56    pub fn remediation(&self) -> Option<&'static str> {
57        match self {
58            IntegrationError::NotImplemented(_) => Some(
59                "only power, linearity, sin/cos/exp rules and algebraic (sqrt) rules \
60                 are implemented; use a numeric integrator for arbitrary functions",
61            ),
62            IntegrationError::DivisionByZero => None,
63            IntegrationError::UnsupportedExtensionDegree(_) => Some(
64                "v1.1 supports sqrt(P(x)) only; higher-degree radicals (cbrt, nth-root) \
65                 are planned for v2.0",
66            ),
67            IntegrationError::NonElementary(_) => Some(
68                "this integrand has no closed-form antiderivative in terms of elementary \
69                 functions; use a numeric integrator or elliptic-integral library",
70            ),
71        }
72    }
73
74    /// Optional source span `(start_byte, end_byte)` within the input text.
75    pub fn span(&self) -> Option<(usize, usize)> {
76        None
77    }
78}
79
80impl crate::errors::AlkahestError for IntegrationError {
81    fn code(&self) -> &'static str {
82        match self {
83            IntegrationError::NotImplemented(_) => "E-INT-001",
84            IntegrationError::DivisionByZero => "E-INT-002",
85            IntegrationError::UnsupportedExtensionDegree(_) => "E-INT-003",
86            IntegrationError::NonElementary(_) => "E-INT-004",
87        }
88    }
89
90    fn remediation(&self) -> Option<&'static str> {
91        IntegrationError::remediation(self)
92    }
93}
94
95// ---------------------------------------------------------------------------
96// Helpers
97// ---------------------------------------------------------------------------
98
99/// Return the i64 value of an integer expression, or None.
100fn as_integer(expr: ExprId, pool: &ExprPool) -> Option<i64> {
101    pool.with(expr, |data| match data {
102        ExprData::Integer(n) => n.0.to_i64(),
103        _ => None,
104    })
105}
106
107/// Return `true` if `expr` does not involve `var` (is a constant w.r.t. `var`).
108fn is_free_of(expr: ExprId, var: ExprId, pool: &ExprPool) -> bool {
109    if expr == var {
110        return false;
111    }
112    let children: Vec<ExprId> = pool.with(expr, |data| match data {
113        ExprData::Add(args) | ExprData::Mul(args) => args.clone(),
114        ExprData::Pow { base, exp } => vec![*base, *exp],
115        ExprData::Func { args, .. } => args.clone(),
116        _ => vec![],
117    });
118    children.into_iter().all(|c| is_free_of(c, var, pool))
119}
120
121/// If `expr = a*var + b` where `a`, `b` are free of `var`, return `Some((a, b))`.
122/// Returns `Some((1, 0))` when `expr == var`.
123fn is_linear_in(expr: ExprId, var: ExprId, pool: &ExprPool) -> Option<(ExprId, ExprId)> {
124    if expr == var {
125        return Some((pool.integer(1_i32), pool.integer(0_i32)));
126    }
127    match pool.get(expr) {
128        ExprData::Mul(args) => {
129            let var_pos = args.iter().position(|&a| a == var)?;
130            let others: Vec<ExprId> = args
131                .iter()
132                .enumerate()
133                .filter(|&(i, _)| i != var_pos)
134                .map(|(_, &a)| a)
135                .collect();
136            let a = match others.len() {
137                0 => pool.integer(1_i32),
138                1 => others[0],
139                _ => pool.mul(others),
140            };
141            if is_free_of(a, var, pool) {
142                Some((a, pool.integer(0_i32)))
143            } else {
144                None
145            }
146        }
147        ExprData::Add(args) => {
148            let mut a_opt: Option<ExprId> = None;
149            let mut b_parts: Vec<ExprId> = vec![];
150            for &arg in &args {
151                if arg == var {
152                    if a_opt.is_some() {
153                        return None;
154                    }
155                    a_opt = Some(pool.integer(1_i32));
156                } else {
157                    match pool.get(arg) {
158                        ExprData::Mul(margs) => {
159                            let vpos = margs.iter().position(|&m| m == var);
160                            if let Some(vp) = vpos {
161                                if a_opt.is_some() {
162                                    return None;
163                                }
164                                let others: Vec<ExprId> = margs
165                                    .iter()
166                                    .enumerate()
167                                    .filter(|&(i, _)| i != vp)
168                                    .map(|(_, &m)| m)
169                                    .collect();
170                                let coeff = match others.len() {
171                                    0 => pool.integer(1_i32),
172                                    1 => others[0],
173                                    _ => pool.mul(others),
174                                };
175                                if is_free_of(coeff, var, pool) {
176                                    a_opt = Some(coeff);
177                                } else {
178                                    b_parts.push(arg);
179                                }
180                            } else if is_free_of(arg, var, pool) {
181                                b_parts.push(arg);
182                            } else {
183                                return None;
184                            }
185                        }
186                        _ if is_free_of(arg, var, pool) => b_parts.push(arg),
187                        _ => return None,
188                    }
189                }
190            }
191            let a = a_opt?;
192            let b = match b_parts.len() {
193                0 => pool.integer(0_i32),
194                1 => b_parts[0],
195                _ => pool.add(b_parts),
196            };
197            Some((a, b))
198        }
199        _ => None,
200    }
201}
202
203/// Match `∫ c * x * exp(x) dx = c * exp(x) * (x - 1)`.
204///
205/// Recognises any `Mul` containing exactly one `exp(var)` factor, exactly one
206/// `var` factor, and zero or more constant (free-of-var) factors.
207fn try_x_times_func(
208    expr: ExprId,
209    var: ExprId,
210    pool: &ExprPool,
211    log: &mut DerivationLog,
212) -> Option<ExprId> {
213    let args = match pool.get(expr) {
214        ExprData::Mul(v) => v,
215        _ => return None,
216    };
217
218    let exp_pos = args.iter().position(|&a| {
219        pool.with(a, |d| match d {
220            ExprData::Func { name, args } => name == "exp" && args.len() == 1 && args[0] == var,
221            _ => false,
222        })
223    })?;
224
225    let var_pos = args.iter().position(|&a| a == var)?;
226
227    let others: Vec<ExprId> = args
228        .iter()
229        .enumerate()
230        .filter(|&(i, _)| i != exp_pos && i != var_pos)
231        .map(|(_, &a)| a)
232        .collect();
233    if !others.iter().all(|&a| is_free_of(a, var, pool)) {
234        return None;
235    }
236
237    // ∫ c * x * exp(x) dx = c * exp(x) * (x - 1)
238    let exp_x = args[exp_pos];
239    let x_minus_1 = pool.add(vec![var, pool.integer(-1_i32)]);
240    let mut factors = vec![exp_x, x_minus_1];
241    factors.extend_from_slice(&others);
242    let result = pool.mul(factors);
243    log.push(RewriteStep::simple("int_x_exp", expr, result));
244    Some(result)
245}
246
247// ---------------------------------------------------------------------------
248// Core integration (no simplification yet)
249// ---------------------------------------------------------------------------
250
251/// Crate-internal entry to the rule-based integrator (no algebraic dispatch).
252/// Used by the algebraic engine to integrate the rational part A(x).
253pub(crate) fn integrate_raw(
254    expr: ExprId,
255    var: ExprId,
256    pool: &ExprPool,
257    log: &mut DerivationLog,
258) -> Result<ExprId, IntegrationError> {
259    // Fast-path: ∫ c * x * exp(x) dx = c * exp(x) * (x - 1)
260    if let Some(result) = try_x_times_func(expr, var, pool, log) {
261        return Ok(result);
262    }
263
264    // Snapshot node type without holding the lock during recursive calls.
265    enum Node {
266        IsVar,
267        Constant,
268        Add(Vec<ExprId>),
269        Mul(Vec<ExprId>),
270        Pow { base: ExprId, exp: ExprId },
271        Func { name: String, arg: ExprId },
272        Unknown,
273    }
274
275    let node = pool.with(expr, |data| match data {
276        ExprData::Symbol { .. } if expr == var => Node::IsVar,
277        ExprData::Symbol { .. }
278        | ExprData::Integer(_)
279        | ExprData::Rational(_)
280        | ExprData::Float(_) => Node::Constant,
281        ExprData::Add(args) => Node::Add(args.clone()),
282        ExprData::Mul(args) => Node::Mul(args.clone()),
283        ExprData::Pow { base, exp } => Node::Pow {
284            base: *base,
285            exp: *exp,
286        },
287        ExprData::Func { name, args } if args.len() == 1 => Node::Func {
288            name: name.clone(),
289            arg: args[0],
290        },
291        _ => Node::Unknown,
292    });
293
294    match node {
295        // ∫ x dx = x²/2
296        Node::IsVar => {
297            let two = pool.integer(2_i32);
298            let inv_two = pool.pow(two, pool.integer(-1_i32));
299            let result = pool.mul(vec![pool.pow(var, two), inv_two]);
300            log.push(RewriteStep::simple("power_rule", expr, result));
301            Ok(result)
302        }
303
304        // ∫ c dx = c*x  (c free of var)
305        Node::Constant => {
306            let result = pool.mul(vec![expr, var]);
307            log.push(RewriteStep::simple("constant_rule", expr, result));
308            Ok(result)
309        }
310
311        // Sum rule: ∫(f + g + …) = ∫f + ∫g + …
312        Node::Add(args) => {
313            let mut int_args = Vec::with_capacity(args.len());
314            for a in &args {
315                let ia = integrate_raw(*a, var, pool, log)?;
316                int_args.push(ia);
317            }
318            let result = pool.add(int_args);
319            log.push(RewriteStep::simple("sum_rule", expr, result));
320            Ok(result)
321        }
322
323        // Constant-multiple / power rule for Mul
324        Node::Mul(args) => {
325            // Partition args into constants (free of var) and non-constants
326            let (consts, non_consts): (Vec<ExprId>, Vec<ExprId>) =
327                args.iter().partition(|&&a| is_free_of(a, var, pool));
328
329            if non_consts.is_empty() {
330                // All factors are constants — treat whole expression as constant
331                let result = pool.mul(vec![expr, var]);
332                log.push(RewriteStep::simple("constant_rule", expr, result));
333                return Ok(result);
334            }
335
336            // Build the non-constant part
337            let inner = match non_consts.len() {
338                1 => non_consts[0],
339                _ => pool.mul(non_consts.clone()),
340            };
341
342            // Build the constant factor
343            let const_factor = match consts.len() {
344                0 => None,
345                1 => Some(consts[0]),
346                _ => Some(pool.mul(consts.clone())),
347            };
348
349            // Integrate the non-constant part
350            let int_inner = integrate_raw(inner, var, pool, log)?;
351
352            let result = match const_factor {
353                None => int_inner,
354                Some(c) => {
355                    let r = pool.mul(vec![c, int_inner]);
356                    log.push(RewriteStep::simple("constant_multiple_rule", expr, r));
357                    r
358                }
359            };
360            Ok(result)
361        }
362
363        // Power rule: ∫ f^n dx
364        Node::Pow { base, exp } => {
365            // Check if exponent is a constant integer
366            let n_opt = as_integer(exp, pool);
367
368            if let Some(n) = n_opt {
369                if base == var {
370                    if n == -1 {
371                        // ∫ x^(-1) dx = ln(x)
372                        let result = pool.func("log", vec![var]);
373                        log.push(RewriteStep::simple("log_rule", expr, result));
374                        return Ok(result);
375                    }
376                    // ∫ x^n dx = x^(n+1) / (n+1)
377                    let np1 = pool.integer(n + 1);
378                    let inv_np1 = pool.pow(np1, pool.integer(-1_i32));
379                    let result = pool.mul(vec![pool.pow(var, np1), inv_np1]);
380                    log.push(RewriteStep::simple("power_rule", expr, result));
381                    return Ok(result);
382                }
383
384                // ∫ 1/(a*x + b) dx = log(a*x + b) / a
385                if n == -1 {
386                    if let Some((a, _b)) = is_linear_in(base, var, pool) {
387                        let log_base = pool.func("log", vec![base]);
388                        let a_inv = pool.pow(a, pool.integer(-1_i32));
389                        let result = pool.mul(vec![a_inv, log_base]);
390                        log.push(RewriteStep::simple("int_linear_inv", expr, result));
391                        return Ok(result);
392                    }
393                }
394
395                // base is free of var: ∫ c^n dx = c^n * x
396                if is_free_of(base, var, pool) {
397                    let result = pool.mul(vec![expr, var]);
398                    log.push(RewriteStep::simple("constant_rule", expr, result));
399                    return Ok(result);
400                }
401            }
402
403            Err(IntegrationError::NotImplemented(
404                "∫ (expr)^(exp) where base or exp is non-trivial".to_string(),
405            ))
406        }
407
408        // Named single-argument functions
409        Node::Func { name, arg } => {
410            if arg != var {
411                // Only handle f(x) directly; chain rule is out of scope
412                if is_free_of(arg, var, pool) {
413                    // ∫ f(c) dx = f(c) * x
414                    let result = pool.mul(vec![expr, var]);
415                    log.push(RewriteStep::simple("constant_rule", expr, result));
416                    return Ok(result);
417                }
418                // ∫ exp(a*x + b) dx = exp(a*x + b) / a
419                if name == "exp" {
420                    if let Some((a, _b)) = is_linear_in(arg, var, pool) {
421                        let exp_expr = pool.func("exp", vec![arg]);
422                        let a_inv = pool.pow(a, pool.integer(-1_i32));
423                        let result = pool.mul(vec![a_inv, exp_expr]);
424                        log.push(RewriteStep::simple("int_exp_linear", expr, result));
425                        return Ok(result);
426                    }
427                }
428                return Err(IntegrationError::NotImplemented(format!(
429                    "∫ {name}(non-trivial arg) — chain rule not implemented"
430                )));
431            }
432            match name.as_str() {
433                // ∫ sin(x) dx = -cos(x)
434                "sin" => {
435                    let neg_one = pool.integer(-1_i32);
436                    let result = pool.mul(vec![neg_one, pool.func("cos", vec![var])]);
437                    log.push(RewriteStep::simple("int_sin", expr, result));
438                    Ok(result)
439                }
440                // ∫ cos(x) dx = sin(x)
441                "cos" => {
442                    let result = pool.func("sin", vec![var]);
443                    log.push(RewriteStep::simple("int_cos", expr, result));
444                    Ok(result)
445                }
446                // ∫ exp(x) dx = exp(x)
447                "exp" => {
448                    let result = pool.func("exp", vec![var]);
449                    log.push(RewriteStep::simple("int_exp", expr, result));
450                    Ok(result)
451                }
452                // ∫ log(x) dx = x*log(x) - x  (integration by parts)
453                "log" => {
454                    let log_x = pool.func("log", vec![var]);
455                    let x_log_x = pool.mul(vec![var, log_x]);
456                    let neg_x = pool.mul(vec![pool.integer(-1_i32), var]);
457                    let result = pool.add(vec![x_log_x, neg_x]);
458                    log.push(RewriteStep::simple("int_log", expr, result));
459                    Ok(result)
460                }
461                "sqrt" => Err(IntegrationError::NotImplemented(
462                    "∫ sqrt(x) — not in the supported Risch subset".to_string(),
463                )),
464                other => Err(IntegrationError::NotImplemented(format!("∫ {other}(x)"))),
465            }
466        }
467
468        Node::Unknown => Err(IntegrationError::NotImplemented(
469            "unsupported expression node".to_string(),
470        )),
471    }
472}
473
474// ---------------------------------------------------------------------------
475// Public API
476// ---------------------------------------------------------------------------
477
478/// Symbolically integrate `expr` with respect to `var`.
479///
480/// Returns the antiderivative (without the constant of integration) after
481/// applying the rule-based simplifier.  The derivation log records every
482/// rule applied.
483///
484/// # Supported operations
485///
486/// | Input              | Result                      | Rule                    |
487/// |--------------------|-----------------------------|-------------------------|
488/// | `c` (constant)     | `c·x`                       | `constant_rule`         |
489/// | `x^n` (n≠-1)      | `x^(n+1)/(n+1)`             | `power_rule`            |
490/// | `x^(-1)`           | `ln(x)`                     | `log_rule`              |
491/// | `f + g`            | `∫f + ∫g`                   | `sum_rule`              |
492/// | `c · f`            | `c · ∫f`                    | `constant_multiple_rule`|
493/// | `sin(x)`           | `-cos(x)`                   | `int_sin`               |
494/// | `cos(x)`           | `sin(x)`                    | `int_cos`               |
495/// | `exp(x)`           | `exp(x)`                    | `int_exp`               |
496/// | `exp(a*x + b)`     | `exp(a*x+b) / a`            | `int_exp_linear`        |
497/// | `log(x)`           | `x*log(x) - x`              | `int_log`               |
498/// | `x * exp(x)`       | `exp(x)*(x-1)`              | `int_x_exp`             |
499/// | `1/(a*x + b)`      | `log(a*x+b) / a`            | `int_linear_inv`        |
500///
501/// # Verification
502///
503/// For all supported inputs, `diff(integrate(f, x), x)` should simplify to
504/// `f` (modulo simplification of the constant rule).  The property tests in
505/// this module verify this on random polynomials.
506pub fn integrate(
507    expr: ExprId,
508    var: ExprId,
509    pool: &ExprPool,
510) -> Result<DerivedExpr<ExprId>, IntegrationError> {
511    // V1-2: Route algebraic integrands to the Trager/Risch algebraic engine.
512    if super::algebraic::contains_algebraic_subterm(expr, pool) {
513        return super::algebraic::integrate_algebraic(expr, var, pool);
514    }
515
516    let mut log = DerivationLog::new();
517    let raw = integrate_raw(expr, var, pool, &mut log)?;
518    let simplified = simplify(raw, pool);
519    let final_log = log.merge(simplified.log);
520    Ok(DerivedExpr::with_log(simplified.value, final_log))
521}
522
523// ---------------------------------------------------------------------------
524// Tests
525// ---------------------------------------------------------------------------
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::diff::diff;
531    use crate::kernel::{Domain, ExprPool};
532    use crate::poly::UniPoly;
533
534    fn p() -> ExprPool {
535        ExprPool::new()
536    }
537
538    fn coeffs_equal(a: ExprId, b: ExprId, x: ExprId, pool: &ExprPool) -> bool {
539        let ap = UniPoly::from_symbolic(a, x, pool);
540        let bp = UniPoly::from_symbolic(b, x, pool);
541        match (ap, bp) {
542            (Ok(a), Ok(b)) => a.coefficients_i64() == b.coefficients_i64(),
543            _ => a == b,
544        }
545    }
546
547    // Verify the antiderivative: diff(∫f) should equal f (mod simplification).
548    fn verify(expr: ExprId, x: ExprId, pool: &ExprPool) {
549        let integral = integrate(expr, x, pool).unwrap();
550        let deriv = diff(integral.value, x, pool).unwrap();
551        assert!(
552            coeffs_equal(deriv.value, expr, x, pool),
553            "diff(integrate(f)) ≠ f for f = {}",
554            pool.display(expr)
555        );
556    }
557
558    #[test]
559    fn integrate_constant() {
560        let pool = p();
561        let x = pool.symbol("x", Domain::Real);
562        // ∫ 5 dx = 5x
563        let r = integrate(pool.integer(5_i32), x, &pool).unwrap();
564        let expected = pool.mul(vec![pool.integer(5_i32), x]);
565        assert!(coeffs_equal(r.value, expected, x, &pool));
566    }
567
568    #[test]
569    fn integrate_x() {
570        // ∫ x dx = x²/2
571        let pool = p();
572        let x = pool.symbol("x", Domain::Real);
573        verify(x, x, &pool);
574    }
575
576    #[test]
577    fn integrate_x_squared() {
578        // ∫ x² dx = x³/3
579        let pool = p();
580        let x = pool.symbol("x", Domain::Real);
581        let x2 = pool.pow(x, pool.integer(2_i32));
582        verify(x2, x, &pool);
583    }
584
585    #[test]
586    fn integrate_polynomial() {
587        // ∫ (x² + 2x) dx = x³/3 + x²
588        let pool = p();
589        let x = pool.symbol("x", Domain::Real);
590        let expr = pool.add(vec![
591            pool.pow(x, pool.integer(2_i32)),
592            pool.mul(vec![pool.integer(2_i32), x]),
593        ]);
594        let r = integrate(expr, x, &pool).unwrap();
595        // Verify by differentiation
596        let d = diff(r.value, x, &pool).unwrap();
597        assert!(
598            coeffs_equal(d.value, expr, x, &pool),
599            "diff(∫(x²+2x)) ≠ x²+2x; got {}",
600            pool.display(d.value)
601        );
602    }
603
604    #[test]
605    fn integrate_one_over_x() {
606        // ∫ x^(-1) dx = log(x)
607        let pool = p();
608        let x = pool.symbol("x", Domain::Real);
609        let x_inv = pool.pow(x, pool.integer(-1_i32));
610        let r = integrate(x_inv, x, &pool).unwrap();
611        assert_eq!(r.value, pool.func("log", vec![x]));
612        assert!(r.log.steps().iter().any(|s| s.rule_name == "log_rule"));
613    }
614
615    #[test]
616    fn integrate_sin() {
617        // ∫ sin(x) dx = -cos(x)
618        let pool = p();
619        let x = pool.symbol("x", Domain::Real);
620        let sin_x = pool.func("sin", vec![x]);
621        let r = integrate(sin_x, x, &pool).unwrap();
622        let neg_one = pool.integer(-1_i32);
623        let expected = pool.mul(vec![neg_one, pool.func("cos", vec![x])]);
624        assert_eq!(r.value, expected);
625        assert!(r.log.steps().iter().any(|s| s.rule_name == "int_sin"));
626    }
627
628    #[test]
629    fn integrate_cos() {
630        // ∫ cos(x) dx = sin(x)
631        let pool = p();
632        let x = pool.symbol("x", Domain::Real);
633        let r = integrate(pool.func("cos", vec![x]), x, &pool).unwrap();
634        assert_eq!(r.value, pool.func("sin", vec![x]));
635    }
636
637    #[test]
638    fn integrate_exp() {
639        // ∫ exp(x) dx = exp(x)
640        let pool = p();
641        let x = pool.symbol("x", Domain::Real);
642        let r = integrate(pool.func("exp", vec![x]), x, &pool).unwrap();
643        assert_eq!(r.value, pool.func("exp", vec![x]));
644    }
645
646    #[test]
647    fn integrate_constant_multiple() {
648        // ∫ 3*x² dx = 3 * x³/3 = x³
649        let pool = p();
650        let x = pool.symbol("x", Domain::Real);
651        let expr = pool.mul(vec![pool.integer(3_i32), pool.pow(x, pool.integer(2_i32))]);
652        verify(expr, x, &pool);
653    }
654
655    #[test]
656    fn integrate_not_implemented() {
657        let pool = p();
658        let x = pool.symbol("x", Domain::Real);
659        // ∫ sin(x²) dx has no elementary antiderivative and is outside the supported subset
660        let x2 = pool.pow(x, pool.integer(2_i32));
661        let err = integrate(pool.func("sin", vec![x2]), x, &pool);
662        assert!(matches!(err, Err(IntegrationError::NotImplemented(_))));
663    }
664
665    // --- New rules (v0.5 Risch extension) ---
666
667    #[test]
668    fn integrate_log_x() {
669        // ∫ log(x) dx = x*log(x) - x
670        let pool = p();
671        let x = pool.symbol("x", Domain::Real);
672        let log_x = pool.func("log", vec![x]);
673        let r = integrate(log_x, x, &pool).unwrap();
674        assert!(
675            r.log.steps().iter().any(|s| s.rule_name == "int_log"),
676            "should have logged int_log step"
677        );
678        // Structural check: result contains log(x)
679        let result_str = pool.display(r.value).to_string();
680        assert!(
681            result_str.contains("log"),
682            "result should contain log: {result_str}"
683        );
684    }
685
686    #[test]
687    fn integrate_exp_linear_arg() {
688        // ∫ exp(2*x) dx = exp(2*x) / 2
689        let pool = p();
690        let x = pool.symbol("x", Domain::Real);
691        let two = pool.integer(2_i32);
692        let two_x = pool.mul(vec![two, x]);
693        let expr = pool.func("exp", vec![two_x]);
694        let r = integrate(expr, x, &pool).unwrap();
695        assert!(
696            r.log
697                .steps()
698                .iter()
699                .any(|s| s.rule_name == "int_exp_linear"),
700            "should fire int_exp_linear"
701        );
702        // Structural check: result is 2^(-1) * exp(2*x)
703        let result_str = pool.display(r.value).to_string();
704        assert!(
705            result_str.contains("exp"),
706            "result should contain exp: {result_str}"
707        );
708    }
709
710    #[test]
711    fn integrate_x_times_exp_x() {
712        // ∫ x * exp(x) dx = exp(x) * (x - 1)
713        let pool = p();
714        let x = pool.symbol("x", Domain::Real);
715        let expr = pool.mul(vec![x, pool.func("exp", vec![x])]);
716        let r = integrate(expr, x, &pool).unwrap();
717        assert!(
718            r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
719            "should fire int_x_exp"
720        );
721        let result_str = pool.display(r.value).to_string();
722        assert!(
723            result_str.contains("exp"),
724            "result should contain exp: {result_str}"
725        );
726    }
727
728    #[test]
729    fn integrate_const_times_x_times_exp_x() {
730        // ∫ 3 * x * exp(x) dx  — constant factor should be preserved
731        let pool = p();
732        let x = pool.symbol("x", Domain::Real);
733        let three = pool.integer(3_i32);
734        let expr = pool.mul(vec![three, x, pool.func("exp", vec![x])]);
735        let r = integrate(expr, x, &pool).unwrap();
736        assert!(
737            r.log.steps().iter().any(|s| s.rule_name == "int_x_exp"),
738            "should fire int_x_exp for 3*x*exp(x)"
739        );
740    }
741
742    #[test]
743    fn integrate_one_over_linear() {
744        // ∫ 1/(2*x + 3) dx = log(2*x + 3) / 2
745        let pool = p();
746        let x = pool.symbol("x", Domain::Real);
747        let two = pool.integer(2_i32);
748        let three = pool.integer(3_i32);
749        let linear = pool.add(vec![pool.mul(vec![two, x]), three]);
750        let expr = pool.pow(linear, pool.integer(-1_i32));
751        let r = integrate(expr, x, &pool).unwrap();
752        assert!(
753            r.log
754                .steps()
755                .iter()
756                .any(|s| s.rule_name == "int_linear_inv"),
757            "should fire int_linear_inv"
758        );
759        let result_str = pool.display(r.value).to_string();
760        assert!(
761            result_str.contains("log"),
762            "result should contain log: {result_str}"
763        );
764    }
765
766    #[test]
767    fn integrate_x_cubed_plus_2x() {
768        // ∫ (x³ + 2x) dx — antiderivative check
769        let pool = p();
770        let x = pool.symbol("x", Domain::Real);
771        let expr = pool.add(vec![
772            pool.pow(x, pool.integer(3_i32)),
773            pool.mul(vec![pool.integer(2_i32), x]),
774        ]);
775        verify(expr, x, &pool);
776    }
777
778    #[test]
779    fn integrate_derivation_log_nonempty() {
780        let pool = p();
781        let x = pool.symbol("x", Domain::Real);
782        let r = integrate(pool.pow(x, pool.integer(2_i32)), x, &pool).unwrap();
783        assert!(
784            !r.log.is_empty(),
785            "integration should produce a derivation log"
786        );
787        assert!(r.log.steps().iter().any(|s| s.rule_name == "power_rule"));
788    }
789
790    #[test]
791    fn integrate_sqrt_x() {
792        // ∫ sqrt(x) dx  should succeed (linear P)
793        let pool = p();
794        let x = pool.symbol("x", Domain::Real);
795        let sqrt_x = pool.func("sqrt", vec![x]);
796        let result = integrate(sqrt_x, x, &pool);
797        match &result {
798            Ok(r) => println!("sqrt(x) integral = {}", pool.display(r.value)),
799            Err(e) => println!("ERROR: {e}"),
800        }
801        assert!(result.is_ok(), "∫ sqrt(x) dx failed: {:?}", result);
802    }
803
804    #[test]
805    fn integrate_inv_sqrt_x() {
806        // ∫ 1/sqrt(x) dx = 2·sqrt(x)
807        let pool = p();
808        let x = pool.symbol("x", Domain::Real);
809        let sqrt_x = pool.func("sqrt", vec![x]);
810        let inv_sqrt_x = pool.pow(sqrt_x, pool.integer(-1_i32));
811        let result = integrate(inv_sqrt_x, x, &pool);
812        match &result {
813            Ok(r) => println!("1/sqrt(x) integral = {}", pool.display(r.value)),
814            Err(e) => println!("ERROR: {e}"),
815        }
816        assert!(result.is_ok(), "∫ 1/sqrt(x) dx failed: {:?}", result);
817    }
818
819    #[test]
820    fn integrate_sqrt_x2_plus_1() {
821        // ∫ sqrt(x²+1) dx  should succeed (quadratic P)
822        let pool = p();
823        let x = pool.symbol("x", Domain::Real);
824        let p_expr = pool.add(vec![pool.pow(x, pool.integer(2_i32)), pool.integer(1_i32)]);
825        let sqrt_p = pool.func("sqrt", vec![p_expr]);
826        let result = integrate(sqrt_p, x, &pool);
827        match &result {
828            Ok(r) => println!("sqrt(x^2+1) integral = {}", pool.display(r.value)),
829            Err(e) => println!("ERROR: {e}"),
830        }
831        assert!(result.is_ok(), "∫ sqrt(x²+1) dx failed: {:?}", result);
832    }
833}