use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct DualTensor<T: Float> {
pub primal: Tensor<T>,
pub tangent: Tensor<T>,
}
impl<T: Float> DualTensor<T> {
pub fn new(primal: Tensor<T>, tangent: Tensor<T>) -> FerrotorchResult<Self> {
if primal.shape() != tangent.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"DualTensor: primal shape {:?} != tangent shape {:?}",
primal.shape(),
tangent.shape()
),
});
}
Ok(Self { primal, tangent })
}
pub fn constant(primal: Tensor<T>) -> FerrotorchResult<Self> {
let zero_data = vec![<T as num_traits::Zero>::zero(); primal.numel()];
let tangent = Tensor::from_storage(
TensorStorage::cpu(zero_data),
primal.shape().to_vec(),
false,
)?;
Ok(Self { primal, tangent })
}
pub fn shape(&self) -> &[usize] {
self.primal.shape()
}
pub fn numel(&self) -> usize {
self.primal.numel()
}
}
pub fn dual_add<T: Float>(a: &DualTensor<T>, b: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::arithmetic::add(&a.primal, &b.primal)?;
let tangent = crate::grad_fns::arithmetic::add(&a.tangent, &b.tangent)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_sub<T: Float>(a: &DualTensor<T>, b: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::arithmetic::sub(&a.primal, &b.primal)?;
let tangent = crate::grad_fns::arithmetic::sub(&a.tangent, &b.tangent)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_mul<T: Float>(a: &DualTensor<T>, b: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::arithmetic::mul(&a.primal, &b.primal)?;
let term1 = crate::grad_fns::arithmetic::mul(&a.primal, &b.tangent)?;
let term2 = crate::grad_fns::arithmetic::mul(&a.tangent, &b.primal)?;
let tangent = crate::grad_fns::arithmetic::add(&term1, &term2)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_div<T: Float>(a: &DualTensor<T>, b: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::arithmetic::div(&a.primal, &b.primal)?;
let da_b = crate::grad_fns::arithmetic::mul(&a.tangent, &b.primal)?;
let a_db = crate::grad_fns::arithmetic::mul(&a.primal, &b.tangent)?;
let numer = crate::grad_fns::arithmetic::sub(&da_b, &a_db)?;
let b_sq = crate::grad_fns::arithmetic::mul(&b.primal, &b.primal)?;
let tangent = crate::grad_fns::arithmetic::div(&numer, &b_sq)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_neg<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::arithmetic::neg(&a.primal)?;
let tangent = crate::grad_fns::arithmetic::neg(&a.tangent)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_matmul<T: Float>(
a: &DualTensor<T>,
b: &DualTensor<T>,
) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::linalg::matmul_differentiable(&a.primal, &b.primal)?;
let term1 = crate::grad_fns::linalg::matmul_differentiable(&a.tangent, &b.primal)?;
let term2 = crate::grad_fns::linalg::matmul_differentiable(&a.primal, &b.tangent)?;
let tangent = crate::grad_fns::arithmetic::add(&term1, &term2)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_relu<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::activation::relu(&a.primal)?;
let a_data = a.primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let zero = <T as num_traits::Zero>::zero();
let tangent_data: Vec<T> = a_data
.iter()
.zip(da_data.iter())
.map(|(&x, &dx)| if x > zero { dx } else { zero })
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_sigmoid<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::activation::sigmoid(&a.primal)?;
let sigma_data = primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let one = <T as num_traits::One>::one();
let tangent_data: Vec<T> = sigma_data
.iter()
.zip(da_data.iter())
.map(|(&s, &dx)| dx * s * (one - s))
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_tanh<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::activation::tanh(&a.primal)?;
let tanh_data = primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let one = <T as num_traits::One>::one();
let tangent_data: Vec<T> = tanh_data
.iter()
.zip(da_data.iter())
.map(|(&t, &dx)| dx * (one - t * t))
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_exp<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::transcendental::exp(&a.primal)?;
let exp_data = primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let tangent_data: Vec<T> = exp_data
.iter()
.zip(da_data.iter())
.map(|(&e, &dx)| dx * e)
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_log<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::transcendental::log(&a.primal)?;
let a_data = a.primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let tangent_data: Vec<T> = a_data
.iter()
.zip(da_data.iter())
.map(|(&x, &dx)| dx / x)
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_sin<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::transcendental::sin(&a.primal)?;
let a_data = a.primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let tangent_data: Vec<T> = a_data
.iter()
.zip(da_data.iter())
.map(|(&x, &dx)| dx * x.cos())
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn dual_cos<T: Float>(a: &DualTensor<T>) -> FerrotorchResult<DualTensor<T>> {
let primal = crate::grad_fns::transcendental::cos(&a.primal)?;
let a_data = a.primal.data_vec()?;
let da_data = a.tangent.data_vec()?;
let tangent_data: Vec<T> = a_data
.iter()
.zip(da_data.iter())
.map(|(&x, &dx)| -dx * x.sin())
.collect();
let tangent = Tensor::from_storage(
TensorStorage::cpu(tangent_data),
a.primal.shape().to_vec(),
false,
)?;
Ok(DualTensor { primal, tangent })
}
pub fn jvp_exact<T: Float, F>(
f: F,
input: &Tensor<T>,
v: &Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Tensor<T>)>
where
F: Fn(DualTensor<T>) -> FerrotorchResult<DualTensor<T>>,
{
if input.shape() != v.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"jvp_exact: input shape {:?} != v shape {:?}",
input.shape(),
v.shape()
),
});
}
let dual_input = DualTensor::new(input.clone(), v.clone())?;
let dual_output = f(dual_input)?;
Ok((dual_output.primal, dual_output.tangent))
}
pub fn jacfwd<T: Float, F>(f: F, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
where
F: Fn(DualTensor<T>) -> FerrotorchResult<DualTensor<T>>,
{
let shape = input.shape();
if shape.len() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!("jacfwd: input must be 1-D, got shape {:?}", shape),
});
}
let n = shape[0];
let zero = <T as num_traits::Zero>::zero();
let one = <T as num_traits::One>::one();
let mut columns: Vec<Tensor<T>> = Vec::with_capacity(n);
for j in 0..n {
let mut basis = vec![zero; n];
basis[j] = one;
let e_j = Tensor::from_storage(TensorStorage::cpu(basis), vec![n], false)?;
let (_primal, tangent) = jvp_exact(&f, input, &e_j)?;
columns.push(tangent);
}
let m = columns[0].numel();
let mut jac_data = vec![zero; m * n];
for j in 0..n {
let col_data = columns[j].data_vec()?;
for i in 0..m {
jac_data[i * n + j] = col_data[i];
}
}
Tensor::from_storage(TensorStorage::cpu(jac_data), vec![m, n], false)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn leaf_vec(data: &[f32], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
vec![data.len()],
requires_grad,
)
.unwrap()
}
fn leaf_mat(data: &[f32], rows: usize, cols: usize) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![rows, cols], false).unwrap()
}
fn assert_approx(actual: f32, expected: f32, tol: f32, msg: &str) {
assert!(
(actual - expected).abs() < tol,
"{msg}: expected {expected}, got {actual}"
);
}
#[test]
fn test_dual_tensor_new() {
let primal = leaf_vec(&[1.0, 2.0, 3.0], false);
let tangent = leaf_vec(&[0.1, 0.2, 0.3], false);
let dual = DualTensor::new(primal, tangent).unwrap();
assert_eq!(dual.shape(), &[3]);
assert_eq!(dual.numel(), 3);
}
#[test]
fn test_dual_tensor_shape_mismatch() {
let primal = leaf_vec(&[1.0, 2.0], false);
let tangent = leaf_vec(&[0.1, 0.2, 0.3], false);
assert!(DualTensor::new(primal, tangent).is_err());
}
#[test]
fn test_dual_tensor_constant() {
let primal = leaf_vec(&[1.0, 2.0], false);
let dual = DualTensor::constant(primal).unwrap();
let t_data = dual.tangent.data_vec().unwrap();
assert_eq!(t_data, vec![0.0, 0.0]);
}
#[test]
fn test_dual_add() {
let a =
DualTensor::new(leaf_vec(&[1.0, 2.0], false), leaf_vec(&[0.5, 0.3], false)).unwrap();
let b =
DualTensor::new(leaf_vec(&[3.0, 4.0], false), leaf_vec(&[0.1, 0.2], false)).unwrap();
let c = dual_add(&a, &b).unwrap();
let p = c.primal.data_vec().unwrap();
let t = c.tangent.data_vec().unwrap();
assert_approx(p[0], 4.0, 1e-6, "add primal[0]");
assert_approx(p[1], 6.0, 1e-6, "add primal[1]");
assert_approx(t[0], 0.6, 1e-6, "add tangent[0]");
assert_approx(t[1], 0.5, 1e-6, "add tangent[1]");
}
#[test]
fn test_dual_sub() {
let a =
DualTensor::new(leaf_vec(&[5.0, 3.0], false), leaf_vec(&[1.0, 0.5], false)).unwrap();
let b =
DualTensor::new(leaf_vec(&[2.0, 1.0], false), leaf_vec(&[0.3, 0.1], false)).unwrap();
let c = dual_sub(&a, &b).unwrap();
let p = c.primal.data_vec().unwrap();
let t = c.tangent.data_vec().unwrap();
assert_approx(p[0], 3.0, 1e-6, "sub primal[0]");
assert_approx(p[1], 2.0, 1e-6, "sub primal[1]");
assert_approx(t[0], 0.7, 1e-6, "sub tangent[0]");
assert_approx(t[1], 0.4, 1e-6, "sub tangent[1]");
}
#[test]
fn test_dual_mul() {
let a = DualTensor::new(leaf_vec(&[2.0], false), leaf_vec(&[0.5], false)).unwrap();
let b = DualTensor::new(leaf_vec(&[3.0], false), leaf_vec(&[0.1], false)).unwrap();
let c = dual_mul(&a, &b).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 6.0, 1e-6, "mul primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 1.7, 1e-5, "mul tangent");
}
#[test]
fn test_dual_div() {
let a = DualTensor::new(leaf_vec(&[6.0], false), leaf_vec(&[1.0], false)).unwrap();
let b = DualTensor::new(leaf_vec(&[3.0], false), leaf_vec(&[0.5], false)).unwrap();
let c = dual_div(&a, &b).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 2.0, 1e-6, "div primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 0.0, 1e-5, "div tangent");
}
#[test]
fn test_dual_neg() {
let a =
DualTensor::new(leaf_vec(&[3.0, -2.0], false), leaf_vec(&[1.0, 0.5], false)).unwrap();
let c = dual_neg(&a).unwrap();
let p = c.primal.data_vec().unwrap();
let t = c.tangent.data_vec().unwrap();
assert_approx(p[0], -3.0, 1e-6, "neg primal[0]");
assert_approx(p[1], 2.0, 1e-6, "neg primal[1]");
assert_approx(t[0], -1.0, 1e-6, "neg tangent[0]");
assert_approx(t[1], -0.5, 1e-6, "neg tangent[1]");
}
#[test]
fn test_dual_matmul() {
let a_primal = leaf_mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
let a_tangent = leaf_mat(&[0.1, 0.0, 0.0, 0.1], 2, 2);
let b_primal = leaf_mat(&[5.0, 6.0, 7.0, 8.0], 2, 2);
let b_tangent = leaf_mat(&[0.0, 0.0, 0.0, 0.0], 2, 2);
let a = DualTensor::new(a_primal, a_tangent).unwrap();
let b = DualTensor::new(b_primal, b_tangent).unwrap();
let c = dual_matmul(&a, &b).unwrap();
let p = c.primal.data_vec().unwrap();
assert_approx(p[0], 19.0, 1e-5, "matmul primal[0,0]");
assert_approx(p[1], 22.0, 1e-5, "matmul primal[0,1]");
assert_approx(p[2], 43.0, 1e-5, "matmul primal[1,0]");
assert_approx(p[3], 50.0, 1e-5, "matmul primal[1,1]");
let t = c.tangent.data_vec().unwrap();
assert_approx(t[0], 0.5, 1e-4, "matmul tangent[0,0]");
assert_approx(t[1], 0.6, 1e-4, "matmul tangent[0,1]");
assert_approx(t[2], 0.7, 1e-4, "matmul tangent[1,0]");
assert_approx(t[3], 0.8, 1e-4, "matmul tangent[1,1]");
}
#[test]
fn test_dual_relu_positive() {
let a =
DualTensor::new(leaf_vec(&[2.0, 3.0], false), leaf_vec(&[0.5, 1.0], false)).unwrap();
let c = dual_relu(&a).unwrap();
let p = c.primal.data_vec().unwrap();
let t = c.tangent.data_vec().unwrap();
assert_approx(p[0], 2.0, 1e-6, "relu primal[0]");
assert_approx(p[1], 3.0, 1e-6, "relu primal[1]");
assert_approx(t[0], 0.5, 1e-6, "relu tangent[0]");
assert_approx(t[1], 1.0, 1e-6, "relu tangent[1]");
}
#[test]
fn test_dual_relu_negative() {
let a =
DualTensor::new(leaf_vec(&[-1.0, -5.0], false), leaf_vec(&[0.5, 1.0], false)).unwrap();
let c = dual_relu(&a).unwrap();
let p = c.primal.data_vec().unwrap();
let t = c.tangent.data_vec().unwrap();
assert_approx(p[0], 0.0, 1e-6, "relu neg primal[0]");
assert_approx(p[1], 0.0, 1e-6, "relu neg primal[1]");
assert_approx(t[0], 0.0, 1e-6, "relu neg tangent[0]");
assert_approx(t[1], 0.0, 1e-6, "relu neg tangent[1]");
}
#[test]
fn test_dual_sigmoid() {
let a = DualTensor::new(leaf_vec(&[0.0], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_sigmoid(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 0.5, 1e-5, "sigmoid primal");
assert_approx(
c.tangent.data_vec().unwrap()[0],
0.25,
1e-5,
"sigmoid tangent",
);
}
#[test]
fn test_dual_tanh() {
let a = DualTensor::new(leaf_vec(&[0.0], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_tanh(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 0.0, 1e-5, "tanh primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 1.0, 1e-5, "tanh tangent");
}
#[test]
fn test_dual_exp() {
let a = DualTensor::new(leaf_vec(&[0.0], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_exp(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 1.0, 1e-5, "exp primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 1.0, 1e-5, "exp tangent");
}
#[test]
fn test_dual_exp_nonzero() {
let a = DualTensor::new(leaf_vec(&[1.0], false), leaf_vec(&[2.0], false)).unwrap();
let c = dual_exp(&a).unwrap();
let e = std::f32::consts::E;
assert_approx(c.primal.data_vec().unwrap()[0], e, 1e-5, "exp(1) primal");
assert_approx(
c.tangent.data_vec().unwrap()[0],
2.0 * e,
1e-4,
"exp(1) tangent",
);
}
#[test]
fn test_dual_log() {
let e = std::f32::consts::E;
let a = DualTensor::new(leaf_vec(&[e], false), leaf_vec(&[2.0], false)).unwrap();
let c = dual_log(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 1.0, 1e-5, "log primal");
assert_approx(
c.tangent.data_vec().unwrap()[0],
2.0 / e,
1e-5,
"log tangent",
);
}
#[test]
fn test_dual_sin() {
let a = DualTensor::new(leaf_vec(&[0.0], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_sin(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 0.0, 1e-6, "sin primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 1.0, 1e-5, "sin tangent");
}
#[test]
fn test_dual_sin_at_pi_half() {
let pi_half = std::f32::consts::FRAC_PI_2;
let a = DualTensor::new(leaf_vec(&[pi_half], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_sin(&a).unwrap();
assert_approx(
c.primal.data_vec().unwrap()[0],
1.0,
1e-5,
"sin(pi/2) primal",
);
assert_approx(
c.tangent.data_vec().unwrap()[0],
0.0,
1e-5,
"sin(pi/2) tangent",
);
}
#[test]
fn test_dual_cos() {
let a = DualTensor::new(leaf_vec(&[0.0], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_cos(&a).unwrap();
assert_approx(c.primal.data_vec().unwrap()[0], 1.0, 1e-6, "cos primal");
assert_approx(c.tangent.data_vec().unwrap()[0], 0.0, 1e-6, "cos tangent");
}
#[test]
fn test_dual_cos_at_pi_half() {
let pi_half = std::f32::consts::FRAC_PI_2;
let a = DualTensor::new(leaf_vec(&[pi_half], false), leaf_vec(&[1.0], false)).unwrap();
let c = dual_cos(&a).unwrap();
assert_approx(
c.primal.data_vec().unwrap()[0],
0.0,
1e-5,
"cos(pi/2) primal",
);
assert_approx(
c.tangent.data_vec().unwrap()[0],
-1.0,
1e-5,
"cos(pi/2) tangent",
);
}
#[test]
fn test_jvp_exact_identity() {
let input = leaf_vec(&[1.0, 2.0, 3.0], false);
let v = leaf_vec(&[4.0, 5.0, 6.0], false);
let (primal, tangent) = jvp_exact(|x| Ok(x), &input, &v).unwrap();
let p = primal.data_vec().unwrap();
let t = tangent.data_vec().unwrap();
assert_approx(p[0], 1.0, 1e-6, "jvp identity primal[0]");
assert_approx(t[0], 4.0, 1e-6, "jvp identity tangent[0]");
assert_approx(t[1], 5.0, 1e-6, "jvp identity tangent[1]");
assert_approx(t[2], 6.0, 1e-6, "jvp identity tangent[2]");
}
#[test]
fn test_jvp_exact_square() {
let input = leaf_vec(&[3.0, 4.0], false);
let v = leaf_vec(&[1.0, 1.0], false);
let (_primal, tangent) = jvp_exact(|x| dual_mul(&x, &x), &input, &v).unwrap();
let t = tangent.data_vec().unwrap();
assert_approx(t[0], 6.0, 1e-5, "jvp x^2 tangent[0]");
assert_approx(t[1], 8.0, 1e-5, "jvp x^2 tangent[1]");
}
#[test]
fn test_jvp_exact_composition() {
let input = leaf_vec(&[1.0], false);
let v = leaf_vec(&[1.0], false);
let (_primal, tangent) = jvp_exact(
|x| {
let x2 = dual_mul(&x, &x)?;
dual_exp(&x2)
},
&input,
&v,
)
.unwrap();
let e = std::f32::consts::E;
assert_approx(
tangent.data_vec().unwrap()[0],
2.0 * e,
1e-4,
"jvp exp(x^2) tangent",
);
}
#[test]
fn test_jvp_exact_shape_mismatch() {
let input = leaf_vec(&[1.0, 2.0], false);
let v = leaf_vec(&[1.0], false);
assert!(jvp_exact(|x| Ok(x), &input, &v).is_err());
}
#[test]
fn test_jvp_exact_matches_finite_diff() {
let input = leaf_vec(&[3.0, 4.0], false);
let v = leaf_vec(&[1.0, 1.0], false);
let (_primal, exact_tangent) = jvp_exact(|x| dual_mul(&x, &x), &input, &v).unwrap();
let exact = exact_tangent.data_vec().unwrap();
assert_approx(exact[0], 6.0, 1e-6, "exact jvp[0]");
assert_approx(exact[1], 8.0, 1e-6, "exact jvp[1]");
}
#[test]
fn test_jacfwd_linear() {
let input = leaf_vec(&[1.0, 2.0, 3.0], false);
let jac = jacfwd(
|x| {
let two = DualTensor::constant(
Tensor::from_storage(TensorStorage::cpu(vec![2.0f32; 3]), vec![3], false)
.unwrap(),
)
.unwrap();
dual_mul(&two, &x)
},
&input,
)
.unwrap();
assert_eq!(jac.shape(), &[3, 3]);
let data = jac.data_vec().unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 2.0 } else { 0.0 };
assert_approx(
data[i * 3 + j],
expected,
1e-5,
&format!("jacfwd 2x [{i},{j}]"),
);
}
}
}
#[test]
fn test_jacfwd_quadratic() {
let input = leaf_vec(&[1.0, 2.0, 3.0], false);
let jac = jacfwd(|x| dual_mul(&x, &x), &input).unwrap();
assert_eq!(jac.shape(), &[3, 3]);
let data = jac.data_vec().unwrap();
let expected_diag = [2.0, 4.0, 6.0];
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { expected_diag[i] } else { 0.0 };
assert_approx(
data[i * 3 + j],
expected,
1e-5,
&format!("jacfwd x^2 [{i},{j}]"),
);
}
}
}
#[test]
fn test_jacfwd_non_1d_input_error() {
let input = leaf_mat(&[1.0, 2.0, 3.0, 4.0], 2, 2);
assert!(jacfwd(|x| Ok(x), &input).is_err());
}
#[test]
fn test_dual_chain_add_mul() {
let x = DualTensor::new(leaf_vec(&[3.0], false), leaf_vec(&[1.0], false)).unwrap();
let y = DualTensor::constant(leaf_vec(&[2.0], false)).unwrap();
let sum = dual_add(&x, &y).unwrap();
let prod = dual_mul(&sum, &x).unwrap();
assert_approx(
prod.primal.data_vec().unwrap()[0],
15.0,
1e-5,
"chain primal",
);
assert_approx(
prod.tangent.data_vec().unwrap()[0],
8.0,
1e-5,
"chain tangent",
);
}
#[test]
fn test_dual_log_exp_roundtrip() {
let x = DualTensor::new(leaf_vec(&[2.0], false), leaf_vec(&[1.0], false)).unwrap();
let ex = dual_exp(&x).unwrap();
let result = dual_log(&ex).unwrap();
assert_approx(
result.primal.data_vec().unwrap()[0],
2.0,
1e-4,
"log(exp) primal",
);
assert_approx(
result.tangent.data_vec().unwrap()[0],
1.0,
1e-4,
"log(exp) tangent",
);
}
#[test]
fn test_dual_sin_cos_derivative_identity() {
let val = 1.5f32;
let x = DualTensor::new(leaf_vec(&[val], false), leaf_vec(&[1.0], false)).unwrap();
let sx = dual_sin(&x).unwrap();
let cx = dual_cos(&x).unwrap();
let s2 = dual_mul(&sx, &sx).unwrap();
let c2 = dual_mul(&cx, &cx).unwrap();
let sum = dual_add(&s2, &c2).unwrap();
assert_approx(
sum.primal.data_vec().unwrap()[0],
1.0,
1e-4,
"sin^2+cos^2 primal",
);
assert_approx(
sum.tangent.data_vec().unwrap()[0],
0.0,
1e-4,
"sin^2+cos^2 tangent",
);
}
#[test]
fn test_jacfwd_sin() {
let pi_half = std::f32::consts::FRAC_PI_2;
let input = leaf_vec(&[0.0, pi_half], false);
let jac = jacfwd(|x| dual_sin(&x), &input).unwrap();
let data = jac.data_vec().unwrap();
assert_approx(data[0], 1.0, 1e-5, "jacfwd sin [0,0]");
assert_approx(data[1], 0.0, 1e-5, "jacfwd sin [0,1]");
assert_approx(data[2], 0.0, 1e-5, "jacfwd sin [1,0]");
assert_approx(data[3], 0.0, 1e-5, "jacfwd sin [1,1]");
}
}