burn-cubecl 0.21.0-pre.4

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use crate::{
    CubeRuntime,
    kernel::utils::address_type,
    ops::{max_vector_size, numeric::empty_device_dtype},
    tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

#[cube(launch, address_type = "dynamic")]
pub(crate) fn cast_element<I: Numeric, O: Numeric, N: Size>(
    input: &LinearView<Vector<I, N>>,
    output: &mut LinearView<Vector<O, N>, ReadWrite>,
    #[define(I, O)] _dtypes: [StorageType; 2],
) {
    if !output.is_in_bounds(ABSOLUTE_POS) {
        terminate!();
    }

    output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]);
}

/// Cast a tensor to the given element type.
///
/// Note: When input element is semantically a boolean, prefer bool_cast function.
pub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {
    let dtype_output = match dtype {
        DType::Flex32 => DType::F32,
        _ => dtype,
    };
    let dtype_input = match input.dtype {
        DType::Flex32 => DType::F32,
        _ => input.dtype,
    };

    if dtype_input == dtype_output {
        return input;
    }

    let client = input.client.clone();

    let vector_size = max_vector_size(&input);

    let num_elems: usize = input.meta.num_elements();

    let working_units = num_elems / vector_size as usize;
    let cube_dim = CubeDim::new(&client, working_units);
    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);

    let output = empty_device_dtype(
        client.clone(),
        input.device.clone(),
        input.shape(),
        dtype, // We take the same dtype as passed as input (Flex32 not F32)
    );

    cast_element::launch(
        &client,
        cube_count,
        cube_dim,
        address_type!(input, output),
        vector_size,
        input.into_linear_view(),
        output.clone().into_linear_view(),
        [dtype_input.into(), dtype_output.into()],
    );

    output
}