burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(repeat)]
mod tests {
    use super::*;
    use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};

    #[test]
    fn should_support_repeat_ops_one_dimension() {
        let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);
        let tensor = TestTensor::<2>::from_data(data, &Default::default());

        let output = tensor.repeat(&[4, 1, 1]);
        let expected = TensorData::from([
            [0.0f32, 1.0f32, 2.0f32],
            [0.0f32, 1.0f32, 2.0f32],
            [0.0f32, 1.0f32, 2.0f32],
            [0.0f32, 1.0f32, 2.0f32],
        ]);

        output.into_data().assert_eq(&expected, false);
    }

    #[test]
    fn should_support_bool_repeat_ops_one_dimension() {
        let data = TensorData::from([[true, false, false]]);
        let tensor = TestTensorBool::<2>::from_data(data, &Default::default());

        let output = tensor.repeat(&[4, 1, 1]);
        let expected = TensorData::from([
            [true, false, false],
            [true, false, false],
            [true, false, false],
            [true, false, false],
        ]);
        output.into_data().assert_eq(&expected, true);
    }

    #[test]
    fn should_support_int_repeat_ops_one_dimension() {
        let data = TensorData::from([[0i32, 1i32, 2i32]]);
        let tensor = TestTensorInt::<2>::from_data(data, &Default::default());

        let output = tensor.repeat(&[4, 1, 1]);
        let expected = TensorData::from([
            [0i32, 1i32, 2i32],
            [0i32, 1i32, 2i32],
            [0i32, 1i32, 2i32],
            [0i32, 1i32, 2i32],
        ]);

        output.into_data().assert_eq(&expected, false);
    }

    #[test]
    fn should_support_float_repeat_repeating_on_many_dimensions() {
        let data = TensorData::from([
            [[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
            [[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
            [[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
            [[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
        ]);
        let tensor = TestTensor::<3>::from_data(data, &Default::default());

        let output = tensor.repeat(&[2, 3, 2]);
        let expected = TensorData::from([
            [
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
            ],
            [
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
            ],
            [
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
            ],
            [
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
            ],
            [
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
                [1.0f32, 2.0f32, 1.0f32, 2.0f32],
                [3.0f32, 4.0f32, 3.0f32, 4.0f32],
            ],
            [
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
                [5.0f32, 6.0f32, 5.0f32, 6.0f32],
                [7.0f32, 8.0f32, 7.0f32, 8.0f32],
            ],
            [
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
                [9.0f32, 10.0f32, 9.0f32, 10.0f32],
                [11.0f32, 12.0f32, 11.0f32, 12.0f32],
            ],
            [
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
                [13.0f32, 14.0f32, 13.0f32, 14.0f32],
                [15.0f32, 16.0f32, 15.0f32, 16.0f32],
            ],
        ]);

        output.into_data().assert_eq(&expected, false);
    }

    #[test]
    fn should_support_int_repeat_on_many_dims() {
        let data = TensorData::from([
            [[1i32, 2i32], [3i32, 4i32]],
            [[5i32, 6i32], [7i32, 8i32]],
            [[9i32, 10i32], [11i32, 12i32]],
            [[13i32, 14i32], [15i32, 16i32]],
        ]);
        let tensor = TestTensorInt::<3>::from_data(data, &Default::default());

        let output = tensor.repeat(&[2, 3, 2]);

        let expected = TensorData::from([
            [
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
            ],
            [
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
            ],
            [
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
            ],
            [
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
            ],
            [
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
                [1i32, 2i32, 1i32, 2i32],
                [3i32, 4i32, 3i32, 4i32],
            ],
            [
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
                [5i32, 6i32, 5i32, 6i32],
                [7i32, 8i32, 7i32, 8i32],
            ],
            [
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
                [9i32, 10i32, 9i32, 10i32],
                [11i32, 12i32, 11i32, 12i32],
            ],
            [
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
                [13i32, 14i32, 13i32, 14i32],
                [15i32, 16i32, 15i32, 16i32],
            ],
        ]);

        output.into_data().assert_eq(&expected, false);
    }

    #[test]
    fn should_support_bool_repeat_on_many_dimension() {
        let data = TensorData::from([
            [[false, true], [true, false]],
            [[true, true], [false, false]],
        ]);
        let tensor = TestTensorBool::<3>::from_data(data, &Default::default());

        let output = tensor.repeat(&[2, 3, 2]);
        let expected = TensorData::from([
            [
                [false, true, false, true],
                [true, false, true, false],
                [false, true, false, true],
                [true, false, true, false],
                [false, true, false, true],
                [true, false, true, false],
            ],
            [
                [true, true, true, true],
                [false, false, false, false],
                [true, true, true, true],
                [false, false, false, false],
                [true, true, true, true],
                [false, false, false, false],
            ],
            [
                [false, true, false, true],
                [true, false, true, false],
                [false, true, false, true],
                [true, false, true, false],
                [false, true, false, true],
                [true, false, true, false],
            ],
            [
                [true, true, true, true],
                [false, false, false, false],
                [true, true, true, true],
                [false, false, false, false],
                [true, true, true, true],
                [false, false, false, false],
            ],
        ]);

        output.into_data().assert_eq(&expected, true);
    }
}