use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use super::hyperdual::HyperDual;
use crate::autodiff::{Tape, Var};
pub fn hessian_exact<F, T>(f: F, x: &[T]) -> Result<Array<T>>
where
F: Fn(&[HyperDual<T>]) -> HyperDual<T>,
T: Float,
{
let n = x.len();
if n == 0 {
return Err(NumRs2Error::InvalidInput(
"Input vector must be non-empty for Hessian computation".to_string(),
));
}
let mut hess_data = vec![T::zero(); n * n];
for i in 0..n {
for j in i..n {
let inputs: Vec<HyperDual<T>> = (0..n)
.map(|k| HyperDual::make_variable(x[k], k == i, k == j))
.collect();
let result = f(&inputs);
let h_ij = result.eps1eps2();
hess_data[i * n + j] = h_ij;
hess_data[j * n + i] = h_ij;
}
}
Ok(Array::from_vec(hess_data).reshape(&[n, n]))
}
pub fn hessian_vector_product<F, T>(f: F, x: &[T], v: &[T]) -> Result<Vec<T>>
where
F: Fn(&[HyperDual<T>]) -> HyperDual<T>,
T: Float,
{
let n = x.len();
if n == 0 {
return Err(NumRs2Error::InvalidInput(
"Input vector must be non-empty".to_string(),
));
}
if v.len() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: vec![v.len()],
});
}
let mut hv = Vec::with_capacity(n);
for i in 0..n {
let inputs: Vec<HyperDual<T>> = (0..n)
.map(|k| {
HyperDual::new(
x[k],
if k == i { T::one() } else { T::zero() },
v[k],
T::zero(),
)
})
.collect();
let result = f(&inputs);
hv.push(result.eps1eps2());
}
Ok(hv)
}
pub fn hessian_forward_over_reverse<F, T>(f: F, x: &[T]) -> Result<Array<T>>
where
F: Fn(&mut Tape<T>, &[Var]) -> Var,
T: Float,
{
let n = x.len();
if n == 0 {
return Err(NumRs2Error::InvalidInput(
"Input vector must be non-empty for Hessian computation".to_string(),
));
}
let eps = T::epsilon().sqrt();
let two = T::one() + T::one();
let compute_gradient = |point: &[T]| -> Vec<T> {
let mut tape = Tape::new();
let vars: Vec<Var> = point.iter().map(|&v| tape.var(v)).collect();
let output = f(&mut tape, &vars);
tape.backward(output);
vars.iter().map(|&v| tape.grad(v)).collect()
};
let mut hess_data = Vec::with_capacity(n * n);
for i in 0..n {
let mut x_plus = x.to_vec();
let mut x_minus = x.to_vec();
x_plus[i] = x_plus[i] + eps;
x_minus[i] = x_minus[i] - eps;
let grad_plus = compute_gradient(&x_plus);
let grad_minus = compute_gradient(&x_minus);
for j in 0..n {
let h_ij = (grad_plus[j] - grad_minus[j]) / (two * eps);
hess_data.push(h_ij);
}
}
Ok(Array::from_vec(hess_data).reshape(&[n, n]))
}