use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use super::hyperdual::HyperDual;
#[derive(Debug, Clone)]
pub struct TaylorExpansion2<T> {
pub center: Vec<T>,
pub value: T,
pub gradient: Vec<T>,
pub hessian_flat: Vec<T>,
pub dim: usize,
}
impl<T: Float> TaylorExpansion2<T> {
pub fn evaluate(&self, x: &[T]) -> Result<T> {
let n = self.dim;
if x.len() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: vec![x.len()],
});
}
let h: Vec<T> = (0..n).map(|i| x[i] - self.center[i]).collect();
let mut result = self.value;
for i in 0..n {
result = result + self.gradient[i] * h[i];
}
let two = T::one() + T::one();
for i in 0..n {
for j in 0..n {
result = result + self.hessian_flat[i * n + j] * h[i] * h[j] / two;
}
}
Ok(result)
}
pub fn hessian_element(&self, i: usize, j: usize) -> T {
self.hessian_flat[i * self.dim + j]
}
}
pub fn multivariate_taylor<F, T>(f: F, center: &[T]) -> Result<TaylorExpansion2<T>>
where
F: Fn(&[HyperDual<T>]) -> HyperDual<T>,
T: Float,
{
let n = center.len();
if n == 0 {
return Err(NumRs2Error::InvalidInput(
"Center point must be non-empty".to_string(),
));
}
let mut value = T::zero();
let mut gradient = vec![T::zero(); n];
let mut hessian_flat = vec![T::zero(); n * n];
for i in 0..n {
for j in i..n {
let inputs: Vec<HyperDual<T>> = (0..n)
.map(|k| HyperDual::make_variable(center[k], k == i, k == j))
.collect();
let result = f(&inputs);
if i == 0 && j == 0 {
value = result.real();
}
if i == j {
gradient[i] = result.eps1();
}
let h_ij = result.eps1eps2();
hessian_flat[i * n + j] = h_ij;
hessian_flat[j * n + i] = h_ij;
}
}
Ok(TaylorExpansion2 {
center: center.to_vec(),
value,
gradient,
hessian_flat,
dim: n,
})
}