use super::ops::*;
use crate::algorithm::LinearAlgebraAlgorithms;
use crate::autograd::Var;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use std::sync::Arc;
pub fn var_trace<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + LinearAlgebraAlgorithms<R>,
R::Client: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
let output = client.trace(a.tensor())?;
if a.requires_grad() {
let grad_fn = TraceBackward::<R>::new(a.id(), a.tensor().clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_inverse<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + LinearAlgebraAlgorithms<R>,
R::Client: TensorOps<R>,
{
let output = client.inverse(a.tensor())?;
if a.requires_grad() {
let grad_fn = InverseBackward::<R>::new(a.id(), output.clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_det<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + LinearAlgebraAlgorithms<R>,
R::Client: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
let output = client.det(a.tensor())?;
if a.requires_grad() {
let grad_fn = DetBackward::<R>::new(
a.id(),
a.tensor().clone(),
output.clone(), a.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_solve<R, C>(a: &Var<R>, b: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + LinearAlgebraAlgorithms<R>,
R::Client: TensorOps<R> + LinearAlgebraAlgorithms<R>,
{
let output = client.solve(a.tensor(), b.tensor())?;
if a.requires_grad() || b.requires_grad() {
let grad_fn = SolveBackward::<R>::new(
a.id(),
b.id(),
a.tensor().clone(),
output.clone(),
a.grad_fn().cloned(),
b.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_cholesky<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + LinearAlgebraAlgorithms<R>,
R::Client: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
let decomp = client.cholesky_decompose(a.tensor())?;
let output = decomp.l;
if a.requires_grad() {
let grad_fn = CholeskyBackward::<R>::new(a.id(), output.clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_var_trace_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0, 4.0], &[2, 2], &device),
true,
);
let trace = var_trace(&a, &client).unwrap();
let grads = backward(&trace, &client).unwrap();
let grad_a: Vec<f64> = grads.get(a.id()).unwrap().to_vec();
assert!((grad_a[0] - 1.0).abs() < 1e-6); assert!((grad_a[1] - 0.0).abs() < 1e-6); assert!((grad_a[2] - 0.0).abs() < 1e-6); assert!((grad_a[3] - 1.0).abs() < 1e-6); }
#[test]
fn test_var_inverse_backward() {
use super::super::reduce;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device),
true,
);
let inv_a = var_inverse(&a, &client).unwrap();
let sum_inv = reduce::var_sum(&inv_a, &[0, 1], false, &client).unwrap();
let grads = backward(&sum_inv, &client).unwrap();
let grad_a = grads.get(a.id()).unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
let grad_a_data: Vec<f64> = grad_a.to_vec();
let expected = -1.0 / 9.0;
assert!(
(grad_a_data[0] - expected).abs() < 1e-6,
"grad[0,0] = {}, expected {}",
grad_a_data[0],
expected
);
assert!(
(grad_a_data[1] - expected).abs() < 1e-6,
"grad[0,1] = {}, expected {}",
grad_a_data[1],
expected
);
assert!(
(grad_a_data[2] - expected).abs() < 1e-6,
"grad[1,0] = {}, expected {}",
grad_a_data[2],
expected
);
assert!(
(grad_a_data[3] - expected).abs() < 1e-6,
"grad[1,1] = {}, expected {}",
grad_a_data[3],
expected
);
}
#[test]
fn test_var_det_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device),
true,
);
let det = var_det(&a, &client).unwrap();
let grads = backward(&det, &client).unwrap();
let grad_a: Vec<f64> = grads.get(a.id()).unwrap().to_vec();
assert!((grad_a[0] - 2.0).abs() < 1e-6);
assert!((grad_a[1] - (-1.0)).abs() < 1e-6);
assert!((grad_a[2] - (-1.0)).abs() < 1e-6);
assert!((grad_a[3] - 2.0).abs() < 1e-6);
}
#[test]
fn test_var_solve_backward() {
use super::super::reduce;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device),
true,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f64, 3.0], &[2, 1], &device),
true,
);
let x = var_solve(&a, &b, &client).unwrap();
let loss = reduce::var_sum(&x, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let grad_a = grads.get(a.id()).unwrap();
let grad_b = grads.get(b.id()).unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
assert_eq!(grad_b.shape(), &[2, 1]);
let grad_b_data: Vec<f64> = grad_b.to_vec();
assert!(
(grad_b_data[0] - 1.0 / 3.0).abs() < 1e-6,
"grad_b[0] = {}, expected {}",
grad_b_data[0],
1.0 / 3.0
);
assert!(
(grad_b_data[1] - 1.0 / 3.0).abs() < 1e-6,
"grad_b[1] = {}, expected {}",
grad_b_data[1],
1.0 / 3.0
);
let grad_a_data: Vec<f64> = grad_a.to_vec();
let expected_a = -1.0 / 3.0;
assert!(
(grad_a_data[0] - expected_a).abs() < 1e-6,
"grad_a[0,0] = {}, expected {}",
grad_a_data[0],
expected_a
);
assert!(
(grad_a_data[1] - expected_a).abs() < 1e-6,
"grad_a[0,1] = {}, expected {}",
grad_a_data[1],
expected_a
);
assert!(
(grad_a_data[2] - expected_a).abs() < 1e-6,
"grad_a[1,0] = {}, expected {}",
grad_a_data[2],
expected_a
);
assert!(
(grad_a_data[3] - expected_a).abs() < 1e-6,
"grad_a[1,1] = {}, expected {}",
grad_a_data[3],
expected_a
);
}
#[test]
fn test_var_cholesky_backward() {
use super::super::reduce;
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a_data = [4.0f64, 2.0, 2.0, 5.0];
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&a_data, &[2, 2], &device),
true,
);
let l = var_cholesky(&a, &client).unwrap();
let loss = reduce::var_sum(&l, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let grad_a = grads.get(a.id()).unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
let grad_a_data: Vec<f64> = grad_a.to_vec();
assert!(
(grad_a_data[1] - grad_a_data[2]).abs() < 1e-10,
"grad_a[0,1] = {}, grad_a[1,0] = {}",
grad_a_data[1],
grad_a_data[2]
);
let eps = 1e-5;
for idx in 0..4 {
let mut a_plus = a_data;
let mut a_minus = a_data;
a_plus[idx] += eps;
a_minus[idx] -= eps;
let i = idx / 2;
let j = idx % 2;
if i != j {
let t_idx = j * 2 + i;
a_plus[t_idx] += eps;
a_minus[t_idx] -= eps;
}
let a_plus_t = Tensor::<CpuRuntime>::from_slice(&a_plus, &[2, 2], &device);
let a_minus_t = Tensor::<CpuRuntime>::from_slice(&a_minus, &[2, 2], &device);
let l_plus = client.cholesky_decompose(&a_plus_t).unwrap().l;
let l_minus = client.cholesky_decompose(&a_minus_t).unwrap().l;
let loss_plus: f64 = l_plus.to_vec::<f64>().iter().sum();
let loss_minus: f64 = l_minus.to_vec::<f64>().iter().sum();
let fd_grad = (loss_plus - loss_minus) / (2.0 * eps);
let expected_grad = if i != j {
grad_a_data[idx] * 2.0 } else {
grad_a_data[idx]
};
assert!(
(fd_grad - expected_grad).abs() < 1e-4,
"Finite diff gradient at ({},{}) = {}, analytical = {}",
i,
j,
fd_grad,
expected_grad
);
}
}
}