cubek-test-utils 0.2.0

CubeK: Test Utils
Documentation
use cubecl::{
    TestRuntime,
    prelude::*,
    std::tensor::{
        TensorHandle, ViewOperations, ViewOperationsExpand, ViewOperationsMut,
        ViewOperationsMutExpand,
    },
    tensor_vector_size_parallel,
    zspace::{shape, strides},
};

#[cube(launch)]
fn cast_launch<From: Numeric, To: Numeric, N: Size>(
    from: &Tensor<Vector<From, N>>,
    to: &mut Tensor<Vector<To, N>>,
    #[define(From, To)] _types: [StorageType; 2],
) {
    cast_inner::<From, To, N>(from, to);
}

#[cube]
fn cast_inner<From: Numeric, To: Numeric, N: Size>(
    from: &Tensor<Vector<From, N>>,
    to: &mut Tensor<Vector<To, N>>,
) {
    to.write_checked(
        ABSOLUTE_POS,
        Vector::cast_from(from.read_checked(ABSOLUTE_POS)),
    )
}

pub fn copy_casted(
    client: &ComputeClient<TestRuntime>,
    original: TensorHandle<TestRuntime>,
    target_type: StorageType,
) -> TensorHandle<TestRuntime> {
    if target_type == original.dtype {
        return TensorHandle::new_contiguous(
            original.shape().clone(),
            original.handle.clone(),
            target_type,
        );
    }

    let num_elems: usize = original.shape().num_elements();

    let vector_size = tensor_vector_size_parallel(
        client.io_optimized_vector_sizes(target_type.size()),
        &shape![num_elems],
        &strides![1],
        0,
    );

    let working_units: u32 = num_elems as u32 / vector_size as u32;
    let cube_dim = CubeDim::new(client, working_units as usize);
    let cube_count = working_units.div_ceil(cube_dim.num_elems());

    let out = TensorHandle::new_contiguous(
        original.shape().clone(),
        client.empty(target_type.size() * num_elems),
        target_type,
    );

    let dtype = original.dtype;

    cast_launch::launch::<TestRuntime>(
        client,
        CubeCount::Static(cube_count, 1, 1),
        cube_dim,
        vector_size,
        original.into_arg(),
        out.clone().into_arg(),
        [dtype, target_type],
    );

    out
}