use crate::error::AutogradError;
use crate::forward_mode::DualNumber;
use crate::tensor::Tensor;
use crate::{Context, Float, Result};
use num::Float as NumFloat;
use scirs2_core::ndarray::{Array1, Array2};
use std::fmt;
pub fn hessian_diagonal<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
n: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
validate_scalar(f)?;
if n == 0 {
return Err(AutogradError::shape_error(
"Dimension n must be positive".to_string(),
));
}
let mut diag_elements = Vec::with_capacity(n);
for i in 0..n {
let mut e_i_vec = vec![T::zero(); n];
e_i_vec[i] = T::one();
let e_i_arr = scirs2_core::ndarray::Array1::from(e_i_vec).into_dyn();
let e_i = crate::tensor_ops::convert_to_tensor(e_i_arr, ctx);
let hvp_i = super::hessian_vector_product(f, x, &e_i, ctx)?;
let hvp_i_flat = crate::tensor_ops::flatten(hvp_i);
let diag_elem = crate::tensor_ops::slice(hvp_i_flat, [i as isize], [(i + 1) as isize]);
diag_elements.push(diag_elem);
}
Ok(crate::tensor_ops::linear_algebra::concat(&diag_elements, 0))
}
pub fn laplacian<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
n: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
let diag = hessian_diagonal(f, x, n, ctx)?;
Ok(crate::tensor_ops::reduction::sum_all(diag))
}
pub fn fisher_information<'graph, T: Float>(
log_prob: &Tensor<'graph, T>,
params: &Tensor<'graph, T>,
n: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
validate_scalar(log_prob)?;
if n == 0 {
return Err(AutogradError::shape_error(
"Number of parameters must be positive".to_string(),
));
}
let score = crate::tensor_ops::grad(&[*log_prob], &[*params])[0];
let score_flat = crate::tensor_ops::flatten(score);
let mut fisher_rows = Vec::with_capacity(n);
for i in 0..n {
let score_i = crate::tensor_ops::slice(score_flat, [i as isize], [(i + 1) as isize]);
let row = score_i * score_flat;
fisher_rows.push(row);
}
Ok(crate::tensor_ops::linear_algebra::concat(&fisher_rows, 0))
}
pub fn fisher_diagonal<'graph, T: Float>(
log_prob: &Tensor<'graph, T>,
params: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
validate_scalar(log_prob)?;
let score = crate::tensor_ops::grad(&[*log_prob], &[*params])[0];
Ok(score * score)
}
pub fn gauss_newton_product<'graph, T: Float>(
residual: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
v: &Tensor<'graph, T>,
m: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
let r_flat = crate::tensor_ops::flatten(*residual);
let mut jv_elements = Vec::with_capacity(m);
for i in 0..m {
let r_i = crate::tensor_ops::slice(r_flat, [i as isize], [(i + 1) as isize]);
let grad_ri = crate::tensor_ops::grad(&[r_i], &[*x])[0];
let jv_i = crate::tensor_ops::reduction::sum_all(grad_ri * *v);
jv_elements.push(jv_i);
}
let jv = crate::tensor_ops::linear_algebra::concat(&jv_elements, 0);
let jv_flat = crate::tensor_ops::flatten(jv);
let weighted = crate::tensor_ops::reduction::sum_all(jv_flat * r_flat);
let gn_product = crate::tensor_ops::grad(&[weighted], &[*x])[0];
Ok(gn_product)
}
pub fn efficient_second_order_grad<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
n: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
validate_scalar(f)?;
let grad_f = crate::tensor_ops::grad(&[*f], &[*x])[0];
let grad_f_flat = crate::tensor_ops::flatten(grad_f);
let mut second_derivs = Vec::with_capacity(n);
for i in 0..n {
let grad_i = crate::tensor_ops::slice(grad_f_flat, [i as isize], [(i + 1) as isize]);
let d2_i = crate::tensor_ops::grad(&[grad_i], &[*x])[0];
let d2_i_flat = crate::tensor_ops::flatten(d2_i);
let diag_i = crate::tensor_ops::slice(d2_i_flat, [i as isize], [(i + 1) as isize]);
second_derivs.push(diag_i);
}
Ok(crate::tensor_ops::linear_algebra::concat(&second_derivs, 0))
}
pub fn hessian_diagonal_forward<F, Func>(f: &Func, x: &Array1<F>) -> Array1<F>
where
F: NumFloat + Copy + Send + Sync + fmt::Debug + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F>,
{
let n = x.len();
let mut diag = Vec::with_capacity(n);
let eps = F::from(1e-7).unwrap_or(F::epsilon());
let two = F::one() + F::one();
for i in 0..n {
let mut dual_plus = Vec::with_capacity(n);
let mut dual_minus = Vec::with_capacity(n);
for j in 0..n {
let tangent = if i == j { F::one() } else { F::zero() };
let x_plus = if i == j { x[j] + eps } else { x[j] };
let x_minus = if i == j { x[j] - eps } else { x[j] };
dual_plus.push(DualNumber::new(x_plus, tangent));
dual_minus.push(DualNumber::new(x_minus, tangent));
}
let f_plus = f(&dual_plus);
let f_minus = f(&dual_minus);
let d2 = (f_plus.tangent() - f_minus.tangent()) / (two * eps);
diag.push(d2);
}
Array1::from(diag)
}
pub fn laplacian_forward<F, Func>(f: &Func, x: &Array1<F>) -> F
where
F: NumFloat + Copy + Send + Sync + fmt::Debug + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F>,
{
let diag = hessian_diagonal_forward(f, x);
diag.iter().fold(F::zero(), |acc, &v| acc + v)
}
pub fn fisher_information_forward<F, Func>(log_prob: &Func, theta: &Array1<F>) -> Array2<F>
where
F: NumFloat + Copy + Send + Sync + fmt::Debug + 'static,
Func: Fn(&[DualNumber<F>]) -> DualNumber<F>,
{
let n = theta.len();
let grad = crate::forward_mode::gradient_forward(log_prob, theta);
let mut fisher = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
fisher[[i, j]] = grad[i] * grad[j];
}
}
fisher
}
fn validate_scalar<T: Float>(f: &Tensor<T>) -> Result<()> {
let shape = f.shape();
if shape.len() > 1 || (shape.len() == 1 && shape[0] != 1) {
return Err(AutogradError::shape_error(
"Function must be scalar (shape [] or [1])".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_hessian_diagonal_standalone() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let x_sq = x * x;
let coeffs = convert_to_tensor(Array1::from(vec![1.0, 2.0, 3.0]).into_dyn(), ctx);
let f = crate::tensor_ops::reduction::sum_all(coeffs * x_sq);
let diag = hessian_diagonal(&f, &x, 3, ctx).expect("Should compute diagonal");
let x_val = scirs2_core::ndarray::arr1(&[1.0, 1.0, 1.0]);
let result = ctx
.evaluator()
.push(&diag)
.feed(x, x_val.view().into_dyn())
.run();
let diag_arr = result[0].as_ref().expect("Should evaluate");
let diag_vals = diag_arr.as_slice().unwrap_or(&[]);
assert!((diag_vals[0] - 2.0).abs() < 1e-5);
assert!((diag_vals[1] - 4.0).abs() < 1e-5);
assert!((diag_vals[2] - 6.0).abs() < 1e-5);
});
}
#[test]
fn test_laplacian_graph() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let f = crate::tensor_ops::reduction::sum_all(x * x);
let lap = laplacian(&f, &x, 3, ctx).expect("Should compute Laplacian");
let x_val = scirs2_core::ndarray::arr1(&[1.0, 2.0, 3.0]);
let result = ctx
.evaluator()
.push(&lap)
.feed(x, x_val.view().into_dyn())
.run();
let lap_val = result[0]
.as_ref()
.expect("Should evaluate")
.first()
.copied()
.unwrap_or(0.0);
assert!((lap_val - 6.0).abs() < 1e-5);
});
}
#[test]
fn test_fisher_information_graph() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let neg_half = convert_to_tensor(scirs2_core::ndarray::arr0(-0.5).into_dyn(), ctx);
let log_prob = neg_half * crate::tensor_ops::reduction::sum_all(x * x);
let fisher = fisher_information(&log_prob, &x, 2, ctx).expect("Should compute Fisher");
let x_val = scirs2_core::ndarray::arr1(&[3.0, 4.0]);
let result = ctx
.evaluator()
.push(&fisher)
.feed(x, x_val.view().into_dyn())
.run();
let f_arr = result[0].as_ref().expect("Should evaluate");
let f_vals = f_arr.as_slice().unwrap_or(&[]);
assert!((f_vals[0] - 9.0).abs() < 1e-5);
assert!((f_vals[1] - 12.0).abs() < 1e-5);
assert!((f_vals[2] - 12.0).abs() < 1e-5);
assert!((f_vals[3] - 16.0).abs() < 1e-5);
});
}
#[test]
fn test_fisher_diagonal_graph() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let neg_half = convert_to_tensor(scirs2_core::ndarray::arr0(-0.5).into_dyn(), ctx);
let log_prob = neg_half * crate::tensor_ops::reduction::sum_all(x * x);
let fisher_d =
fisher_diagonal(&log_prob, &x, ctx).expect("Should compute Fisher diagonal");
let x_val = scirs2_core::ndarray::arr1(&[3.0, 4.0]);
let result = ctx
.evaluator()
.push(&fisher_d)
.feed(x, x_val.view().into_dyn())
.run();
let f_arr = result[0].as_ref().expect("Should evaluate");
let f_vals = f_arr.as_slice().unwrap_or(&[]);
assert!((f_vals[0] - 9.0).abs() < 1e-5);
assert!((f_vals[1] - 16.0).abs() < 1e-5);
});
}
#[test]
fn test_hessian_diagonal_forward_quadratic() {
let f = |xs: &[DualNumber<f64>]| {
let two = DualNumber::constant(2.0);
let three = DualNumber::constant(3.0);
two * xs[0] * xs[0] + three * xs[1] * xs[1]
};
let x = Array1::from(vec![1.0, 1.0]);
let diag = hessian_diagonal_forward(&f, &x);
assert!((diag[0] - 4.0).abs() < 1e-3);
assert!((diag[1] - 6.0).abs() < 1e-3);
}
#[test]
fn test_laplacian_forward_mode() {
let f = |xs: &[DualNumber<f64>]| xs[0] * xs[0] + xs[1] * xs[1] + xs[2] * xs[2];
let x = Array1::from(vec![1.0, 2.0, 3.0]);
let lap = laplacian_forward(&f, &x);
assert!((lap - 6.0).abs() < 1e-3);
}
#[test]
fn test_fisher_information_forward_mode() {
let log_prob = |xs: &[DualNumber<f64>]| {
let neg_half = DualNumber::constant(-0.5);
neg_half * (xs[0] * xs[0] + xs[1] * xs[1])
};
let theta = Array1::from(vec![3.0, 4.0]);
let fisher = fisher_information_forward(&log_prob, &theta);
assert!((fisher[[0, 0]] - 9.0).abs() < 1e-5);
assert!((fisher[[0, 1]] - 12.0).abs() < 1e-5);
assert!((fisher[[1, 0]] - 12.0).abs() < 1e-5);
assert!((fisher[[1, 1]] - 16.0).abs() < 1e-5);
}
#[test]
fn test_efficient_second_order_grad() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let x_sq = x * x;
let coeffs = convert_to_tensor(Array1::from(vec![1.0, 2.0]).into_dyn(), ctx);
let f = crate::tensor_ops::reduction::sum_all(coeffs * x_sq);
let d2 = efficient_second_order_grad(&f, &x, 2, ctx)
.expect("Should compute second-order grad");
let x_val = scirs2_core::ndarray::arr1(&[1.0, 1.0]);
let result = ctx
.evaluator()
.push(&d2)
.feed(x, x_val.view().into_dyn())
.run();
let d2_arr = result[0].as_ref().expect("Should evaluate");
let d2_vals = d2_arr.as_slice().unwrap_or(&[]);
assert!((d2_vals[0] - 2.0).abs() < 1e-5);
assert!((d2_vals[1] - 4.0).abs() < 1e-5);
});
}
#[test]
fn test_validate_scalar_rejects_non_scalar() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[3]);
let result = hessian_diagonal(&x, &x, 3, ctx);
});
}
#[test]
fn test_hessian_diagonal_dimension_zero_error() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let f = x * x;
let result = hessian_diagonal(&f, &x, 0, ctx);
assert!(result.is_err());
});
}
}