burn-cubecl 0.21.0-pre.3

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use core::hash::Hash;
use cubecl::{
    prelude::*,
    std::{
        FastDivmod,
        tensor::{
            View,
            launch::ViewArg,
            layout::fixed_dim::{FixedDimLayout, FixedDimLayoutLaunch},
        },
    },
};

use crate::{CubeRuntime, kernel::utils::decompose_linear, tensor::CubeTensor};

pub trait Pool2dDirectStrategyFamily: Send + Sync + 'static {
    type Indices<N: Size>: LaunchArg;
    type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;
    type Pool2d<T: Numeric, N: Size>: Pool2dDirectStrategy<T, N, Config = Self::Config, Indices = Self::Indices<N>>;
}

pub(super) type Position = (usize, usize, usize, usize);

#[cube]
pub(crate) trait Pool2dDirectStrategy<T: Numeric, N: Size>: Send + Sync + 'static {
    type Accumulator: CubeType;
    type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;

    type Indices: LaunchArg;

    fn initialize(#[comptime] config: &Self::Config) -> Self::Accumulator;

    fn accumulate(
        #[comptime] config: &Self::Config,
        accumulator: &mut Self::Accumulator,
        index: usize,
        result: Vector<T, N>,
    );

    /// Count a position within the kernel window (for avg_pool count_include_pad).
    /// Called for each position in the kernel window with the current ih/iw coordinates.
    /// Only avg_pool uses this; max_pool implements as no-op.
    fn count_position(
        #[comptime] config: &Self::Config,
        accumulator: &mut Self::Accumulator,
        ih: u32,
        iw: u32,
    );

    fn store(
        #[comptime] config: &Self::Config,
        position: Position,
        output: &mut View<Vector<T, N>, Position, ReadWrite>,
        output_indices: &mut Self::Indices,
        accumulator: Self::Accumulator,
    );
}

#[derive(CubeLaunch, CubeType)]
pub struct Pool2dDirectArgs {
    pub strides_0: u32,
    pub strides_1: u32,
    pub dilation_0: u32,
    pub dilation_1: u32,
    pub padding_0: u32,
    pub padding_1: u32,
}

#[cube(launch, address_type = "dynamic")]
pub fn pool2d_direct<E: Numeric, N: Size, S: Pool2dDirectStrategyFamily>(
    input: &Tensor<Vector<E, N>>,
    output: &mut View<Vector<E, N>, Position, ReadWrite>,
    indices: &mut S::Indices<N>,
    out_shape: Sequence<FastDivmod<usize>>,
    working_units: usize,
    args: &Pool2dDirectArgs,
    #[comptime] kernel_size: (u32, u32),
    #[comptime] config: &S::Config,
    #[define(E)] _dtype: StorageType,
) {
    if ABSOLUTE_POS >= working_units {
        terminate!();
    }

    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);
    let [b, oh, ow, c] = *pos else { unreachable!() };

    let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));
    let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32);

    let mut accumulator = S::Pool2d::<E, N>::initialize(config);

    let in_b_off = b * input.stride(0);
    let in_c_off = c * input.stride(3);

    let border_bottom = in_h + args.padding_0;
    let border_right = in_w + args.padding_1;

    for kh in 0..kernel_size.0 {
        let ih = oh as u32 * args.strides_0 + kh * args.dilation_0;
        let within_padding_h = ih >= args.padding_0 && ih < border_bottom;

        for kw in 0..kernel_size.1 {
            let iw = ow as u32 * args.strides_1 + kw * args.dilation_1;
            let within_padding_w = iw >= args.padding_1 && iw < border_right;

            // Let strategy handle position counting (only used by avg_pool)
            S::Pool2d::<E, N>::count_position(config, &mut accumulator, ih, iw);

            // Only accumulate values from valid input positions
            if within_padding_h && within_padding_w {
                let ih_pad = ih - args.padding_0;
                let iw_pad = iw - args.padding_1;

                let in_h_off = ih_pad as usize * in_stride_h;
                let in_w_off = iw_pad as usize * in_stride_w;

                let index_input = in_b_off + in_c_off + in_h_off + in_w_off;

                S::Pool2d::<E, N>::accumulate(
                    config,
                    &mut accumulator,
                    ih_pad as usize * in_w as usize + iw_pad as usize,
                    input[index_input / input.vector_size()],
                );
            }
        }
    }

    S::Pool2d::<E, N>::store(config, (b, oh, ow, c), output, indices, accumulator);
}

pub(super) fn view4d<R: CubeRuntime>(
    tensor: CubeTensor<R>,
    vector_size: VectorSize,
) -> ViewArg<Position, R> {
    let shape = tensor.meta.shape();
    let shape = (shape[0], shape[1], shape[2], shape[3]);
    let binding = tensor.binding();
    let layout = FixedDimLayoutLaunch::<Position, R>::from_shape_handle_unchecked(
        &binding,
        shape,
        vector_size,
    );
    let buffer = binding.into_tensor_arg();
    ViewArg::new_tensor::<FixedDimLayout<Position>>(buffer, layout)
}