mathhook_core/calculus/integrals/risch/
differential_extension.rs

1//! Differential extension tower construction
2//!
3//! Builds a tower of differential extensions for integrand expressions.
4//! Each extension represents a transcendental function (exponential or logarithmic).
5
6use super::helpers::is_one;
7use crate::core::{Expression, Symbol};
8
9/// Differential extension tower element
10#[derive(Debug, Clone, PartialEq)]
11pub enum DifferentialExtension {
12    /// Base field (rational functions)
13    Rational,
14
15    /// Exponential extension: e^(argument)
16    Exponential {
17        argument: Box<Expression>,
18        derivative: Box<Expression>,
19    },
20
21    /// Logarithmic extension: ln(argument)
22    Logarithmic {
23        argument: Box<Expression>,
24        derivative: Box<Expression>,
25    },
26}
27
28/// Build differential extension tower for expression
29///
30/// Analyzes the expression structure and identifies transcendental
31/// extensions (exponentials and logarithms).
32///
33/// # Arguments
34///
35/// * `expr` - The expression to analyze
36/// * `var` - The variable of integration
37///
38/// # Examples
39///
40/// ```rust
41/// use mathhook_core::calculus::integrals::risch::differential_extension::build_extension_tower;
42/// use mathhook_core::Expression;
43/// use mathhook_core::symbol;
44///
45/// let x = symbol!(x);
46/// let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
47/// let tower = build_extension_tower(&expr, x);
48/// assert!(tower.is_some());
49/// ```
50pub fn build_extension_tower(expr: &Expression, var: Symbol) -> Option<Vec<DifferentialExtension>> {
51    let mut extensions = vec![DifferentialExtension::Rational];
52
53    // Detect exponential extensions
54    if let Some(exp_ext) = detect_exponential_extension(expr, var.clone()) {
55        extensions.push(exp_ext);
56    }
57
58    // Detect logarithmic extensions
59    if let Some(log_ext) = detect_logarithmic_extension(expr, var) {
60        extensions.push(log_ext);
61    }
62
63    Some(extensions)
64}
65
66/// Detect exponential extension in expression
67///
68/// Looks for patterns like exp(x), exp(ax), exp(ax+b).
69fn detect_exponential_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
70    match expr {
71        Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
72            let arg = &args[0];
73
74            // Check if argument contains the variable
75            if arg.contains_variable(&var) {
76                // For basic implementation: handle e^x, e^(ax), e^(ax+b)
77                Some(DifferentialExtension::Exponential {
78                    argument: Box::new(arg.clone()),
79                    derivative: Box::new(compute_exponential_derivative(arg, var)),
80                })
81            } else {
82                None
83            }
84        }
85        Expression::Mul(factors) => {
86            // Check factors for exponential
87            for factor in &**factors {
88                if let Some(ext) = detect_exponential_extension(factor, var.clone()) {
89                    return Some(ext);
90                }
91            }
92            None
93        }
94        _ => None,
95    }
96}
97
98/// Detect logarithmic extension in expression
99///
100/// Looks for patterns like ln(x), ln(ax), 1/x patterns.
101fn detect_logarithmic_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
102    use super::helpers::extract_division;
103
104    match expr {
105        Expression::Function { name, args }
106            if (name == "ln" || name == "log") && args.len() == 1 =>
107        {
108            let arg = &args[0];
109
110            if arg.contains_variable(&var) {
111                Some(DifferentialExtension::Logarithmic {
112                    argument: Box::new(arg.clone()),
113                    derivative: Box::new(compute_logarithmic_derivative(arg, var)),
114                })
115            } else {
116                None
117            }
118        }
119        Expression::Mul(_) => {
120            // Check for division pattern: numerator * denominator^(-1)
121            if let Some((num, den)) = extract_division(expr) {
122                // Check for 1/x pattern (logarithmic derivative)
123                if is_one(&num) && den.contains_variable(&var) {
124                    return Some(DifferentialExtension::Logarithmic {
125                        argument: Box::new(den.clone()),
126                        derivative: Box::new(Expression::div(Expression::integer(1), den)),
127                    });
128                }
129            }
130            None
131        }
132        Expression::Pow(_, _) => {
133            // Check for den^(-1) pattern (represents 1/den)
134            if let Some((num, den)) = extract_division(expr) {
135                // Check for 1/x pattern (logarithmic derivative)
136                if is_one(&num) && den.contains_variable(&var) {
137                    return Some(DifferentialExtension::Logarithmic {
138                        argument: Box::new(den.clone()),
139                        derivative: Box::new(Expression::div(Expression::integer(1), den)),
140                    });
141                }
142            }
143            None
144        }
145        _ => None,
146    }
147}
148
149/// Compute derivative of exponential extension
150///
151/// For t = e^g, derivative is g' * e^g = g' * t
152fn compute_exponential_derivative(arg: &Expression, var: Symbol) -> Expression {
153    // Derivative of e^g is g' * e^g
154    let arg_derivative = derivative_of(arg, var);
155    Expression::mul(vec![
156        arg_derivative,
157        Expression::function("exp", vec![arg.clone()]),
158    ])
159}
160
161/// Compute derivative of logarithmic extension
162///
163/// For t = ln(g), derivative is g'/g
164fn compute_logarithmic_derivative(arg: &Expression, var: Symbol) -> Expression {
165    // Derivative of ln(g) is g'/g
166    let arg_derivative = derivative_of(arg, var);
167    Expression::div(arg_derivative, arg.clone())
168}
169
170/// Compute simple derivative (basic cases only)
171///
172/// This is a simplified derivative computation for the Risch algorithm.
173/// For more complex cases, use the full derivative system.
174fn derivative_of(expr: &Expression, var: Symbol) -> Expression {
175    match expr {
176        Expression::Symbol(s) if *s == var => Expression::integer(1),
177        Expression::Number(_) | Expression::Constant(_) => Expression::integer(0),
178        Expression::Symbol(_) => Expression::integer(0),
179        Expression::Mul(factors) => {
180            // Product rule: (fg)' = f'g + fg'
181            if factors.len() == 2 {
182                let f = &factors[0];
183                let g = &factors[1];
184                let f_prime = derivative_of(f, var.clone());
185                let g_prime = derivative_of(g, var);
186                Expression::add(vec![
187                    Expression::mul(vec![f_prime, g.clone()]),
188                    Expression::mul(vec![f.clone(), g_prime]),
189                ])
190            } else {
191                // For simplicity, handle basic cases
192                Expression::integer(0)
193            }
194        }
195        Expression::Add(terms) => {
196            // Sum rule: (f+g)' = f' + g'
197            Expression::add(
198                terms
199                    .iter()
200                    .map(|t| derivative_of(t, var.clone()))
201                    .collect(),
202            )
203        }
204        Expression::Pow(base, exp) => {
205            // Power rule for constant exponent
206            if !exp.contains_variable(&var) {
207                // (x^n)' = n*x^(n-1) * x'
208                let base_derivative = derivative_of(base, var);
209                Expression::mul(vec![
210                    (**exp).clone(),
211                    Expression::pow(
212                        (**base).clone(),
213                        Expression::add(vec![(**exp).clone(), Expression::integer(-1)]),
214                    ),
215                    base_derivative,
216                ])
217            } else {
218                Expression::integer(0)
219            }
220        }
221        _ => Expression::integer(0),
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::symbol;
229
230    #[test]
231    fn test_detect_exponential_simple() {
232        let x = symbol!(x);
233        let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
234
235        let ext = detect_exponential_extension(&expr, x);
236        assert!(ext.is_some());
237        assert!(matches!(
238            ext.unwrap(),
239            DifferentialExtension::Exponential { .. }
240        ));
241    }
242
243    #[test]
244    fn test_detect_logarithmic_simple() {
245        let x = symbol!(x);
246        let expr = Expression::function("ln", vec![Expression::symbol(x.clone())]);
247
248        let ext = detect_logarithmic_extension(&expr, x);
249        assert!(ext.is_some());
250        assert!(matches!(
251            ext.unwrap(),
252            DifferentialExtension::Logarithmic { .. }
253        ));
254    }
255
256    #[test]
257    fn test_detect_logarithmic_derivative() {
258        let x = symbol!(x);
259        let expr = Expression::div(Expression::integer(1), Expression::symbol(x.clone()));
260
261        let ext = detect_logarithmic_extension(&expr, x);
262        assert!(ext.is_some());
263        assert!(matches!(
264            ext.unwrap(),
265            DifferentialExtension::Logarithmic { .. }
266        ));
267    }
268
269    #[test]
270    fn test_build_tower_exponential() {
271        let x = symbol!(x);
272        let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
273
274        let tower = build_extension_tower(&expr, x);
275        assert!(tower.is_some());
276        let extensions = tower.unwrap();
277        assert!(extensions.len() >= 2);
278    }
279}