mathhook_core/calculus/derivatives/
checker.rs

1//! Differentiability checking utilities
2
3use crate::core::{Expression, Symbol};
4
5/// Differentiability checker
6pub struct DifferentiabilityChecker;
7
8impl DifferentiabilityChecker {
9    /// Check if an expression is differentiable
10    ///
11    /// # Examples
12    ///
13    /// ```rust
14    /// use mathhook_core::{Expression, symbol};
15    /// use mathhook_core::calculus::derivatives::DifferentiabilityChecker;
16    ///
17    /// let x = symbol!(x);
18    /// let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
19    /// let is_diff = DifferentiabilityChecker::check(&expr, x.clone());
20    /// ```
21    pub fn check(expr: &Expression, variable: Symbol) -> bool {
22        match expr {
23            Expression::Number(_) | Expression::Constant(_) | Expression::Symbol(_) => true,
24            Expression::Add(terms) | Expression::Mul(terms) => {
25                terms.iter().all(|term| Self::check(term, variable.clone()))
26            }
27            Expression::Pow(base, exponent) => {
28                Self::check(base, variable.clone()) && Self::check(exponent, variable)
29            }
30            Expression::Function { name, args } => {
31                Self::is_function_differentiable(name)
32                    && args.iter().all(|arg| Self::check(arg, variable.clone()))
33            }
34            _ => true,
35        }
36    }
37
38    /// Check if a specific function is differentiable
39    ///
40    /// # Examples
41    ///
42    /// ```rust
43    /// use mathhook_core::calculus::derivatives::DifferentiabilityChecker;
44    ///
45    /// let is_sin_diff = DifferentiabilityChecker::is_function_differentiable("sin");
46    /// let is_abs_diff = DifferentiabilityChecker::is_function_differentiable("abs");
47    /// ```
48    pub fn is_function_differentiable(name: &str) -> bool {
49        !matches!(name, "abs" | "floor" | "ceil" | "sign")
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::symbol;
57    use crate::{MathConstant, Number};
58
59    #[test]
60    fn test_basic_differentiability() {
61        let x = symbol!(x);
62
63        assert!(DifferentiabilityChecker::check(
64            &Expression::integer(5),
65            x.clone()
66        ));
67        assert!(DifferentiabilityChecker::check(
68            &Expression::number(Number::float(std::f64::consts::PI)),
69            x.clone()
70        ));
71        assert!(DifferentiabilityChecker::check(
72            &Expression::symbol(x.clone()),
73            x.clone()
74        ));
75        assert!(DifferentiabilityChecker::check(
76            &Expression::constant(MathConstant::Pi),
77            x.clone()
78        ));
79        assert!(DifferentiabilityChecker::check(
80            &Expression::constant(MathConstant::E),
81            x.clone()
82        ));
83    }
84
85    #[test]
86    fn test_arithmetic_operations() {
87        let x = symbol!(x);
88
89        let sum = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
90        assert!(DifferentiabilityChecker::check(&sum, x.clone()));
91
92        let product = Expression::mul(vec![Expression::symbol(x.clone()), Expression::integer(2)]);
93        assert!(DifferentiabilityChecker::check(&product, x.clone()));
94
95        let power = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
96        assert!(DifferentiabilityChecker::check(&power, x.clone()));
97    }
98
99    #[test]
100    fn test_smooth_functions() {
101        let x = symbol!(x);
102
103        let smooth_functions = vec![
104            "sin", "cos", "tan", "sec", "csc", "cot", "sinh", "cosh", "tanh", "sech", "csch",
105            "coth", "exp", "ln", "log", "log2", "sqrt", "cbrt", "arcsin", "arccos", "arctan",
106            "asinh", "acosh", "atanh", "erf", "erfc", "gamma", "lgamma",
107        ];
108
109        for func_name in smooth_functions {
110            let func_expr = Expression::function(func_name, vec![Expression::symbol(x.clone())]);
111            assert!(
112                DifferentiabilityChecker::check(&func_expr, x.clone()),
113                "Function {} should be differentiable",
114                func_name
115            );
116            assert!(
117                DifferentiabilityChecker::is_function_differentiable(func_name),
118                "Function {} should be marked as differentiable",
119                func_name
120            );
121        }
122    }
123
124    #[test]
125    fn test_non_differentiable_functions() {
126        let non_diff_functions = vec!["abs", "floor", "ceil", "sign"];
127
128        for func_name in non_diff_functions {
129            assert!(
130                !DifferentiabilityChecker::is_function_differentiable(func_name),
131                "Function {} should be marked as non-differentiable",
132                func_name
133            );
134        }
135    }
136
137    #[test]
138    fn test_composite_expressions() {
139        let x = symbol!(x);
140
141        let composite1 = Expression::add(vec![
142            Expression::function("sin", vec![Expression::symbol(x.clone())]),
143            Expression::function("cos", vec![Expression::symbol(x.clone())]),
144        ]);
145        assert!(DifferentiabilityChecker::check(&composite1, x.clone()));
146
147        let composite2 = Expression::mul(vec![
148            Expression::function("exp", vec![Expression::symbol(x.clone())]),
149            Expression::function("ln", vec![Expression::symbol(x.clone())]),
150        ]);
151        assert!(DifferentiabilityChecker::check(&composite2, x.clone()));
152
153        let composite3 = Expression::pow(
154            Expression::function("sin", vec![Expression::symbol(x.clone())]),
155            Expression::integer(2),
156        );
157        assert!(DifferentiabilityChecker::check(&composite3, x.clone()));
158    }
159
160    #[test]
161    fn test_nested_functions() {
162        let x = symbol!(x);
163
164        let nested1 = Expression::function(
165            "sin",
166            vec![Expression::function(
167                "cos",
168                vec![Expression::symbol(x.clone())],
169            )],
170        );
171        assert!(DifferentiabilityChecker::check(&nested1, x.clone()));
172
173        let nested2 = Expression::function(
174            "exp",
175            vec![Expression::function(
176                "ln",
177                vec![Expression::symbol(x.clone())],
178            )],
179        );
180        assert!(DifferentiabilityChecker::check(&nested2, x.clone()));
181
182        let nested3 = Expression::function(
183            "sqrt",
184            vec![Expression::add(vec![
185                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
186                Expression::integer(1),
187            ])],
188        );
189        assert!(DifferentiabilityChecker::check(&nested3, x.clone()));
190    }
191
192    #[test]
193    fn test_multivariate_expressions() {
194        let x = symbol!(x);
195        let y = symbol!(y);
196
197        let multivar1 = Expression::add(vec![
198            Expression::symbol(x.clone()),
199            Expression::symbol(y.clone()),
200        ]);
201        assert!(DifferentiabilityChecker::check(&multivar1, x.clone()));
202        assert!(DifferentiabilityChecker::check(&multivar1, y.clone()));
203
204        let multivar2 = Expression::function(
205            "sin",
206            vec![Expression::mul(vec![
207                Expression::symbol(x.clone()),
208                Expression::symbol(y.clone()),
209            ])],
210        );
211        assert!(DifferentiabilityChecker::check(&multivar2, x.clone()));
212        assert!(DifferentiabilityChecker::check(&multivar2, y.clone()));
213    }
214
215    #[test]
216    fn test_edge_cases() {
217        let x = symbol!(x);
218        let y = symbol!(y);
219
220        let zero_expr = Expression::integer(0);
221        assert!(DifferentiabilityChecker::check(&zero_expr, x.clone()));
222
223        let one_expr = Expression::integer(1);
224        assert!(DifferentiabilityChecker::check(&one_expr, x.clone()));
225
226        let other_var = Expression::symbol(y.clone());
227        assert!(DifferentiabilityChecker::check(&other_var, x.clone()));
228
229        let empty_sum = Expression::add(vec![]);
230        assert!(DifferentiabilityChecker::check(&empty_sum, x.clone()));
231
232        let empty_product = Expression::mul(vec![]);
233        assert!(DifferentiabilityChecker::check(&empty_product, x.clone()));
234    }
235
236    #[test]
237    fn test_function_differentiability_lookup() {
238        assert!(DifferentiabilityChecker::is_function_differentiable("sin"));
239        assert!(DifferentiabilityChecker::is_function_differentiable("cos"));
240        assert!(DifferentiabilityChecker::is_function_differentiable("exp"));
241        assert!(DifferentiabilityChecker::is_function_differentiable("ln"));
242        assert!(DifferentiabilityChecker::is_function_differentiable("sqrt"));
243
244        assert!(!DifferentiabilityChecker::is_function_differentiable("abs"));
245        assert!(!DifferentiabilityChecker::is_function_differentiable(
246            "floor"
247        ));
248        assert!(!DifferentiabilityChecker::is_function_differentiable(
249            "ceil"
250        ));
251        assert!(!DifferentiabilityChecker::is_function_differentiable(
252            "sign"
253        ));
254
255        assert!(DifferentiabilityChecker::is_function_differentiable(
256            "unknown_function"
257        ));
258    }
259
260    #[test]
261    fn test_complex_expressions() {
262        let x = symbol!(x);
263
264        let complex1 = Expression::add(vec![
265            Expression::mul(vec![
266                Expression::function("sin", vec![Expression::symbol(x.clone())]),
267                Expression::function("cos", vec![Expression::symbol(x.clone())]),
268            ]),
269            Expression::pow(
270                Expression::function("exp", vec![Expression::symbol(x.clone())]),
271                Expression::integer(2),
272            ),
273        ]);
274        assert!(DifferentiabilityChecker::check(&complex1, x.clone()));
275
276        let complex2 = Expression::function(
277            "ln",
278            vec![Expression::add(vec![
279                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
280                Expression::integer(1),
281            ])],
282        );
283        assert!(DifferentiabilityChecker::check(&complex2, x.clone()));
284    }
285}