use super::DualTensor;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::TensorOps;
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn jvp<R, C, F>(
f: F,
primals: &[&Tensor<R>],
tangents: &[&Tensor<R>],
client: &C,
) -> Result<(Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
F: FnOnce(&[DualTensor<R>], &C) -> Result<DualTensor<R>>,
{
assert_eq!(
primals.len(),
tangents.len(),
"Number of primals ({}) must match number of tangents ({})",
primals.len(),
tangents.len()
);
let dual_inputs: Vec<DualTensor<R>> = primals
.iter()
.zip(tangents.iter())
.map(|(p, t)| DualTensor::with_tangent((*p).clone(), (*t).clone()))
.collect();
let dual_output = f(&dual_inputs, client)?;
let output_primal = dual_output.primal().clone();
let output_tangent = match dual_output.tangent() {
Some(t) => t.clone(),
None => Tensor::zeros(
output_primal.shape(),
output_primal.dtype(),
output_primal.device(),
),
};
Ok((output_primal, output_tangent))
}
pub fn jvp_multi<R, C, F>(
f: F,
primals: &[&Tensor<R>],
tangents: &[&Tensor<R>],
client: &C,
) -> Result<(Vec<Tensor<R>>, Vec<Tensor<R>>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
F: FnOnce(&[DualTensor<R>], &C) -> Result<Vec<DualTensor<R>>>,
{
assert_eq!(
primals.len(),
tangents.len(),
"Number of primals ({}) must match number of tangents ({})",
primals.len(),
tangents.len()
);
let dual_inputs: Vec<DualTensor<R>> = primals
.iter()
.zip(tangents.iter())
.map(|(p, t)| DualTensor::with_tangent((*p).clone(), (*t).clone()))
.collect();
let dual_outputs = f(&dual_inputs, client)?;
let mut output_primals = Vec::with_capacity(dual_outputs.len());
let mut output_tangents = Vec::with_capacity(dual_outputs.len());
for dual in dual_outputs {
let primal = dual.primal().clone();
let tangent = match dual.tangent() {
Some(t) => t.clone(),
None => Tensor::zeros(primal.shape(), primal.dtype(), primal.device()),
};
output_primals.push(primal);
output_tangents.push(tangent);
}
Ok((output_primals, output_tangents))
}
pub fn jacobian_forward<R, C, F>(f: F, x: &Tensor<R>, client: &C) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
F: Fn(&DualTensor<R>, &C) -> Result<DualTensor<R>>,
{
let n = x.numel();
let device = x.device();
let dtype = x.dtype();
let mut columns: Vec<Tensor<R>> = Vec::with_capacity(n);
for j in 0..n {
let mut v_data = vec![0.0f64; n];
v_data[j] = 1.0;
let v = match dtype {
crate::dtype::DType::F32 => {
let v_f32: Vec<f32> = v_data.iter().map(|&x| x as f32).collect();
Tensor::<R>::from_slice(&v_f32, x.shape(), device)
}
crate::dtype::DType::F64 => Tensor::<R>::from_slice(&v_data, x.shape(), device),
_ => {
Tensor::<R>::from_slice(&v_data, x.shape(), device)
}
};
let dual_x = DualTensor::with_tangent(x.clone(), v);
let dual_y = f(&dual_x, client)?;
let col = match dual_y.tangent() {
Some(t) => t.clone(),
None => Tensor::zeros(dual_y.shape(), dtype, device),
};
columns.push(col);
}
let col_refs: Vec<&Tensor<R>> = columns.iter().collect();
client.stack(&col_refs, 1)
}
pub fn hvp<R, C, F>(grad_f: F, x: &Tensor<R>, v: &Tensor<R>, client: &C) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
F: Fn(&DualTensor<R>, &C) -> Result<DualTensor<R>>,
{
let (_, hvp_result) = jvp(
|inputs, c| {
assert_eq!(inputs.len(), 1);
grad_f(&inputs[0], c)
},
&[x],
&[v],
client,
)?;
Ok(hvp_result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::dual_ops::*;
use crate::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(device, client)
}
#[test]
fn test_jvp_square() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let (y, dy) = jvp(
|inputs, c| {
let x = &inputs[0];
dual_mul(x, x, c)
},
&[&x],
&[&v],
&client,
)
.unwrap();
assert_eq!(y.to_vec::<f32>(), [9.0]);
assert_eq!(dy.to_vec::<f32>(), [6.0]);
}
#[test]
fn test_jvp_sum_squares() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device);
let (y, dy) = jvp(
|inputs, c| {
let x = &inputs[0];
let x_sq = dual_mul(x, x, c)?;
dual_sum(&x_sq, &[0], false, c)
},
&[&x],
&[&v],
&client,
)
.unwrap();
assert_eq!(y.to_vec::<f32>(), [14.0]);
assert_eq!(dy.to_vec::<f32>(), [12.0]);
}
#[test]
fn test_jvp_chain_rule() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let (y, dy) = jvp(
|inputs, c| {
let x = &inputs[0];
let x_sq = dual_mul(x, x, c)?;
dual_exp(&x_sq, c)
},
&[&x],
&[&v],
&client,
)
.unwrap();
let e = std::f32::consts::E;
assert!((y.to_vec::<f32>()[0] - e).abs() < 1e-5);
assert!((dy.to_vec::<f32>()[0] - 2.0 * e).abs() < 1e-4);
}
#[test]
fn test_jvp_two_inputs() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let y = Tensor::<CpuRuntime>::from_slice(&[3.0f32], &[1], &device);
let vx = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let vy = Tensor::<CpuRuntime>::from_slice(&[0.0f32], &[1], &device);
let (f, df) = jvp(
|inputs, c| {
let x = &inputs[0];
let y = &inputs[1];
dual_mul(x, y, c)
},
&[&x, &y],
&[&vx, &vy],
&client,
)
.unwrap();
assert_eq!(f.to_vec::<f32>(), [6.0]);
assert_eq!(df.to_vec::<f32>(), [3.0]);
}
#[test]
fn test_jvp_matmul() {
let (device, client) = setup();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2, 1], &device);
let da = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0], &[1, 2], &device);
let db = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2, 1], &device);
let (y, dy) = jvp(
|inputs, c| {
let a = &inputs[0];
let b = &inputs[1];
dual_matmul(a, b, c)
},
&[&a, &b],
&[&da, &db],
&client,
)
.unwrap();
assert_eq!(y.to_vec::<f32>(), [3.0]);
assert_eq!(dy.to_vec::<f32>(), [1.0]);
}
#[test]
fn test_jacobian_forward_linear() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let jacobian =
jacobian_forward(|dual_x, c| dual_mul_scalar(dual_x, 2.0, c), &x, &client).unwrap();
assert_eq!(jacobian.shape(), &[3, 3]);
let j: Vec<f32> = jacobian.to_vec();
assert!((j[0] - 2.0).abs() < 1e-6); assert!((j[1] - 0.0).abs() < 1e-6); assert!((j[2] - 0.0).abs() < 1e-6); assert!((j[3] - 0.0).abs() < 1e-6); assert!((j[4] - 2.0).abs() < 1e-6); assert!((j[5] - 0.0).abs() < 1e-6); assert!((j[6] - 0.0).abs() < 1e-6); assert!((j[7] - 0.0).abs() < 1e-6); assert!((j[8] - 2.0).abs() < 1e-6); }
#[test]
fn test_jvp_multi() {
let (device, client) = setup();
let x = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device);
let (ys, dys) = jvp_multi(
|inputs, c| {
let x = &inputs[0];
let x_sq = dual_mul(x, x, c)?;
let x_cube = dual_mul(&x_sq, x, c)?;
Ok(vec![x_sq, x_cube])
},
&[&x],
&[&v],
&client,
)
.unwrap();
assert_eq!(ys.len(), 2);
assert_eq!(dys.len(), 2);
assert_eq!(ys[0].to_vec::<f32>(), [4.0]);
assert_eq!(ys[1].to_vec::<f32>(), [8.0]);
assert_eq!(dys[0].to_vec::<f32>(), [4.0]); assert_eq!(dys[1].to_vec::<f32>(), [12.0]); }
}