cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use alloc::{vec, vec::Vec};

use crate::prelude::*;
use crate::{self as cubecl};

#[cube(launch, cluster_dim = CubeDim::new_3d(1, 2, 3))]
fn cluster_meta_kernel(out: &mut Array<u32>) {
    if UNIT_POS == 0 {
        if CUBE_POS == 0 {
            out[0] = CUBE_CLUSTER_DIM;
            out[1] = CUBE_CLUSTER_DIM_X;
            out[2] = CUBE_CLUSTER_DIM_Y;
            out[3] = CUBE_CLUSTER_DIM_Z;
        }

        let offset = CUBE_POS * 4 + 4;

        out[offset] = CUBE_POS_CLUSTER;
        out[offset + 1] = CUBE_POS_CLUSTER_X;
        out[offset + 2] = CUBE_POS_CLUSTER_Y;
        out[offset + 3] = CUBE_POS_CLUSTER_Z;
    }
}

pub fn test_cluster_meta<R: Runtime>(client: ComputeClient<R>) {
    if !client.features().cube_cluster {
        return;
    }

    let cluster_dim_x = 1;
    let cluster_dim_y = 2;
    let cluster_dim_z = 3;

    let cube_count_x = 2;
    let cube_count_y = 2;
    let cube_count_z = 6;
    let cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
    let num_cubes = cube_count_x * cube_count_y * cube_count_z;

    let handle = client.empty((num_cubes as usize * 4 + 4) * size_of::<u32>());

    cluster_meta_kernel::launch(&client, cube_count, CubeDim::new_single(), unsafe {
        ArrayArg::from_raw_parts(handle.clone(), num_cubes as usize * 8)
    });

    let actual = client.read_one_unchecked(handle);
    let actual = u32::from_bytes(&actual);

    let mut expected: Vec<u32> = vec![6, 1, 2, 3];
    for z in 0..cube_count_z {
        for y in 0..cube_count_y {
            for x in 0..cube_count_x {
                let rank_x = x % cluster_dim_x;
                let rank_y = y % cluster_dim_y;
                let rank_z = z % cluster_dim_z;
                let rank_abs = rank_z * cluster_dim_y + rank_y * cluster_dim_x + rank_x;
                expected.extend([rank_abs, rank_x, rank_y, rank_z]);
            }
        }
    }

    assert_eq!(actual, &expected);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_cluster {
    () => {
        use super::*;

        #[$crate::runtime_tests::test_log::test]
        fn test_cluster_meta() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::cluster::test_cluster_meta::<TestRuntime>(client);
        }
    };
}