burn-tensor 0.1.0

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
use crate::tensor::{
    backend::tch::{TchShape, TchTensor},
    ops::*,
    Data, Distribution, Shape,
};
use rand::distributions::{uniform::SampleUniform, Standard};

impl<P, const D: usize> TensorCreationLike<P, D> for TchTensor<P, D>
where
    P: tch::kind::Element + std::fmt::Debug + SampleUniform + Default,
    Standard: rand::distributions::Distribution<P>,
{
    fn new_like_empty(&self) -> Self {
        let tensor = self.tensor.empty_like();
        let shape = self.shape.clone();
        let kind = self.kind.clone();

        Self {
            kind,
            tensor,
            shape,
        }
    }

    fn new_like_random(&self, distribution: Distribution<P>) -> Self {
        let device = self.tensor.device();
        let data = Data::<P, D>::random(self.shape.clone(), distribution);

        Self::from_data(data, device)
    }

    fn new_like_data(&self, data: Data<P, D>) -> Self {
        let device = self.tensor.device();
        Self::from_data(data, device)
    }

    fn new_like_zeros(&self) -> Self {
        self.zeros()
    }

    fn new_like_ones(&self) -> Self {
        self.ones()
    }
}

impl<P, const D: usize, const D2: usize> TensorCreationFork<P, D, D2, TchTensor<P, D2>>
    for TchTensor<P, D>
where
    P: tch::kind::Element + std::fmt::Debug + SampleUniform + Default + Copy,
    Standard: rand::distributions::Distribution<P>,
{
    fn new_fork_empty(&self, shape: Shape<D2>) -> TchTensor<P, D2> {
        let device = self.tensor.device();
        let kind = self.kind.clone();

        let tch_shape = TchShape::from(shape.clone());
        let tensor = tch::Tensor::empty(&tch_shape.dims, (kind.kind(), device));

        TchTensor {
            kind,
            tensor,
            shape,
        }
    }

    fn new_fork_random(&self, shape: Shape<D2>, distribution: Distribution<P>) -> TchTensor<P, D2> {
        let device = self.tensor.device();
        let data = Data::<P, D2>::random(shape, distribution);

        TchTensor::from_data(data, device)
    }

    fn new_fork_data(&self, data: Data<P, D2>) -> TchTensor<P, D2> {
        let device = self.tensor.device();
        TchTensor::from_data(data, device)
    }

    fn new_fork_zeros(&self, shape: Shape<D2>) -> TchTensor<P, D2> {
        let device = self.tensor.device();
        let kind = self.kind.clone();

        let tch_shape = TchShape::from(shape.clone());
        let tensor = tch::Tensor::zeros(&tch_shape.dims, (kind.kind(), device));

        TchTensor {
            kind,
            tensor,
            shape,
        }
    }

    fn new_fork_ones(&self, shape: Shape<D2>) -> TchTensor<P, D2> {
        let device = self.tensor.device();
        let kind = self.kind.clone();

        let tch_shape = TchShape::from(shape.clone());
        let tensor = tch::Tensor::ones(&tch_shape.dims, (kind.kind(), device));

        TchTensor {
            kind,
            tensor,
            shape,
        }
    }
}

impl<P, const D: usize> Zeros<TchTensor<P, D>> for TchTensor<P, D>
where
    P: tch::kind::Element,
{
    fn zeros(&self) -> TchTensor<P, D> {
        let tensor = self.tensor.zeros_like();
        let shape = self.shape.clone();
        let kind = self.kind.clone();

        Self {
            kind,
            tensor,
            shape,
        }
    }
}

impl<P, const D: usize> Ones<TchTensor<P, D>> for TchTensor<P, D>
where
    P: tch::kind::Element,
{
    fn ones(&self) -> TchTensor<P, D> {
        let tensor = self.tensor.ones_like();
        let shape = self.shape.clone();
        let kind = self.kind.clone();

        Self {
            kind,
            tensor,
            shape,
        }
    }
}