cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use alloc::{vec, vec::Vec};

use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_shape_dim_4(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
    if ABSOLUTE_POS >= out.len() {
        terminate!();
    }

    out[0] = lhs.shape(0) as u32;
    out[1] = lhs.shape(1) as u32;
    out[2] = lhs.shape(2) as u32;
    out[3] = lhs.shape(3) as u32;
    out[4] = rhs.shape(0) as u32;
    out[5] = rhs.shape(1) as u32;
    out[6] = rhs.shape(2) as u32;
    out[7] = rhs.shape(3) as u32;
    out[8] = out.shape(0) as u32;
    out[9] = out.shape(1) as u32;
    out[10] = out.shape(2) as u32;
    out[11] = out.shape(3) as u32;
}

#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_shape_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
    if ABSOLUTE_POS >= out.len() {
        terminate!();
    }

    out[0] = lhs.shape(0) as u32;
    out[1] = lhs.shape(1) as u32;
    out[2] = lhs.shape(2) as u32;
    out[3] = lhs.shape(3) as u32;
    out[4] = rhs.shape(0) as u32;
    out[5] = rhs.shape(1) as u32;
    out[6] = rhs.shape(2) as u32;
    out[7] = out.shape(0) as u32;
    out[8] = out.shape(1) as u32;
    out[9] = lhs.rank() as u32;
    out[10] = rhs.rank() as u32;
    out[11] = out.rank() as u32;
}

#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_stride_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
    if ABSOLUTE_POS >= out.len() {
        terminate!();
    }

    out[0] = lhs.stride(0) as u32;
    out[1] = lhs.stride(1) as u32;
    out[2] = lhs.stride(2) as u32;
    out[3] = lhs.stride(3) as u32;
    out[4] = rhs.stride(0) as u32;
    out[5] = rhs.stride(1) as u32;
    out[6] = rhs.stride(2) as u32;
    out[7] = out.stride(0) as u32;
    out[8] = out.stride(1) as u32;
}

#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_len_different_ranks(lhs: &Tensor<f32>, rhs: &Tensor<f32>, out: &mut Tensor<u32>) {
    if ABSOLUTE_POS >= out.len() {
        terminate!();
    }

    out[0] = lhs.len() as u32;
    out[1] = rhs.len() as u32;
    out[2] = out.len() as u32;
}

#[cube(launch_unchecked, address_type = "dynamic")]
pub fn kernel_buffer_len<N: Size>(out: &mut Tensor<Vector<u32, N>>) {
    if ABSOLUTE_POS >= out.len() {
        terminate!();
    }

    out[0] = Vector::new(out.buffer_len() as u32);
}

pub fn test_shape_dim_4<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(12 * core::mem::size_of::<u32>());
    let handle2 = client.empty(12 * core::mem::size_of::<u32>());
    let handle3 = client.empty(12 * core::mem::size_of::<u32>());

    unsafe {
        kernel_shape_dim_4::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
            TensorArg::from_raw_parts(handle2, [1, 1, 1, 1].into(), [9, 8, 7, 6].into()),
            TensorArg::from_raw_parts(
                handle3.clone(),
                [1, 1, 1, 1].into(),
                [10, 11, 12, 13].into(),
            ),
        )
    };

    let actual = client.read_one_unchecked(handle3);
    let actual = u32::from_bytes(&actual);
    let expect: Vec<u32> = vec![2, 3, 4, 5, 9, 8, 7, 6, 10, 11, 12, 13];

    assert_eq!(actual, &expect);
}

pub fn test_shape_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(12 * core::mem::size_of::<u32>());
    let handle2 = client.empty(12 * core::mem::size_of::<u32>());
    let handle3 = client.empty(12 * core::mem::size_of::<u32>());

    unsafe {
        kernel_shape_different_ranks::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
            TensorArg::from_raw_parts(handle2, [1, 1, 1].into(), [9, 8, 7].into()),
            TensorArg::from_raw_parts(handle3.clone(), [1, 1].into(), [10, 11].into()),
        )
    };

    let actual = client.read_one_unchecked(handle3);
    let actual = u32::from_bytes(&actual);
    let expect: Vec<u32> = vec![2, 3, 4, 5, 9, 8, 7, 10, 11, 4, 3, 2];

    assert_eq!(actual, &expect);
}

