mathhook_core/calculus/integrals/risch/
rde.rs1use super::{
6 differential_extension::DifferentialExtension,
7 helpers::{extract_division, is_just_variable, is_one},
8 RischResult,
9};
10use crate::calculus::derivatives::Derivative;
11use crate::core::{Expression, Number, Symbol};
12use crate::simplify::Simplify;
13pub fn integrate_transcendental(
39 expr: &Expression,
40 _extensions: &[DifferentialExtension],
41 var: &Symbol,
42) -> RischResult {
43 if let Some(result) = try_simple_exponential(expr, var) {
44 return RischResult::Integral(result);
45 }
46 if let Some(result) = try_logarithmic_derivative(expr, var) {
47 return RischResult::Integral(result);
48 }
49 if let Some(result) = try_exponential_product(expr, var) {
50 return RischResult::Integral(result);
51 }
52 if is_non_elementary_pattern(expr, var) {
53 return RischResult::NonElementary;
54 }
55 RischResult::Unknown
56}
57fn try_simple_exponential(expr: &Expression, var: &Symbol) -> Option<Expression> {
62 match expr {
63 Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
64 let arg = &args[0];
65 if let Some(coeff) = extract_linear_coefficient(arg, var) {
66 return Some(Expression::div(expr.clone(), coeff));
67 }
68 if is_just_variable(arg, var) {
69 return Some(expr.clone());
70 }
71 None
72 }
73 _ => None,
74 }
75}
76fn try_logarithmic_derivative(expr: &Expression, var: &Symbol) -> Option<Expression> {
80 if let Some((num, den)) = extract_division(expr) {
81 if is_one(&num) {
82 if is_just_variable(&den, var) {
83 return Some(Expression::function("ln", vec![den]));
84 }
85 if let Some((a, b)) = extract_linear_form(&den, var) {
86 let ln_arg = if b == Expression::integer(0) {
87 Expression::mul(vec![a.clone(), Expression::symbol(var.clone())])
88 } else {
89 Expression::add(vec![
90 Expression::mul(vec![a.clone(), Expression::symbol(var.clone())]),
91 b,
92 ])
93 };
94 return Some(Expression::div(Expression::function("ln", vec![ln_arg]), a));
95 }
96 }
97 if let Some(log_arg) = is_logarithmic_derivative_pattern(&num, &den, var.clone()) {
98 return Some(Expression::function("ln", vec![log_arg]));
99 }
100 }
101 None
102}
103fn try_exponential_product(expr: &Expression, var: &Symbol) -> Option<Expression> {
107 match expr {
108 Expression::Mul(factors) if factors.len() == 2 => {
109 let f1 = &factors[0];
110 let f2 = &factors[1];
111 if let Some(result) = check_exp_product(f1, f2, var) {
112 return Some(result);
113 }
114 if let Some(result) = check_exp_product(f2, f1, var) {
115 return Some(result);
116 }
117 None
118 }
119 _ => None,
120 }
121}
122fn check_exp_product(
124 linear: &Expression,
125 exp_part: &Expression,
126 var: &Symbol,
127) -> Option<Expression> {
128 if let Expression::Function { name, args } = exp_part {
129 if name == "exp" && args.len() == 1 {
130 let exp_arg = &args[0];
131 if is_just_variable(linear, var) && is_just_variable(exp_arg, var) {
132 return Some(Expression::mul(vec![
133 Expression::add(vec![
134 Expression::symbol(var.clone()),
135 Expression::integer(-1),
136 ]),
137 exp_part.clone(),
138 ]));
139 }
140 }
141 }
142 None
143}
144fn is_non_elementary_pattern(expr: &Expression, var: &Symbol) -> bool {
148 if let Some((num, den)) = extract_division(expr) {
149 if is_exponential_of_var(&num, var) && is_just_variable(&den, var) {
150 return true;
151 }
152 if is_sine_of_var(&num, var) && is_just_variable(&den, var) {
153 return true;
154 }
155 if is_one(&num) && is_logarithm_of_var(&den, var) {
156 return true;
157 }
158 }
159 if let Expression::Function { name, args } = expr {
160 if name == "exp" && args.len() == 1 && is_quadratic(&args[0], var) {
161 return true;
162 }
163 }
164 false
165}
166fn extract_linear_coefficient(expr: &Expression, var: &Symbol) -> Option<Expression> {
168 match expr {
169 Expression::Symbol(s) if s == var => Some(Expression::integer(1)),
170 Expression::Mul(factors) => {
171 let mut coeff = None;
172 let mut has_var = false;
173 for factor in &**factors {
174 if is_just_variable(factor, var) {
175 has_var = true;
176 } else if !factor.contains_variable(var) {
177 coeff = Some(factor.clone());
178 }
179 }
180 if has_var {
181 coeff.or(Some(Expression::integer(1)))
182 } else {
183 None
184 }
185 }
186 _ => None,
187 }
188}
189fn extract_linear_form(expr: &Expression, var: &Symbol) -> Option<(Expression, Expression)> {
191 match expr {
192 Expression::Symbol(s) if s == var => Some((Expression::integer(1), Expression::integer(0))),
193 Expression::Add(terms) if terms.len() == 2 => {
194 let t1 = &terms[0];
195 let t2 = &terms[1];
196 if let Some(a) = extract_linear_coefficient(t1, var) {
197 if !t2.contains_variable(var) {
198 return Some((a, t2.clone()));
199 }
200 }
201 if let Some(a) = extract_linear_coefficient(t2, var) {
202 if !t1.contains_variable(var) {
203 return Some((a, t1.clone()));
204 }
205 }
206 None
207 }
208 Expression::Mul(_) => {
209 extract_linear_coefficient(expr, var).map(|a| (a, Expression::integer(0)))
210 }
211 _ => None,
212 }
213}
214fn is_logarithmic_derivative_pattern(
231 num: &Expression,
232 den: &Expression,
233 var: Symbol,
234) -> Option<Expression> {
235 let den_derivative = den.derivative(var).simplify();
236 let num_simplified = num.simplify();
237 if num_simplified == den_derivative {
238 Some(den.clone())
239 } else {
240 None
241 }
242}
243fn is_exponential_of_var(expr: &Expression, var: &Symbol) -> bool {
245 match expr {
246 Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
247 is_just_variable(&args[0], var)
248 }
249 _ => false,
250 }
251}
252fn is_sine_of_var(expr: &Expression, var: &Symbol) -> bool {
254 match expr {
255 Expression::Function { name, args } if name == "sin" && args.len() == 1 => {
256 is_just_variable(&args[0], var)
257 }
258 _ => false,
259 }
260}
261fn is_logarithm_of_var(expr: &Expression, var: &Symbol) -> bool {
263 match expr {
264 Expression::Function { name, args }
265 if (name == "ln" || name == "log") && args.len() == 1 =>
266 {
267 is_just_variable(&args[0], var)
268 }
269 _ => false,
270 }
271}
272fn is_quadratic(expr: &Expression, var: &Symbol) -> bool {
274 match expr {
275 Expression::Pow(base, exp) => is_just_variable(base, var) && is_integer_two(exp),
276 Expression::Mul(factors) if factors.len() == 2 => {
277 if is_negative_one(&factors[0]) {
278 is_quadratic(&factors[1], var)
279 } else if is_negative_one(&factors[1]) {
280 is_quadratic(&factors[0], var)
281 } else {
282 false
283 }
284 }
285 _ => false,
286 }
287}
288fn is_negative_one(expr: &Expression) -> bool {
290 match expr {
291 Expression::Number(Number::Integer(n)) if *n == -1 => true,
292 Expression::Mul(factors) if factors.len() == 2 => {
293 matches!(&factors[0], Expression::Number(Number::Integer(-1))) && is_one(&factors[1])
294 || is_one(&factors[0])
295 && matches!(&factors[1], Expression::Number(Number::Integer(-1)))
296 }
297 _ => false,
298 }
299}
300fn is_integer_two(expr: &Expression) -> bool {
302 matches!(expr, Expression::Number(Number::Integer(2)))
303}
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::symbol;
308 #[test]
309 fn test_simple_exp_x() {
310 let x = symbol!(x);
311 let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
312 let extensions = vec![DifferentialExtension::Rational];
313 let result = integrate_transcendental(&expr, &extensions, &x);
314 assert!(matches!(result, RischResult::Integral(_)));
315 }
316 #[test]
317 fn test_simple_exp_2x() {
318 let x = symbol!(x);
319 let expr = Expression::function(
320 "exp",
321 vec![Expression::mul(vec![
322 Expression::integer(2),
323 Expression::symbol(x.clone()),
324 ])],
325 );
326 let extensions = vec![DifferentialExtension::Rational];
327 let result = integrate_transcendental(&expr, &extensions, &x);
328 assert!(matches!(result, RischResult::Integral(_)));
329 }
330 #[test]
331 fn test_logarithmic_derivative_one_over_x() {
332 let x = symbol!(x);
333 let expr = Expression::div(Expression::integer(1), Expression::symbol(x.clone()));
334 let extensions = vec![DifferentialExtension::Rational];
335 let result = integrate_transcendental(&expr, &extensions, &x);
336 assert!(matches!(result, RischResult::Integral(_)));
337 }
338 #[test]
339 fn test_non_elementary_exp_x_squared() {
340 let x = symbol!(x);
341 let x_squared = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
342 let expr = Expression::function("exp", vec![x_squared]);
343 let extensions = vec![DifferentialExtension::Rational];
344 let result = integrate_transcendental(&expr, &extensions, &x);
345 assert!(matches!(result, RischResult::NonElementary));
346 }
347 #[test]
348 fn test_non_elementary_exp_over_x() {
349 let x = symbol!(x);
350 let exp_x = Expression::function("exp", vec![Expression::symbol(x.clone())]);
351 let expr = Expression::div(exp_x, Expression::symbol(x.clone()));
352 let extensions = vec![DifferentialExtension::Rational];
353 let result = integrate_transcendental(&expr, &extensions, &x);
354 assert!(matches!(result, RischResult::NonElementary));
355 }
356 #[test]
357 fn test_extract_linear_coefficient_simple() {
358 let x = symbol!(x);
359 let expr = Expression::symbol(x.clone());
360 let coeff = extract_linear_coefficient(&expr, &x);
361 assert_eq!(coeff, Some(Expression::integer(1)));
362 }
363 #[test]
364 fn test_extract_linear_coefficient_scaled() {
365 let x = symbol!(x);
366 let expr = Expression::mul(vec![Expression::integer(3), Expression::symbol(x.clone())]);
367 let coeff = extract_linear_coefficient(&expr, &x);
368 assert_eq!(coeff, Some(Expression::integer(3)));
369 }
370 #[test]
371 fn test_is_quadratic_x_squared() {
372 let x = symbol!(x);
373 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
374 assert!(is_quadratic(&expr, &x));
375 }
376 #[test]
377 fn test_is_not_quadratic_x_cubed() {
378 let x = symbol!(x);
379 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(3));
380 assert!(!is_quadratic(&expr, &x));
381 }
382 #[test]
383 fn test_logarithmic_derivative_pattern_basic() {
384 let x = symbol!(x);
385 let num = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
386 let den = Expression::add(vec![
387 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
388 Expression::integer(1),
389 ]);
390 let result = is_logarithmic_derivative_pattern(&num, &den, x);
391 assert!(result.is_some());
392 assert_eq!(result.unwrap(), den);
393 }
394 #[test]
395 fn test_logarithmic_derivative_pattern_no_match() {
396 let x = symbol!(x);
397 let num = Expression::symbol(x.clone());
398 let den = Expression::add(vec![
399 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
400 Expression::integer(1),
401 ]);
402 let result = is_logarithmic_derivative_pattern(&num, &den, x);
403 assert!(result.is_none());
404 }
405}