burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(module_conv3d)]
mod tests {
    use super::*;
    use burn_tensor::module::conv3d;
    use burn_tensor::ops::ConvOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_conv3d_simple() {
        let test = Conv3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 1,
            padding_2: 1,
            padding_3: 1,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 1,
            depth: 4,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [29980.0, 44860.0, 45640.0, 30324.0],
                    [45072.0, 67380.0, 68496.0, 45468.0],
                    [48096.0, 71844.0, 72960.0, 48396.0],
                    [31780.0, 47428.0, 48136.0, 31900.0],
                ],
                [
                    [47292.0, 70548.0, 71556.0, 47400.0],
                    [70335.0, 104823.0, 106254.0, 70317.0],
                    [74223.0, 110547.0, 111978.0, 74061.0],
                    [48552.0, 72240.0, 73140.0, 48324.0],
                ],
                [
                    [58236.0, 86676.0, 87684.0, 57960.0],
                    [85887.0, 127719.0, 129150.0, 85293.0],
                    [89775.0, 133443.0, 134874.0, 89037.0],
                    [58344.0, 86640.0, 87540.0, 57732.0],
                ],
                [
                    [36148.0, 53620.0, 54184.0, 35692.0],
                    [52740.0, 78144.0, 78936.0, 51936.0],
                    [54900.0, 81312.0, 82104.0, 54000.0],
                    [35260.0, 52156.0, 52648.0, 34580.0],
                ],
            ],
            [
                [
                    [66701.0, 100589.0, 102665.0, 68773.0],
                    [102745.0, 154861.0, 157921.0, 105733.0],
                    [110953.0, 167101.0, 170161.0, 113845.0],
                    [75413.0, 113525.0, 115529.0, 77261.0],
                ],
                [
                    [112741.0, 169693.0, 172645.0, 115441.0],
                    [172396.0, 259372.0, 263719.0, 176266.0],
                    [184060.0, 276760.0, 281107.0, 187786.0],
                    [124369.0, 186937.0, 189781.0, 126733.0],
                ],
                [
                    [144421.0, 216925.0, 219877.0, 146737.0],
                    [219052.0, 328924.0, 333271.0, 222346.0],
                    [230716.0, 346312.0, 350659.0, 233866.0],
                    [154897.0, 232441.0, 235285.0, 156877.0],
                ],
                [
                    [100517.0, 150821.0, 152681.0, 101789.0],
                    [151885.0, 227833.0, 230569.0, 153673.0],
                    [159229.0, 238777.0, 241513.0, 160921.0],
                    [106541.0, 159725.0, 161513.0, 107589.0],
                ],
            ],
        ]]));
    }

    #[test]
    fn test_conv3d_groups() {
        let test = Conv3dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            kernel_size_3: 3,
            padding_1: 0,
            padding_2: 0,
            padding_3: 0,
            stride_1: 1,
            stride_2: 1,
            stride_3: 1,
            dilation_1: 1,
            dilation_2: 1,
            dilation_3: 1,
            groups: 2,
            depth: 5,
            height: 5,
            width: 5,
        };

        test.assert_output(TestTensor::from([[
            [
                [
                    [15219., 15570., 15921.],
                    [16974., 17325., 17676.],
                    [18729., 19080., 19431.],
                ],
                [
                    [23994., 24345., 24696.],
                    [25749., 26100., 26451.],
                    [27504., 27855., 28206.],
                ],
                [
                    [32769., 33120., 33471.],
                    [34524., 34875., 35226.],
                    [36279., 36630., 36981.],
                ],
            ],
            [
                [
                    [172819., 173899., 174979.],
                    [178219., 179299., 180379.],
                    [183619., 184699., 185779.],
                ],
                [
                    [199819., 200899., 201979.],
                    [205219., 206299., 207379.],
                    [210619., 211699., 212779.],
                ],
                [
                    [226819., 227899., 228979.],
                    [232219., 233299., 234379.],
                    [237619., 238699., 239779.],
                ],
            ],
        ]]));
    }

    #[test]
    fn test_conv3d_complex() {
        let test = Conv3dTestCase {
            batch_size: 2,
            channels_in: 3,
            channels_out: 4,
            kernel_size_1: 4,
            kernel_size_2: 3,
            kernel_size_3: 2,
            padding_1: 1,
            padding_2: 2,
            padding_3: 3,
            stride_1: 2,
            stride_2: 3,
            stride_3: 4,
            dilation_1: 1,
            dilation_2: 2,
            dilation_3: 3,
            groups: 1,
            depth: 4,
            height: 5,
            width: 6,
        };

        test.assert_output(TestTensor::from([
            [
                [
                    [[149148., 299070., 149850.], [147636., 295758., 148050.]],
                    [[150660., 301014., 150282.], [147420., 294246., 146754.]],
                ],
                [
                    [[351325., 709903., 358507.], [357589., 722143., 364483.]],
                    [[391717., 789607., 397819.], [396253., 798391., 402067.]],
                ],
                [
                    [[553502., 1120736., 567164.], [567542., 1148528., 580916.]],
                    [[632774., 1278200., 645356.], [645086., 1302536., 657380.]],
                ],
                [
                    [[755679., 1531569., 775821.], [777495., 1574913., 797349.]],
                    [[873831., 1766793., 892893.], [893919., 1806681., 912693.]],
                ],
            ],
            [
                [
                    [[408348., 810990., 402570.], [393876., 781758., 387810.]],
                    [[370980., 735174., 364122.], [354780., 702486., 347634.]],
                ],
                [
                    [
                        [1077085., 2154943., 1077787.],
                        [1070389., 2141263., 1070803.],
                    ],
                    [
                        [1078597., 2156887., 1078219.],
                        [1070173., 2139751., 1069507.],
                    ],
                ],
                [
                    [
                        [1745822., 3498896., 1753004.],
                        [1746902., 3500768., 1753796.],
                    ],
                    [
                        [1786214., 3578600., 1792316.],
                        [1785566., 3577016., 1791380.],
                    ],
                ],
                [
                    [
                        [2414559., 4842849., 2428221.],
                        [2423415., 4860273., 2436789.],
                    ],
                    [
                        [2493831., 5000313., 2506413.],
                        [2500959., 5014281., 2513253.],
                    ],
                ],
            ],
        ]));
    }

    struct Conv3dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        kernel_size_3: usize,
        padding_1: usize,
        padding_2: usize,
        padding_3: usize,
        stride_1: usize,
        stride_2: usize,
        stride_3: usize,
        dilation_1: usize,
        dilation_2: usize,
        dilation_3: usize,
        groups: usize,
        depth: usize,
        height: usize,
        width: usize,
    }

    impl Conv3dTestCase {
        fn assert_output(self, y: TestTensor<5>) {
            let shape_x = Shape::new([
                self.batch_size,
                self.channels_in,
                self.depth,
                self.height,
                self.width,
            ]);
            let shape_weight = Shape::new([
                self.channels_out,
                self.channels_in / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
                self.kernel_size_3,
            ]);
            let device = Default::default();
            let weight = TestTensor::from(
                TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_weight)
                    .into_data(),
            );
            let bias = TestTensor::from(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
            );
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_x)
                    .into_data(),
            );
            let output = conv3d(
                x,
                weight,
                Some(bias),
                ConvOptions::new(
                    [self.stride_1, self.stride_2, self.stride_3],
                    [self.padding_1, self.padding_2, self.padding_3],
                    [self.dilation_1, self.dilation_2, self.dilation_3],
                    self.groups,
                ),
            );

            y.to_data()
                .assert_approx_eq_diff(&output.into_data(), 0.002);
        }
    }
}