cubecl-core 0.10.0-pre.3

CubeCL core create
Documentation
use alloc::{vec, vec::Vec};
use std::println;

use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_common::{e2m1x2, e2m3, e3m2, e4m3, e5m2, ue8m0};
use cubecl_ir::features::TypeUsage;

#[cube(launch_unchecked)]
pub fn kernel_fp8<F: Float, N: Size>(
    input: &mut Array<Vector<F, N>>,
    out: &mut Array<Vector<u8, N>>,
) {
    if ABSOLUTE_POS == 0 {
        let value = input[0];

        out[0] = Vector::reinterpret(Vector::<e4m3, N>::cast_from(value));
        out[1] = Vector::reinterpret(Vector::<e5m2, N>::cast_from(value));
        input[0] = Vector::cast_from(Vector::<e4m3, N>::reinterpret(out[0]));
    }
}

#[cube(launch_unchecked)]
pub fn kernel_fp6<F: Float, N: Size>(
    input: &mut Array<Vector<F, N>>,
    out: &mut Array<Vector<u8, N>>,
) {
    if ABSOLUTE_POS == 0 {
        let value = input[0];

        out[0] = Vector::reinterpret(Vector::<e2m3, N>::cast_from(value));
        out[1] = Vector::reinterpret(Vector::<e3m2, N>::cast_from(value));
        input[0] = Vector::cast_from(Vector::<e2m3, N>::reinterpret(out[0]));
    }
}

#[cube(launch_unchecked)]
pub fn kernel_fp4<F: Float, N: Size, N2: Size>(
    input: &mut Array<Vector<F, N>>,
    out: &mut Array<Vector<u8, N2>>,
) {
    if ABSOLUTE_POS == 0 {
        let value = input[0];

        out[0] = Vector::reinterpret(Vector::<e2m1x2, N2>::cast_from(value));
        input[0] = Vector::cast_from(Vector::<e2m1x2, N2>::reinterpret(out[0]));
    }
}

#[cube(launch_unchecked)]
pub fn kernel_scale<N: Size>(input: &mut Array<Vector<f32, N>>, out: &mut Array<Vector<ue8m0, N>>) {
    if ABSOLUTE_POS == 0 {
        let value = input[0];

        out[0] = Vector::<ue8m0, N>::cast_from(value);
        input[0] = Vector::cast_from(out[0]);
    }
}

#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp8<R: Runtime, F: Float + CubeElement>(
    client: ComputeClient<R>,
    vector_size: VectorSize,
) {
    if !e4m3::supported_uses(&client).contains(TypeUsage::Conversion) {
        println!("Unsupported, skipping");
        return;
    }

    let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
    let num_out = vector_size;
    let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
    let handle2 = client.empty(2 * num_out * size_of::<u8>());

    unsafe {
        kernel_fp8::launch_unchecked::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            vector_size,
            ArrayArg::from_raw_parts(handle1.clone(), num_out),
            ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
        )
    };

    let actual = client.read_one_unchecked(handle2);
    let actual = u8::from_bytes(&actual);
    let expect_0: Vec<u8> = vec![0b1_1000_000, 0b0_0111_110, 0b0_0101_101, 0b0_0111_010];
    let expect_1: Vec<u8> = vec![0b1_10000_00, 0b0_01111_11, 0b0_01101_10, 0b0_01111_01];
    let mut expected = expect_0[..num_out].to_vec();
    expected.extend(expect_1[..num_out].iter().copied());

    // TODO: Eventually add approx comparison that can deal with arbitrary floats. Manually
    // double check for now
    let actual_2 = client.read_one_unchecked(handle1);
    let actual_2 = F::from_bytes(&actual_2);
    println!("actual_2: {actual_2:?}");

    // Data rounded to the nearest e4m3 value
    let expected_data = as_type![F: -2.0, 1.75, 0.40625, 1.25];

    assert_eq!(actual, &expected);
    assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}

