mathhook_core/calculus/derivatives/
checker.rs1use crate::core::{Expression, Symbol};
4
5pub struct DifferentiabilityChecker;
7
8impl DifferentiabilityChecker {
9 pub fn check(expr: &Expression, variable: Symbol) -> bool {
22 match expr {
23 Expression::Number(_) | Expression::Constant(_) | Expression::Symbol(_) => true,
24 Expression::Add(terms) | Expression::Mul(terms) => {
25 terms.iter().all(|term| Self::check(term, variable.clone()))
26 }
27 Expression::Pow(base, exponent) => {
28 Self::check(base, variable.clone()) && Self::check(exponent, variable)
29 }
30 Expression::Function { name, args } => {
31 Self::is_function_differentiable(name)
32 && args.iter().all(|arg| Self::check(arg, variable.clone()))
33 }
34 _ => true,
35 }
36 }
37
38 pub fn is_function_differentiable(name: &str) -> bool {
49 !matches!(name, "abs" | "floor" | "ceil" | "sign")
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::symbol;
57 use crate::{MathConstant, Number};
58
59 #[test]
60 fn test_basic_differentiability() {
61 let x = symbol!(x);
62
63 assert!(DifferentiabilityChecker::check(
64 &Expression::integer(5),
65 x.clone()
66 ));
67 assert!(DifferentiabilityChecker::check(
68 &Expression::number(Number::float(std::f64::consts::PI)),
69 x.clone()
70 ));
71 assert!(DifferentiabilityChecker::check(
72 &Expression::symbol(x.clone()),
73 x.clone()
74 ));
75 assert!(DifferentiabilityChecker::check(
76 &Expression::constant(MathConstant::Pi),
77 x.clone()
78 ));
79 assert!(DifferentiabilityChecker::check(
80 &Expression::constant(MathConstant::E),
81 x.clone()
82 ));
83 }
84
85 #[test]
86 fn test_arithmetic_operations() {
87 let x = symbol!(x);
88
89 let sum = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
90 assert!(DifferentiabilityChecker::check(&sum, x.clone()));
91
92 let product = Expression::mul(vec![Expression::symbol(x.clone()), Expression::integer(2)]);
93 assert!(DifferentiabilityChecker::check(&product, x.clone()));
94
95 let power = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
96 assert!(DifferentiabilityChecker::check(&power, x.clone()));
97 }
98
99 #[test]
100 fn test_smooth_functions() {
101 let x = symbol!(x);
102
103 let smooth_functions = vec![
104 "sin", "cos", "tan", "sec", "csc", "cot", "sinh", "cosh", "tanh", "sech", "csch",
105 "coth", "exp", "ln", "log", "log2", "sqrt", "cbrt", "arcsin", "arccos", "arctan",
106 "asinh", "acosh", "atanh", "erf", "erfc", "gamma", "lgamma",
107 ];
108
109 for func_name in smooth_functions {
110 let func_expr = Expression::function(func_name, vec![Expression::symbol(x.clone())]);
111 assert!(
112 DifferentiabilityChecker::check(&func_expr, x.clone()),
113 "Function {} should be differentiable",
114 func_name
115 );
116 assert!(
117 DifferentiabilityChecker::is_function_differentiable(func_name),
118 "Function {} should be marked as differentiable",
119 func_name
120 );
121 }
122 }
123
124 #[test]
125 fn test_non_differentiable_functions() {
126 let non_diff_functions = vec!["abs", "floor", "ceil", "sign"];
127
128 for func_name in non_diff_functions {
129 assert!(
130 !DifferentiabilityChecker::is_function_differentiable(func_name),
131 "Function {} should be marked as non-differentiable",
132 func_name
133 );
134 }
135 }
136
137 #[test]
138 fn test_composite_expressions() {
139 let x = symbol!(x);
140
141 let composite1 = Expression::add(vec![
142 Expression::function("sin", vec![Expression::symbol(x.clone())]),
143 Expression::function("cos", vec![Expression::symbol(x.clone())]),
144 ]);
145 assert!(DifferentiabilityChecker::check(&composite1, x.clone()));
146
147 let composite2 = Expression::mul(vec![
148 Expression::function("exp", vec![Expression::symbol(x.clone())]),
149 Expression::function("ln", vec![Expression::symbol(x.clone())]),
150 ]);
151 assert!(DifferentiabilityChecker::check(&composite2, x.clone()));
152
153 let composite3 = Expression::pow(
154 Expression::function("sin", vec![Expression::symbol(x.clone())]),
155 Expression::integer(2),
156 );
157 assert!(DifferentiabilityChecker::check(&composite3, x.clone()));
158 }
159
160 #[test]
161 fn test_nested_functions() {
162 let x = symbol!(x);
163
164 let nested1 = Expression::function(
165 "sin",
166 vec![Expression::function(
167 "cos",
168 vec![Expression::symbol(x.clone())],
169 )],
170 );
171 assert!(DifferentiabilityChecker::check(&nested1, x.clone()));
172
173 let nested2 = Expression::function(
174 "exp",
175 vec![Expression::function(
176 "ln",
177 vec![Expression::symbol(x.clone())],
178 )],
179 );
180 assert!(DifferentiabilityChecker::check(&nested2, x.clone()));
181
182 let nested3 = Expression::function(
183 "sqrt",
184 vec![Expression::add(vec![
185 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
186 Expression::integer(1),
187 ])],
188 );
189 assert!(DifferentiabilityChecker::check(&nested3, x.clone()));
190 }
191
192 #[test]
193 fn test_multivariate_expressions() {
194 let x = symbol!(x);
195 let y = symbol!(y);
196
197 let multivar1 = Expression::add(vec![
198 Expression::symbol(x.clone()),
199 Expression::symbol(y.clone()),
200 ]);
201 assert!(DifferentiabilityChecker::check(&multivar1, x.clone()));
202 assert!(DifferentiabilityChecker::check(&multivar1, y.clone()));
203
204 let multivar2 = Expression::function(
205 "sin",
206 vec![Expression::mul(vec![
207 Expression::symbol(x.clone()),
208 Expression::symbol(y.clone()),
209 ])],
210 );
211 assert!(DifferentiabilityChecker::check(&multivar2, x.clone()));
212 assert!(DifferentiabilityChecker::check(&multivar2, y.clone()));
213 }
214
215 #[test]
216 fn test_edge_cases() {
217 let x = symbol!(x);
218 let y = symbol!(y);
219
220 let zero_expr = Expression::integer(0);
221 assert!(DifferentiabilityChecker::check(&zero_expr, x.clone()));
222
223 let one_expr = Expression::integer(1);
224 assert!(DifferentiabilityChecker::check(&one_expr, x.clone()));
225
226 let other_var = Expression::symbol(y.clone());
227 assert!(DifferentiabilityChecker::check(&other_var, x.clone()));
228
229 let empty_sum = Expression::add(vec![]);
230 assert!(DifferentiabilityChecker::check(&empty_sum, x.clone()));
231
232 let empty_product = Expression::mul(vec![]);
233 assert!(DifferentiabilityChecker::check(&empty_product, x.clone()));
234 }
235
236 #[test]
237 fn test_function_differentiability_lookup() {
238 assert!(DifferentiabilityChecker::is_function_differentiable("sin"));
239 assert!(DifferentiabilityChecker::is_function_differentiable("cos"));
240 assert!(DifferentiabilityChecker::is_function_differentiable("exp"));
241 assert!(DifferentiabilityChecker::is_function_differentiable("ln"));
242 assert!(DifferentiabilityChecker::is_function_differentiable("sqrt"));
243
244 assert!(!DifferentiabilityChecker::is_function_differentiable("abs"));
245 assert!(!DifferentiabilityChecker::is_function_differentiable(
246 "floor"
247 ));
248 assert!(!DifferentiabilityChecker::is_function_differentiable(
249 "ceil"
250 ));
251 assert!(!DifferentiabilityChecker::is_function_differentiable(
252 "sign"
253 ));
254
255 assert!(DifferentiabilityChecker::is_function_differentiable(
256 "unknown_function"
257 ));
258 }
259
260 #[test]
261 fn test_complex_expressions() {
262 let x = symbol!(x);
263
264 let complex1 = Expression::add(vec![
265 Expression::mul(vec![
266 Expression::function("sin", vec![Expression::symbol(x.clone())]),
267 Expression::function("cos", vec![Expression::symbol(x.clone())]),
268 ]),
269 Expression::pow(
270 Expression::function("exp", vec![Expression::symbol(x.clone())]),
271 Expression::integer(2),
272 ),
273 ]);
274 assert!(DifferentiabilityChecker::check(&complex1, x.clone()));
275
276 let complex2 = Expression::function(
277 "ln",
278 vec![Expression::add(vec![
279 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
280 Expression::integer(1),
281 ])],
282 );
283 assert!(DifferentiabilityChecker::check(&complex2, x.clone()));
284 }
285}