mathhook_core/calculus/integrals/
substitution.rs

1//! U-substitution for integration
2//!
3//! Implements automatic u-substitution detection and execution for composite
4//! functions. Handles patterns like f'(g(x)) * g'(x) by substituting u = g(x).
5//!
6//! # Algorithm
7//!
8//! 1. Identify candidate substitutions u = g(x) from the integrand structure
9//! 2. Compute du = g'(x) dx for each candidate
10//! 3. Check if integrand can be rewritten as f(u) * du (possibly with constant factor)
11//! 4. Integrate f(u) with respect to u
12//! 5. Substitute back u = g(x) in the result
13//!
14//! # Supported Patterns
15//!
16//! - Polynomial inner functions: ∫2x*sin(x²) dx = -cos(x²)
17//! - Exponential compositions: ∫e^x*sin(e^x) dx = -cos(e^x)
18//! - Logarithmic patterns: ∫1/(x*ln(x)) dx = ln|ln(x)|
19//! - Rational functions: ∫x/(x²+1) dx = (1/2)*ln(x²+1)
20//! - Linear inner functions: ∫sqrt(x+1) dx = (2/3)(x+1)^(3/2)
21//!
22//! # Patterns Recognized
23//!
24//! **Pattern 1**: `f'(x)·g(f(x))` - Exact derivative match
25//! - Example: `2x·e^(x²)` where u = x², du = 2x dx
26//!
27//! **Pattern 2**: `c·f'(x)·g(f(x))` - Derivative with coefficient
28//! - Example: `x·sin(x²)` where u = x², du = 2x dx, coefficient = 1/2
29//!
30//! **Pattern 3**: `f^n(x)·f'(x)` - Power of function times derivative
31//! - Example: `sin³(x)·cos(x)` where u = sin(x), du = cos(x) dx
32//!
33//! **Pattern 4**: `f(ax+b)` - Constant derivative (linear inner function)
34//! - Example: `sqrt(x+1)` where u = x+1, du = 1 dx (constant derivative)
35
36use crate::calculus::derivatives::Derivative;
37use crate::core::{Expression, Number, Symbol};
38use crate::simplify::Simplify;
39
40/// Maximum recursion depth for integration attempts.
41///
42/// Based on analysis of SymPy's integration behavior, most legitimate
43/// integrals converge within 3-4 recursive calls. A limit of 10 provides
44/// a safe margin while preventing pathological infinite recursion cases.
45///
46/// This depth limit prevents infinite loops in cases like:
47/// - Circular substitutions (u = v, v = u)
48/// - Self-referential patterns that don't simplify
49/// - Nested substitutions that don't converge
50///
51/// **SymPy Comparison**: SymPy uses cache-based loop detection, but our
52/// MAX_DEPTH approach is simpler, safer, and sufficient for MathHook's needs.
53const MAX_DEPTH: usize = 10;
54
55/// Try to integrate using u-substitution
56///
57/// Automatically detects composite function patterns and applies substitution.
58///
59/// # Arguments
60///
61/// * `expr` - The integrand expression
62/// * `var` - The variable of integration
63/// * `depth` - Current recursion depth (prevents infinite recursion)
64///
65/// # Returns
66///
67/// Some(result) if substitution succeeds, None if no suitable substitution found
68///
69/// # Examples
70///
71/// ```rust
72/// use mathhook_core::calculus::integrals::substitution::try_substitution;
73/// use mathhook_core::symbol;
74/// use mathhook_core::core::Expression;
75///
76/// let x = symbol!(x);
77/// // ∫2x*sin(x²) dx
78/// let integrand = Expression::mul(vec![
79///     Expression::integer(2),
80///     Expression::symbol(x.clone()),
81///     Expression::function("sin", vec![
82///         Expression::pow(Expression::symbol(x.clone()), Expression::integer(2))
83///     ])
84/// ]);
85///
86/// let result = try_substitution(&integrand, &x, 0);
87/// assert!(result.is_some());
88/// ```
89pub fn try_substitution(expr: &Expression, var: &Symbol, depth: usize) -> Option<Expression> {
90    if depth >= MAX_DEPTH {
91        return None;
92    }
93
94    let candidates = find_substitution_candidates(expr, var);
95
96    for candidate in candidates.iter() {
97        let g_prime = candidate.derivative(var.clone());
98
99        if let Some((f_of_u, constant_factor)) =
100            check_derivative_match(expr, candidate, &g_prime, var)
101        {
102            let u_symbol = Symbol::scalar("u");
103            let u_expr = Expression::symbol(u_symbol.clone());
104
105            let integrated = integrate_in_u(&f_of_u, u_symbol, depth)?;
106
107            let result = substitute_back(&integrated, &u_expr, candidate);
108
109            let final_result = if (constant_factor - 1.0).abs() > 1e-10 {
110                if constant_factor.abs() < 1.0 {
111                    let denom = (1.0 / constant_factor) as i64;
112                    Expression::mul(vec![Expression::rational(1, denom), result])
113                } else {
114                    let numer = constant_factor as i64;
115                    Expression::mul(vec![Expression::integer(numer), result])
116                }
117            } else {
118                result
119            };
120
121            return Some(final_result);
122        }
123    }
124
125    None
126}
127
128/// Find candidate expressions for substitution u = g(x)
129///
130/// Looks for inner functions, polynomial expressions, exponential/logarithm arguments.
131fn find_substitution_candidates(expr: &Expression, var: &Symbol) -> Vec<Expression> {
132    let mut candidates = Vec::new();
133
134    collect_candidates_recursive(expr, var, &mut candidates);
135
136    candidates.sort_by_key(|c| std::cmp::Reverse(expression_complexity(c)));
137    candidates.dedup_by(|a, b| expressions_equivalent(a, b));
138
139    candidates
140}
141
142/// Recursively collect substitution candidates from expression tree
143fn collect_candidates_recursive(expr: &Expression, var: &Symbol, candidates: &mut Vec<Expression>) {
144    match expr {
145        Expression::Function { name: _, args } => {
146            // For function arguments, consider the function itself as a candidate
147            // Example: sin(x) is a candidate in sin³(x)·cos(x)
148            if args.len() == 1 && args[0].contains_variable(var) {
149                // If argument is just x, consider the whole function
150                if is_simple_variable(&args[0], var) {
151                    candidates.push(expr.clone());
152                } else {
153                    // If argument is composite, consider the argument
154                    candidates.push(args[0].clone());
155                }
156            }
157            for arg in args.iter() {
158                if arg.contains_variable(var) && !is_simple_variable(arg, var) {
159                    candidates.push(arg.clone());
160                }
161                collect_candidates_recursive(arg, var, candidates);
162            }
163        }
164        Expression::Pow(base, exp) => {
165            if base.contains_variable(var) && !is_simple_variable(base, var) {
166                candidates.push((**base).clone());
167            }
168            if exp.contains_variable(var) && !is_simple_variable(exp, var) {
169                candidates.push((**exp).clone());
170            }
171            collect_candidates_recursive(base, var, candidates);
172            collect_candidates_recursive(exp, var, candidates);
173        }
174        Expression::Add(terms) => {
175            for term in terms.iter() {
176                collect_candidates_recursive(term, var, candidates);
177            }
178        }
179        Expression::Mul(factors) => {
180            for factor in factors.iter() {
181                collect_candidates_recursive(factor, var, candidates);
182            }
183        }
184        _ => {}
185    }
186}
187
188/// Check if expression contains the given candidate expression
189///
190/// This is used to separate f(u) from g'(x): factors containing the candidate are f(u)
191fn contains_expression(expr: &Expression, candidate: &Expression) -> bool {
192    if expr == candidate {
193        return true;
194    }
195
196    match expr {
197        Expression::Add(terms) => terms.iter().any(|t| contains_expression(t, candidate)),
198        Expression::Mul(factors) => factors.iter().any(|f| contains_expression(f, candidate)),
199        Expression::Pow(base, exp) => {
200            contains_expression(base, candidate) || contains_expression(exp, candidate)
201        }
202        Expression::Function { name: _, args } => {
203            args.iter().any(|a| contains_expression(a, candidate))
204        }
205        _ => false,
206    }
207}
208
209/// Check if expression is just the variable itself
210fn is_simple_variable(expr: &Expression, var: &Symbol) -> bool {
211    matches!(expr, Expression::Symbol(s) if s == var)
212}
213
214/// Measure complexity of expression (for prioritizing candidates)
215fn expression_complexity(expr: &Expression) -> usize {
216    match expr {
217        Expression::Number(_) | Expression::Symbol(_) | Expression::Constant(_) => 1,
218        Expression::Add(terms) => terms.iter().map(expression_complexity).sum::<usize>() + 1,
219        Expression::Mul(factors) => factors.iter().map(expression_complexity).sum::<usize>() + 1,
220        Expression::Pow(base, exp) => expression_complexity(base) + expression_complexity(exp) + 1,
221        Expression::Function { name: _, args } => {
222            args.iter().map(expression_complexity).sum::<usize>() + 2
223        }
224        _ => 1,
225    }
226}
227
228/// Check if two expressions are equivalent
229fn expressions_equivalent(a: &Expression, b: &Expression) -> bool {
230    a == b
231}
232
233/// Check if derivative is constant (does not depend on variable)
234fn is_constant_derivative(g_prime: &Expression, var: &Symbol) -> bool {
235    !g_prime.contains_variable(var)
236}
237
238/// Check if derivative g'(x) appears in the integrand
239///
240/// Returns Some((f(u), constant_factor)) if a match is found, where:
241/// - f(u) is the expression in terms of u
242/// - constant_factor accounts for numerical differences between g'(x) and actual factor
243///
244/// This function recognizes patterns where the derivative appears as:
245/// 1. Exact match: `g'(x)` appears as-is
246/// 2. With coefficient: `c·g'(x)` where c is a constant
247/// 3. Distributed across factors: g'(x) = a*b and both a and b appear separately in the product
248/// 4. Constant derivative (implicit): For f(ax+b), derivative is constant and doesn't appear explicitly
249fn check_derivative_match(
250    expr: &Expression,
251    g: &Expression,
252    g_prime: &Expression,
253    var: &Symbol,
254) -> Option<(Expression, f64)> {
255    let expr_simplified = expr.clone().simplify();
256    let g_prime_simplified = g_prime.clone().simplify();
257
258    // Pattern 4: Check for constant derivative (linear inner function like x+1, 2x+3)
259    // For expressions like sqrt(x+1), derivative is 1 (constant), not appearing explicitly
260    if is_constant_derivative(&g_prime_simplified, var) {
261        // Extract the constant value of the derivative
262        if let Some(derivative_value) = extract_constant_value(&g_prime_simplified) {
263            // If the expression contains g, we can use it for substitution
264            if contains_expression(&expr_simplified, g) {
265                let u_symbol = Symbol::scalar("u");
266                let u_expr = Expression::symbol(u_symbol);
267
268                // Replace g with u in the entire expression
269                let f_of_u = replace_expression(&expr_simplified, g, &u_expr);
270
271                // Return with reciprocal of derivative as constant factor
272                // For du = c·dx, we need 1/c factor when substituting
273                return Some((f_of_u, 1.0 / derivative_value));
274            }
275        }
276    }
277
278    if let Expression::Mul(factors) = &expr_simplified {
279        let u_symbol = Symbol::scalar("u");
280        let u_expr = Expression::symbol(u_symbol);
281
282        // NEW STRATEGY: Separate factors into:
283        // 1. Those that contain the candidate g (these are f(u))
284        // 2. The rest (these could be g'(x) or constants)
285        let (f_of_g_factors, derivative_candidate_factors): (Vec<_>, Vec<_>) =
286            factors.iter().partition(|f| contains_expression(f, g));
287
288        if !f_of_g_factors.is_empty() && !derivative_candidate_factors.is_empty() {
289            // Reconstruct what we think is the derivative from available factors
290            let derivative_candidate = if derivative_candidate_factors.len() == 1 {
291                derivative_candidate_factors[0].clone()
292            } else {
293                Expression::mul(
294                    derivative_candidate_factors
295                        .iter()
296                        .map(|f| (*f).clone())
297                        .collect(),
298                )
299            };
300
301            // Check if this matches the derivative (possibly with a constant ratio)
302            if let Some(ratio) = compute_constant_ratio(&derivative_candidate, &g_prime_simplified)
303            {
304                // Success! We found the derivative (with coefficient ratio)
305                // The remaining factors (those containing g) become f(u)
306                let remaining = if f_of_g_factors.is_empty() {
307                    Expression::integer(1)
308                } else if f_of_g_factors.len() == 1 {
309                    f_of_g_factors[0].clone()
310                } else {
311                    Expression::mul(f_of_g_factors.iter().map(|f| (*f).clone()).collect())
312                };
313
314                // Replace g with u in the remaining expression
315                let f_of_u = replace_expression(&remaining, g, &u_expr);
316
317                return Some((f_of_u, ratio));
318            }
319        }
320
321        // Fallback: Try the old partitioning strategy for backward compatibility
322        let (derivative_factors, other_factors): (Vec<_>, Vec<_>) = factors
323            .iter()
324            .partition(|f| factor_matches_derivative(f, &g_prime_simplified, var));
325
326        if derivative_factors.is_empty() {
327            return None;
328        }
329
330        let derivative_product = if derivative_factors.len() == 1 {
331            derivative_factors[0].clone()
332        } else {
333            Expression::mul(derivative_factors.iter().map(|f| (*f).clone()).collect())
334        };
335
336        let constant_factor = compute_constant_ratio(&derivative_product, &g_prime_simplified)?;
337
338        let remaining = if other_factors.is_empty() {
339            Expression::integer(1)
340        } else if other_factors.len() == 1 {
341            other_factors[0].clone()
342        } else {
343            Expression::mul(other_factors.iter().map(|f| (*f).clone()).collect())
344        };
345
346        let f_of_u = replace_expression(&remaining, g, &u_expr);
347
348        Some((f_of_u, constant_factor))
349    } else {
350        let constant_factor = compute_constant_ratio(&expr_simplified, &g_prime_simplified)?;
351        let f_of_u = Expression::integer(1);
352        Some((f_of_u, constant_factor))
353    }
354}
355
356/// Extract constant value from an expression that doesn't depend on variables
357///
358/// Returns Some(value) if expression is a constant number, None otherwise
359fn extract_constant_value(expr: &Expression) -> Option<f64> {
360    match expr {
361        Expression::Number(n) => number_to_f64(n),
362        _ => None,
363    }
364}
365
366/// Replace all occurrences of `pattern` with `replacement` in `expr`
367///
368/// This is used to convert f(g(x)) to f(u) by replacing g(x) with u.
369fn replace_expression(
370    expr: &Expression,
371    pattern: &Expression,
372    replacement: &Expression,
373) -> Expression {
374    // Check if the entire expression matches the pattern
375    if expr == pattern {
376        return replacement.clone();
377    }
378
379    // Recursively replace in subexpressions
380    match expr {
381        Expression::Add(terms) => Expression::add(
382            terms
383                .iter()
384                .map(|t| replace_expression(t, pattern, replacement))
385                .collect(),
386        ),
387        Expression::Mul(factors) => Expression::mul(
388            factors
389                .iter()
390                .map(|f| replace_expression(f, pattern, replacement))
391                .collect(),
392        ),
393        Expression::Pow(base, exp) => Expression::pow(
394            replace_expression(base, pattern, replacement),
395            replace_expression(exp, pattern, replacement),
396        ),
397        Expression::Function { name, args } => Expression::function(
398            name,
399            args.iter()
400                .map(|a| replace_expression(a, pattern, replacement))
401                .collect(),
402        ),
403        _ => expr.clone(),
404    }
405}
406
407/// Check if a factor matches the derivative (possibly with constant multiple)
408fn factor_matches_derivative(factor: &Expression, derivative: &Expression, var: &Symbol) -> bool {
409    if factor == derivative {
410        return true;
411    }
412
413    let factor_simplified = factor.clone().simplify();
414    let derivative_simplified = derivative.clone().simplify();
415
416    if factor_simplified == derivative_simplified {
417        return true;
418    }
419
420    if let (Expression::Mul(f_factors), Expression::Mul(d_factors)) =
421        (&factor_simplified, &derivative_simplified)
422    {
423        let f_non_const: Vec<_> = f_factors
424            .iter()
425            .filter(|f| f.contains_variable(var))
426            .collect();
427        let d_non_const: Vec<_> = d_factors
428            .iter()
429            .filter(|f| f.contains_variable(var))
430            .collect();
431
432        if f_non_const.len() == d_non_const.len() {
433            return f_non_const
434                .iter()
435                .zip(d_non_const.iter())
436                .all(|(f, d)| f == d);
437        }
438    }
439
440    match (&factor_simplified, &derivative_simplified) {
441        (Expression::Symbol(f_sym), Expression::Symbol(d_sym)) => f_sym == d_sym,
442        (Expression::Pow(f_base, f_exp), Expression::Pow(d_base, d_exp)) => {
443            f_base == d_base && f_exp == d_exp
444        }
445        _ => false,
446    }
447}
448
449/// Compute constant ratio between two expressions
450///
451/// Returns Some(ratio) where expr = ratio * target
452/// This handles cases like:
453/// - expr = 2x, target = 2x → ratio = 1.0
454/// - expr = x, target = 2x → ratio = 0.5
455/// - expr = 3x, target = 2x → ratio = 1.5
456fn compute_constant_ratio(expr: &Expression, target: &Expression) -> Option<f64> {
457    if expr == target {
458        return Some(1.0);
459    }
460
461    let expr_simp = expr.clone().simplify();
462    let target_simp = target.clone().simplify();
463
464    if expr_simp == target_simp {
465        return Some(1.0);
466    }
467
468    // Try to match structurally by extracting coefficients
469    match (&expr_simp, &target_simp) {
470        (Expression::Number(n1), Expression::Number(n2)) => {
471            let v1 = number_to_f64(n1)?;
472            let v2 = number_to_f64(n2)?;
473            if v2.abs() > 1e-10 {
474                let ratio = v1 / v2;
475                Some(ratio)
476            } else {
477                None
478            }
479        }
480        // Both are products - try to extract coefficients
481        (Expression::Mul(e_factors), Expression::Mul(t_factors)) => {
482            let e_coeff = extract_coefficient(e_factors);
483            let t_coeff = extract_coefficient(t_factors);
484
485            let e_non_const: Vec<_> = e_factors
486                .iter()
487                .filter(|f| !matches!(f, Expression::Number(_)))
488                .collect();
489            let t_non_const: Vec<_> = t_factors
490                .iter()
491                .filter(|f| !matches!(f, Expression::Number(_)))
492                .collect();
493
494            // Check if non-constant parts match
495            if e_non_const.len() == t_non_const.len()
496                && e_non_const
497                    .iter()
498                    .zip(t_non_const.iter())
499                    .all(|(a, b)| *a == *b)
500                && t_coeff.abs() > 1e-10
501            {
502                let ratio = e_coeff / t_coeff;
503                return Some(ratio);
504            }
505            None
506        }
507        // expr is product, target is not - check if they match structurally
508        (Expression::Mul(factors), _) => {
509            let coeff = extract_coefficient(factors);
510            let non_const: Vec<_> = factors
511                .iter()
512                .filter(|f| !matches!(f, Expression::Number(_)))
513                .collect();
514
515            let non_const_product = if non_const.is_empty() {
516                Expression::integer(1)
517            } else if non_const.len() == 1 {
518                (*non_const[0]).clone()
519            } else {
520                Expression::mul(non_const.iter().map(|f| (*f).clone()).collect())
521            };
522
523            if non_const_product == target_simp {
524                Some(coeff)
525            } else {
526                None
527            }
528        }
529        // target is product, expr is not
530        (_, Expression::Mul(factors)) => {
531            let coeff = extract_coefficient(factors);
532            let non_const: Vec<_> = factors
533                .iter()
534                .filter(|f| !matches!(f, Expression::Number(_)))
535                .collect();
536
537            let non_const_product = if non_const.is_empty() {
538                Expression::integer(1)
539            } else if non_const.len() == 1 {
540                (*non_const[0]).clone()
541            } else {
542                Expression::mul(non_const.iter().map(|f| (*f).clone()).collect())
543            };
544
545            if expr_simp == non_const_product && coeff.abs() > 1e-10 {
546                let ratio = 1.0 / coeff;
547                Some(ratio)
548            } else {
549                None
550            }
551        }
552        _ => None,
553    }
554}
555
556/// Extract numeric coefficient from a product of factors
557///
558/// Returns the product of all numeric factors, or 1.0 if there are none
559fn extract_coefficient(factors: &[Expression]) -> f64 {
560    let nums: Vec<f64> = factors
561        .iter()
562        .filter_map(|f| {
563            if let Expression::Number(n) = f {
564                number_to_f64(n)
565            } else {
566                None
567            }
568        })
569        .collect();
570
571    if nums.is_empty() {
572        1.0
573    } else {
574        nums.iter().product()
575    }
576}
577
578/// Convert Number to f64
579fn number_to_f64(num: &Number) -> Option<f64> {
580    match num {
581        Number::Integer(i) => Some(*i as f64),
582        Number::Rational(r) => {
583            use num_traits::ToPrimitive;
584            r.to_f64()
585        }
586        Number::Float(f) => Some(*f),
587        _ => None,
588    }
589}
590
591/// Integrate expression with respect to u
592///
593/// Depth is incremented and passed to prevent infinite recursion in nested substitutions.
594/// When depth reaches MAX_DEPTH, integration returns None to break recursion chains.
595fn integrate_in_u(expr: &Expression, u: Symbol, depth: usize) -> Option<Expression> {
596    use crate::calculus::integrals::strategy::integrate_with_strategy;
597
598    let result = integrate_with_strategy(expr, u, depth + 1);
599
600    if matches!(result, Expression::Calculus(_)) {
601        None
602    } else {
603        Some(result)
604    }
605}
606
607/// Substitute u = g(x) back into the result
608///
609/// After integrating f(u), we have a result in terms of u.
610/// This function replaces u with g(x) to get the final answer.
611fn substitute_back(expr: &Expression, u: &Expression, g: &Expression) -> Expression {
612    replace_expression(expr, u, g)
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use crate::symbol;
619
620    #[test]
621    fn test_is_simple_variable() {
622        let x = symbol!(x);
623
624        assert!(is_simple_variable(&Expression::symbol(x.clone()), &x));
625        assert!(!is_simple_variable(&Expression::integer(5), &x));
626        assert!(!is_simple_variable(
627            &Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
628            &x
629        ));
630    }
631
632    #[test]
633    fn test_expression_complexity() {
634        let x = symbol!(x);
635
636        assert_eq!(expression_complexity(&Expression::integer(5)), 1);
637        assert_eq!(expression_complexity(&Expression::symbol(x.clone())), 1);
638
639        let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
640        assert_eq!(expression_complexity(&x_squared), 3);
641
642        let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
643        assert_eq!(expression_complexity(&sin_x), 3);
644    }
645
646    #[test]
647    fn test_find_substitution_candidates_basic() {
648        let x = symbol!(x);
649        let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
650        let sin_x_squared = Expression::function("sin", vec![x_squared.clone()]);
651
652        let candidates = find_substitution_candidates(&sin_x_squared, &x);
653
654        assert!(!candidates.is_empty());
655        assert!(candidates.contains(&x_squared));
656    }
657
658    #[test]
659    fn test_replace_expression() {
660        let x = symbol!(x);
661        let u = symbol!(u);
662
663        // Test replacing x² with u in exp(x²)
664        let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
665        let expr = Expression::function("exp", vec![x_squared.clone()]);
666        let u_expr = Expression::symbol(u.clone());
667
668        let result = replace_expression(&expr, &x_squared, &u_expr);
669        let expected = Expression::function("exp", vec![u_expr]);
670
671        assert_eq!(result, expected);
672    }
673
674    #[test]
675    fn test_is_constant_derivative() {
676        let x = symbol!(x);
677
678        // Constant expressions don't contain x
679        assert!(is_constant_derivative(&Expression::integer(1), &x));
680        assert!(is_constant_derivative(&Expression::integer(2), &x));
681        assert!(is_constant_derivative(&Expression::rational(3, 2), &x));
682
683        // Non-constant expressions contain x
684        assert!(!is_constant_derivative(&Expression::symbol(x.clone()), &x));
685        assert!(!is_constant_derivative(
686            &Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
687            &x
688        ));
689    }
690
691    #[test]
692    fn test_extract_constant_value() {
693        assert_eq!(extract_constant_value(&Expression::integer(1)), Some(1.0));
694        assert_eq!(extract_constant_value(&Expression::integer(5)), Some(5.0));
695        assert_eq!(
696            extract_constant_value(&Expression::rational(3, 2)),
697            Some(1.5)
698        );
699
700        let x = symbol!(x);
701        assert_eq!(extract_constant_value(&Expression::symbol(x.clone())), None);
702    }
703
704    #[test]
705    fn test_exponential_chain_rule_pattern() {
706        // Test 3: ∫2x·e^(x²) dx
707        let x = symbol!(x);
708        let expr = Expression::mul(vec![
709            Expression::integer(2),
710            Expression::symbol(x.clone()),
711            Expression::function(
712                "exp",
713                vec![Expression::pow(
714                    Expression::symbol(x.clone()),
715                    Expression::integer(2),
716                )],
717            ),
718        ]);
719
720        let result = try_substitution(&expr, &x, 0);
721        assert!(
722            result.is_some(),
723            "Exponential chain rule pattern should succeed"
724        );
725    }
726
727    #[test]
728    fn test_trig_substitution_with_coefficient() {
729        // Test 4: ∫x·sin(x²) dx
730        let x = symbol!(x);
731        let expr = Expression::mul(vec![
732            Expression::symbol(x.clone()),
733            Expression::function(
734                "sin",
735                vec![Expression::pow(
736                    Expression::symbol(x.clone()),
737                    Expression::integer(2),
738                )],
739            ),
740        ]);
741
742        let result = try_substitution(&expr, &x, 0);
743        assert!(
744            result.is_some(),
745            "Trig substitution with coefficient should succeed"
746        );
747    }
748
749    #[test]
750    fn test_power_chain_rule_pattern() {
751        // Test 7: ∫sin³(x)·cos(x) dx
752        let x = symbol!(x);
753        let expr = Expression::mul(vec![
754            Expression::pow(
755                Expression::function("sin", vec![Expression::symbol(x.clone())]),
756                Expression::integer(3),
757            ),
758            Expression::function("cos", vec![Expression::symbol(x.clone())]),
759        ]);
760
761        let result = try_substitution(&expr, &x, 0);
762        assert!(result.is_some(), "Power chain rule pattern should succeed");
763    }
764
765    #[test]
766    fn test_constant_derivative_linear() {
767        // Test: ∫sqrt(x+1) dx - constant derivative case
768        let x = symbol!(x);
769        let inner = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
770        let expr = Expression::function("sqrt", vec![inner.clone()]);
771
772        let result = try_substitution(&expr, &x, 0);
773        assert!(
774            result.is_some(),
775            "Constant derivative substitution should succeed for sqrt(x+1)"
776        );
777    }
778
779    #[test]
780    fn test_max_depth_prevents_infinite_recursion() {
781        let x = symbol!(x);
782
783        let simple_expr = Expression::symbol(x.clone());
784        let _result_at_limit = try_substitution(&simple_expr, &x, MAX_DEPTH - 1);
785
786        let result_over_limit = try_substitution(&simple_expr, &x, MAX_DEPTH);
787        assert_eq!(
788            result_over_limit, None,
789            "Should return None when depth >= MAX_DEPTH"
790        );
791    }
792}