burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(module_unfold4d)]
mod tests {
    use super::*;
    use burn_tensor::module::unfold4d;
    use burn_tensor::ops::UnfoldOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_unfold4d_shape() {
        let test = Unfold4dTestCase {
            batch_size: 2,
            channels_in: 5,
            kernel_size: [2, 3],
            padding: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            height: 3,
            width: 4,
        };

        test.assert_shape([2, 30, 4]);
    }

    #[test]
    fn test_unfold4d_simple() {
        let test = Unfold4dTestCase {
            batch_size: 1,
            channels_in: 2,
            kernel_size: [2, 2],
            padding: [0, 0],
            stride: [1, 1],
            dilation: [1, 1],
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [0., 1., 2., 4., 5., 6., 8., 9., 10.],
            [1., 2., 3., 5., 6., 7., 9., 10., 11.],
            [4., 5., 6., 8., 9., 10., 12., 13., 14.],
            [5., 6., 7., 9., 10., 11., 13., 14., 15.],
            [16., 17., 18., 20., 21., 22., 24., 25., 26.],
            [17., 18., 19., 21., 22., 23., 25., 26., 27.],
            [20., 21., 22., 24., 25., 26., 28., 29., 30.],
            [21., 22., 23., 25., 26., 27., 29., 30., 31.],
        ]]));
    }

    #[test]
    fn test_unfold4d_complex() {
        let test = Unfold4dTestCase {
            batch_size: 1,
            channels_in: 2,
            kernel_size: [2, 3],
            padding: [0, 1],
            stride: [1, 2],
            dilation: [1, 2],
            height: 3,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [0., 0.],
            [1., 5.],
            [3., 7.],
            [0., 0.],
            [5., 9.],
            [7., 11.],
            [0., 0.],
            [13., 17.],
            [15., 19.],
            [0., 0.],
            [17., 21.],
            [19., 23.],
        ]]));
    }

    struct Unfold4dTestCase {
        batch_size: usize,
        channels_in: usize,
        kernel_size: [usize; 2],
        padding: [usize; 2],
        stride: [usize; 2],
        dilation: [usize; 2],
        height: usize,
        width: usize,
    }

    impl Unfold4dTestCase {
        fn assert_shape(self, expected_shape: [usize; 3]) {
            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
                    .reshape::<4, _>(shape_x)
                    .into_data()
                    .convert::<f32>(),
            );

            let output = unfold4d(
                x,
                self.kernel_size,
                UnfoldOptions::new(self.stride, self.padding, self.dilation),
            );

            assert_eq!(
                output.shape().dims,
                expected_shape,
                "Expected shape doesn't match the actual shape"
            );
        }

        fn assert_output(self, expected: TestTensor<3>) {
            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
                    .reshape::<4, _>(shape_x)
                    .into_data(),
            );

            let output = unfold4d(
                x,
                self.kernel_size,
                UnfoldOptions::new(self.stride, self.padding, self.dilation),
            );

            output
                .into_data()
                .assert_approx_eq(&expected.into_data(), 3);
        }
    }
}