use crate::error::AutogradError;
use crate::tensor::Tensor;
use crate::{Context, Float, Result};
pub fn hessian_matrix<'graph, F, Func>(
f: Func,
x: &Tensor<'graph, F>,
ctx: &'graph Context<'graph, F>,
n: usize,
) -> Result<Tensor<'graph, F>>
where
F: Float,
Func: Fn(&Tensor<'graph, F>) -> Tensor<'graph, F> + Copy,
{
if n == 0 {
return Err(AutogradError::shape_error(
"hessian_matrix: dimension n must be positive".to_string(),
));
}
let mut hessian_rows = Vec::with_capacity(n);
for i in 0..n {
let mut e_i_vec = vec![F::zero(); n];
e_i_vec[i] = F::one();
let e_i_arr = scirs2_core::ndarray::Array1::from(e_i_vec).into_dyn();
let e_i = crate::tensor_ops::convert_to_tensor(e_i_arr, ctx);
let h_row = hvp(|t| f(t), x, &e_i, ctx)?;
hessian_rows.push(crate::tensor_ops::flatten(h_row));
}
Ok(crate::tensor_ops::linear_algebra::concat(&hessian_rows, 0))
}
pub fn hvp<'graph, F, Func>(
f: Func,
x: &Tensor<'graph, F>,
v: &Tensor<'graph, F>,
ctx: &'graph Context<'graph, F>,
) -> Result<Tensor<'graph, F>>
where
F: Float,
Func: Fn(&Tensor<'graph, F>) -> Tensor<'graph, F>,
{
let x_shape = x.shape();
let v_shape = v.shape();
if x_shape != v_shape {
return Err(AutogradError::shape_error(format!(
"hvp: x shape {:?} and v shape {:?} must match",
x_shape, v_shape
)));
}
let y = f(x);
let g = crate::tensor_ops::grad(&[y], &[*x])[0];
let gv = g * *v;
let result = crate::tensor_ops::grad(&[gv], &[*x])[0];
Ok(result)
}
pub fn jacobian_matrix<'graph, F, Func>(
f: Func,
x: &Tensor<'graph, F>,
ctx: &'graph Context<'graph, F>,
m: usize,
n: usize,
) -> Result<Tensor<'graph, F>>
where
F: Float,
Func: Fn(&Tensor<'graph, F>) -> Tensor<'graph, F>,
{
if m == 0 || n == 0 {
return Err(AutogradError::shape_error(
"jacobian_matrix: dimensions m and n must be positive".to_string(),
));
}
let y = f(x);
let y_flat = crate::tensor_ops::flatten(y);
let mut rows = Vec::with_capacity(m);
for i in 0..m {
let y_i = crate::tensor_ops::slice(y_flat, [i as isize], [(i + 1) as isize]);
let grad_i = crate::tensor_ops::grad(&[y_i], &[*x])[0];
rows.push(crate::tensor_ops::flatten(grad_i));
}
Ok(crate::tensor_ops::linear_algebra::concat(&rows, 0))
}
pub fn jvp<'graph, F, Func>(
f: Func,
x: &Tensor<'graph, F>,
v: &Tensor<'graph, F>,
ctx: &'graph Context<'graph, F>,
m: usize,
) -> Result<Tensor<'graph, F>>
where
F: Float,
Func: Fn(&Tensor<'graph, F>) -> Tensor<'graph, F>,
{
let x_shape = x.shape();
let v_shape = v.shape();
if x_shape != v_shape {
return Err(AutogradError::shape_error(format!(
"jvp: x shape {:?} and v shape {:?} must match",
x_shape, v_shape
)));
}
if m == 0 {
return Err(AutogradError::shape_error(
"jvp: output dimension m must be positive".to_string(),
));
}
let y = f(x);
if m == 1 {
let g = crate::tensor_ops::grad(&[y], &[*x])[0];
let gv = g * *v;
let axes = (0..gv.shape().len())
.map(|i| i as isize)
.collect::<Vec<_>>();
let jvp_scalar = if axes.is_empty() {
gv
} else {
crate::tensor_ops::reduction::sum_all(gv)
};
return Ok(jvp_scalar);
}
let y_flat = crate::tensor_ops::flatten(y);
let v_flat = crate::tensor_ops::flatten(*v);
let mut jvp_elements = Vec::with_capacity(m);
for i in 0..m {
let y_i = crate::tensor_ops::slice(y_flat, [i as isize], [(i + 1) as isize]);
let grad_i = crate::tensor_ops::grad(&[y_i], &[*x])[0];
let grad_i_flat = crate::tensor_ops::flatten(grad_i);
let jvp_i = crate::tensor_ops::reduction::sum_all(grad_i_flat * v_flat);
let jvp_i_1d = crate::tensor_ops::reshape(jvp_i, &[1_isize]);
jvp_elements.push(jvp_i_1d);
}
Ok(crate::tensor_ops::linear_algebra::concat(&jvp_elements, 0))
}
pub fn vjp<'graph, F, Func>(
f: Func,
x: &Tensor<'graph, F>,
v: &Tensor<'graph, F>,
ctx: &'graph Context<'graph, F>,
) -> Result<Tensor<'graph, F>>
where
F: Float,
Func: Fn(&Tensor<'graph, F>) -> Tensor<'graph, F>,
{
let y = f(x);
let y_flat = crate::tensor_ops::flatten(y);
let v_flat = crate::tensor_ops::flatten(*v);
let dot = crate::tensor_ops::reduction::sum_all(v_flat * y_flat);
let result = crate::tensor_ops::grad(&[dot], &[*x])[0];
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
#[test]
fn test_hessian_matrix_quadratic() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let h = hessian_matrix(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
ctx,
2,
)
.expect("hessian_matrix should succeed");
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 1.0]);
let result = ctx
.evaluator()
.push(&h)
.feed(x, x_val.view().into_dyn())
.run();
let arr = result[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 2.0).abs() < 1e-6, "H[0,0] expected 2, got {}", s[0]);
assert!((s[1]).abs() < 1e-6, "H[0,1] expected 0, got {}", s[1]);
assert!((s[2]).abs() < 1e-6, "H[1,0] expected 0, got {}", s[2]);
assert!((s[3] - 2.0).abs() < 1e-6, "H[1,1] expected 2, got {}", s[3]);
});
}
#[test]
fn test_hessian_matrix_constant_hessian() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let h = hessian_matrix(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
ctx,
2,
)
.expect("hessian_matrix should succeed");
let x_val = scirs2_core::ndarray::arr1(&[3.0f64, 4.0]);
let result = ctx
.evaluator()
.push(&h)
.feed(x, x_val.view().into_dyn())
.run();
let arr = result[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 2.0).abs() < 1e-6, "H[0,0] = 2, got {}", s[0]);
assert!((s[1]).abs() < 1e-6, "H[0,1] = 0, got {}", s[1]);
assert!((s[2]).abs() < 1e-6, "H[1,0] = 0, got {}", s[2]);
assert!((s[3] - 2.0).abs() < 1e-6, "H[1,1] = 2, got {}", s[3]);
});
}
#[test]
fn test_hvp_diagonal_hessian() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let v_arr = scirs2_core::ndarray::arr1(&[1.0f64, 1.0]).into_dyn();
let v = convert_to_tensor(v_arr, ctx);
let w_arr = scirs2_core::ndarray::arr1(&[1.0f64, 2.0]).into_dyn();
let w = convert_to_tensor(w_arr, ctx);
let result = hvp(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
&v,
ctx,
)
.expect("hvp should succeed");
let x_val = scirs2_core::ndarray::arr1(&[1.0f64, 1.0]);
let out = ctx
.evaluator()
.push(&result)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 2.0).abs() < 1e-6, "H·v[0] expected 2, got {}", s[0]);
assert!((s[1] - 2.0).abs() < 1e-6, "H·v[1] expected 2, got {}", s[1]);
});
}
#[test]
fn test_hvp_shape_mismatch_error() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let v_arr = scirs2_core::ndarray::arr1(&[1.0f64, 1.0, 1.0]).into_dyn();
let v = convert_to_tensor(v_arr, ctx);
let result = hvp(|t| reduction::sum_all(*t), &x, &v, ctx);
assert!(result.is_err(), "mismatched shapes should return error");
});
}
#[test]
fn test_jacobian_matrix_linear() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let jac = jacobian_matrix(|t| *t, &x, ctx, 2, 2)
.expect("jacobian_matrix should succeed");
let x_val = scirs2_core::ndarray::arr1(&[3.0f64, 5.0]);
let out = ctx
.evaluator()
.push(&jac)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 1.0).abs() < 1e-6);
assert!((s[1]).abs() < 1e-6);
assert!((s[2]).abs() < 1e-6);
assert!((s[3] - 1.0).abs() < 1e-6);
});
}
#[test]
fn test_jacobian_matrix_nonlinear() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let jac = jacobian_matrix(
|t| {
let x0 = slice(*t, [0isize], [1isize]);
let x1 = slice(*t, [1isize], [2isize]);
linear_algebra::concat(&[x0 * x0, x0 * x1], 0)
},
&x,
ctx,
2,
2,
)
.expect("jacobian_matrix should succeed");
let x_val = scirs2_core::ndarray::arr1(&[2.0f64, 3.0]);
let out = ctx
.evaluator()
.push(&jac)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 4.0).abs() < 1e-6, "J[0,0] expected 4, got {}", s[0]);
assert!((s[1]).abs() < 1e-6, "J[0,1] expected 0, got {}", s[1]);
assert!((s[2] - 3.0).abs() < 1e-6, "J[1,0] expected 3, got {}", s[2]);
assert!((s[3] - 2.0).abs() < 1e-6, "J[1,1] expected 2, got {}", s[3]);
});
}
#[test]
fn test_jvp_unit_vectors() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let v_arr = scirs2_core::ndarray::arr1(&[1.0f64, 0.0]).into_dyn();
let v = convert_to_tensor(v_arr, ctx);
let result = jvp(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
&v,
ctx,
1, )
.expect("jvp should succeed");
let x_val = scirs2_core::ndarray::arr1(&[2.0f64, 3.0]);
let out = ctx
.evaluator()
.push(&result)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 4.0).abs() < 1e-6, "JVP expected 4, got {}", s[0]);
});
}
#[test]
fn test_vjp_squared_norm() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let v_arr = scirs2_core::ndarray::arr1(&[1.0f64]).into_dyn();
let v = convert_to_tensor(v_arr, ctx);
let result = vjp(
|t| {
let x0 = slice(*t, [0isize], [1isize]);
let x1 = slice(*t, [1isize], [2isize]);
reduction::sum_all(x0 * x0 + x1 * x1)
},
&x,
&v,
ctx,
)
.expect("vjp should succeed");
let x_val = scirs2_core::ndarray::arr1(&[2.0f64, 3.0]);
let out = ctx
.evaluator()
.push(&result)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let s = arr.as_slice().expect("slice");
assert!((s[0] - 4.0).abs() < 1e-6, "VJP[0] expected 4, got {}", s[0]);
assert!((s[1] - 6.0).abs() < 1e-6, "VJP[1] expected 6, got {}", s[1]);
});
}
#[test]
fn test_grad_x_squared_is_2x() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let y = x * x; let dy_dx = &crate::tensor_ops::grad(&[y], &[x])[0];
let x_val = scirs2_core::ndarray::arr0(3.0f64);
let out = ctx
.evaluator()
.push(dy_dx)
.feed(x, x_val.view().into_dyn())
.run();
let arr = out[0].as_ref().expect("should evaluate");
let val = arr.first().copied().expect("first element");
assert!((val - 6.0).abs() < 1e-9, "d(x^2)/dx at x=3 should be 6, got {}", val);
});
}
#[test]
fn test_jvp_vjp_consistency_scalar() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let v1_arr = scirs2_core::ndarray::arr1(&[1.0f64, 0.0]).into_dyn();
let v1 = convert_to_tensor(v1_arr, ctx);
let jvp_val = jvp(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
&v1,
ctx,
1,
)
.expect("jvp should succeed");
let v2_arr = scirs2_core::ndarray::arr1(&[1.0f64]).into_dyn();
let v2 = convert_to_tensor(v2_arr, ctx);
let vjp_val = vjp(
|t| {
let axes = [0_isize];
reduce_sum(*t * *t, &axes, false)
},
&x,
&v2,
ctx,
)
.expect("vjp should succeed");
let x_val = scirs2_core::ndarray::arr1(&[2.0f64, 1.0]);
let outs = ctx
.evaluator()
.push(&jvp_val)
.push(&vjp_val)
.feed(x, x_val.view().into_dyn())
.run();
let jvp_arr = outs[0].as_ref().expect("jvp eval");
let jvp_s = jvp_arr.as_slice().expect("jvp slice");
assert!((jvp_s[0] - 4.0).abs() < 1e-6, "JVP expected 4, got {}", jvp_s[0]);
let vjp_arr = outs[1].as_ref().expect("vjp eval");
let vjp_s = vjp_arr.as_slice().expect("vjp slice");
assert!((vjp_s[0] - 4.0).abs() < 1e-6, "VJP[0] expected 4, got {}", vjp_s[0]);
assert!((vjp_s[1] - 2.0).abs() < 1e-6, "VJP[1] expected 2, got {}", vjp_s[1]);
});
}
}