mathhook_core/calculus/derivatives/partial/
hessian.rs1use crate::calculus::derivatives::Derivative;
3use crate::core::{Expression, Symbol};
4use crate::simplify::Simplify;
5pub struct HessianOperations;
7impl HessianOperations {
8 pub fn compute(expr: &Expression, variables: &[Symbol]) -> Vec<Vec<Expression>> {
28 let n = variables.len();
29 let mut hessian = Vec::with_capacity(n);
30 for _ in 0..n {
31 hessian.push(Vec::with_capacity(n));
32 }
33 for i in 0..n {
34 for j in 0..n {
35 if j >= i {
36 let second_partial = expr
37 .derivative(variables[i].clone())
38 .derivative(variables[j].clone())
39 .simplify();
40 hessian[i].push(second_partial);
41 } else {
42 let symmetric_entry = hessian[j][i].clone();
43 hessian[i].push(symmetric_entry);
44 }
45 }
46 }
47 hessian
48 }
49 pub fn determinant(expr: &Expression, variables: Vec<Symbol>) -> Expression {
67 let hessian = Self::compute(expr, &variables);
68 Self::matrix_determinant(&hessian)
69 }
70 fn matrix_determinant(matrix: &[Vec<Expression>]) -> Expression {
72 let n = matrix.len();
73 match n {
74 0 => Expression::integer(1),
75 1 => matrix[0][0].clone(),
76 2 => {
77 let a = &matrix[0][0];
78 let b = &matrix[0][1];
79 let c = &matrix[1][0];
80 let d = &matrix[1][1];
81 Expression::add(vec![
82 Expression::mul(vec![a.clone(), d.clone()]),
83 Expression::mul(vec![
84 Expression::integer(-1),
85 Expression::mul(vec![b.clone(), c.clone()]),
86 ]),
87 ])
88 .simplify()
89 }
90 _ => {
91 let mut det_terms = Vec::with_capacity(n);
92 for j in 0..n {
93 let cofactor = Self::cofactor(matrix, 0, j);
94 let sign = if j % 2 == 0 { 1 } else { -1 };
95 det_terms.push(Expression::mul(vec![
96 Expression::integer(sign),
97 matrix[0][j].clone(),
98 cofactor,
99 ]));
100 }
101 Expression::add(det_terms).simplify()
102 }
103 }
104 }
105 fn cofactor(matrix: &[Vec<Expression>], row: usize, col: usize) -> Expression {
107 let n = matrix.len();
108 let minor: Vec<Vec<_>> = (0..n)
109 .filter(|&i| i != row)
110 .map(|i| {
111 (0..n)
112 .filter(|&j| j != col)
113 .map(|j| matrix[i][j].clone())
114 .collect()
115 })
116 .collect();
117 Self::matrix_determinant(&minor)
118 }
119 pub fn is_positive_definite(expr: &Expression, variables: Vec<Symbol>) -> bool {
137 let hessian = Self::compute(expr, &variables);
138 Self::check_positive_definite(&hessian)
139 }
140 fn check_positive_definite(hessian: &[Vec<Expression>]) -> bool {
142 let n = hessian.len();
143 for k in 1..=n {
144 let submatrix: Vec<Vec<_>> = (0..k)
145 .map(|i| (0..k).map(|j| hessian[i][j].clone()).collect())
146 .collect();
147 let det = Self::matrix_determinant(&submatrix);
148 if det.is_zero() {
149 return false;
150 }
151 }
152 true
153 }
154 pub fn trace(expr: &Expression, variables: Vec<Symbol>) -> Expression {
172 let hessian = Self::compute(expr, &variables);
173 let n = hessian.len();
174 let mut diagonal_terms = Vec::with_capacity(n);
175 diagonal_terms.extend((0..n).map(|i| hessian[i][i].clone()));
176 Expression::add(diagonal_terms).simplify()
177 }
178}
179#[cfg(test)]
180mod tests {
181 use std::slice::from_ref;
182
183 use super::*;
184 use crate::expr;
185 use crate::symbol;
186 #[test]
187 fn test_quadratic_hessian() {
188 let x = symbol!(x);
189 let y = symbol!(y);
190 let quadratic = Expression::add(vec![
191 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
192 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
193 ]);
194 let hessian = HessianOperations::compute(&quadratic, &[x.clone(), y.clone()]);
195 assert_eq!(hessian.len(), 2);
196 assert_eq!(hessian[0].len(), 2);
197 assert_eq!(hessian[1].len(), 2);
198 assert_eq!(hessian[0][0].simplify(), Expression::integer(2));
199 assert_eq!(hessian[1][1].simplify(), Expression::integer(2));
200 assert_eq!(hessian[0][1].simplify(), Expression::integer(0));
201 assert_eq!(hessian[1][0].simplify(), Expression::integer(0));
202 }
203 #[test]
204 fn test_mixed_partial_hessian() {
205 let x = symbol!(x);
206 let y = symbol!(y);
207 let mixed = Expression::mul(vec![
208 Expression::symbol(x.clone()),
209 Expression::symbol(y.clone()),
210 ]);
211 let hessian = HessianOperations::compute(&mixed, &[x.clone(), y.clone()]);
212 assert_eq!(hessian[0][0].simplify(), Expression::integer(0));
213 assert_eq!(hessian[1][1].simplify(), Expression::integer(0));
214 assert_eq!(hessian[0][1].simplify(), Expression::integer(1));
215 assert_eq!(hessian[1][0].simplify(), Expression::integer(1));
216 }
217 #[test]
218 fn test_cubic_polynomial_hessian() {
219 let x = symbol!(x);
220 let y = symbol!(y);
221 let cubic = Expression::add(vec![
222 Expression::pow(Expression::symbol(x.clone()), Expression::integer(3)),
223 Expression::mul(vec![
224 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
225 Expression::symbol(y.clone()),
226 ]),
227 Expression::pow(Expression::symbol(y.clone()), Expression::integer(3)),
228 ]);
229 let hessian = HessianOperations::compute(&cubic, &[x.clone(), y.clone()]);
230 assert_eq!(hessian.len(), 2);
231 assert!(!hessian[0][0].is_zero());
232 assert!(!hessian[1][1].is_zero());
233 assert!(!hessian[0][1].is_zero());
234 assert!(!hessian[1][0].is_zero());
235 }
236 #[test]
237 fn test_single_variable_hessian() {
238 let x = symbol!(x);
239 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(4));
240 let hessian = HessianOperations::compute(&expr, from_ref(&x));
241 assert_eq!(hessian.len(), 1);
242 assert_eq!(hessian[0].len(), 1);
243 assert!(!hessian[0][0].is_zero());
244 }
245 #[test]
246 fn test_three_variable_hessian() {
247 let x = symbol!(x);
248 let y = symbol!(y);
249 let z = symbol!(z);
250 let expr = Expression::add(vec![
251 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
252 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
253 Expression::pow(Expression::symbol(z.clone()), Expression::integer(2)),
254 ]);
255 let hessian = HessianOperations::compute(&expr, &[x.clone(), y.clone(), z.clone()]);
256 assert_eq!(hessian.len(), 3);
257 for (i, row) in hessian.iter().enumerate().take(3) {
258 assert_eq!(row.len(), 3);
259 for (j, val) in row.iter().enumerate().take(3) {
260 let expected = if i == j {
261 Expression::integer(2)
262 } else {
263 Expression::integer(0)
264 };
265 assert_eq!(val.simplify(), expected);
266 }
267 }
268 }
269 #[test]
270 fn test_hessian_determinant_2x2() {
271 let x = symbol!(x);
272 let y = symbol!(y);
273 let expr = Expression::add(vec![
274 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
275 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
276 ]);
277 let det = HessianOperations::determinant(&expr, vec![x.clone(), y.clone()]);
278 assert_eq!(det.simplify(), Expression::integer(4));
279 }
280 #[test]
281 fn test_hessian_trace() {
282 let x = symbol!(x);
283 let y = symbol!(y);
284 let expr = Expression::add(vec![
285 Expression::mul(vec![
286 Expression::integer(3),
287 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
288 ]),
289 Expression::mul(vec![
290 Expression::integer(5),
291 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
292 ]),
293 ]);
294 let trace = HessianOperations::trace(&expr, vec![x.clone(), y.clone()]);
295 assert_eq!(trace.simplify(), Expression::integer(16));
296 }
297 #[test]
298 fn test_constant_function_hessian() {
299 let x = symbol!(x);
300 let y = symbol!(y);
301 let constant = Expression::integer(42);
302 let hessian = HessianOperations::compute(&constant, &[x.clone(), y.clone()]);
303 for value in hessian.iter().flatten() {
304 assert_eq!(value.simplify(), expr!(0));
305 }
306 }
307 #[test]
308 fn test_linear_function_hessian() {
309 let x = symbol!(x);
310 let y = symbol!(y);
311 let linear = Expression::add(vec![
312 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
313 Expression::mul(vec![Expression::integer(3), Expression::symbol(y.clone())]),
314 Expression::integer(1),
315 ]);
316 let hessian = HessianOperations::compute(&linear, &[x.clone(), y.clone()]);
317 for value in hessian.iter().flatten() {
318 assert_eq!(value.simplify(), expr!(0));
319 }
320 }
321 #[test]
322 fn test_hessian_symmetry() {
323 let x = symbol!(x);
324 let y = symbol!(y);
325 let expr = Expression::add(vec![
326 Expression::mul(vec![
327 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
328 Expression::symbol(y.clone()),
329 ]),
330 Expression::mul(vec![
331 Expression::symbol(x.clone()),
332 Expression::pow(Expression::symbol(y.clone()), Expression::integer(2)),
333 ]),
334 ]);
335 let hessian = HessianOperations::compute(&expr, &[x.clone(), y.clone()]);
336 assert_eq!(hessian[0][1], hessian[1][0]);
337 }
338 #[test]
339 fn test_trigonometric_hessian() {
340 let x = symbol!(x);
341 let y = symbol!(y);
342 let trig_expr = Expression::add(vec![
343 Expression::function("sin", vec![Expression::symbol(x.clone())]),
344 Expression::function("cos", vec![Expression::symbol(y.clone())]),
345 ]);
346 let hessian = HessianOperations::compute(&trig_expr, &[x.clone(), y.clone()]);
347 assert_eq!(hessian.len(), 2);
348 assert!(!hessian[0][0].is_zero());
349 assert!(!hessian[1][1].is_zero());
350 assert_eq!(hessian[0][1].simplify(), Expression::integer(0));
351 assert_eq!(hessian[1][0].simplify(), Expression::integer(0));
352 }
353}