cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use crate::{self as cubecl, as_bytes, as_type};
use cubecl::prelude::*;

#[cube(launch)]
pub fn slice_select<F: Float>(input: &Array<F>, output: &mut Array<F>) {
    if UNIT_POS == 0 {
        let slice = input.slice(2, 3);
        output[0] = slice[0];
    }
}

#[cube(launch)]
pub fn slice_len<F: Float>(input: &Array<F>, output: &mut Array<u32>) {
    if UNIT_POS == 0 {
        let slice = input.slice(2, 4);
        output[0] = slice.len() as u32;
    }
}

#[cube(launch)]
pub fn slice_for<F: Float>(input: &Array<F>, output: &mut Array<F>) {
    if UNIT_POS == 0 {
        let mut sum = F::new(0.0);

        for item in input.slice(2, 4) {
            sum += item;
        }

        output[0] = sum;
    }
}

#[cube(launch)]
pub fn slice_mut_assign<F: Float>(input: &Array<F>, output: &mut Array<F>) {
    if UNIT_POS == 0 {
        let slice_1 = &mut output.slice_mut(2, 3);
        slice_1[0] = input[0];
    }
}

#[cube(launch)]
pub fn slice_mut_len(output: &mut Array<u32>) {
    if UNIT_POS == 0 {
        let slice = output.slice_mut(0, 2).into_vectorized();
        output[0] = slice.len() as u32;
    }
}

pub fn test_slice_select<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
    let input = client.create_from_slice(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]);
    let output = client.empty(core::mem::size_of::<F>());

    unsafe {
        slice_select::launch::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            ArrayArg::from_raw_parts(input, 5),
            ArrayArg::from_raw_parts(output.clone(), 1),
        )
    };

    let actual = client.read_one_unchecked(output);
    let actual = F::from_bytes(&actual);

    assert_eq!(actual[0], F::new(2.0));
}

pub fn test_slice_len<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
    let input = client.create_from_slice(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]);
    let output = client.empty(core::mem::size_of::<u32>());

    unsafe {
        slice_len::launch::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            ArrayArg::from_raw_parts(input, 5),
            ArrayArg::from_raw_parts(output.clone(), 1),
        )
    };

    let actual = client.read_one_unchecked(output);
    let actual = u32::from_bytes(&actual);

    assert_eq!(actual, &[2]);
}

pub fn test_slice_for<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
    let input = client.create_from_slice(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]);
    let output = client.create_from_slice(as_bytes![F: 0.0]);

    unsafe {
        slice_for::launch::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            ArrayArg::from_raw_parts(input, 5),
            ArrayArg::from_raw_parts(output.clone(), 1),
        )
    };

    let actual = client.read_one_unchecked(output);
    let actual = F::from_bytes(&actual);

    assert_eq!(actual[0], F::new(5.0));
}

pub fn test_slice_mut_assign<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
    let input = client.create_from_slice(as_bytes![F: 15.0]);
    let output = client.create_from_slice(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]);

    unsafe {
        slice_mut_assign::launch::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            ArrayArg::from_raw_parts(input, 5),
            ArrayArg::from_raw_parts(output.clone(), 1),
        )
    };

    let actual = client.read_one_unchecked(output);
    let actual = F::from_bytes(&actual);

    assert_eq!(&actual[0..5], as_type![F: 0.0, 1.0, 15.0, 3.0, 4.0]);
}

pub fn test_slice_mut_len<R: Runtime>(client: ComputeClient<R>) {
    let output = client.empty(core::mem::size_of::<u32>() * 4);

    unsafe {
        slice_mut_len::launch(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            ArrayArg::from_raw_parts(output.clone(), 4),
        )
    };

    let actual = client.read_one_unchecked(output);
    let actual = u32::from_bytes(&actual);

    assert_eq!(actual[0], 2);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_slice {
    () => {
        use super::*;

        #[$crate::runtime_tests::test_log::test]
        fn test_slice_select() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::slice::test_slice_select::<TestRuntime, FloatType>(client);
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_slice_len() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::slice::test_slice_len::<TestRuntime, FloatType>(client);
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_slice_for() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::slice::test_slice_for::<TestRuntime, FloatType>(client);
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_slice_mut_assign() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::slice::test_slice_mut_assign::<TestRuntime, FloatType>(
                client,
            );
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_slice_mut_len() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::slice::test_slice_mut_len::<TestRuntime>(client);
        }
    };
}