use crate::DType;
use numr::autograd::{DualTensor, Var, backward, jacobian_forward, jvp, var_mul, var_sum};
use numr::error::Result;
use numr::ops::TensorOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn jacobian_autograd<R, C, F>(client: &C, f: F, x: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
F: Fn(&DualTensor<R>, &C) -> Result<DualTensor<R>>,
{
jacobian_forward(f, x, client)
}
pub fn jvp_autograd<R, C, F>(
client: &C,
f: F,
x: &Tensor<R>,
v: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
F: FnOnce(&[DualTensor<R>], &C) -> Result<DualTensor<R>>,
{
jvp(f, &[x], &[v], client)
}
pub fn vjp_autograd<R, C, F>(
client: &C,
f: F,
x: &Tensor<R>,
v: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &C) -> Result<Var<R>>,
{
let x_var = Var::new(x.clone(), true);
let y_var = f(&x_var, client)?;
let fx = y_var.tensor().clone();
let v_var = Var::new(v.clone(), false);
let prod = var_mul(&y_var, &v_var, client)?;
let all_dims: Vec<usize> = (0..prod.tensor().shape().len()).collect();
let loss = var_sum(&prod, &all_dims, false, client)?;
let grads = backward(&loss, client)?;
let vjp_result = match grads.get(x_var.id()) {
Some(g) => g.clone(),
None => {
Tensor::<R>::zeros(x.shape(), x.dtype(), x.device())
}
};
Ok((fx, vjp_result))
}
pub fn vjp_with_params<R, C, F>(
client: &C,
f: F,
t: f64,
y: &Tensor<R>,
p: &Tensor<R>,
v: &Tensor<R>,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
R::Client: TensorOps<R>,
F: Fn(&Var<R>, &Var<R>, &Var<R>, &C) -> Result<Var<R>>,
{
let device = y.device();
let t_var = Var::new(Tensor::<R>::from_slice(&[t], &[1], device), false);
let y_var = Var::new(y.clone(), true);
let p_var = Var::new(p.clone(), true);
let f_var = f(&t_var, &y_var, &p_var, client)?;
let fx = f_var.tensor().clone();
let v_var = Var::new(v.clone(), false);
let prod = var_mul(&f_var, &v_var, client)?;
let all_dims: Vec<usize> = (0..prod.tensor().shape().len()).collect();
let loss = var_sum(&prod, &all_dims, false, client)?;
let grads = backward(&loss, client)?;
let vjp_y = match grads.get(y_var.id()) {
Some(g) => g.clone(),
None => Tensor::<R>::zeros(y.shape(), y.dtype(), device),
};
let vjp_p = match grads.get(p_var.id()) {
Some(g) => g.clone(),
None => Tensor::<R>::zeros(p.shape(), p.dtype(), device),
};
Ok((fx, vjp_y, vjp_p))
}
#[cfg(test)]
mod tests {
use super::*;
use numr::autograd::dual_ops::{dual_mul, dual_mul_scalar};
use numr::autograd::var_mul_scalar;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_jacobian_autograd_linear() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let jacobian =
jacobian_autograd(&client, |dual_x, c| dual_mul_scalar(dual_x, 2.0, c), &x).unwrap();
assert_eq!(jacobian.shape(), &[3, 3]);
let j: Vec<f64> = jacobian.to_vec();
assert!((j[0] - 2.0).abs() < 1e-10);
assert!((j[4] - 2.0).abs() < 1e-10);
assert!((j[8] - 2.0).abs() < 1e-10);
assert!(j[1].abs() < 1e-10);
assert!(j[2].abs() < 1e-10);
assert!(j[3].abs() < 1e-10);
}
#[test]
fn test_jacobian_autograd_quadratic() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let jacobian =
jacobian_autograd(&client, |dual_x, c| dual_mul(dual_x, dual_x, c), &x).unwrap();
let j: Vec<f64> = jacobian.to_vec();
assert!((j[0] - 2.0).abs() < 1e-10);
assert!((j[4] - 4.0).abs() < 1e-10);
assert!((j[8] - 6.0).abs() < 1e-10);
}
#[test]
fn test_jvp_autograd() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[2.0f64], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let (fx, jv) = jvp_autograd(
&client,
|inputs, c| {
let x = &inputs[0];
dual_mul(x, x, c)
},
&x,
&v,
)
.unwrap();
assert!((fx.to_vec::<f64>()[0] - 4.0).abs() < 1e-10);
assert!((jv.to_vec::<f64>()[0] - 4.0).abs() < 1e-10);
}
#[test]
fn test_vjp_autograd_simple() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[2.0f64], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let (fx, vjp_result) =
vjp_autograd(&client, |x_var, c| var_mul(x_var, x_var, c), &x, &v).unwrap();
assert!((fx.to_vec::<f64>()[0] - 4.0).abs() < 1e-10);
assert!((vjp_result.to_vec::<f64>()[0] - 4.0).abs() < 1e-10);
}
#[test]
fn test_vjp_autograd_linear() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0, 1.0], &[3], &device);
let (fx, vjp_result) =
vjp_autograd(&client, |x_var, c| var_mul_scalar(x_var, 2.0, c), &x, &v).unwrap();
let fx_vals: Vec<f64> = fx.to_vec();
assert!((fx_vals[0] - 2.0).abs() < 1e-10);
assert!((fx_vals[1] - 4.0).abs() < 1e-10);
assert!((fx_vals[2] - 6.0).abs() < 1e-10);
let vjp_vals: Vec<f64> = vjp_result.to_vec();
assert!((vjp_vals[0] - 2.0).abs() < 1e-10);
assert!((vjp_vals[1] - 2.0).abs() < 1e-10);
assert!((vjp_vals[2] - 2.0).abs() < 1e-10);
}
#[test]
fn test_vjp_with_params() {
let (device, client) = setup();
let y = Tensor::<CpuRuntime>::from_slice(&[2.0f64], &[1], &device);
let p = Tensor::<CpuRuntime>::from_slice(&[3.0f64], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[1], &device);
let (fx, vjp_y, vjp_p) = vjp_with_params(
&client,
|_t, y_var, p_var, c| var_mul(p_var, y_var, c),
0.0,
&y,
&p,
&v,
)
.unwrap();
assert!((fx.to_vec::<f64>()[0] - 6.0).abs() < 1e-10);
assert!(
(vjp_y.to_vec::<f64>()[0] - 3.0).abs() < 1e-10,
"vjp_y = {}",
vjp_y.to_vec::<f64>()[0]
);
assert!(
(vjp_p.to_vec::<f64>()[0] - 2.0).abs() < 1e-10,
"vjp_p = {}",
vjp_p.to_vec::<f64>()[0]
);
}
}