use crate::tensor_ops as T;
use crate::{error::AutogradError, tensor::Tensor, Context, Float, NdArray, Result};
pub mod extensions;
pub mod hessian;
pub mod jacobian;
pub mod vjp;
pub fn hessian_vector_product<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
v: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if f.shape().len() > 1 || (f.shape().len() == 1 && f.shape()[0] != 1) {
return Err(AutogradError::shape_error(
"Function f must be scalar (shape [] or [1])".to_string(),
));
}
if x.shape() != v.shape() {
return Err(AutogradError::shape_error(format!(
"Shapes of x {:?} and v {:?} must match",
x.shape(),
v.shape()
)));
}
let grad_f = &crate::tensor_ops::grad(&[*f], &[*x])[0];
let grad_f_flat = crate::tensor_ops::flatten(*grad_f);
let v_flat = crate::tensor_ops::flatten(*v);
let dot_product = crate::tensor_ops::reduction::sum_all(grad_f_flat * v_flat);
let hvp = &crate::tensor_ops::grad(&[dot_product], &[*x])[0];
Ok(*hvp)
}
pub fn hessian<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if f.shape().len() > 1 || (f.shape().len() == 1 && f.shape()[0] != 1) {
return Err(AutogradError::shape_error(
"Function f must be scalar".to_string(),
));
}
let x_shape = x.shape();
let n: usize = x_shape.iter().product();
if n > 1000 {
eprintln!("Warning: Computing full Hessian for {} parameters. Consider using hessian_vector_product instead.", n);
}
let grad_f = &crate::tensor_ops::grad(&[*f], &[*x])[0];
let grad_f_flat = crate::tensor_ops::flatten(*grad_f);
let mut hessian_rows = Vec::with_capacity(n);
for i in 0..n {
let grad_i = crate::tensor_ops::slice(grad_f_flat, [i as isize], [(i + 1) as isize]);
let hessian_row = &crate::tensor_ops::grad(&[grad_i], &[*x])[0];
let hessian_row_flat = crate::tensor_ops::flatten(*hessian_row);
hessian_rows.push(hessian_row_flat);
}
let hessian_tensor = crate::tensor_ops::linear_algebra::concat(&hessian_rows, 0);
Ok(hessian_tensor)
}
pub fn jacobian_vector_product<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
v: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if x.shape() != v.shape() {
return Err(AutogradError::shape_error(format!(
"Shapes of x {:?} and v {:?} must match",
x.shape(),
v.shape()
)));
}
let grad_f = &crate::tensor_ops::grad(&[*f], &[*x])[0];
let jvp = *grad_f * *v;
let jvp_flat = crate::tensor_ops::flatten(jvp);
let jvp = crate::tensor_ops::reduction::sum_all(jvp_flat);
Ok(jvp)
}
pub fn vector_jacobian_product<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
v: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if f.shape() != v.shape() {
return Err(AutogradError::shape_error(format!(
"Shapes of f {:?} and v {:?} must match",
f.shape(),
v.shape()
)));
}
let vf = crate::tensor_ops::reduction::sum_all(*v * *f);
let vjp = &crate::tensor_ops::grad(&[vf], &[*x])[0];
Ok(*vjp)
}
pub fn mixed_partial<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
y: &Tensor<'graph, T>,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if f.shape().len() > 1 || (f.shape().len() == 1 && f.shape()[0] != 1) {
return Err(AutogradError::shape_error(
"Function f must be scalar".to_string(),
));
}
let df_dx = &crate::tensor_ops::grad(&[*f], &[*x])[0];
let d2f_dxdy = &crate::tensor_ops::grad(&[*df_dx], &[*y])[0];
Ok(*d2f_dxdy)
}
pub fn nth_order_gradient<'graph, T: Float>(
f: &Tensor<'graph, T>,
x: &Tensor<'graph, T>,
order: usize,
ctx: &'graph Context<'graph, T>,
) -> Result<Tensor<'graph, T>> {
if order == 0 {
return Ok(*f);
}
if f.shape().len() > 1 || (f.shape().len() == 1 && f.shape()[0] != 1) {
return Err(AutogradError::shape_error(
"Function f must be scalar".to_string(),
));
}
let mut current_grad = crate::tensor_ops::grad(&[*f], &[*x])[0];
for _ in 1..order {
let scalar_grad = crate::tensor_ops::reduction::sum_all(current_grad);
current_grad = crate::tensor_ops::grad(&[scalar_grad], &[*x])[0];
}
Ok(current_grad)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_ops::*;
#[test]
fn test_hessian_vector_product() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[2]);
let axes = [0_isize];
let f: crate::tensor::Tensor<f64> = reduce_sum(x * x, &axes, false);
let v = convert_to_tensor(
scirs2_core::ndarray::Array1::from(vec![1.0, 1.0]).into_dyn(),
ctx,
);
let hvp = hessian_vector_product(&f, &x, &v, ctx).expect("Should compute HVP");
let x_val = scirs2_core::ndarray::arr1(&[1.0, 1.0]);
let result = ctx
.evaluator()
.push(&hvp)
.feed(x, x_val.view().into_dyn())
.run();
let result_data = result[0]
.as_ref()
.expect("Should evaluate")
.as_slice()
.expect("Should get slice");
assert!((result_data[0] - 2.0).abs() < 1e-6);
assert!((result_data[1] - 2.0).abs() < 1e-6);
});
}
#[test]
fn test_mixed_partial() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let y = ctx.placeholder("y", &[]);
let f = x * y;
let mixed = mixed_partial(&f, &x, &y, ctx).expect("Should compute mixed partial");
let result = mixed.eval(ctx).expect("Should evaluate");
let result_val = result.first().copied().unwrap_or(0.0);
assert!((result_val - 1.0).abs() < 1e-6);
});
}
#[test]
fn test_nth_order_gradient() {
crate::run(|ctx: &mut Context<f64>| {
let x = ctx.placeholder("x", &[]);
let f = x * x * x;
let grad1 = nth_order_gradient(&f, &x, 1, ctx).expect("Should compute 1st derivative");
let grad2 = nth_order_gradient(&f, &x, 2, ctx).expect("Should compute 2nd derivative");
let x_val = scirs2_core::ndarray::arr0(2.0);
let result1 = ctx
.evaluator()
.push(&grad1)
.feed(x, x_val.view().into_dyn())
.run();
let val1 = result1[0]
.as_ref()
.expect("Should evaluate")
.first()
.copied()
.unwrap_or(0.0);
assert!((val1 - 12.0).abs() < 1e-6);
let result2 = ctx
.evaluator()
.push(&grad2)
.feed(x, x_val.view().into_dyn())
.run();
let val2 = result2[0]
.as_ref()
.expect("Should evaluate")
.first()
.copied()
.unwrap_or(0.0);
assert!((val2 - 12.0).abs() < 1e-6); });
}
}