pub fn test_stride_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(9 * core::mem::size_of::<u32>());
    let handle2 = client.empty(9 * core::mem::size_of::<u32>());
    let handle3 = client.empty(9 * core::mem::size_of::<u32>());

    unsafe {
        kernel_stride_different_ranks::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            TensorArg::from_raw_parts(handle1, [1, 2, 3, 4].into(), [1, 1, 1, 1].into()),
            TensorArg::from_raw_parts(handle2, [4, 5, 6].into(), [1, 1, 1].into()),
            TensorArg::from_raw_parts(handle3.clone(), [3, 2].into(), [1, 1].into()),
        )
    };

    let actual = client.read_one_unchecked(handle3);
    let actual = u32::from_bytes(&actual);
    let expect: Vec<u32> = vec![1, 2, 3, 4, 4, 5, 6, 3, 2];

    assert_eq!(actual, &expect);
}

pub fn test_len_different_ranks<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(3 * core::mem::size_of::<u32>());
    let handle2 = client.empty(3 * core::mem::size_of::<u32>());
    let handle3 = client.empty(3 * core::mem::size_of::<u32>());

    unsafe {
        kernel_len_different_ranks::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            TensorArg::from_raw_parts(handle1, [1, 1, 1, 1].into(), [2, 3, 4, 5].into()),
            TensorArg::from_raw_parts(handle2, [1, 1, 1].into(), [9, 8, 7].into()),
            TensorArg::from_raw_parts(handle3.clone(), [1, 1].into(), [10, 11].into()),
        )
    };

    let actual = client.read_one_unchecked(handle3);
    let actual = u32::from_bytes(&actual);
    let expect: Vec<u32> = vec![2 * 3 * 4 * 5, 9 * 8 * 7, 10 * 11];

    assert_eq!(actual, &expect);
}

pub fn test_buffer_len_discontiguous<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(64 * core::mem::size_of::<u32>());

    unsafe {
        kernel_buffer_len::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            1,
            TensorArg::from_raw_parts(handle1.clone(), [32, 16, 4, 1].into(), [2, 2, 2, 2].into()),
        )
    };

    let actual = client.read_one_unchecked(handle1);
    let actual = u32::from_bytes(&actual);

    assert_eq!(actual[0], 64);
}

pub fn test_buffer_len_vectorized<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(32 * core::mem::size_of::<u32>());

    unsafe {
        kernel_buffer_len::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            4,
            TensorArg::from_raw_parts(handle1.clone(), [16, 8, 4, 1].into(), [2, 2, 2, 4].into()),
        )
    };

    let actual = client.read_one_unchecked(handle1);
    let actual = u32::from_bytes(&actual);

    assert_eq!(actual[0], 8);
}

pub fn test_buffer_len_offset<R: Runtime>(client: ComputeClient<R>, addr_type: AddressType) {
    if !client.properties().supports_address(addr_type) {
        return;
    }

    let handle1 = client.empty(256 * core::mem::size_of::<u32>());
    // We use an offset of 256 bytes here because this is the default in WebGPU and
    // as of wgpu 22+, 256 is the value of 'min_storage_buffer_offset_alignment' for metal GPUs.
    let handle1 = handle1
        .offset_start(64 * core::mem::size_of::<u32>() as u64)
        .offset_end(64 * core::mem::size_of::<u32>() as u64);

    unsafe {
        kernel_buffer_len::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            addr_type,
            2,
            TensorArg::from_raw_parts(handle1.clone(), [32, 16, 4, 1].into(), [4, 4, 4, 8].into()),
        )
    };

    let actual = client.read_one_unchecked(handle1);
    let actual = u32::from_bytes(&actual);

    assert_eq!(actual[0], 64);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_metadata {
    () => {
        mod metadata {
            use super::*;

            #[$crate::runtime_tests::test_log::test]
            fn test_shape() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_shape_dim_4::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_shape_dim_4::<TestRuntime>(
                    client.clone(),
                    AddressType::U64,
                );
                cubecl_core::runtime_tests::metadata::test_shape_different_ranks::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_shape_different_ranks::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }

            #[$crate::runtime_tests::test_log::test]
            fn test_stride() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_stride_different_ranks::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_stride_different_ranks::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }

            #[$crate::runtime_tests::test_log::test]
            fn test_len() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_len_different_ranks::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_len_different_ranks::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }

            #[$crate::runtime_tests::test_log::test]
            fn test_buffer_len_discontiguous() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_buffer_len_discontiguous::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_buffer_len_discontiguous::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }

            #[$crate::runtime_tests::test_log::test]
            fn test_buffer_len_vectorized() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_buffer_len_vectorized::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_buffer_len_vectorized::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }

            #[$crate::runtime_tests::test_log::test]
            fn test_buffer_len_offset() {
                let client = TestRuntime::client(&Default::default());
                cubecl_core::runtime_tests::metadata::test_buffer_len_offset::<TestRuntime>(
                    client.clone(),
                    AddressType::U32,
                );
                cubecl_core::runtime_tests::metadata::test_buffer_len_offset::<TestRuntime>(
                    client,
                    AddressType::U64,
                );
            }
        }
    };
}