use crate::algorithm::LinearAlgebraAlgorithms;
use crate::autograd::var_ops::{var_matmul, var_mul, var_neg};
use crate::autograd::{GradFn, Var};
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{
BinaryOps, LinalgOps, MatmulOps, ScalarOps, TensorOps, TypeConversionOps, UnaryOps,
};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
fn tril_with_halved_diagonal<R: Runtime<DType = DType>>(
x: &Tensor<R>,
client: &R::Client,
) -> Result<Tensor<R>>
where
R::Client: TensorOps<R> + ScalarOps<R>,
{
let n = x.shape()[0];
debug_assert_eq!(x.shape().len(), 2);
debug_assert_eq!(x.shape()[0], x.shape()[1]);
let mut mask_data = vec![0.0f64; n * n];
for i in 0..n {
for j in 0..i {
mask_data[i * n + j] = 1.0;
}
mask_data[i * n + i] = 0.5;
}
let mask_f64 = Tensor::<R>::from_slice(&mask_data, &[n, n], x.device());
let mask = client.cast(&mask_f64, x.dtype())?;
let phi = client.mul(x, &mask)?;
Ok(phi)
}
pub struct TraceBackward<R: Runtime> {
input_ids: [TensorId; 1],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 1],
}
impl<R: Runtime> TraceBackward<R> {
pub fn new(a_id: TensorId, a: Tensor<R>, a_grad_fn: Option<Arc<dyn GradFn<R>>>) -> Self {
Self {
input_ids: [a_id],
saved_tensors: vec![a],
input_grad_fns: [a_grad_fn],
}
}
}
impl<R: Runtime<DType = DType>> GradFn<R> for TraceBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let saved_a = &self.saved_tensors[0];
let n = saved_a.shape()[0];
let client = R::default_client(saved_a.device());
let ones_vec = Tensor::<R>::ones(&[n], saved_a.dtype(), saved_a.device());
let scaled_diag = client.mul(&ones_vec, grad_output)?;
let eye = LinalgOps::diagflat(&client, &scaled_diag)?;
Ok(vec![Some(eye)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
let saved_a = &self.saved_tensors[0];
let n = saved_a.shape()[0];
let client = R::default_client(saved_a.device());
let ones_vec = Tensor::<R>::ones(&[n], saved_a.dtype(), saved_a.device());
let ones_var = Var::new(ones_vec, false);
let scaled_diag = var_mul(&ones_var, grad_output, &client)?;
let eye = LinalgOps::diagflat(&client, scaled_diag.tensor())?;
Ok(vec![Some(Var::new(eye, false))])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"TraceBackward"
}
}
pub struct InverseBackward<R: Runtime> {
input_ids: [TensorId; 1],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 1],
}
impl<R: Runtime> InverseBackward<R> {
pub fn new(a_id: TensorId, inv_a: Tensor<R>, a_grad_fn: Option<Arc<dyn GradFn<R>>>) -> Self {
Self {
input_ids: [a_id],
saved_tensors: vec![inv_a],
input_grad_fns: [a_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for InverseBackward<R>
where
R::Client: MatmulOps<R> + TensorOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let inv_a = &self.saved_tensors[0];
let inv_a_t = inv_a.t()?;
let temp = client.matmul(&inv_a_t, grad_output)?;
let grad_a = client.matmul(&temp, &inv_a_t)?;
let grad_a = client.neg(&grad_a)?;
Ok(vec![Some(grad_a)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + MatmulOps<R> + TensorOps<R>,
{
let client = R::default_client(grad_output.tensor().device());
let inv_a = &self.saved_tensors[0];
let inv_a_t = inv_a.t()?;
let inv_a_t_var = Var::new(inv_a_t, false);
let temp = var_matmul(&inv_a_t_var, grad_output, &client)?;
let inv_a_t2 = inv_a.t()?;
let inv_a_t_var2 = Var::new(inv_a_t2, false);
let grad_a = var_matmul(&temp, &inv_a_t_var2, &client)?;
let grad_a = var_neg(&grad_a, &client)?;
Ok(vec![Some(grad_a)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"InverseBackward"
}
}
pub struct DetBackward<R: Runtime> {
input_ids: [TensorId; 1],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 1],
}
impl<R: Runtime> DetBackward<R> {
pub fn new(
a_id: TensorId,
a: Tensor<R>,
det_output: Tensor<R>,
a_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id],
saved_tensors: vec![a, det_output],
input_grad_fns: [a_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for DetBackward<R>
where
R::Client: TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
use crate::error::Error;
let client = R::default_client(grad_output.device());
let saved_a = &self.saved_tensors[0];
let det_output = &self.saved_tensors[1];
let inv_a = LinalgOps::inverse(&client, saved_a).map_err(|e| {
Error::Internal(format!(
"DetBackward: failed to compute inverse for gradient \
(matrix may be singular or nearly singular): {}",
e
))
})?;
let inv_a_t = inv_a.t()?.contiguous();
let det_scaled = client.mul(&inv_a_t, det_output)?;
let grad_a = client.mul(&det_scaled, grad_output)?;
Ok(vec![Some(grad_a)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
use crate::error::Error;
let client = R::default_client(grad_output.tensor().device());
let saved_a = &self.saved_tensors[0];
let det_output = &self.saved_tensors[1];
let inv_a = LinalgOps::inverse(&client, saved_a).map_err(|e| {
Error::Internal(format!(
"DetBackward: failed to compute inverse for gradient \
(matrix may be singular or nearly singular): {}",
e
))
})?;
let inv_a_t = inv_a.t()?.contiguous();
let det_scaled = client.mul(&inv_a_t, det_output)?;
let det_scaled_var = Var::new(det_scaled, false);
let grad_a = var_mul(&det_scaled_var, grad_output, &client)?;
Ok(vec![Some(grad_a)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"DetBackward"
}
}
pub struct SolveBackward<R: Runtime> {
input_ids: [TensorId; 2], saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 2],
}
impl<R: Runtime> SolveBackward<R> {
pub fn new(
a_id: TensorId,
b_id: TensorId,
a: Tensor<R>,
x: Tensor<R>, a_grad_fn: Option<Arc<dyn GradFn<R>>>,
b_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_ids: [a_id, b_id],
saved_tensors: vec![a, x],
input_grad_fns: [a_grad_fn, b_grad_fn],
}
}
}
impl<R: Runtime> GradFn<R> for SolveBackward<R>
where
R::Client: MatmulOps<R> + TensorOps<R> + LinearAlgebraAlgorithms<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let saved_a = &self.saved_tensors[0];
let saved_x = &self.saved_tensors[1];
let a_t = saved_a.t()?.contiguous();
let v = LinalgOps::solve(&client, &a_t, grad_output)?;
let grad_b = v.clone();
let x_t = saved_x.t()?;
let grad_a = client.matmul(&v, &x_t)?;
let grad_a = client.neg(&grad_a)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R> + MatmulOps<R> + TensorOps<R> + LinearAlgebraAlgorithms<R>,
{
let client = R::default_client(grad_output.tensor().device());
let saved_a = &self.saved_tensors[0];
let saved_x = &self.saved_tensors[1];
let a_t = saved_a.t()?.contiguous();
let v = LinalgOps::solve(&client, &a_t, grad_output.tensor())?;
let grad_b = Var::new(v.clone(), false);
let v_var = Var::new(v, false);
let x_t = saved_x.t()?;
let x_t_var = Var::new(x_t, false);
let grad_a = var_matmul(&v_var, &x_t_var, &client)?;
let grad_a = var_neg(&grad_a, &client)?;
Ok(vec![Some(grad_a), Some(grad_b)])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"SolveBackward"
}
}
pub struct CholeskyBackward<R: Runtime> {
input_ids: [TensorId; 1],
saved_tensors: Vec<Tensor<R>>, input_grad_fns: [Option<Arc<dyn GradFn<R>>>; 1],
}
impl<R: Runtime> CholeskyBackward<R> {
pub fn new(a_id: TensorId, l: Tensor<R>, a_grad_fn: Option<Arc<dyn GradFn<R>>>) -> Self {
Self {
input_ids: [a_id],
saved_tensors: vec![l],
input_grad_fns: [a_grad_fn],
}
}
}
impl<R: Runtime<DType = DType>> GradFn<R> for CholeskyBackward<R>
where
R::Client: MatmulOps<R> + TensorOps<R> + ScalarOps<R> + LinearAlgebraAlgorithms<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
use crate::error::Error;
let client = R::default_client(grad_output.device());
let l = &self.saved_tensors[0];
let l_t = l.t()?.contiguous();
let s = client.matmul(&l_t, grad_output)?;
let phi = tril_with_halved_diagonal::<R>(&s.contiguous(), &client)?;
let phi_t = phi.t()?.contiguous();
let w = client.solve_triangular_upper(&l_t, &phi_t).map_err(|e| {
Error::Internal(format!(
"CholeskyBackward: triangular solve failed in step 3 \
(L may have zero diagonal elements): {}",
e
))
})?;
let z = w.t()?.contiguous();
let grad_a_raw = client.solve_triangular_upper(&l_t, &z).map_err(|e| {
Error::Internal(format!(
"CholeskyBackward: triangular solve failed in step 4 \
(L may have zero diagonal elements): {}",
e
))
})?;
let y_contiguous = grad_a_raw.contiguous();
let y_t = y_contiguous.t()?.contiguous();
let sum = client.add(&y_contiguous, &y_t)?;
let grad_a = client.div_scalar(&sum, 2.0)?;
Ok(vec![Some(grad_a)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
where
R::Client: RuntimeClient<R>
+ MatmulOps<R>
+ TensorOps<R>
+ ScalarOps<R>
+ LinearAlgebraAlgorithms<R>,
{
use crate::error::Error;
let client = R::default_client(grad_output.tensor().device());
let l = &self.saved_tensors[0];
let l_t = l.t()?.contiguous();
let l_t_var = Var::new(l_t.clone(), false);
let s = var_matmul(&l_t_var, grad_output, &client)?;
let phi = tril_with_halved_diagonal::<R>(&s.tensor().contiguous(), &client)?;
let phi_t = phi.t()?.contiguous();
let w = client.solve_triangular_upper(&l_t, &phi_t).map_err(|e| {
Error::Internal(format!(
"CholeskyBackward: triangular solve failed in step 3: {}",
e
))
})?;
let z = w.t()?.contiguous();
let grad_a_raw = client.solve_triangular_upper(&l_t, &z).map_err(|e| {
Error::Internal(format!(
"CholeskyBackward: triangular solve failed in step 4: {}",
e
))
})?;
let y_contiguous = grad_a_raw.contiguous();
let y_t = y_contiguous.t()?.contiguous();
let sum = client.add(&y_contiguous, &y_t)?;
let grad_a = client.div_scalar(&sum, 2.0)?;
Ok(vec![Some(Var::new(grad_a, false))])
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.to_vec()
}
fn saved_tensors(&self) -> &[Tensor<R>] {
&self.saved_tensors
}
fn name(&self) -> &'static str {
"CholeskyBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn approx_eq_vec(a: &[f64], b: &[f64], tol: f64) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| approx_eq(*x, *y, tol))
}
#[test]
fn test_trace_backward() {
let device = CpuDevice::new();
let _client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0, 4.0], &[2, 2], &device);
let grad_out = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[], &device);
let backward = TraceBackward::<CpuRuntime>::new(a.id(), a.clone(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f64> = grads[0].as_ref().unwrap().to_vec();
assert!(approx_eq_vec(&grad_a, &[1.0, 0.0, 0.0, 1.0], 1e-10));
}
#[test]
fn test_inverse_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device);
let inv_a = LinalgOps::inverse(&client, &a).unwrap();
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 2], DType::F64, &device);
let backward = InverseBackward::<CpuRuntime>::new(a.id(), inv_a.clone(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
let grad_a_data: Vec<f64> = grad_a.to_vec();
assert!(grad_a_data[0] < 0.0); }
#[test]
fn test_det_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device);
let det_output = LinalgOps::det(&client, &a).unwrap();
let grad_out = Tensor::<CpuRuntime>::from_slice(&[1.0f64], &[], &device);
let backward = DetBackward::<CpuRuntime>::new(a.id(), a.clone(), det_output, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a: Vec<f64> = grads[0].as_ref().unwrap().to_vec();
assert!(approx_eq_vec(&grad_a, &[2.0, -1.0, -1.0, 2.0], 1e-10));
}
#[test]
fn test_solve_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[2.0f64, 1.0, 1.0, 2.0], &[2, 2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[3.0f64, 3.0], &[2, 1], &device);
let x = LinalgOps::solve(&client, &a, &b).unwrap();
let grad_out = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0], &[2, 1], &device);
let backward =
SolveBackward::<CpuRuntime>::new(a.id(), b.id(), a.clone(), x.clone(), None, None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
let grad_b = grads[1].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
assert_eq!(grad_b.shape(), &[2, 1]);
let grad_b_data: Vec<f64> = grad_b.to_vec();
assert!(approx_eq_vec(&grad_b_data, &[1.0 / 3.0, 1.0 / 3.0], 1e-10));
}
#[test]
fn test_cholesky_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Tensor::<CpuRuntime>::from_slice(&[4.0f64, 2.0, 2.0, 5.0], &[2, 2], &device);
let l = client.cholesky_decompose(&a).unwrap().l;
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 2], DType::F64, &device);
let backward = CholeskyBackward::<CpuRuntime>::new(a.id(), l.clone(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad_a = grads[0].as_ref().unwrap();
assert_eq!(grad_a.shape(), &[2, 2]);
let grad_a_data: Vec<f64> = grad_a.to_vec();
assert!(
approx_eq(grad_a_data[1], grad_a_data[2], 1e-10),
"grad_a[0,1] = {}, grad_a[1,0] = {}, diff = {}",
grad_a_data[1],
grad_a_data[2],
(grad_a_data[1] - grad_a_data[2]).abs()
);
}
}