burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(cat)]
mod tests {
    use super::*;
    use alloc::vec::Vec;
    use burn_tensor::{Bool, Int, Tensor, TensorData};
    #[test]
    fn should_support_cat_ops_2d_dim0() {
        let device = Default::default();
        let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
        let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);

        let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);
        let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);

        output.into_data().assert_approx_eq(&expected, 3);
    }

    #[test]
    fn should_support_cat_ops_int() {
        let device = Default::default();
        let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);
        let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device);

        let output = Tensor::cat(vec![tensor_1, tensor_2], 0);

        output
            .into_data()
            .assert_eq(&TensorData::from([[1, 2, 3], [4, 5, 6]]), false);
    }

    #[test]
    fn should_support_cat_ops_bool() {
        let device = Default::default();
        let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);
        let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);

        let output = Tensor::cat(vec![tensor_1, tensor_2], 0);

        output.into_data().assert_eq(
            &TensorData::from([[false, true, true], [true, true, false]]),
            false,
        );
    }

    #[test]
    fn should_support_cat_ops_2d_dim1() {
        let device = Default::default();
        let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);
        let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);

        let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);
        let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]);

        output.into_data().assert_approx_eq(&expected, 3);
    }

    #[test]
    fn should_support_cat_ops_3d() {
        let device = Default::default();
        let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);
        let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);

        let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);
        let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);

        output.into_data().assert_approx_eq(&expected, 3);
    }

    #[test]
    #[should_panic]
    fn should_panic_when_dimensions_are_not_the_same() {
        let device = Default::default();
        let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device);
        let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device);

        TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data();
    }

    #[test]
    #[should_panic]
    fn should_panic_when_list_of_vectors_is_empty() {
        let tensor: Vec<TestTensor<2>> = vec![];
        TestTensor::cat(tensor, 0).into_data();
    }

    #[test]
    #[should_panic]
    fn should_panic_when_cat_exceeds_dimension() {
        let device = Default::default();
        let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);
        let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);

        TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data();
    }
}