burn-tensor 0.16.1

Tensor library with user-friendly APIs and automatic differentiation support
Documentation
#[burn_tensor_testgen::testgen(module_max_pool2d)]
mod tests {
    use super::*;
    use burn_tensor::module::{max_pool2d, max_pool2d_with_indices};
    use burn_tensor::{backend::Backend, Tensor, TensorData};

    #[test]
    fn test_max_pool2d_simple() {
        let batch_size = 2;
        let channels_in = 2;
        let kernel_size_1 = 3;
        let kernel_size_2 = 3;
        let padding_1 = 1;
        let padding_2 = 1;
        let stride_1 = 1;
        let stride_2 = 1;
        let dilation_1 = 1;
        let dilation_2 = 1;

        let x = TestTensor::from([
            [
                [
                    [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],
                    [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],
                    [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182],
                    [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392],
                    [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605],
                    [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068],
                ],
                [
                    [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473],
                    [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947],
                    [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149],
                    [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426],
                    [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544],
                    [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572],
                ],
            ],
            [
                [
                    [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910],
                    [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546],
                    [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584],
                    [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441],
                    [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881],
                    [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467],
                ],
                [
                    [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952],
                    [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454],
                    [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628],
                    [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527],
                    [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827],
                    [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421],
                ],
            ],
        ]);
        let y = TestTensor::<4>::from([
            [
                [
                    [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
                    [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
                    [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
                    [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
                    [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
                    [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
                ],
                [
                    [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947],
                    [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149],
                    [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149],
                    [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149],
                    [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025],
                    [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775],
                ],
            ],
            [
                [
                    [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
                    [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],
                    [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546],
                    [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
                    [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],
                    [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037],
                ],
                [
                    [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378],
                    [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378],
                    [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445],
                    [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128],
                    [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128],
                    [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128],
                ],
            ],
        ]);

        let output = max_pool2d(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
    }

    #[test]
    fn test_max_pool2d_different_padding_stride_kernel() {
        let batch_size = 1;
        let channels_in = 1;
        let kernel_size_1 = 3;
        let kernel_size_2 = 1;
        let padding_1 = 1;
        let padding_2 = 0;
        let stride_1 = 1;
        let stride_2 = 2;
        let dilation_1 = 1;
        let dilation_2 = 1;

        let x = TestTensor::from([[[
            [0.6309, 0.6112, 0.6998],
            [0.4708, 0.9161, 0.5402],
            [0.4577, 0.7397, 0.9870],
            [0.6380, 0.4352, 0.5884],
            [0.6277, 0.5139, 0.4525],
            [0.9333, 0.9846, 0.5006],
        ]]]);
        let y = TestTensor::<4>::from([[[
            [0.6309, 0.6998],
            [0.6309, 0.9870],
            [0.6380, 0.9870],
            [0.6380, 0.9870],
            [0.9333, 0.5884],
            [0.9333, 0.5006],
        ]]]);

        let output = max_pool2d(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
    }

    #[test]
    fn test_max_pool2d_with_neg() {
        let batch_size = 1;
        let channels_in = 1;
        let kernel_size_1 = 3;
        let kernel_size_2 = 3;
        let padding_1 = 1;
        let padding_2 = 1;
        let stride_1 = 1;
        let stride_2 = 1;
        let dilation_1 = 1;
        let dilation_2 = 1;

        let x = TestTensor::from([[[
            [0.6309, 0.6112, 0.6998],
            [0.4708, 0.9161, 0.5402],
            [0.4577, 0.7397, 0.9870],
            [0.6380, 0.4352, 0.5884],
            [0.6277, 0.5139, 0.4525],
            [0.9333, 0.9846, 0.5006],
        ]]])
        .neg();
        let y = TestTensor::<4>::from([[[
            [-0.4708, -0.4708, -0.5402],
            [-0.4577, -0.4577, -0.5402],
            [-0.4352, -0.4352, -0.4352],
            [-0.4352, -0.4352, -0.4352],
            [-0.4352, -0.4352, -0.4352],
            [-0.5139, -0.4525, -0.4525],
        ]]]);

        let output = max_pool2d(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
    }

    #[test]
    fn test_max_pool2d_with_dilation() {
        let batch_size = 1;
        let channels_in = 1;
        let kernel_size_1 = 2;
        let kernel_size_2 = 2;
        let padding_1 = 0;
        let padding_2 = 0;
        let stride_1 = 1;
        let stride_2 = 1;
        let dilation_1 = 2;
        let dilation_2 = 2;

        let x = TestTensor::from([[[
            [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
            [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],
            [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],
            [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],
            [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],
            [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],
        ]]]);
        let y = TestTensor::<4>::from([[[
            [0.9861, 0.9861, 0.9540, 0.9490],
            [0.9861, 0.9861, 0.9540, 0.9490],
            [0.9540, 0.9540, 0.9540, 0.9490],
            [0.9540, 0.9540, 0.9540, 0.9432],
        ]]]);

        let output = max_pool2d(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
    }

    fn test_max_pool2d_with_indices() {
        let batch_size = 1;
        let channels_in = 1;
        let kernel_size_1 = 2;
        let kernel_size_2 = 2;
        let padding_1 = 1;
        let padding_2 = 1;
        let stride_1 = 1;
        let stride_2 = 1;
        let dilation_1 = 1;
        let dilation_2 = 1;

        let x = TestTensor::from([[[
            [0.2479, 0.6386, 0.3166, 0.5742],
            [0.7065, 0.1940, 0.6305, 0.8959],
            [0.5416, 0.8602, 0.8129, 0.1662],
            [0.3358, 0.3059, 0.8293, 0.0990],
        ]]]);
        let indices = TensorData::from([[[
            [0, 1, 1, 3, 3],
            [4, 4, 1, 7, 7],
            [4, 9, 9, 7, 7],
            [8, 9, 9, 14, 11],
            [12, 12, 14, 14, 15],
        ]]]);
        let y = TestTensor::<4>::from([[[
            [0.2479, 0.6386, 0.6386, 0.5742, 0.5742],
            [0.7065, 0.7065, 0.6386, 0.8959, 0.8959],
            [0.7065, 0.8602, 0.8602, 0.8959, 0.8959],
            [0.5416, 0.8602, 0.8602, 0.8293, 0.1662],
            [0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
        ]]]);

        let (output, output_indices) = max_pool2d_with_indices(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
        output_indices.into_data().assert_eq(&indices, false);
    }

    #[test]
    fn test_max_pool2d_complex() {
        let batch_size = 1;
        let channels_in = 1;
        let kernel_size_1 = 4;
        let kernel_size_2 = 2;
        let padding_1 = 2;
        let padding_2 = 1;
        let stride_1 = 1;
        let stride_2 = 2;
        let dilation_1 = 1;
        let dilation_2 = 1;

        let x = TestTensor::from([[[
            [0.5388, 0.0676, 0.7122, 0.8316, 0.0653],
            [0.9154, 0.1536, 0.9089, 0.8016, 0.7518],
            [0.2073, 0.0501, 0.8811, 0.5604, 0.5075],
            [0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
            [0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
        ]]]);
        let indices = TensorData::from([[[
            [5, 7, 3],
            [5, 7, 3],
            [5, 16, 3],
            [5, 16, 8],
            [15, 16, 24],
            [15, 16, 24],
        ]]]);
        let y = TestTensor::<4>::from([[[
            [0.9154, 0.9089, 0.8316],
            [0.9154, 0.9089, 0.8316],
            [0.9154, 0.9963, 0.8316],
            [0.9154, 0.9963, 0.8016],
            [0.4384, 0.9963, 0.688],
            [0.4384, 0.9963, 0.688],
        ]]]);
        let (output, output_indices) = max_pool2d_with_indices(
            x,
            [kernel_size_1, kernel_size_2],
            [stride_1, stride_2],
            [padding_1, padding_2],
            [dilation_1, dilation_2],
        );

        y.to_data().assert_approx_eq(&output.into_data(), 3);
        output_indices.into_data().assert_eq(&indices, false);
    }
}