1use crate::calculus::derivatives::Derivative;
37use crate::core::{Expression, Number, Symbol};
38use crate::simplify::Simplify;
39
40const MAX_DEPTH: usize = 10;
54
55pub fn try_substitution(expr: &Expression, var: &Symbol, depth: usize) -> Option<Expression> {
90 if depth >= MAX_DEPTH {
91 return None;
92 }
93
94 let candidates = find_substitution_candidates(expr, var);
95
96 for candidate in candidates.iter() {
97 let g_prime = candidate.derivative(var.clone());
98
99 if let Some((f_of_u, constant_factor)) =
100 check_derivative_match(expr, candidate, &g_prime, var)
101 {
102 let u_symbol = Symbol::scalar("u");
103 let u_expr = Expression::symbol(u_symbol.clone());
104
105 let integrated = integrate_in_u(&f_of_u, u_symbol, depth)?;
106
107 let result = substitute_back(&integrated, &u_expr, candidate);
108
109 let final_result = if (constant_factor - 1.0).abs() > 1e-10 {
110 if constant_factor.abs() < 1.0 {
111 let denom = (1.0 / constant_factor) as i64;
112 Expression::mul(vec![Expression::rational(1, denom), result])
113 } else {
114 let numer = constant_factor as i64;
115 Expression::mul(vec![Expression::integer(numer), result])
116 }
117 } else {
118 result
119 };
120
121 return Some(final_result);
122 }
123 }
124
125 None
126}
127
128fn find_substitution_candidates(expr: &Expression, var: &Symbol) -> Vec<Expression> {
132 let mut candidates = Vec::new();
133
134 collect_candidates_recursive(expr, var, &mut candidates);
135
136 candidates.sort_by_key(|c| std::cmp::Reverse(expression_complexity(c)));
137 candidates.dedup_by(|a, b| expressions_equivalent(a, b));
138
139 candidates
140}
141
142fn collect_candidates_recursive(expr: &Expression, var: &Symbol, candidates: &mut Vec<Expression>) {
144 match expr {
145 Expression::Function { name: _, args } => {
146 if args.len() == 1 && args[0].contains_variable(var) {
149 if is_simple_variable(&args[0], var) {
151 candidates.push(expr.clone());
152 } else {
153 candidates.push(args[0].clone());
155 }
156 }
157 for arg in args.iter() {
158 if arg.contains_variable(var) && !is_simple_variable(arg, var) {
159 candidates.push(arg.clone());
160 }
161 collect_candidates_recursive(arg, var, candidates);
162 }
163 }
164 Expression::Pow(base, exp) => {
165 if base.contains_variable(var) && !is_simple_variable(base, var) {
166 candidates.push((**base).clone());
167 }
168 if exp.contains_variable(var) && !is_simple_variable(exp, var) {
169 candidates.push((**exp).clone());
170 }
171 collect_candidates_recursive(base, var, candidates);
172 collect_candidates_recursive(exp, var, candidates);
173 }
174 Expression::Add(terms) => {
175 for term in terms.iter() {
176 collect_candidates_recursive(term, var, candidates);
177 }
178 }
179 Expression::Mul(factors) => {
180 for factor in factors.iter() {
181 collect_candidates_recursive(factor, var, candidates);
182 }
183 }
184 _ => {}
185 }
186}
187
188fn contains_expression(expr: &Expression, candidate: &Expression) -> bool {
192 if expr == candidate {
193 return true;
194 }
195
196 match expr {
197 Expression::Add(terms) => terms.iter().any(|t| contains_expression(t, candidate)),
198 Expression::Mul(factors) => factors.iter().any(|f| contains_expression(f, candidate)),
199 Expression::Pow(base, exp) => {
200 contains_expression(base, candidate) || contains_expression(exp, candidate)
201 }
202 Expression::Function { name: _, args } => {
203 args.iter().any(|a| contains_expression(a, candidate))
204 }
205 _ => false,
206 }
207}
208
209fn is_simple_variable(expr: &Expression, var: &Symbol) -> bool {
211 matches!(expr, Expression::Symbol(s) if s == var)
212}
213
214fn expression_complexity(expr: &Expression) -> usize {
216 match expr {
217 Expression::Number(_) | Expression::Symbol(_) | Expression::Constant(_) => 1,
218 Expression::Add(terms) => terms.iter().map(expression_complexity).sum::<usize>() + 1,
219 Expression::Mul(factors) => factors.iter().map(expression_complexity).sum::<usize>() + 1,
220 Expression::Pow(base, exp) => expression_complexity(base) + expression_complexity(exp) + 1,
221 Expression::Function { name: _, args } => {
222 args.iter().map(expression_complexity).sum::<usize>() + 2
223 }
224 _ => 1,
225 }
226}
227
228fn expressions_equivalent(a: &Expression, b: &Expression) -> bool {
230 a == b
231}
232
233fn is_constant_derivative(g_prime: &Expression, var: &Symbol) -> bool {
235 !g_prime.contains_variable(var)
236}
237
238fn check_derivative_match(
250 expr: &Expression,
251 g: &Expression,
252 g_prime: &Expression,
253 var: &Symbol,
254) -> Option<(Expression, f64)> {
255 let expr_simplified = expr.clone().simplify();
256 let g_prime_simplified = g_prime.clone().simplify();
257
258 if is_constant_derivative(&g_prime_simplified, var) {
261 if let Some(derivative_value) = extract_constant_value(&g_prime_simplified) {
263 if contains_expression(&expr_simplified, g) {
265 let u_symbol = Symbol::scalar("u");
266 let u_expr = Expression::symbol(u_symbol);
267
268 let f_of_u = replace_expression(&expr_simplified, g, &u_expr);
270
271 return Some((f_of_u, 1.0 / derivative_value));
274 }
275 }
276 }
277
278 if let Expression::Mul(factors) = &expr_simplified {
279 let u_symbol = Symbol::scalar("u");
280 let u_expr = Expression::symbol(u_symbol);
281
282 let (f_of_g_factors, derivative_candidate_factors): (Vec<_>, Vec<_>) =
286 factors.iter().partition(|f| contains_expression(f, g));
287
288 if !f_of_g_factors.is_empty() && !derivative_candidate_factors.is_empty() {
289 let derivative_candidate = if derivative_candidate_factors.len() == 1 {
291 derivative_candidate_factors[0].clone()
292 } else {
293 Expression::mul(
294 derivative_candidate_factors
295 .iter()
296 .map(|f| (*f).clone())
297 .collect(),
298 )
299 };
300
301 if let Some(ratio) = compute_constant_ratio(&derivative_candidate, &g_prime_simplified)
303 {
304 let remaining = if f_of_g_factors.is_empty() {
307 Expression::integer(1)
308 } else if f_of_g_factors.len() == 1 {
309 f_of_g_factors[0].clone()
310 } else {
311 Expression::mul(f_of_g_factors.iter().map(|f| (*f).clone()).collect())
312 };
313
314 let f_of_u = replace_expression(&remaining, g, &u_expr);
316
317 return Some((f_of_u, ratio));
318 }
319 }
320
321 let (derivative_factors, other_factors): (Vec<_>, Vec<_>) = factors
323 .iter()
324 .partition(|f| factor_matches_derivative(f, &g_prime_simplified, var));
325
326 if derivative_factors.is_empty() {
327 return None;
328 }
329
330 let derivative_product = if derivative_factors.len() == 1 {
331 derivative_factors[0].clone()
332 } else {
333 Expression::mul(derivative_factors.iter().map(|f| (*f).clone()).collect())
334 };
335
336 let constant_factor = compute_constant_ratio(&derivative_product, &g_prime_simplified)?;
337
338 let remaining = if other_factors.is_empty() {
339 Expression::integer(1)
340 } else if other_factors.len() == 1 {
341 other_factors[0].clone()
342 } else {
343 Expression::mul(other_factors.iter().map(|f| (*f).clone()).collect())
344 };
345
346 let f_of_u = replace_expression(&remaining, g, &u_expr);
347
348 Some((f_of_u, constant_factor))
349 } else {
350 let constant_factor = compute_constant_ratio(&expr_simplified, &g_prime_simplified)?;
351 let f_of_u = Expression::integer(1);
352 Some((f_of_u, constant_factor))
353 }
354}
355
356fn extract_constant_value(expr: &Expression) -> Option<f64> {
360 match expr {
361 Expression::Number(n) => number_to_f64(n),
362 _ => None,
363 }
364}
365
366fn replace_expression(
370 expr: &Expression,
371 pattern: &Expression,
372 replacement: &Expression,
373) -> Expression {
374 if expr == pattern {
376 return replacement.clone();
377 }
378
379 match expr {
381 Expression::Add(terms) => Expression::add(
382 terms
383 .iter()
384 .map(|t| replace_expression(t, pattern, replacement))
385 .collect(),
386 ),
387 Expression::Mul(factors) => Expression::mul(
388 factors
389 .iter()
390 .map(|f| replace_expression(f, pattern, replacement))
391 .collect(),
392 ),
393 Expression::Pow(base, exp) => Expression::pow(
394 replace_expression(base, pattern, replacement),
395 replace_expression(exp, pattern, replacement),
396 ),
397 Expression::Function { name, args } => Expression::function(
398 name,
399 args.iter()
400 .map(|a| replace_expression(a, pattern, replacement))
401 .collect(),
402 ),
403 _ => expr.clone(),
404 }
405}
406
407fn factor_matches_derivative(factor: &Expression, derivative: &Expression, var: &Symbol) -> bool {
409 if factor == derivative {
410 return true;
411 }
412
413 let factor_simplified = factor.clone().simplify();
414 let derivative_simplified = derivative.clone().simplify();
415
416 if factor_simplified == derivative_simplified {
417 return true;
418 }
419
420 if let (Expression::Mul(f_factors), Expression::Mul(d_factors)) =
421 (&factor_simplified, &derivative_simplified)
422 {
423 let f_non_const: Vec<_> = f_factors
424 .iter()
425 .filter(|f| f.contains_variable(var))
426 .collect();
427 let d_non_const: Vec<_> = d_factors
428 .iter()
429 .filter(|f| f.contains_variable(var))
430 .collect();
431
432 if f_non_const.len() == d_non_const.len() {
433 return f_non_const
434 .iter()
435 .zip(d_non_const.iter())
436 .all(|(f, d)| f == d);
437 }
438 }
439
440 match (&factor_simplified, &derivative_simplified) {
441 (Expression::Symbol(f_sym), Expression::Symbol(d_sym)) => f_sym == d_sym,
442 (Expression::Pow(f_base, f_exp), Expression::Pow(d_base, d_exp)) => {
443 f_base == d_base && f_exp == d_exp
444 }
445 _ => false,
446 }
447}
448
449fn compute_constant_ratio(expr: &Expression, target: &Expression) -> Option<f64> {
457 if expr == target {
458 return Some(1.0);
459 }
460
461 let expr_simp = expr.clone().simplify();
462 let target_simp = target.clone().simplify();
463
464 if expr_simp == target_simp {
465 return Some(1.0);
466 }
467
468 match (&expr_simp, &target_simp) {
470 (Expression::Number(n1), Expression::Number(n2)) => {
471 let v1 = number_to_f64(n1)?;
472 let v2 = number_to_f64(n2)?;
473 if v2.abs() > 1e-10 {
474 let ratio = v1 / v2;
475 Some(ratio)
476 } else {
477 None
478 }
479 }
480 (Expression::Mul(e_factors), Expression::Mul(t_factors)) => {
482 let e_coeff = extract_coefficient(e_factors);
483 let t_coeff = extract_coefficient(t_factors);
484
485 let e_non_const: Vec<_> = e_factors
486 .iter()
487 .filter(|f| !matches!(f, Expression::Number(_)))
488 .collect();
489 let t_non_const: Vec<_> = t_factors
490 .iter()
491 .filter(|f| !matches!(f, Expression::Number(_)))
492 .collect();
493
494 if e_non_const.len() == t_non_const.len()
496 && e_non_const
497 .iter()
498 .zip(t_non_const.iter())
499 .all(|(a, b)| *a == *b)
500 && t_coeff.abs() > 1e-10
501 {
502 let ratio = e_coeff / t_coeff;
503 return Some(ratio);
504 }
505 None
506 }
507 (Expression::Mul(factors), _) => {
509 let coeff = extract_coefficient(factors);
510 let non_const: Vec<_> = factors
511 .iter()
512 .filter(|f| !matches!(f, Expression::Number(_)))
513 .collect();
514
515 let non_const_product = if non_const.is_empty() {
516 Expression::integer(1)
517 } else if non_const.len() == 1 {
518 (*non_const[0]).clone()
519 } else {
520 Expression::mul(non_const.iter().map(|f| (*f).clone()).collect())
521 };
522
523 if non_const_product == target_simp {
524 Some(coeff)
525 } else {
526 None
527 }
528 }
529 (_, Expression::Mul(factors)) => {
531 let coeff = extract_coefficient(factors);
532 let non_const: Vec<_> = factors
533 .iter()
534 .filter(|f| !matches!(f, Expression::Number(_)))
535 .collect();
536
537 let non_const_product = if non_const.is_empty() {
538 Expression::integer(1)
539 } else if non_const.len() == 1 {
540 (*non_const[0]).clone()
541 } else {
542 Expression::mul(non_const.iter().map(|f| (*f).clone()).collect())
543 };
544
545 if expr_simp == non_const_product && coeff.abs() > 1e-10 {
546 let ratio = 1.0 / coeff;
547 Some(ratio)
548 } else {
549 None
550 }
551 }
552 _ => None,
553 }
554}
555
556fn extract_coefficient(factors: &[Expression]) -> f64 {
560 let nums: Vec<f64> = factors
561 .iter()
562 .filter_map(|f| {
563 if let Expression::Number(n) = f {
564 number_to_f64(n)
565 } else {
566 None
567 }
568 })
569 .collect();
570
571 if nums.is_empty() {
572 1.0
573 } else {
574 nums.iter().product()
575 }
576}
577
578fn number_to_f64(num: &Number) -> Option<f64> {
580 match num {
581 Number::Integer(i) => Some(*i as f64),
582 Number::Rational(r) => {
583 use num_traits::ToPrimitive;
584 r.to_f64()
585 }
586 Number::Float(f) => Some(*f),
587 _ => None,
588 }
589}
590
591fn integrate_in_u(expr: &Expression, u: Symbol, depth: usize) -> Option<Expression> {
596 use crate::calculus::integrals::strategy::integrate_with_strategy;
597
598 let result = integrate_with_strategy(expr, u, depth + 1);
599
600 if matches!(result, Expression::Calculus(_)) {
601 None
602 } else {
603 Some(result)
604 }
605}
606
607fn substitute_back(expr: &Expression, u: &Expression, g: &Expression) -> Expression {
612 replace_expression(expr, u, g)
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::symbol;
619
620 #[test]
621 fn test_is_simple_variable() {
622 let x = symbol!(x);
623
624 assert!(is_simple_variable(&Expression::symbol(x.clone()), &x));
625 assert!(!is_simple_variable(&Expression::integer(5), &x));
626 assert!(!is_simple_variable(
627 &Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
628 &x
629 ));
630 }
631
632 #[test]
633 fn test_expression_complexity() {
634 let x = symbol!(x);
635
636 assert_eq!(expression_complexity(&Expression::integer(5)), 1);
637 assert_eq!(expression_complexity(&Expression::symbol(x.clone())), 1);
638
639 let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
640 assert_eq!(expression_complexity(&x_squared), 3);
641
642 let sin_x = Expression::function("sin", vec![Expression::symbol(x.clone())]);
643 assert_eq!(expression_complexity(&sin_x), 3);
644 }
645
646 #[test]
647 fn test_find_substitution_candidates_basic() {
648 let x = symbol!(x);
649 let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
650 let sin_x_squared = Expression::function("sin", vec![x_squared.clone()]);
651
652 let candidates = find_substitution_candidates(&sin_x_squared, &x);
653
654 assert!(!candidates.is_empty());
655 assert!(candidates.contains(&x_squared));
656 }
657
658 #[test]
659 fn test_replace_expression() {
660 let x = symbol!(x);
661 let u = symbol!(u);
662
663 let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
665 let expr = Expression::function("exp", vec![x_squared.clone()]);
666 let u_expr = Expression::symbol(u.clone());
667
668 let result = replace_expression(&expr, &x_squared, &u_expr);
669 let expected = Expression::function("exp", vec![u_expr]);
670
671 assert_eq!(result, expected);
672 }
673
674 #[test]
675 fn test_is_constant_derivative() {
676 let x = symbol!(x);
677
678 assert!(is_constant_derivative(&Expression::integer(1), &x));
680 assert!(is_constant_derivative(&Expression::integer(2), &x));
681 assert!(is_constant_derivative(&Expression::rational(3, 2), &x));
682
683 assert!(!is_constant_derivative(&Expression::symbol(x.clone()), &x));
685 assert!(!is_constant_derivative(
686 &Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
687 &x
688 ));
689 }
690
691 #[test]
692 fn test_extract_constant_value() {
693 assert_eq!(extract_constant_value(&Expression::integer(1)), Some(1.0));
694 assert_eq!(extract_constant_value(&Expression::integer(5)), Some(5.0));
695 assert_eq!(
696 extract_constant_value(&Expression::rational(3, 2)),
697 Some(1.5)
698 );
699
700 let x = symbol!(x);
701 assert_eq!(extract_constant_value(&Expression::symbol(x.clone())), None);
702 }
703
704 #[test]
705 fn test_exponential_chain_rule_pattern() {
706 let x = symbol!(x);
708 let expr = Expression::mul(vec![
709 Expression::integer(2),
710 Expression::symbol(x.clone()),
711 Expression::function(
712 "exp",
713 vec![Expression::pow(
714 Expression::symbol(x.clone()),
715 Expression::integer(2),
716 )],
717 ),
718 ]);
719
720 let result = try_substitution(&expr, &x, 0);
721 assert!(
722 result.is_some(),
723 "Exponential chain rule pattern should succeed"
724 );
725 }
726
727 #[test]
728 fn test_trig_substitution_with_coefficient() {
729 let x = symbol!(x);
731 let expr = Expression::mul(vec![
732 Expression::symbol(x.clone()),
733 Expression::function(
734 "sin",
735 vec![Expression::pow(
736 Expression::symbol(x.clone()),
737 Expression::integer(2),
738 )],
739 ),
740 ]);
741
742 let result = try_substitution(&expr, &x, 0);
743 assert!(
744 result.is_some(),
745 "Trig substitution with coefficient should succeed"
746 );
747 }
748
749 #[test]
750 fn test_power_chain_rule_pattern() {
751 let x = symbol!(x);
753 let expr = Expression::mul(vec![
754 Expression::pow(
755 Expression::function("sin", vec![Expression::symbol(x.clone())]),
756 Expression::integer(3),
757 ),
758 Expression::function("cos", vec![Expression::symbol(x.clone())]),
759 ]);
760
761 let result = try_substitution(&expr, &x, 0);
762 assert!(result.is_some(), "Power chain rule pattern should succeed");
763 }
764
765 #[test]
766 fn test_constant_derivative_linear() {
767 let x = symbol!(x);
769 let inner = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
770 let expr = Expression::function("sqrt", vec![inner.clone()]);
771
772 let result = try_substitution(&expr, &x, 0);
773 assert!(
774 result.is_some(),
775 "Constant derivative substitution should succeed for sqrt(x+1)"
776 );
777 }
778
779 #[test]
780 fn test_max_depth_prevents_infinite_recursion() {
781 let x = symbol!(x);
782
783 let simple_expr = Expression::symbol(x.clone());
784 let _result_at_limit = try_substitution(&simple_expr, &x, MAX_DEPTH - 1);
785
786 let result_over_limit = try_substitution(&simple_expr, &x, MAX_DEPTH);
787 assert_eq!(
788 result_over_limit, None,
789 "Should return None when depth >= MAX_DEPTH"
790 );
791 }
792}