mathhook_core/calculus/integrals/risch/
differential_extension.rs1use super::helpers::is_one;
7use crate::core::{Expression, Symbol};
8
9#[derive(Debug, Clone, PartialEq)]
11pub enum DifferentialExtension {
12 Rational,
14
15 Exponential {
17 argument: Box<Expression>,
18 derivative: Box<Expression>,
19 },
20
21 Logarithmic {
23 argument: Box<Expression>,
24 derivative: Box<Expression>,
25 },
26}
27
28pub fn build_extension_tower(expr: &Expression, var: Symbol) -> Option<Vec<DifferentialExtension>> {
51 let mut extensions = vec![DifferentialExtension::Rational];
52
53 if let Some(exp_ext) = detect_exponential_extension(expr, var.clone()) {
55 extensions.push(exp_ext);
56 }
57
58 if let Some(log_ext) = detect_logarithmic_extension(expr, var) {
60 extensions.push(log_ext);
61 }
62
63 Some(extensions)
64}
65
66fn detect_exponential_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
70 match expr {
71 Expression::Function { name, args } if name == "exp" && args.len() == 1 => {
72 let arg = &args[0];
73
74 if arg.contains_variable(&var) {
76 Some(DifferentialExtension::Exponential {
78 argument: Box::new(arg.clone()),
79 derivative: Box::new(compute_exponential_derivative(arg, var)),
80 })
81 } else {
82 None
83 }
84 }
85 Expression::Mul(factors) => {
86 for factor in &**factors {
88 if let Some(ext) = detect_exponential_extension(factor, var.clone()) {
89 return Some(ext);
90 }
91 }
92 None
93 }
94 _ => None,
95 }
96}
97
98fn detect_logarithmic_extension(expr: &Expression, var: Symbol) -> Option<DifferentialExtension> {
102 use super::helpers::extract_division;
103
104 match expr {
105 Expression::Function { name, args }
106 if (name == "ln" || name == "log") && args.len() == 1 =>
107 {
108 let arg = &args[0];
109
110 if arg.contains_variable(&var) {
111 Some(DifferentialExtension::Logarithmic {
112 argument: Box::new(arg.clone()),
113 derivative: Box::new(compute_logarithmic_derivative(arg, var)),
114 })
115 } else {
116 None
117 }
118 }
119 Expression::Mul(_) => {
120 if let Some((num, den)) = extract_division(expr) {
122 if is_one(&num) && den.contains_variable(&var) {
124 return Some(DifferentialExtension::Logarithmic {
125 argument: Box::new(den.clone()),
126 derivative: Box::new(Expression::div(Expression::integer(1), den)),
127 });
128 }
129 }
130 None
131 }
132 Expression::Pow(_, _) => {
133 if let Some((num, den)) = extract_division(expr) {
135 if is_one(&num) && den.contains_variable(&var) {
137 return Some(DifferentialExtension::Logarithmic {
138 argument: Box::new(den.clone()),
139 derivative: Box::new(Expression::div(Expression::integer(1), den)),
140 });
141 }
142 }
143 None
144 }
145 _ => None,
146 }
147}
148
149fn compute_exponential_derivative(arg: &Expression, var: Symbol) -> Expression {
153 let arg_derivative = derivative_of(arg, var);
155 Expression::mul(vec![
156 arg_derivative,
157 Expression::function("exp", vec![arg.clone()]),
158 ])
159}
160
161fn compute_logarithmic_derivative(arg: &Expression, var: Symbol) -> Expression {
165 let arg_derivative = derivative_of(arg, var);
167 Expression::div(arg_derivative, arg.clone())
168}
169
170fn derivative_of(expr: &Expression, var: Symbol) -> Expression {
175 match expr {
176 Expression::Symbol(s) if *s == var => Expression::integer(1),
177 Expression::Number(_) | Expression::Constant(_) => Expression::integer(0),
178 Expression::Symbol(_) => Expression::integer(0),
179 Expression::Mul(factors) => {
180 if factors.len() == 2 {
182 let f = &factors[0];
183 let g = &factors[1];
184 let f_prime = derivative_of(f, var.clone());
185 let g_prime = derivative_of(g, var);
186 Expression::add(vec![
187 Expression::mul(vec![f_prime, g.clone()]),
188 Expression::mul(vec![f.clone(), g_prime]),
189 ])
190 } else {
191 Expression::integer(0)
193 }
194 }
195 Expression::Add(terms) => {
196 Expression::add(
198 terms
199 .iter()
200 .map(|t| derivative_of(t, var.clone()))
201 .collect(),
202 )
203 }
204 Expression::Pow(base, exp) => {
205 if !exp.contains_variable(&var) {
207 let base_derivative = derivative_of(base, var);
209 Expression::mul(vec![
210 (**exp).clone(),
211 Expression::pow(
212 (**base).clone(),
213 Expression::add(vec![(**exp).clone(), Expression::integer(-1)]),
214 ),
215 base_derivative,
216 ])
217 } else {
218 Expression::integer(0)
219 }
220 }
221 _ => Expression::integer(0),
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::symbol;
229
230 #[test]
231 fn test_detect_exponential_simple() {
232 let x = symbol!(x);
233 let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
234
235 let ext = detect_exponential_extension(&expr, x);
236 assert!(ext.is_some());
237 assert!(matches!(
238 ext.unwrap(),
239 DifferentialExtension::Exponential { .. }
240 ));
241 }
242
243 #[test]
244 fn test_detect_logarithmic_simple() {
245 let x = symbol!(x);
246 let expr = Expression::function("ln", vec![Expression::symbol(x.clone())]);
247
248 let ext = detect_logarithmic_extension(&expr, x);
249 assert!(ext.is_some());
250 assert!(matches!(
251 ext.unwrap(),
252 DifferentialExtension::Logarithmic { .. }
253 ));
254 }
255
256 #[test]
257 fn test_detect_logarithmic_derivative() {
258 let x = symbol!(x);
259 let expr = Expression::div(Expression::integer(1), Expression::symbol(x.clone()));
260
261 let ext = detect_logarithmic_extension(&expr, x);
262 assert!(ext.is_some());
263 assert!(matches!(
264 ext.unwrap(),
265 DifferentialExtension::Logarithmic { .. }
266 ));
267 }
268
269 #[test]
270 fn test_build_tower_exponential() {
271 let x = symbol!(x);
272 let expr = Expression::function("exp", vec![Expression::symbol(x.clone())]);
273
274 let tower = build_extension_tower(&expr, x);
275 assert!(tower.is_some());
276 let extensions = tower.unwrap();
277 assert!(extensions.len() >= 2);
278 }
279}