use crate::error::{NumRs2Error, Result};
use num_traits::Float;
use super::hyperdual::{float_const, HyperDual};
use crate::autodiff::Dual;
pub fn compute_gradient_ad<F, T>(f: F, x: &[T]) -> Result<Vec<T>>
where
F: Fn(&[Dual<T>]) -> Dual<T>,
T: Float,
{
let n = x.len();
let mut grad = Vec::with_capacity(n);
for i in 0..n {
let dual_input: Vec<Dual<T>> = (0..n)
.map(|k| {
let deriv = if k == i { T::one() } else { T::zero() };
Dual::new(x[k], deriv)
})
.collect();
let result = f(&dual_input);
grad.push(result.deriv());
}
Ok(grad)
}
pub fn compute_gradient_hyperdual<F, T>(f: &F, x: &[T]) -> Result<Vec<T>>
where
F: Fn(&[HyperDual<T>]) -> HyperDual<T>,
T: Float,
{
let n = x.len();
let mut grad = Vec::with_capacity(n);
for i in 0..n {
let inputs: Vec<HyperDual<T>> = (0..n)
.map(|k| {
HyperDual::new(
x[k],
if k == i { T::one() } else { T::zero() },
T::zero(),
T::zero(),
)
})
.collect();
let result = f(&inputs);
grad.push(result.eps1());
}
Ok(grad)
}
pub(super) fn numerical_gradient_central<T: Float>(
f: &dyn Fn(&[T]) -> T,
x: &[T],
) -> Result<Vec<T>> {
let n = x.len();
let eps = T::epsilon().sqrt();
let two = T::one() + T::one();
let mut grad = Vec::with_capacity(n);
for i in 0..n {
let mut x_plus = x.to_vec();
let mut x_minus = x.to_vec();
x_plus[i] = x_plus[i] + eps;
x_minus[i] = x_minus[i] - eps;
let f_plus = f(&x_plus);
let f_minus = f(&x_minus);
grad.push((f_plus - f_minus) / (two * eps));
}
Ok(grad)
}
#[derive(Debug, Clone)]
pub struct GradientCheckResult<T> {
pub max_abs_error: T,
pub max_rel_error: T,
pub passed: bool,
pub component_errors: Vec<T>,
pub analytical: Vec<T>,
pub numerical: Vec<T>,
}
pub fn gradient_check<F, T>(
f: F,
x: &[T],
analytical_gradient: &[T],
tolerance: T,
) -> Result<GradientCheckResult<T>>
where
F: Fn(&[T]) -> T,
T: Float,
{
let n = x.len();
if analytical_gradient.len() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: vec![analytical_gradient.len()],
});
}
let numerical = numerical_gradient_central(&f, x)?;
let mut max_abs = T::zero();
let mut max_rel = T::zero();
let mut component_errors = Vec::with_capacity(n);
let threshold = T::epsilon() * float_const::<T>(100.0)?;
for i in 0..n {
let abs_err = (analytical_gradient[i] - numerical[i]).abs();
component_errors.push(abs_err);
if abs_err > max_abs {
max_abs = abs_err;
}
let max_mag = analytical_gradient[i].abs().max(numerical[i].abs());
let rel_err = if max_mag < threshold {
abs_err
} else {
abs_err / max_mag
};
if rel_err > max_rel {
max_rel = rel_err;
}
}
Ok(GradientCheckResult {
max_abs_error: max_abs,
max_rel_error: max_rel,
passed: max_rel < tolerance,
component_errors,
analytical: analytical_gradient.to_vec(),
numerical,
})
}
pub fn gradient_check_ad<F1, F2, T>(
f_ad: F1,
f_numerical: F2,
x: &[T],
tolerance: T,
) -> Result<GradientCheckResult<T>>
where
F1: Fn(&[Dual<T>]) -> Dual<T>,
F2: Fn(&[T]) -> T,
T: Float,
{
let ad_grad = compute_gradient_ad(f_ad, x)?;
gradient_check(f_numerical, x, &ad_grad, tolerance)
}