mathhook_core/calculus/integrals/risch/
rde.rs

1//! Risch Differential Equation (RDE) solving
2//!
3//! Solves the RDE: y' + f*y = g for basic cases.
4//! This module handles simple exponential and logarithmic patterns.
5use super::{
6    differential_extension::DifferentialExtension,
7    helpers::{extract_division, is_just_variable, is_one},
8    RischResult,
9};
10use crate::calculus::derivatives::Derivative;
11use crate::core::{Expression, Number, Symbol};
12use crate::simplify::Simplify;
13/// Integrate transcendental part using RDE
14///
15/// Attempts to solve the Risch differential equation for basic patterns.
16/// Returns Integral, NonElementary, or Unknown based on the analysis.
17///
18/// # Arguments
19///
20/// * `expr` - The transcendental expression to integrate
21/// * `extensions` - The differential extension tower
22/// * `var` - The variable of integration
23///
24/// # Examples
25///
26/// ```rust
27/// use mathhook_core::calculus::integrals::risch::rde::integrate_transcendental;
28/// use mathhook_core::calculus::integrals::risch::differential_extension::DifferentialExtension;
29/// use mathhook_core::Expression;
30/// use mathhook_core::symbol;
31///
32/// let x = symbol!(x);
33/// let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
34/// let extensions = vec![DifferentialExtension::Rational];
35///
36/// let result = integrate_transcendental(&expr, &extensions, &x);
37/// ```
38pub fn integrate_transcendental(
39    expr: &Expression,
40    _extensions: &[DifferentialExtension],
41    var: &Symbol,
42) -> RischResult {
43    if let Some(result) = try_simple_exponential(expr, var) {
44        return RischResult::Integral(result);
45    }
46    if let Some(result) = try_logarithmic_derivative(expr, var) {
47        return RischResult::Integral(result);
48    }
49    if let Some(result) = try_exponential_product(expr, var) {
50        return RischResult::Integral(result);
51    }
52    if is_non_elementary_pattern(expr, var) {
53        return RischResult::NonElementary;
54    }
55    RischResult::Unknown
56}
57/// Try to integrate simple exponential e^(ax)
58///
59/// Handles patterns: e^(ax) where a is constant
60/// Result: e^(ax)/a
61fn try_simple_exponential(expr: &Expression, var: &Symbol) -> Option<Expression> {
62    match expr {
63        Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
64            let arg = &args[0];
65            if let Some(coeff) = extract_linear_coefficient(arg, var) {
66                return Some(Expression::div(expr.clone(), coeff));
67            }
68            if is_just_variable(arg, var) {
69                return Some(expr.clone());
70            }
71            None
72        }
73        _ => None,
74    }
75}
76/// Try to integrate logarithmic derivative patterns
77///
78/// Handles: 1/x → ln|x|, 1/(ax+b) → (1/a)*ln|ax+b|
79fn try_logarithmic_derivative(expr: &Expression, var: &Symbol) -> Option<Expression> {
80    if let Some((num, den)) = extract_division(expr) {
81        if is_one(&num) {
82            if is_just_variable(&den, var) {
83                return Some(Expression::function("ln", vec![den]));
84            }
85            if let Some((a, b)) = extract_linear_form(&den, var) {
86                let ln_arg = if b == Expression::integer(0) {
87                    Expression::mul(vec![a.clone(), Expression::symbol(var.clone())])
88                } else {
89                    Expression::add(vec![
90                        Expression::mul(vec![a.clone(), Expression::symbol(var.clone())]),
91                        b,
92                    ])
93                };
94                return Some(Expression::div(Expression::function("ln", vec![ln_arg]), a));
95            }
96        }
97        if let Some(log_arg) = is_logarithmic_derivative_pattern(&num, &den, var.clone()) {
98            return Some(Expression::function("ln", vec![log_arg]));
99        }
100    }
101    None
102}
103/// Try to integrate exponential products
104///
105/// Handles: x*e^x, (ax+b)*e^(cx)
106fn try_exponential_product(expr: &Expression, var: &Symbol) -> Option<Expression> {
107    match expr {
108        Expression::Mul(factors) if factors.len() == 2 => {
109            let f1 = &factors[0];
110            let f2 = &factors[1];
111            if let Some(result) = check_exp_product(f1, f2, var) {
112                return Some(result);
113            }
114            if let Some(result) = check_exp_product(f2, f1, var) {
115                return Some(result);
116            }
117            None
118        }
119        _ => None,
120    }
121}
122/// Check if pattern is f1 * exp(f2) where f1 is linear
123fn check_exp_product(
124    linear: &Expression,
125    exp_part: &Expression,
126    var: &Symbol,
127) -> Option<Expression> {
128    if let Expression::Function { name, args } = exp_part {
129        if name == "exp" && args.len() == 1 {
130            let exp_arg = &args[0];
131            if is_just_variable(linear, var) && is_just_variable(exp_arg, var) {
132                return Some(Expression::mul(vec![
133                    Expression::add(vec![
134                        Expression::symbol(var.clone()),
135                        Expression::integer(-1),
136                    ]),
137                    exp_part.clone(),
138                ]));
139            }
140        }
141    }
142    None
143}
144/// Check if pattern is known to be non-elementary
145///
146/// Detects patterns that provably have no elementary antiderivative.
147fn is_non_elementary_pattern(expr: &Expression, var: &Symbol) -> bool {
148    if let Some((num, den)) = extract_division(expr) {
149        if is_exponential_of_var(&num, var) && is_just_variable(&den, var) {
150            return true;
151        }
152        if is_sine_of_var(&num, var) && is_just_variable(&den, var) {
153            return true;
154        }
155        if is_one(&num) && is_logarithm_of_var(&den, var) {
156            return true;
157        }
158    }
159    if let Expression::Function { name, args } = expr {
160        if name == "exp" && args.len() == 1 && is_quadratic(&args[0], var) {
161            return true;
162        }
163    }
164    false
165}
166/// Extract coefficient from linear expression ax
167fn extract_linear_coefficient(expr: &Expression, var: &Symbol) -> Option<Expression> {
168    match expr {
169        Expression::Symbol(s) if s == var => Some(Expression::integer(1)),
170        Expression::Mul(factors) => {
171            let mut coeff = None;
172            let mut has_var = false;
173            for factor in &**factors {
174                if is_just_variable(factor, var) {
175                    has_var = true;
176                } else if !factor.contains_variable(var) {
177                    coeff = Some(factor.clone());
178                }
179            }
180            if has_var {
181                coeff.or(Some(Expression::integer(1)))
182            } else {
183                None
184            }
185        }
186        _ => None,
187    }
188}
189/// Extract (a, b) from ax+b form
190fn extract_linear_form(expr: &Expression, var: &Symbol) -> Option<(Expression, Expression)> {
191    match expr {
192        Expression::Symbol(s) if s == var => Some((Expression::integer(1), Expression::integer(0))),
193        Expression::Add(terms) if terms.len() == 2 => {
194            let t1 = &terms[0];
195            let t2 = &terms[1];
196            if let Some(a) = extract_linear_coefficient(t1, var) {
197                if !t2.contains_variable(var) {
198                    return Some((a, t2.clone()));
199                }
200            }
201            if let Some(a) = extract_linear_coefficient(t2, var) {
202                if !t1.contains_variable(var) {
203                    return Some((a, t1.clone()));
204                }
205            }
206            None
207        }
208        Expression::Mul(_) => {
209            extract_linear_coefficient(expr, var).map(|a| (a, Expression::integer(0)))
210        }
211        _ => None,
212    }
213}
214/// Check if pattern is f'/f (logarithmic derivative)
215///
216/// Recognizes when the numerator is the derivative of the denominator,
217/// which integrates to ln|denominator|.
218///
219/// # Arguments
220///
221/// * `num` - Numerator of the fraction
222/// * `den` - Denominator of the fraction
223/// * `var` - Variable of integration
224///
225/// # Examples
226///
227/// The pattern f'(x)/f(x) integrates to ln|f(x)|. For example:
228/// - 2x/(x²+1) → ln|x²+1| because d/dx[x²+1] = 2x
229/// - 3x²/(x³+1) → ln|x³+1| because d/dx[x³+1] = 3x²
230fn is_logarithmic_derivative_pattern(
231    num: &Expression,
232    den: &Expression,
233    var: Symbol,
234) -> Option<Expression> {
235    let den_derivative = den.derivative(var).simplify();
236    let num_simplified = num.simplify();
237    if num_simplified == den_derivative {
238        Some(den.clone())
239    } else {
240        None
241    }
242}
243/// Check if expression is e^x
244fn is_exponential_of_var(expr: &Expression, var: &Symbol) -> bool {
245    match expr {
246        Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
247            is_just_variable(&args[0], var)
248        }
249        _ => false,
250    }
251}
252/// Check if expression is sin(x)
253fn is_sine_of_var(expr: &Expression, var: &Symbol) -> bool {
254    match expr {
255        Expression::Function { name, args } if name == "sin" && args.len() == 1 => {
256            is_just_variable(&args[0], var)
257        }
258        _ => false,
259    }
260}
261/// Check if expression is ln(x) or log(x)
262fn is_logarithm_of_var(expr: &Expression, var: &Symbol) -> bool {
263    match expr {
264        Expression::Function { name, args }
265            if (name == "ln" || name == "log") && args.len() == 1 =>
266        {
267            is_just_variable(&args[0], var)
268        }
269        _ => false,
270    }
271}
272/// Check if expression is quadratic in variable (x² or -x²)
273fn is_quadratic(expr: &Expression, var: &Symbol) -> bool {
274    match expr {
275        Expression::Pow(base, exp) => is_just_variable(base, var) && is_integer_two(exp),
276        Expression::Mul(factors) if factors.len() == 2 => {
277            if is_negative_one(&factors[0]) {
278                is_quadratic(&factors[1], var)
279            } else if is_negative_one(&factors[1]) {
280                is_quadratic(&factors[0], var)
281            } else {
282                false
283            }
284        }
285        _ => false,
286    }
287}
288/// Check if expression is the constant -1
289fn is_negative_one(expr: &Expression) -> bool {
290    match expr {
291        Expression::Number(Number::Integer(n)) if *n == -1 => true,
292        Expression::Mul(factors) if factors.len() == 2 => {
293            matches!(&factors[0], Expression::Number(Number::Integer(-1))) && is_one(&factors[1])
294                || is_one(&factors[0])
295                    && matches!(&factors[1], Expression::Number(Number::Integer(-1)))
296        }
297        _ => false,
298    }
299}
300/// Check if expression is the integer 2
301fn is_integer_two(expr: &Expression) -> bool {
302    matches!(expr, Expression::Number(Number::Integer(2)))
303}
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::symbol;
308    #[test]
309    fn test_simple_exp_x() {
310        let x = symbol!(x);
311        let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
312        let extensions = vec![DifferentialExtension::Rational];
313        let result = integrate_transcendental(&expr, &extensions, &x);
314        assert!(matches!(result, RischResult::Integral(_)));
315    }
316    #[test]
317    fn test_simple_exp_2x() {
318        let x = symbol!(x);
319        let expr = Expression::function(
320            "exp",
321            vec![Expression::mul(vec![
322                Expression::integer(2),
323                Expression::symbol(x.clone()),
324            ])],
325        );
326        let extensions = vec![DifferentialExtension::Rational];
327        let result = integrate_transcendental(&expr, &extensions, &x);
328        assert!(matches!(result, RischResult::Integral(_)));
329    }
330    #[test]
331    fn test_logarithmic_derivative_one_over_x() {
332        let x = symbol!(x);
333        let expr = Expression::div(Expression::integer(1), Expression::symbol(x.clone()));
334        let extensions = vec![DifferentialExtension::Rational];
335        let result = integrate_transcendental(&expr, &extensions, &x);
336        assert!(matches!(result, RischResult::Integral(_)));
337    }
338    #[test]
339    fn test_non_elementary_exp_x_squared() {
340        let x = symbol!(x);
341        let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
342        let expr = Expression::function("exp", vec![x_squared]);
343        let extensions = vec![DifferentialExtension::Rational];
344        let result = integrate_transcendental(&expr, &extensions, &x);
345        assert!(matches!(result, RischResult::NonElementary));
346    }
347    #[test]
348    fn test_non_elementary_exp_over_x() {
349        let x = symbol!(x);
350        let exp_x = Expression::function("exp", vec![Expression::symbol(x.clone())]);
351        let expr = Expression::div(exp_x, Expression::symbol(x.clone()));
352        let extensions = vec![DifferentialExtension::Rational];
353        let result = integrate_transcendental(&expr, &extensions, &x);
354        assert!(matches!(result, RischResult::NonElementary));
355    }
356    #[test]
357    fn test_extract_linear_coefficient_simple() {
358        let x = symbol!(x);
359        let expr = Expression::symbol(x.clone());
360        let coeff = extract_linear_coefficient(&expr, &x);
361        assert_eq!(coeff, Some(Expression::integer(1)));
362    }
363    #[test]
364    fn test_extract_linear_coefficient_scaled() {
365        let x = symbol!(x);
366        let expr = Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]);
367        let coeff = extract_linear_coefficient(&expr, &x);
368        assert_eq!(coeff, Some(Expression::integer(3)));
369    }
370    #[test]
371    fn test_is_quadratic_x_squared() {
372        let x = symbol!(x);
373        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
374        assert!(is_quadratic(&expr, &x));
375    }
376    #[test]
377    fn test_is_not_quadratic_x_cubed() {
378        let x = symbol!(x);
379        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(3));
380        assert!(!is_quadratic(&expr, &x));
381    }
382    #[test]
383    fn test_logarithmic_derivative_pattern_basic() {
384        let x = symbol!(x);
385        let num = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
386        let den = Expression::add(vec![
387            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
388            Expression::integer(1),
389        ]);
390        let result = is_logarithmic_derivative_pattern(&num, &den, x);
391        assert!(result.is_some());
392        assert_eq!(result.unwrap(), den);
393    }
394    #[test]
395    fn test_logarithmic_derivative_pattern_no_match() {
396        let x = symbol!(x);
397        let num = Expression::symbol(x.clone());
398        let den = Expression::add(vec![
399            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
400            Expression::integer(1),
401        ]);
402        let result = is_logarithmic_derivative_pattern(&num, &den, x);
403        assert!(result.is_none());
404    }
405}