#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp6<R: Runtime, F: Float + CubeElement>(
    client: ComputeClient<R>,
    vector_size: VectorSize,
) {
    if !e2m3::supported_uses(&client).contains(TypeUsage::Conversion) {
        println!("Unsupported, skipping");
        return;
    }

    let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
    let num_out = vector_size;
    let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
    let handle2 = client.empty(2 * num_out * size_of::<u8>());

    unsafe {
        kernel_fp6::launch_unchecked::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            vector_size,
            ArrayArg::from_raw_parts(handle1.clone(), num_out),
            ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
        )
    };

    let actual = client.read_one_unchecked(handle2);
    let actual = u8::from_bytes(&actual);
    let expect_0: Vec<u8> = vec![0b1_10_000, 0b0_01_110, 0b0_00_011, 0b0_01_010];
    let expect_1: Vec<u8> = vec![0b1_100_00, 0b0_011_11, 0b0_001_10, 0b0_011_01];
    let mut expected = expect_0[..num_out].to_vec();
    expected.extend(expect_1[..num_out].iter().copied());

    // TODO: Eventually add approx comparison that can deal with arbitrary floats. Manually
    // double check for now
    let actual_2 = client.read_one_unchecked(handle1);
    let actual_2 = F::from_bytes(&actual_2);
    println!("actual_2: {actual_2:?}");

    // Data rounded to the nearest e2m3 value
    let expected_data = as_type![F: -2.0, 1.75, 0.375, 1.25];

    assert_eq!(actual, &expected);
    assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}

#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp4<R: Runtime, F: Float + CubeElement>(
    client: ComputeClient<R>,
    vector_size: VectorSize,
) {
    if !e2m1x2::supported_uses(&client).contains(TypeUsage::Conversion) {
        println!("Unsupported, skipping");
        return;
    }

    let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
    let num_out = vector_size;
    let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
    let handle2 = client.empty(num_out / 2 * size_of::<u8>());

    unsafe {
        kernel_fp4::launch_unchecked::<F, R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            vector_size,
            vector_size / 2,
            ArrayArg::from_raw_parts(handle1.clone(), num_out),
            ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
        )
    };

    let actual = client.read_one_unchecked(handle2);
    let actual = u8::from_bytes(&actual);
    // LITTLE ENDIAN FOR PACKED VALUES
    let expect_0: Vec<u8> = vec![0b0_10_0__1_10_0, 0b0_01_0__0_00_1];
    let expected = expect_0[..num_out / 2].to_vec();

    let actual_2 = client.read_one_unchecked(handle1);
    let actual_2 = F::from_bytes(&actual_2);
    println!("actual_2: {actual_2:?}");

    // Data rounded to the nearest e2m1 value
    let expected_data = as_type![F: -2.0, 2.0, 0.5, 1.0];

    assert_eq!(actual, &expected);
    assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}

pub fn test_scale<R: Runtime>(client: ComputeClient<R>, vector_size: VectorSize) {
    if !ue8m0::supported_uses(&client).contains(TypeUsage::Conversion) {
        println!("Unsupported, skipping");
        return;
    }

    let data = [2.0, 1024.0, 57312.0, f32::from_bits(0x7F000000)];
    let num_out = vector_size;
    let handle1 = client.create_from_slice(f32::as_bytes(&data[..num_out]));
    let handle2 = client.empty(num_out * size_of::<u8>());

    unsafe {
        kernel_scale::launch_unchecked(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim::new_1d(1),
            vector_size,
            ArrayArg::from_raw_parts(handle1.clone(), num_out),
            ArrayArg::from_raw_parts(handle2.clone(), num_out),
        )
    };

    let actual = client.read_one_unchecked(handle2);
    let actual = u8::from_bytes(&actual);
    let expect: Vec<u8> = vec![0b1000_0000, 0b1000_1001, 0b1000_1111, 0b1111_1110];

    // TODO: Eventually add approx comparison that can deal with arbitrary floats. Manually
    // double check for now
    let actual_2 = client.read_one_unchecked(handle1);
    let actual_2 = f32::from_bytes(&actual_2);
    println!("actual_2: {actual_2:?}");

    assert_eq!(actual, &expect[..num_out]);
    //assert_eq!(&actual_2[..num_out], &data[..num_out]);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_minifloat {
    () => {
        use super::*;

        #[$crate::runtime_tests::test_log::test]
        fn test_fp8() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
                client.clone(),
                1,
            );
            cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
                client.clone(),
                2,
            );
            cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
                client.clone(),
                4,
            );
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_fp6() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
                client.clone(),
                1,
            );
            cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
                client.clone(),
                2,
            );
            cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
                client.clone(),
                4,
            );
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_fp4() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::minifloat::test_fp4::<TestRuntime, FloatType>(
                client.clone(),
                2,
            );
            cubecl_core::runtime_tests::minifloat::test_fp4::<TestRuntime, FloatType>(
                client.clone(),
                4,
            );
        }

        #[$crate::runtime_tests::test_log::test]
        fn test_scale() {
            let client = TestRuntime::client(&Default::default());
            cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 1);
            cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 2);
            cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 4);
        }
    };
}