#[macro_export]
macro_rules! get_vhessian {
($f:ident, $func_name:ident) => {
get_vhessian!($f, $func_name, [f64]);
};
($f:ident, $func_name:ident, $param_type:ty) => {
fn $func_name<S, V>(x0: &V, p: &$param_type) -> Vec<V::DMatrixMxNf64>
where
S: Scalar,
V: Vector<S>,
{
let x0_hyper_dual = x0.clone().to_hyper_dual_vector();
let n = x0.len();
let f_result = $f(&x0_hyper_dual, p);
let m = f_result.len();
let mut hess = vec![x0.new_dmatrix_m_by_n_f64(n); m];
let mut x_perturbed;
for i in 0..n {
for j in i..n {
x_perturbed = x0_hyper_dual.clone();
if i == j {
let original = x_perturbed.vget(i);
x_perturbed.vset(i, HyperDual::new(original.get_a(), 1.0, 1.0, 0.0));
}
else {
let original_i = x_perturbed.vget(i);
let original_j = x_perturbed.vget(j);
x_perturbed.vset(i, HyperDual::new(original_i.get_a(), 1.0, 0.0, 0.0));
x_perturbed.vset(j, HyperDual::new(original_j.get_a(), 0.0, 1.0, 0.0));
}
let f_result = $f(&x_perturbed, p);
for k in 0..m {
let f_k = f_result.vget(k);
let second_derivative = f_k.get_d();
hess[k][(i, j)] = second_derivative;
hess[k][(j, i)] = second_derivative;
}
}
}
hess
}
};
}
#[cfg(test)]
mod tests {
use crate::{HyperDual, HyperDualVector};
use linalg_traits::{Mat, Matrix, Scalar, Vector};
use nalgebra::{DMatrix, DVector, SMatrix, SVector, dvector};
use ndarray::{Array1, Array2, array};
use numtest::*;
#[test]
fn test_vhessian_1() {
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> V::DVectorT<S> {
V::DVectorT::from_slice(&[x.vget(0).powi(3)])
}
let x0 = vec![2.0];
let p = [];
let hess = |x: &Vec<f64>| vec![Mat::from_row_slice(1, 1, &[6.0 * x[0]])];
get_vhessian!(f, hess_autodiff);
let hess_eval_autodiff: Vec<Mat<f64>> = hess_autodiff(&x0, &p);
let hess_eval: Vec<Mat<f64>> = hess(&x0);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[0], hess_eval[0], 16);
}
#[test]
fn test_vhessian_2() {
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> V::DVectorT<S> {
V::DVectorT::from_slice(&[x.vget(0).powi(2) + x.vget(1).powi(3)])
}
let x0 = SVector::from_row_slice(&[1.0, 2.0]);
let p = [];
let hess = |x: &SVector<f64, 2>| {
vec![SMatrix::<f64, 2, 2>::from_row_slice(&[
2.0,
0.0,
0.0,
6.0 * x[1],
])]
};
get_vhessian!(f, hess_autodiff);
let hess_eval_autodiff: Vec<DMatrix<f64>> = hess_autodiff(&x0, &p);
let hess_eval: Vec<SMatrix<f64, 2, 2>> = hess(&x0);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[0], hess_eval[0], 16);
}
#[test]
fn test_vhessian_3() {
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> V::DVectorT<S> {
V::DVectorT::from_slice(&[
x.vget(0).powi(5) * x.vget(1) + x.vget(0) * x.vget(1).sin().powi(3)
])
}
let x0 = array![1.0, 2.0];
let p = [];
let hess = |x: &Array1<f64>| {
vec![Array2::<f64>::from_row_slice(
2,
2,
&[
20.0 * x[0].powi(3) * x[1],
5.0 * x[0].powi(4) + 3.0 * x[1].sin().powi(2) * x[1].cos(),
5.0 * x[0].powi(4) + 3.0 * x[1].sin().powi(2) * x[1].cos(),
6.0 * x[0] * x[1].sin() * x[1].cos().powi(2) - 3.0 * x[0] * x[1].sin().powi(3),
],
)]
};
get_vhessian!(f, hess_autodiff);
let hess_eval_autodiff: Vec<Array2<f64>> = hess_autodiff(&x0, &p);
let hess_eval: Vec<Array2<f64>> = hess(&x0);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[0], hess_eval[0], 16);
}
#[test]
fn test_vhessian_4() {
fn f<S: Scalar, V: Vector<S>>(x: &V, _p: &[f64]) -> V::DVectorT<S> {
V::DVectorT::from_slice(&[
x.vget(0).powi(5) * x.vget(1) + x.vget(0) * x.vget(1).sin().powi(3),
x.vget(0).powi(3) + x.vget(1).powi(4)
- S::new(3.0) * x.vget(0).powi(2) * x.vget(1).powi(2),
])
}
let x0 = DVector::<f64>::from_slice(&[1.0, 2.0]);
let p = [];
let hess = |x: &DVector<f64>| {
vec![
DMatrix::<f64>::from_row_slice(
2,
2,
&[
20.0 * x[0].powi(3) * x[1],
5.0 * x[0].powi(4) + 3.0 * x[1].sin().powi(2) * x[1].cos(),
5.0 * x[0].powi(4) + 3.0 * x[1].sin().powi(2) * x[1].cos(),
6.0 * x[0] * x[1].sin() * x[1].cos().powi(2)
- 3.0 * x[0] * x[1].sin().powi(3),
],
),
DMatrix::<f64>::from_row_slice(
2,
2,
&[
6.0 * x[0] - 6.0 * x[1].powi(2),
-12.0 * x[0] * x[1],
-12.0 * x[0] * x[1],
12.0 * x[1].powi(2) - 6.0 * x[0].powi(2),
],
),
]
};
get_vhessian!(f, hess_autodiff);
let hess_eval_autodiff: Vec<DMatrix<f64>> = hess_autodiff(&x0, &p);
let hess_eval: Vec<DMatrix<f64>> = hess(&x0);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[0], hess_eval[0], 16);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[1], hess_eval[1], 16);
}
#[test]
fn test_vhessian_5() {
#[allow(clippy::many_single_char_names)]
fn f<S: Scalar, V: Vector<S>>(x: &V, p: &[f64]) -> V::DVectorT<S> {
let a = S::new(p[0]);
let b = S::new(p[1]);
let c = S::new(p[2]);
let d = S::new(p[3]);
V::DVectorT::from_slice(&[
a * x.vget(0).powi(2) * x.vget(1) + b * x.vget(1).powi(2),
c * x.vget(0) * x.vget(1) + d * x.vget(0).powi(2),
])
}
let p = [1.5, 2.0, 0.8, 3.0];
let x0 = dvector![1.0, -0.5];
let hess = |x: &DVector<f64>, p: &[f64]| {
vec![
DMatrix::<f64>::from_row_slice(
2,
2,
&[
2.0 * p[0] * x[1],
2.0 * p[0] * x[0],
2.0 * p[0] * x[0],
2.0 * p[1],
],
),
DMatrix::<f64>::from_row_slice(2, 2, &[2.0 * p[3], p[2], p[2], 0.0]),
]
};
get_vhessian!(f, hess_autodiff);
let hess_eval_autodiff: Vec<DMatrix<f64>> = hess_autodiff(&x0, &p);
let hess_eval: Vec<DMatrix<f64>> = hess(&x0, &p);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[0], hess_eval[0], 16);
assert_arrays_equal_to_decimal!(hess_eval_autodiff[1], hess_eval[1], 16);
}
#[test]
fn test_vhessian_custom_params() {
struct Data {
a: f64,
b: f64,
c: f64,
d: f64,
}
#[allow(clippy::many_single_char_names)]
fn f<S: Scalar, V: Vector<S>>(x: &V, p: &Data) -> V::DVectorT<S> {
let a = S::new(p.a);
let b = S::new(p.b);
let c = S::new(p.c);
let d = S::new(p.d);
V::DVectorT::from_slice(&[
a * x.vget(0).powi(2) * x.vget(1) + b * x.vget(1).powi(2),
c * x.vget(0) * x.vget(1) + d * x.vget(0).powi(2),
])
}
let p = Data {
a: 1.5,
b: 2.0,
c: 0.8,
d: 3.0,
};
let x0 = vec![1.0, -0.5];
get_vhessian!(f, hess, Data);
let hess_f0_true = Mat::from_row_slice(
2,
2,
&[
2.0 * p.a * x0[1],
2.0 * p.a * x0[0],
2.0 * p.a * x0[0],
2.0 * p.b,
],
);
let hess_f1_true = Mat::from_row_slice(2, 2, &[2.0 * p.d, p.c, p.c, 0.0]);
let hess_true = [hess_f0_true, hess_f1_true];
let hess_eval: Vec<Mat<f64>> = hess(&x0, &p);
assert_arrays_equal_to_decimal!(hess_eval[0], hess_true[0], 16);
assert_arrays_equal_to_decimal!(hess_eval[1], hess_true[1], 16);
}
}