burn-cubecl 0.21.0-pre.3

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use crate::{
    CubeRuntime,
    kernel::utils::{address_type, shape_divmod},
    ops::numeric::empty_device_dtype,
    tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::{
    calculate_cube_count_elemwise,
    prelude::*,
    std::{FastDivmod, tensor::layout::linear::LinearView},
};

#[cube(launch_unchecked, address_type = "dynamic")]
fn flip_kernel<E: Numeric, Bool: Int>(
    input: &Tensor<E>,
    output: &mut LinearView<E, ReadWrite>,
    in_shape: Sequence<FastDivmod<usize>>,
    indices: Sequence<InputScalar>,
    #[define(E, Bool)] _dtypes: [StorageType; 2],
) {
    if !output.is_in_bounds(ABSOLUTE_POS) {
        terminate!();
    }

    let rank = in_shape.len().comptime();

    let mut offset = ABSOLUTE_POS;
    let mut offset_input = 0;

    #[unroll]
    for i in 0..rank {
        let dim = rank - i - 1;
        let shape = input.shape(dim);

        let (rem, offset_local) = in_shape[dim].div_mod(offset);
        offset = rem;

        let flip = indices.index(dim).get::<Bool>() == Bool::from_int(1);
        let offset_local = select(flip, shape - offset_local - 1, offset_local);

        offset_input += offset_local * input.stride(dim);
    }

    output[ABSOLUTE_POS] = input[offset_input];
}

pub(crate) fn flip<R: CubeRuntime>(
    tensor: CubeTensor<R>,
    indices: &[usize],
    dtype_bool: DType,
) -> CubeTensor<R> {
    let output = empty_device_dtype(
        tensor.client.clone(),
        tensor.device.clone(),
        tensor.shape(),
        tensor.dtype,
    );
    flip_on_output(tensor, output, indices, dtype_bool)
}

pub(crate) fn flip_on_output<R: CubeRuntime>(
    tensor: CubeTensor<R>,
    output: CubeTensor<R>,
    indices: &[usize],
    dtype_bool: DType,
) -> CubeTensor<R> {
    let dtype_input = tensor.dtype;
    let ndims = tensor.meta.num_dims();
    let mut indices_sequence = SequenceArg::<R, InputScalar>::new();

    for i in 0..ndims {
        indices_sequence.push({
            let val = indices.contains(&i) as u8;
            InputScalar::new(val, dtype_bool)
        });
    }

    let num_elements = output.meta.num_elements();
    let cube_dim = CubeDim::new(&tensor.client, num_elements);
    let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim);

    let shape = shape_divmod(&tensor);
    unsafe {
        flip_kernel::launch_unchecked(
            &output.client,
            cube_count,
            cube_dim,
            address_type!(tensor, output),
            tensor.into_tensor_arg(),
            output.clone().into_linear_view(),
            shape,
            indices_sequence,
            [dtype_input.into(), dtype_bool.into()],
        )
    }

    output
}