mathhook_core/calculus/derivatives/partial/
hessian.rs

1//! Hessian matrix operations for second-order partial derivatives
2use crate::calculus::derivatives::Derivative;
3use crate::core::{Expression, Symbol};
4use crate::simplify::Simplify;
5/// Hessian matrix operations
6pub struct HessianOperations;
7impl HessianOperations {
8    /// Compute Hessian matrix
9    ///
10    /// # Examples
11    ///
12    /// ```rust
13    /// use mathhook_core::simplify::Simplify;
14    /// use mathhook_core::calculus::derivatives::Derivative;
15    /// use mathhook_core::{Expression};
16    /// use mathhook_core::symbol;
17    /// use mathhook_core::calculus::derivatives::HessianOperations;
18    ///
19    /// let x = symbol!(x);
20    /// let y = symbol!(y);
21    /// let expr = Expression::add(vec![
22    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
23    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
24    /// ]);
25    /// let hessian = HessianOperations::compute(&expr, &[x, y]);
26    /// ```
27    pub fn compute(expr: &Expression, variables: &[Symbol]) -> Vec<Vec<Expression>> {
28        let n = variables.len();
29        let mut hessian = Vec::with_capacity(n);
30        for _ in 0..n {
31            hessian.push(Vec::with_capacity(n));
32        }
33        for i in 0..n {
34            for j in 0..n {
35                if j >= i {
36                    let second_partial = expr
37                        .derivative(variables[i].clone())
38                        .derivative(variables[j].clone())
39                        .simplify();
40                    hessian[i].push(second_partial);
41                } else {
42                    let symmetric_entry = hessian[j][i].clone();
43                    hessian[i].push(symmetric_entry);
44                }
45            }
46        }
47        hessian
48    }
49    /// Compute Hessian determinant
50    ///
51    /// # Examples
52    ///
53    /// ```rust
54    /// use mathhook_core::{Expression, Symbol};
55    /// use mathhook_core::symbol;
56    /// use mathhook_core::calculus::derivatives::HessianOperations;
57    ///
58    /// let x = symbol!(x);
59    /// let y = symbol!(y);
60    /// let expr = Expression::add(vec![
61    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
62    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
63    /// ]);
64    /// let det = HessianOperations::determinant(&expr, vec![x, y]);
65    /// ```
66    pub fn determinant(expr: &Expression, variables: Vec<Symbol>) -> Expression {
67        let hessian = Self::compute(expr, &variables);
68        Self::matrix_determinant(&hessian)
69    }
70    /// Compute matrix determinant recursively
71    fn matrix_determinant(matrix: &[Vec<Expression>]) -> Expression {
72        let n = matrix.len();
73        match n {
74            0 => Expression::integer(1),
75            1 => matrix[0][0].clone(),
76            2 => {
77                let a = &matrix[0][0];
78                let b = &matrix[0][1];
79                let c = &matrix[1][0];
80                let d = &matrix[1][1];
81                Expression::add(vec![
82                    Expression::mul(vec![a.clone(), d.clone()]),
83                    Expression::mul(vec![
84                        Expression::integer(-1),
85                        Expression::mul(vec![b.clone(), c.clone()]),
86                    ]),
87                ])
88                .simplify()
89            }
90            _ => {
91                let mut det_terms = Vec::with_capacity(n);
92                for j in 0..n {
93                    let cofactor = Self::cofactor(matrix, 0, j);
94                    let sign = if j % 2 == 0 { 1 } else { -1 };
95                    det_terms.push(Expression::mul(vec![
96                        Expression::integer(sign),
97                        matrix[0][j].clone(),
98                        cofactor,
99                    ]));
100                }
101                Expression::add(det_terms).simplify()
102            }
103        }
104    }
105    /// Compute cofactor for matrix determinant
106    fn cofactor(matrix: &[Vec<Expression>], row: usize, col: usize) -> Expression {
107        let n = matrix.len();
108        let minor: Vec<Vec<_>> = (0..n)
109            .filter(|&i| i != row)
110            .map(|i| {
111                (0..n)
112                    .filter(|&j| j != col)
113                    .map(|j| matrix[i][j].clone())
114                    .collect()
115            })
116            .collect();
117        Self::matrix_determinant(&minor)
118    }
119    /// Check if Hessian is positive definite (for optimization)
120    ///
121    /// # Examples
122    ///
123    /// ```rust
124    /// use mathhook_core::{Expression};
125    /// use mathhook_core::symbol;
126    /// use mathhook_core::calculus::derivatives::HessianOperations;
127    ///
128    /// let x = symbol!(x);
129    /// let y = symbol!(y);
130    /// let expr = Expression::add(vec![
131    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
132    ///     Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))
133    /// ]);
134    /// let is_pos_def = HessianOperations::is_positive_definite(&expr, vec![x, y]);
135    /// ```
136    pub fn is_positive_definite(expr: &Expression, variables: Vec<Symbol>) -> bool {
137        let hessian = Self::compute(expr, &variables);
138        Self::check_positive_definite(&hessian)
139    }
140    /// Check positive definiteness using leading principal minors
141    fn check_positive_definite(hessian: &[Vec<Expression>]) -> bool {
142        let n = hessian.len();
143        for k in 1..=n {
144            let submatrix: Vec<Vec<_>> = (0..k)
145                .map(|i| (0..k).map(|j| hessian[i][j].clone()).collect())
146                .collect();
147            let det = Self::matrix_determinant(&submatrix);
148            if det.is_zero() {
149                return false;
150            }
151        }
152        true
153    }
154    /// Compute trace of Hessian matrix
155    ///
156    /// # Examples
157    ///
158    /// ```rust
159    /// use mathhook_core::{Expression, Symbol};
160    /// use mathhook_core::symbol;
161    /// use mathhook_core::calculus::derivatives::HessianOperations;
162    ///
163    /// let x = symbol!(x);
164    /// let y = symbol!(y);
165    /// let expr = Expression::add(vec![
166    ///     Expression::mul(vec![Expression::integer(3), Expression::pow(Expression::symbol(x.clone()), Expression::integer(2))]),
167    ///     Expression::mul(vec![Expression::integer(5), Expression::pow(Expression::symbol(y.clone()), Expression::integer(2))])
168    /// ]);
169    /// let trace = HessianOperations::trace(&expr, vec![x, y]);
170    /// ```
171    pub fn trace(expr: &Expression, variables: Vec<Symbol>) -> Expression {
172        let hessian = Self::compute(expr, &variables);
173        let n = hessian.len();
174        let mut diagonal_terms = Vec::with_capacity(n);
175        diagonal_terms.extend((0..n).map(|i| hessian[i][i].clone()));
176        Expression::add(diagonal_terms).simplify()
177    }
178}
179#[cfg(test)]
180mod tests {
181    use std::slice::from_ref;
182
183    use super::*;
184    use crate::expr;
185    use crate::symbol;
186    #[test]
187    fn test_quadratic_hessian() {
188        let x = symbol!(x);
189        let y = symbol!(y);
190        let quadratic = Expression::add(vec![
191            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
192            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
193        ]);
194        let hessian = HessianOperations::compute(&quadratic, &[x.clone(), y.clone()]);
195        assert_eq!(hessian.len(), 2);
196        assert_eq!(hessian[0].len(), 2);
197        assert_eq!(hessian[1].len(), 2);
198        assert_eq!(hessian[0][0].simplify(), Expression::integer(2));
199        assert_eq!(hessian[1][1].simplify(), Expression::integer(2));
200        assert_eq!(hessian[0][1].simplify(), Expression::integer(0));
201        assert_eq!(hessian[1][0].simplify(), Expression::integer(0));
202    }
203    #[test]
204    fn test_mixed_partial_hessian() {
205        let x = symbol!(x);
206        let y = symbol!(y);
207        let mixed = Expression::mul(vec![
208            Expression::symbol(x.clone()),
209            Expression::symbol(y.clone()),
210        ]);
211        let hessian = HessianOperations::compute(&mixed, &[x.clone(), y.clone()]);
212        assert_eq!(hessian[0][0].simplify(), Expression::integer(0));
213        assert_eq!(hessian[1][1].simplify(), Expression::integer(0));
214        assert_eq!(hessian[0][1].simplify(), Expression::integer(1));
215        assert_eq!(hessian[1][0].simplify(), Expression::integer(1));
216    }
217    #[test]
218    fn test_cubic_polynomial_hessian() {
219        let x = symbol!(x);
220        let y = symbol!(y);
221        let cubic = Expression::add(vec![
222            Expression::pow(Expression::symbol(x.clone()), Expression::integer(3)),
223            Expression::mul(vec![
224                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
225                Expression::symbol(y.clone()),
226            ]),
227            Expression::pow(Expression::symbol(y.clone()), Expression::integer(3)),
228        ]);
229        let hessian = HessianOperations::compute(&cubic, &[x.clone(), y.clone()]);
230        assert_eq!(hessian.len(), 2);
231        assert!(!hessian[0][0].is_zero());
232        assert!(!hessian[1][1].is_zero());
233        assert!(!hessian[0][1].is_zero());
234        assert!(!hessian[1][0].is_zero());
235    }
236    #[test]
237    fn test_single_variable_hessian() {
238        let x = symbol!(x);
239        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(4));
240        let hessian = HessianOperations::compute(&expr, from_ref(&x));
241        assert_eq!(hessian.len(), 1);
242        assert_eq!(hessian[0].len(), 1);
243        assert!(!hessian[0][0].is_zero());
244    }
245    #[test]
246    fn test_three_variable_hessian() {
247        let x = symbol!(x);
248        let y = symbol!(y);
249        let z = symbol!(z);
250        let expr = Expression::add(vec![
251            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
252            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
253            Expression::pow(Expression::symbol(z.clone()), Expression::integer(2)),
254        ]);
255        let hessian = HessianOperations::compute(&expr, &[x.clone(), y.clone(), z.clone()]);
256        assert_eq!(hessian.len(), 3);
257        for (i, row) in hessian.iter().enumerate().take(3) {
258            assert_eq!(row.len(), 3);
259            for (j, val) in row.iter().enumerate().take(3) {
260                let expected = if i == j {
261                    Expression::integer(2)
262                } else {
263                    Expression::integer(0)
264                };
265                assert_eq!(val.simplify(), expected);
266            }
267        }
268    }
269    #[test]
270    fn test_hessian_determinant_2x2() {
271        let x = symbol!(x);
272        let y = symbol!(y);
273        let expr = Expression::add(vec![
274            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
275            Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
276        ]);
277        let det = HessianOperations::determinant(&expr, vec![x.clone(), y.clone()]);
278        assert_eq!(det.simplify(), Expression::integer(4));
279    }
280    #[test]
281    fn test_hessian_trace() {
282        let x = symbol!(x);
283        let y = symbol!(y);
284        let expr = Expression::add(vec![
285            Expression::mul(vec![
286                Expression::integer(3),
287                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
288            ]),
289            Expression::mul(vec![
290                Expression::integer(5),
291                Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
292            ]),
293        ]);
294        let trace = HessianOperations::trace(&expr, vec![x.clone(), y.clone()]);
295        assert_eq!(trace.simplify(), Expression::integer(16));
296    }
297    #[test]
298    fn test_constant_function_hessian() {
299        let x = symbol!(x);
300        let y = symbol!(y);
301        let constant = Expression::integer(42);
302        let hessian = HessianOperations::compute(&constant, &[x.clone(), y.clone()]);
303        for value in hessian.iter().flatten() {
304            assert_eq!(value.simplify(), expr!(0));
305        }
306    }
307    #[test]
308    fn test_linear_function_hessian() {
309        let x = symbol!(x);
310        let y = symbol!(y);
311        let linear = Expression::add(vec![
312            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
313            Expression::mul(vec![Expression::integer(3), Expression::symbol(y.clone())]),
314            Expression::integer(1),
315        ]);
316        let hessian = HessianOperations::compute(&linear, &[x.clone(), y.clone()]);
317        for value in hessian.iter().flatten() {
318            assert_eq!(value.simplify(), expr!(0));
319        }
320    }
321    #[test]
322    fn test_hessian_symmetry() {
323        let x = symbol!(x);
324        let y = symbol!(y);
325        let expr = Expression::add(vec![
326            Expression::mul(vec![
327                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
328                Expression::symbol(y.clone()),
329            ]),
330            Expression::mul(vec![
331                Expression::symbol(x.clone()),
332                Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
333            ]),
334        ]);
335        let hessian = HessianOperations::compute(&expr, &[x.clone(), y.clone()]);
336        assert_eq!(hessian[0][1], hessian[1][0]);
337    }
338    #[test]
339    fn test_trigonometric_hessian() {
340        let x = symbol!(x);
341        let y = symbol!(y);
342        let trig_expr = Expression::add(vec![
343            Expression::function("sin", vec![Expression::symbol(x.clone())]),
344            Expression::function("cos", vec![Expression::symbol(y.clone())]),
345        ]);
346        let hessian = HessianOperations::compute(&trig_expr, &[x.clone(), y.clone()]);
347        assert_eq!(hessian.len(), 2);
348        assert!(!hessian[0][0].is_zero());
349        assert!(!hessian[1][1].is_zero());
350        assert_eq!(hessian[0][1].simplify(), Expression::integer(0));
351        assert_eq!(hessian[1][0].simplify(), Expression::integer(0));
352    }
353}