burn-cubecl 0.21.0-pre.4

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use super::pool2d::{
    Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,
};
use crate::{
    CubeRuntime,
    kernel::{
        into_contiguous_aligned,
        pool::pool2d::{Position, view4d},
        utils::{address_type, shape_divmod},
    },
    ops::{
        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,
    },
    tensor::CubeTensor,
};
use burn_backend::{Shape, ops::conv::calculate_pool_output_size};
use cubecl::{CubeDim, calculate_cube_count_elemwise, num_traits::Zero};
use cubecl::{prelude::*, std::tensor::View};

struct AvgPoolStrategy;

impl Pool2dDirectStrategyFamily for AvgPoolStrategy {
    type Indices<N: Size> = ();
    type Config = AvgPoolStrategyConfig;
    type Pool2d<T: Numeric, N: Size> = Self;
}

#[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct AvgPoolStrategyConfig {
    count_include_pad: bool,
    /// Total padded height (input_height + 2 * padding_0)
    padded_h: u32,
    /// Total padded width (input_width + 2 * padding_1)
    padded_w: u32,
}

#[cube]
impl<T: Numeric, N: Size> Pool2dDirectStrategy<T, N> for AvgPoolStrategy {
    type Accumulator = (Vector<T, N>, u32);
    type Config = AvgPoolStrategyConfig;
    type Indices = ();

    fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator {
        let sum = Vector::zero();
        // Count will be set dynamically: either by accumulate (count_include_pad=false)
        // or by set_padded_count (count_include_pad=true)
        let count = 0u32;

        (sum, count)
    }

    fn accumulate(
        #[comptime] config: &Self::Config,
        accumulator: &mut Self::Accumulator,
        _index: usize,
        result: Vector<T, N>,
    ) {
        let (sum, count) = accumulator;

        // Only count valid positions when count_include_pad=false
        if comptime![!config.count_include_pad] {
            *count += 1;
        }

        *sum += result;
    }

    fn count_position(
        #[comptime] config: &Self::Config,
        accumulator: &mut Self::Accumulator,
        ih: u32,
        iw: u32,
    ) {
        // When count_include_pad=true, count positions within padded bounds
        // (excludes ceil_mode extensions beyond the padded input)
        if comptime![config.count_include_pad] && ih < config.padded_h && iw < config.padded_w {
            let (_sum, count) = accumulator;
            *count += 1;
        }
    }

    fn store(
        #[comptime] _config: &Self::Config,
        position: Position,
        output: &mut View<Vector<T, N>, Position, ReadWrite>,
        _output_indices: &mut (),
        accumulator: Self::Accumulator,
    ) {
        let (sum, count) = accumulator;
        output[position] = sum / Vector::cast_from(count);
    }
}

pub(crate) fn avg_pool2d<R: CubeRuntime>(
    x: CubeTensor<R>,
    kernel_size: [usize; 2],
    stride: [usize; 2],
    padding: [usize; 2],
    count_include_pad: bool,
    ceil_mode: bool,
) -> CubeTensor<R> {
    let [batch_size, channels, in_h, in_w] = x.meta.shape().dims();
    let dilation = 1;

    let size_0 = calculate_pool_output_size(
        kernel_size[0],
        stride[0],
        padding[0],
        dilation,
        in_h,
        ceil_mode,
    );
    let size_1 = calculate_pool_output_size(
        kernel_size[1],
        stride[1],
        padding[1],
        dilation,
        in_w,
        ceil_mode,
    );

    // Padded dimensions (for count_include_pad with ceil_mode)
    let padded_0 = in_h + 2 * padding[0];
    let padded_1 = in_w + 2 * padding[1];

    let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));
    let vector_size = max_vector_size(&x);

    let shape_out = Shape::new([batch_size, size_0, size_1, channels]);
    let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);

    let working_units = output.meta.num_elements() / vector_size as usize;
    let cube_dim = CubeDim::new(&x.client, working_units);
    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);

    pool2d_direct::launch::<AvgPoolStrategy, R>(
        &output.client,
        cube_count,
        cube_dim,
        address_type!(x, output),
        vector_size,
        x.into_tensor_arg(),
        view4d(output.clone(), vector_size),
        (),
        shape_divmod(&output),
        working_units,
        Pool2dDirectArgsLaunch::new(
            stride[0] as u32,
            stride[1] as u32,
            dilation as u32,
            dilation as u32,
            padding[0] as u32,
            padding[1] as u32,
        ),
        (kernel_size[0] as u32, kernel_size[1] as u32),
        AvgPoolStrategyConfig {
            count_include_pad,
            padded_h: padded_0 as u32,
            padded_w: padded_1 as u32,
        },
        output.dtype.into(),
    );

    permute_nhwc_to_nchw(output)
}