burn-tensor 0.1.0

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
use crate::graph::node::{ForwardNode, ForwardNodeState};
use crate::graph::ops::{
    BinaryOps, BinaryOpsNodeState, ForwardBinaryRecordedOps, ForwardUnaryRecordedOps, UnaryOps,
    UnaryOpsNodeState,
};
use crate::tensor::backend::autodiff::{ADKind, ADTensor};
use crate::tensor::ops::*;
use crate::tensor::{Element, Tensor};
use std::{ops::Range, sync::Arc};

#[derive(Debug)]
struct ADTensorOpsIndex<P, const D1: usize, const D2: usize> {
    indexes: [Range<usize>; D2],
    _kind: ADKind<P>,
}

impl<P: Default, const D1: usize, const D2: usize> ADTensorOpsIndex<P, D1, D2> {
    pub fn new(indexes: [Range<usize>; D2]) -> Self {
        Self {
            indexes,
            _kind: ADKind::new(),
        }
    }
}

impl<T1, P, const D1: usize, const D2: usize> UnaryOps<T1, T1> for ADTensorOpsIndex<P, D1, D2>
where
    P: Element,
    T1: Tensor<P, D1> + TensorOpsIndex<P, D1, D2>,
{
    fn partial(&self, state: &UnaryOpsNodeState<T1, T1>) -> T1 {
        state
            .input
            .value()
            .zeros()
            .index_assign(self.indexes.clone(), &state.output.grad())
    }
}

#[derive(Debug)]
struct ADTensorOpsIndexAssign<P, const D1: usize, const D2: usize> {
    indexes: [Range<usize>; D2],
    _kind: ADKind<P>,
}

impl<P: Default, const D1: usize, const D2: usize> ADTensorOpsIndexAssign<P, D1, D2> {
    pub fn new(indexes: [Range<usize>; D2]) -> Self {
        Self {
            indexes,
            _kind: ADKind::new(),
        }
    }
}

impl<T, P, const D1: usize, const D2: usize> BinaryOps<T, T, T>
    for ADTensorOpsIndexAssign<P, D1, D2>
where
    P: Element,
    T: Tensor<P, D1> + TensorOpsIndex<P, D1, D2>,
{
    fn partial_left(&self, state: &BinaryOpsNodeState<T, T, T>) -> T {
        state
            .output
            .grad()
            .index_assign(self.indexes.clone(), &state.right.value().zeros())
    }

    fn partial_right(&self, state: &BinaryOpsNodeState<T, T, T>) -> T {
        state.output.grad().index(self.indexes.clone())
    }
}

impl<P, const D1: usize, const D2: usize, T> TensorOpsIndex<P, D1, D2> for ADTensor<P, D1, T>
where
    P: Element,
    T: Tensor<P, D1> + TensorOpsIndex<P, D1, D2>,
{
    fn index(&self, indexes: [Range<usize>; D2]) -> Self {
        let input = self.tensor();
        let out = TensorOpsIndex::index(&input, indexes.clone());
        let shape = out.shape().clone();

        let state = ForwardNodeState::new(out);

        let ops = ADTensorOpsIndex::<P, D1, D2>::new(indexes);
        let ops = Arc::new(ops);
        let ops = ForwardUnaryRecordedOps::new(self.node.clone(), ops);
        let ops = Arc::new(ops);

        let node = ForwardNode::from_unary(&self.node, state, ops);
        let node = Arc::new(node);

        let kind = self.kind.clone();

        Self { node, shape, kind }
    }
    fn index_assign(&self, indexes: [Range<usize>; D2], values: &Self) -> Self {
        let input = self.tensor();
        let out = TensorOpsIndex::index_assign(&input, indexes.clone(), &values.tensor());
        let shape = out.shape().clone();

        let state = ForwardNodeState::new(out);

        let ops = ADTensorOpsIndexAssign::<P, D1, D2>::new(indexes);
        let ops = Arc::new(ops);
        let ops = ForwardBinaryRecordedOps::new(self.node.clone(), values.node.clone(), ops);
        let ops = Arc::new(ops);

        let node = ForwardNode::from_binary(&self.node, &values.node, state, ops);
        let node = Arc::new(node);

        let kind = self.kind.clone();

        Self { node, shape, kind }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};

    #[test]
    fn should_diff_matmul_with_index() {
        let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
        let data_2: Data<f64, 2> = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);

        let tensor_1 = TestADTensor::from_data(data_1.clone());
        let tensor_2 = TestADTensor::from_data(data_2.clone());

        let tensor_3 = tensor_2.index([0..2, 0..2]);
        let tensor_4 = &tensor_1.matmul(&tensor_3);
        let grads = tensor_4.backward();

        let grad_1 = grads.wrt(&tensor_1).unwrap();
        let grad_2 = grads.wrt(&tensor_2).unwrap();

        assert_eq!(grad_1.to_data(), Data::from([[11.0, 5.0], [11.0, 5.0]]));
        assert_eq!(
            grad_2.to_data(),
            Data::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]])
        );
    }

    #[test]
    fn should_diff_matmul_with_index_assign() {
        let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
        let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
        let data_assigned: Data<f64, 2> = Data::from([[9.0]]);

        let tensor_1 = TestADTensor::from_data(data_1.clone());
        let tensor_2 = TestADTensor::from_data(data_2.clone());
        let tensor_assigned = TestADTensor::from_data(data_assigned.clone());

        let tensor_3 = tensor_1.matmul(&tensor_2);
        let tensor_4 = tensor_3.index_assign([0..1, 0..1], &tensor_assigned);
        let tensor_5 = &tensor_4.matmul(&tensor_1);

        let grads = tensor_5.backward();

        let grad_1 = grads.wrt(&tensor_1).unwrap();
        let grad_2 = grads.wrt(&tensor_2).unwrap();

        assert_eq!(grad_1.to_data(), Data::from([[58.0, 38.0], [118.0, 82.0]]));
        assert_eq!(grad_2.to_data(), Data::from([[16.0, 15.0], [24.0, 50.0]]));
    }

    #[test]
    fn should_diff_matmul_with_index_assign_complex() {
        let data_1: Data<f64, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
        let data_2: Data<f64, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
        let data_3: Data<f64, 2> = Data::from([[9.0]]);

        let tensor_1 = TestADTensor::from_data(data_1.clone());
        let tensor_2 = TestADTensor::from_data(data_2.clone());
        let tensor_3 = TestADTensor::from_data(data_3.clone());

        let tensor_4 = tensor_1.matmul(&tensor_2);
        let tensor_5 = tensor_2.index([0..1, 0..1]);
        let tensor_6 = tensor_5.mul(&tensor_3);
        let tensor_7 = tensor_4.index_assign([0..1, 0..1], &tensor_6);
        let tensor_8 = &tensor_7.matmul(&tensor_1);

        let grads = tensor_8.backward();

        let grad_1 = grads.wrt(&tensor_1).unwrap();
        let grad_2 = grads.wrt(&tensor_2).unwrap();
        let grad_3 = grads.wrt(&tensor_3).unwrap();

        assert_eq!(grad_3.to_data(), Data::from([[32.0]]));
        assert_eq!(grad_1.to_data(), Data::from([[85.0, 65.0], [118.0, 82.0]]));
        assert_eq!(grad_2.to_data(), Data::from([[88.0, 15.0], [24.0, 50.0]]));
    }
}