burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(module_conv_transpose2d)]
mod tests {
    use super::*;
    use burn_tensor::module::conv_transpose2d;
    use burn_tensor::ops::ConvTransposeOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_conv_transpose2d_simple_1() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 1,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]]));
    }

    #[test]
    fn test_conv_transpose2d_simple_2() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 3,
            channels_out: 3,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [9855., 15207., 15738., 10797.],
                [16290., 25119., 25956., 17793.],
                [18486., 28467., 29304., 20061.],
                [13593., 20913., 21498., 14703.],
            ],
            [
                [11854., 18286., 18979., 13012.],
                [19612., 30223., 31303., 21439.],
                [22456., 34543., 35623., 24355.],
                [16456., 25288., 26035., 17782.],
            ],
            [
                [13853., 21365., 22220., 15227.],
                [22934., 35327., 36650., 25085.],
                [26426., 40619., 41942., 28649.],
                [19319., 29663., 30572., 20861.],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose2d_simple_3() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 1,
            kernel_size_1: 2,
            kernel_size_2: 2,
            padding_1: 0,
            padding_2: 0,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[[
            [0.0, 0.0, 1.0],
            [0.0, 4.0, 6.0],
            [4.0, 12.0, 9.0],
        ]]]));
    }

    #[test]
    fn test_conv_transpose2d_stride_2() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 1,
            kernel_size_1: 2,
            kernel_size_2: 2,
            padding_1: 0,
            padding_2: 0,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 2,
            stride_2: 2,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[[
            [0.0, 0.0, 0.0, 1.0],
            [0.0, 0.0, 2.0, 3.0],
            [0.0, 2.0, 0.0, 3.0],
            [4.0, 6.0, 6.0, 9.0],
        ]]]));
    }

    #[test]
    fn test_conv_transpose2d_dilation_2() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            padding_out_1: 1,
            padding_out_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 2,
            dilation_2: 2,
            groups: 1,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [
                [126., 116., 136., 124., 146.],
                [108., 88., 114., 92., 120.],
                [156., 140., 166., 148., 176.],
                [126., 100., 132., 104., 138.],
                [186., 164., 196., 172., 206.],
            ],
            [
                [217., 189., 227., 197., 237.],
                [163., 125., 169., 129., 175.],
                [247., 213., 257., 221., 267.],
                [181., 137., 187., 141., 193.],
                [277., 237., 287., 245., 297.],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose2d_stride2_out_padding() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            padding_out_1: 1,
            padding_out_2: 1,
            stride_1: 2,
            stride_2: 2,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [352., 728., 378., 780., 404., 832., 430., 452.],
                [784., 1616., 836., 1720., 888., 1824., 940., 992.],
                [456., 936., 482., 988., 508., 1040., 534., 564.],
                [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.],
                [560., 1144., 586., 1196., 612., 1248., 638., 676.],
                [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.],
                [664., 1352., 690., 1404., 716., 1456., 742., 788.],
                [784., 1598., 816., 1662., 848., 1726., 880., 926.],
            ],
            [
                [497., 1035., 541., 1123., 585., 1211., 629., 651.],
                [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.],
                [673., 1387., 717., 1475., 761., 1563., 805., 835.],
                [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.],
                [849., 1739., 893., 1827., 937., 1915., 981., 1019.],
                [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.],
                [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.],
                [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.],
            ],
        ]]));
    }

    #[test]
    fn test_conv_transpose2d_groups_2() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [[5., 11.], [23., 29.]],
            [[236., 258.], [302., 324.]],
        ]]));
    }

    #[test]
    fn test_conv_transpose2d_groups_different_channels() {
        let test = ConvTranspose2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 6,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            padding_out_1: 0,
            padding_out_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 2,
            height: 2,
            width: 2,
        };

        test.assert_output(TestTensor::from([[
            [
                [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00],
                [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01],
                [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01],
                [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01],
            ],
            [
                [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01],
                [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01],
                [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01],
                [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01],
            ],
            [
                [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01],
                [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01],
                [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01],
                [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01],
            ],
            [
                [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02],
                [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02],
                [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02],
                [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02],
            ],
            [
                [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02],
                [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02],
                [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02],
                [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02],
            ],
            [
                [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02],
                [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02],
                [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02],
                [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02],
            ],
        ]]));
    }

    struct ConvTranspose2dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        padding_1: usize,
        padding_2: usize,
        padding_out_1: usize,
        padding_out_2: usize,
        stride_1: usize,
        stride_2: usize,
        dilation_1: usize,
        dilation_2: usize,
        groups: usize,
        height: usize,
        width: usize,
    }

    impl ConvTranspose2dTestCase {
        fn assert_output(self, y: TestTensor<4>) {
            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let shape_weights = Shape::new([
                self.channels_in,
                self.channels_out / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
            ]);
            let device = Default::default();
            let weights = TestTensor::from(
                TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_weights)
                    .into_data(),
            );
            let bias = TestTensor::from(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
            );
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_x)
                    .into_data(),
            );
            let output = conv_transpose2d(
                x,
                weights,
                Some(bias),
                ConvTransposeOptions::new(
                    [self.stride_1, self.stride_2],
                    [self.padding_1, self.padding_2],
                    [self.padding_out_1, self.padding_out_2],
                    [self.dilation_1, self.dilation_2],
                    self.groups,
                ),
            );

            y.to_data().assert_approx_eq(&output.into_data(), 3);
        }
    }
}