mathhook_core/calculus/derivatives/partial/
utils.rs

1//! Utility functions for partial derivative operations
2
3use crate::core::constants::EPSILON;
4use crate::core::{Expression, Number};
5use crate::simplify::Simplify;
6
7/// Utility functions for partial derivatives
8pub struct PartialUtils;
9
10impl PartialUtils {
11    /// Fast expression equality check with caching
12    ///
13    /// # Examples
14    ///
15    /// ```rust
16    /// use mathhook_core::simplify::Simplify;
17    /// use mathhook_core::calculus::derivatives::PartialUtils;
18    /// use mathhook_core::calculus::derivatives::MatrixUtils;
19    /// use mathhook_core::{Expression};
20    /// use mathhook_core::symbol;
21    ///
22    /// let x = symbol!(x);
23    /// let expr1 = Expression::symbol(x.clone());
24    /// let expr2 = Expression::symbol(x.clone());
25    /// let equal = PartialUtils::expressions_equal(&expr1, &expr2);
26    /// ```
27    pub fn expressions_equal(expr1: &Expression, expr2: &Expression) -> bool {
28        if std::ptr::eq(expr1, expr2) {
29            return true;
30        }
31
32        match (expr1, expr2) {
33            (Expression::Number(n1), Expression::Number(n2)) => n1 == n2,
34            (Expression::Symbol(s1), Expression::Symbol(s2)) => s1 == s2,
35            _ => format!("{:?}", expr1.simplify()) == format!("{:?}", expr2.simplify()),
36        }
37    }
38
39    /// Fast zero check with pattern matching
40    ///
41    /// # Examples
42    ///
43    /// ```rust
44    /// use mathhook_core::calculus::derivatives::PartialUtils;
45    /// use mathhook_core::Expression;
46    ///
47    /// let zero = Expression::integer(0);
48    /// let is_zero = PartialUtils::is_zero(&zero);
49    /// ```
50    pub fn is_zero(expr: &Expression) -> bool {
51        match expr {
52            Expression::Number(Number::Integer(0)) => true,
53            Expression::Number(Number::Float(f)) if f.abs() < EPSILON => true,
54            _ => matches!(expr.simplify(), Expression::Number(Number::Integer(0))),
55        }
56    }
57
58    /// Validate dimension compatibility early
59    ///
60    /// # Examples
61    ///
62    /// ```rust
63    /// use mathhook_core::calculus::derivatives::PartialUtils;
64    /// let result = PartialUtils::validate_dimensions("gradient", 3, 3);
65    /// assert!(result.is_ok());
66    /// ```
67    pub fn validate_dimensions(name: &str, expected: usize, actual: usize) -> Result<(), String> {
68        if expected != actual {
69            Err(format!(
70                "{}: dimension mismatch - expected {}, got {}",
71                name, expected, actual
72            ))
73        } else {
74            Ok(())
75        }
76    }
77}
78
79/// Matrix operations for partial derivatives
80pub struct MatrixUtils;
81
82impl MatrixUtils {
83    /// Compute matrix determinant with optimized algorithms
84    ///
85    /// # Examples
86    ///
87    /// ```rust
88    /// use mathhook_core::{Expression, calculus::derivatives::MatrixUtils};
89    ///
90    /// let matrix = vec![
91    ///     vec![Expression::integer(1), Expression::integer(2)],
92    ///     vec![Expression::integer(3), Expression::integer(4)],
93    /// ];
94    /// let det = MatrixUtils::determinant(&matrix);
95    /// ```
96    pub fn determinant(matrix: &[Vec<Expression>]) -> Expression {
97        let n = matrix.len();
98        if n == 0 {
99            panic!("Matrix must be square and non-empty");
100        }
101
102        // Check that all rows have the same length and that the matrix is square
103        let expected_cols = matrix[0].len();
104        if expected_cols != n {
105            panic!("Matrix must be square and non-empty");
106        }
107
108        for row in matrix.iter() {
109            if row.len() != expected_cols {
110                panic!("Matrix must be square and non-empty");
111            }
112        }
113
114        match n {
115            1 => matrix[0][0].clone(),
116            2 => Self::det_2x2(matrix),
117            3 => Self::det_3x3(matrix),
118            _ => Self::det_symbolic(matrix),
119        }
120    }
121
122    /// Optimized 2×2 determinant: |a b| = ad - bc
123    ///                            |c d|
124    fn det_2x2(matrix: &[Vec<Expression>]) -> Expression {
125        let ad = Expression::mul(vec![matrix[0][0].clone(), matrix[1][1].clone()]).simplify();
126        let bc = Expression::mul(vec![matrix[0][1].clone(), matrix[1][0].clone()]).simplify();
127        let neg_bc = Expression::mul(vec![Expression::integer(-1), bc]).simplify();
128
129        Expression::add(vec![ad, neg_bc]).simplify()
130    }
131
132    /// Optimized 3×3 determinant using cofactor expansion
133    fn det_3x3(matrix: &[Vec<Expression>]) -> Expression {
134        let mut terms = Vec::with_capacity(3);
135
136        for i in 0..3 {
137            let sign = if i % 2 == 0 { 1 } else { -1 };
138            let cofactor = Self::cofactor_2x2(matrix, 0, i);
139            terms.push(Expression::mul(vec![
140                Expression::integer(sign),
141                matrix[0][i].clone(),
142                cofactor,
143            ]));
144        }
145
146        Expression::add(terms).simplify()
147    }
148
149    /// Compute 2×2 cofactor for 3×3 determinant
150    fn cofactor_2x2(matrix: &[Vec<Expression>], skip_row: usize, skip_col: usize) -> Expression {
151        let elements: Vec<Expression> = (0..3)
152            .filter(|&i| i != skip_row)
153            .flat_map(|i| {
154                (0..3)
155                    .filter(|&j| j != skip_col)
156                    .map(move |j| matrix[i][j].clone())
157            })
158            .collect();
159
160        // 2×2 determinant: ad - bc
161        let ad = Expression::mul(vec![elements[0].clone(), elements[3].clone()]).simplify();
162        let bc = Expression::mul(vec![elements[1].clone(), elements[2].clone()]).simplify();
163        let neg_bc = Expression::mul(vec![Expression::integer(-1), bc]).simplify();
164
165        Expression::add(vec![ad, neg_bc]).simplify()
166    }
167
168    /// Symbolic determinant for large matrices
169    fn det_symbolic(matrix: &[Vec<Expression>]) -> Expression {
170        Expression::function(
171            "det",
172            vec![Expression::function(
173                "matrix",
174                matrix.iter().flat_map(|row| row.iter().cloned()).collect(),
175            )],
176        )
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::symbol;
184    use crate::Symbol;
185    use std::f64::consts::PI;
186
187    fn test_symbols() -> (Symbol, Symbol, Symbol) {
188        (symbol!(x), symbol!(y), symbol!(z))
189    }
190
191    #[test]
192    fn test_expression_equality() {
193        let (x, y, _) = test_symbols();
194
195        // Same expressions
196        let expr1 = Expression::symbol(x.clone());
197        let expr2 = Expression::symbol(x.clone());
198        assert!(PartialUtils::expressions_equal(&expr1, &expr2));
199
200        // Different symbols
201        let expr3 = Expression::symbol(y);
202        assert!(!PartialUtils::expressions_equal(&expr1, &expr3));
203
204        // Same numbers
205        let num1 = Expression::integer(42);
206        let num2 = Expression::integer(42);
207        assert!(PartialUtils::expressions_equal(&num1, &num2));
208
209        // Different numbers
210        let num3 = Expression::integer(24);
211        assert!(!PartialUtils::expressions_equal(&num1, &num3));
212
213        // Float numbers
214        let float1 = Expression::float(PI);
215        let float2 = Expression::float(PI);
216        assert!(PartialUtils::expressions_equal(&float1, &float2));
217    }
218
219    #[test]
220    fn test_complex_expression_equality() {
221        let (x, _, _) = test_symbols();
222
223        // x + 1 vs x + 1
224        let expr1 = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
225        let expr2 = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
226        assert!(PartialUtils::expressions_equal(&expr1, &expr2));
227
228        // x² vs x²
229        let poly1 = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
230        let poly2 = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
231        assert!(PartialUtils::expressions_equal(&poly1, &poly2));
232
233        // 2x vs 2x
234        let mult1 = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
235        let mult2 = Expression::mul(vec![Expression::integer(2), Expression::symbol(x)]);
236        assert!(PartialUtils::expressions_equal(&mult1, &mult2));
237    }
238
239    #[test]
240    fn test_zero_detection() {
241        // Integer zero
242        assert!(PartialUtils::is_zero(&Expression::integer(0)));
243
244        // Float zero
245        assert!(PartialUtils::is_zero(&Expression::float(0.0)));
246
247        // Non-zero integers
248        assert!(!PartialUtils::is_zero(&Expression::integer(1)));
249        assert!(!PartialUtils::is_zero(&Expression::integer(-5)));
250
251        // Non-zero floats
252        assert!(!PartialUtils::is_zero(&Expression::float(PI)));
253        assert!(!PartialUtils::is_zero(&Expression::float(-2.71)));
254
255        // Symbols are not zero
256        let x = symbol!(x);
257        assert!(!PartialUtils::is_zero(&Expression::symbol(x)));
258    }
259
260    #[test]
261    fn test_zero_expressions() {
262        let (x, _, _) = test_symbols();
263
264        // 0 + 0 = 0
265        let zero_sum = Expression::add(vec![Expression::integer(0), Expression::integer(0)]);
266        assert!(PartialUtils::is_zero(&zero_sum));
267
268        // 0 * x = 0
269        let zero_mult =
270            Expression::mul(vec![Expression::integer(0), Expression::symbol(x.clone())]);
271        assert!(PartialUtils::is_zero(&zero_mult));
272
273        // x - x should be zero after simplification
274        let diff = Expression::add(vec![
275            Expression::symbol(x.clone()),
276            Expression::mul(vec![Expression::integer(-1), Expression::symbol(x)]),
277        ]);
278        assert!(PartialUtils::is_zero(&diff));
279    }
280
281    #[test]
282    fn test_dimension_validation() {
283        // Valid dimensions
284        assert!(PartialUtils::validate_dimensions("test", 3, 3).is_ok());
285        assert!(PartialUtils::validate_dimensions("gradient", 2, 2).is_ok());
286        assert!(PartialUtils::validate_dimensions("hessian", 4, 4).is_ok());
287
288        // Invalid dimensions
289        let result = PartialUtils::validate_dimensions("jacobian", 3, 2);
290        let error_message = result.unwrap_err();
291
292        assert!(error_message.contains("dimension mismatch"));
293        assert!(error_message.contains("expected 3"));
294        assert!(error_message.contains("got 2"));
295
296        // Zero dimensions
297        assert!(PartialUtils::validate_dimensions("empty", 0, 0).is_ok());
298        let zero_error = PartialUtils::validate_dimensions("non-empty", 1, 0);
299        assert!(zero_error.is_err());
300    }
301
302    #[test]
303    fn test_1x1_determinant() {
304        // |5| = 5
305        let matrix = vec![vec![Expression::integer(5)]];
306        let det = MatrixUtils::determinant(&matrix);
307        assert_eq!(det, Expression::integer(5));
308
309        // |x| = x
310        let x = symbol!(x);
311        let matrix_x = vec![vec![Expression::symbol(x.clone())]];
312        let det_x = MatrixUtils::determinant(&matrix_x);
313        assert_eq!(det_x, Expression::symbol(x));
314    }
315
316    #[test]
317    fn test_2x2_determinant() {
318        // |1 2| = 1*4 - 2*3 = -2
319        // |3 4|
320        let matrix = vec![
321            vec![Expression::integer(1), Expression::integer(2)],
322            vec![Expression::integer(3), Expression::integer(4)],
323        ];
324        let det = MatrixUtils::determinant(&matrix);
325        assert_eq!(det.simplify(), Expression::integer(-2));
326
327        // |a b| = ad - bc
328        // |c d|
329        let (a, b, c) = (symbol!(a), symbol!(b), symbol!(c));
330        let d = symbol!(d);
331        let symbolic_matrix = vec![
332            vec![Expression::symbol(a.clone()), Expression::symbol(b.clone())],
333            vec![Expression::symbol(c.clone()), Expression::symbol(d.clone())],
334        ];
335        let symbolic_det = MatrixUtils::determinant(&symbolic_matrix);
336
337        let expected = Expression::add(vec![
338            Expression::mul(vec![Expression::symbol(a), Expression::symbol(d)]), // ad
339            Expression::mul(vec![
340                Expression::integer(-1),
341                Expression::mul(vec![Expression::symbol(b), Expression::symbol(c)]), // -bc
342            ]),
343        ]);
344        assert_eq!(symbolic_det.simplify(), expected.simplify());
345    }
346
347    #[test]
348    fn test_3x3_determinant() {
349        // |1 0 0|
350        // |0 1 0| = 1
351        // |0 0 1|
352        let identity = vec![
353            vec![
354                Expression::integer(1),
355                Expression::integer(0),
356                Expression::integer(0),
357            ],
358            vec![
359                Expression::integer(0),
360                Expression::integer(1),
361                Expression::integer(0),
362            ],
363            vec![
364                Expression::integer(0),
365                Expression::integer(0),
366                Expression::integer(1),
367            ],
368        ];
369        let det = MatrixUtils::determinant(&identity);
370        assert_eq!(det.simplify(), Expression::integer(1));
371
372        // |1 2 3|
373        // |4 5 6| = 0 (rows are linearly dependent)
374        // |7 8 9|
375        let singular = vec![
376            vec![
377                Expression::integer(1),
378                Expression::integer(2),
379                Expression::integer(3),
380            ],
381            vec![
382                Expression::integer(4),
383                Expression::integer(5),
384                Expression::integer(6),
385            ],
386            vec![
387                Expression::integer(7),
388                Expression::integer(8),
389                Expression::integer(9),
390            ],
391        ];
392        let det_singular = MatrixUtils::determinant(&singular);
393        assert_eq!(det_singular.simplify(), Expression::integer(0));
394    }
395
396    #[test]
397    fn test_3x3_symbolic_determinant() {
398        let (x, y, z) = test_symbols();
399
400        // |x 0 0|
401        // |0 y 0| = xyz
402        // |0 0 z|
403        let diagonal = vec![
404            vec![
405                Expression::symbol(x.clone()),
406                Expression::integer(0),
407                Expression::integer(0),
408            ],
409            vec![
410                Expression::integer(0),
411                Expression::symbol(y.clone()),
412                Expression::integer(0),
413            ],
414            vec![
415                Expression::integer(0),
416                Expression::integer(0),
417                Expression::symbol(z.clone()),
418            ],
419        ];
420        let det = MatrixUtils::determinant(&diagonal);
421
422        let expected = Expression::mul(vec![
423            Expression::symbol(x),
424            Expression::symbol(y),
425            Expression::symbol(z),
426        ]);
427        assert_eq!(det.simplify(), expected.simplify());
428    }
429
430    #[test]
431    fn test_large_matrix_symbolic() {
432        // 4×4 matrix should use symbolic representation
433        let matrix = vec![
434            vec![
435                Expression::integer(1),
436                Expression::integer(2),
437                Expression::integer(3),
438                Expression::integer(4),
439            ],
440            vec![
441                Expression::integer(5),
442                Expression::integer(6),
443                Expression::integer(7),
444                Expression::integer(8),
445            ],
446            vec![
447                Expression::integer(9),
448                Expression::integer(10),
449                Expression::integer(11),
450                Expression::integer(12),
451            ],
452            vec![
453                Expression::integer(13),
454                Expression::integer(14),
455                Expression::integer(15),
456                Expression::integer(16),
457            ],
458        ];
459
460        let det = MatrixUtils::determinant(&matrix);
461
462        // Should be a function call to det(matrix(...))
463        match det {
464            Expression::Function { name, .. } => {
465                assert_eq!(name, "det");
466            }
467            _ => panic!("Expected function call for large matrix determinant"),
468        }
469    }
470
471    #[test]
472    fn test_special_matrices() {
473        // Zero matrix: |0 0| = 0
474        //              |0 0|
475        let zero_matrix = vec![
476            vec![Expression::integer(0), Expression::integer(0)],
477            vec![Expression::integer(0), Expression::integer(0)],
478        ];
479        let det_zero = MatrixUtils::determinant(&zero_matrix);
480        assert_eq!(det_zero.simplify(), Expression::integer(0));
481
482        // Upper triangular: |1 2| = 1*3 = 3
483        //                   |0 3|
484        let upper_tri = vec![
485            vec![Expression::integer(1), Expression::integer(2)],
486            vec![Expression::integer(0), Expression::integer(3)],
487        ];
488        let det_tri = MatrixUtils::determinant(&upper_tri);
489        assert_eq!(det_tri.simplify(), Expression::integer(3));
490    }
491
492    #[test]
493    fn test_rational_determinant() {
494        // |1/2  1/3|
495        // |1/4  1/5|
496        let rational_matrix = vec![
497            vec![Expression::rational(1, 2), Expression::rational(1, 3)],
498            vec![Expression::rational(1, 4), Expression::rational(1, 5)],
499        ];
500        let det = MatrixUtils::determinant(&rational_matrix);
501
502        // (1/2)(1/5) - (1/3)(1/4) = 1/10 - 1/12 = 6/60 - 5/60 = 1/60
503        let expected = Expression::rational(1, 60);
504        assert_eq!(det.simplify(), expected.simplify());
505    }
506
507    #[test]
508    #[should_panic(expected = "Matrix must be square and non-empty")]
509    fn test_non_square_matrix_panic() {
510        let non_square = vec![
511            vec![Expression::integer(1), Expression::integer(2)],
512            vec![
513                Expression::integer(3),
514                Expression::integer(4),
515                Expression::integer(5),
516            ],
517        ];
518        MatrixUtils::determinant(&non_square);
519    }
520
521    #[test]
522    #[should_panic(expected = "Matrix must be square and non-empty")]
523    fn test_empty_matrix_panic() {
524        let empty: Vec<Vec<Expression>> = vec![];
525        MatrixUtils::determinant(&empty);
526    }
527}