cubecl-std 0.10.0-pre.3

CubeCL Standard Library.
Documentation
use cubecl::prelude::*;
use cubecl_common::{
    e2m1, e2m1x2,
    quant::scheme::{QuantScheme, QuantValue},
};
use cubecl_core::{self as cubecl};

use crate::tensor::{
    View,
    launch::ViewArg,
    layout::{plain::PlainLayout, *},
};

#[derive(CubeType, CubeLaunch)]
struct TestPerTensorScaleLayout {
    length: usize,
}

#[cube]
impl Layout for TestPerTensorScaleLayout {
    type Coordinates = Coords1d;
    type SourceCoordinates = Coords1d;

    fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates {
        0usize.runtime()
    }

    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
        (self.to_source_pos(pos), true.runtime())
    }

    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
        true.runtime()
    }

    fn shape(&self) -> Self::Coordinates {
        self.length
    }
}

#[cube(launch_unchecked)]
pub fn kernel_quantized_view<F: Float, N: Size>(
    lhs: View<Vector<F, N>, Coords1d>,
    output: &mut Array<Vector<F, N>>,
) {
    if (UNIT_POS as usize) < lhs.shape() {
        output[UNIT_POS as usize] = lhs[UNIT_POS as usize];
    }
}

#[allow(clippy::needless_range_loop)]
pub fn test_quantized_per_tensor_int<R: Runtime, F: Float + CubeElement>(
    client: ComputeClient<R>,
    vector_size_values: VectorSize,
) {
    let vector_size_float = 8 * vector_size_values;

    let scheme = QuantScheme::default().with_value(QuantValue::Q4F);
    let float_data = (-8..=7)
        .map(|it| F::new(it as f32 * 3.4))
        .collect::<Vec<_>>();

    let output = client.empty(16 * size_of::<F>());
    let values = client.create_from_slice(u32::as_bytes(&[0xFEDCBA98, 0x76543210]));
    let scales = client.create_from_slice(f32::as_bytes(&[3.4]));

    let float_values = client.create_from_slice(F::as_bytes(&float_data));
    let float_output = client.empty(16 * size_of::<F>());

    let scales_layout = TestPerTensorScaleLayoutLaunch::new(16);

    let values_view =
        ViewArg::new_array::<PlainLayout>(unsafe { ArrayArg::from_raw_parts(values, 2) }, ());
    let scales_view = ViewArg::new_array::<TestPerTensorScaleLayout>(
        unsafe { ArrayArg::from_raw_parts(scales, 1) },
        scales_layout,
    );
    let quantized_view = ViewArg::new_quantized(values_view, scales_view, scheme);
    let float_view = ViewArg::new_array::<PlainLayout>(
        unsafe { ArrayArg::from_raw_parts(float_values, 16) },
        (),
    );

    unsafe {
        kernel_quantized_view::launch_unchecked::<F, R>(
            &client,
            CubeCount::new_single(),
            CubeDim::new_1d(2),
            vector_size_float,
            quantized_view,
            ArrayArg::from_raw_parts(output.clone(), 16),
        );
        kernel_quantized_view::launch_unchecked::<F, R>(
            &client,
            CubeCount::new_single(),
            CubeDim::new_1d(2),
            vector_size_float,
            float_view,
            ArrayArg::from_raw_parts(float_output.clone(), 16),
        );
    }

    let actual = client.read_one_unchecked(output);
    let actual_float = client.read_one_unchecked(float_output);
    let actual = F::from_bytes(&actual);
    let actual_float = F::from_bytes(&actual_float);

    assert_eq!(&actual, &float_data);
    assert_eq!(&actual_float, &float_data);
}

#[allow(clippy::needless_range_loop)]
pub fn test_quantized_per_tensor_fp4<R: Runtime, F: Float + CubeElement>(
    client: ComputeClient<R>,
    vector_size_values: VectorSize,
) {
    if !client.properties().supports_type(e2m1x2::cube_type()) {
        return;
    }

    let vector_size_float = 8 * vector_size_values;

    let scheme = QuantScheme::default().with_value(QuantValue::E2M1);
    let float_data = (0..16)
        .map(e2m1::from_bits)
        .map(|it| F::new(it.to_f32() * 3.4))
        .collect::<Vec<_>>();

    let output = client.empty(16 * size_of::<F>());
    let values = client.create_from_slice(u32::as_bytes(&[0x76543210, 0xFEDCBA98]));
    let scales = client.create_from_slice(f32::as_bytes(&[3.4]));

    let float_values = client.create_from_slice(F::as_bytes(&float_data));
    let float_output = client.empty(16 * size_of::<F>());

    let scales_layout = TestPerTensorScaleLayoutLaunch::new(16);

    let values_view =
        ViewArg::new_array::<PlainLayout>(unsafe { ArrayArg::from_raw_parts(values, 2) }, ());
    let scales_view = ViewArg::new_array::<TestPerTensorScaleLayout>(
        unsafe { ArrayArg::from_raw_parts(scales, 1) },
        scales_layout,
    );
    let quantized_view = ViewArg::new_quantized(values_view, scales_view, scheme);
    let float_view = ViewArg::new_array::<PlainLayout>(
        unsafe { ArrayArg::from_raw_parts(float_values, 16) },
        (),
    );

    unsafe {
        kernel_quantized_view::launch_unchecked::<F, R>(
            &client,
            CubeCount::new_single(),
            CubeDim::new_1d(2),
            vector_size_float,
            quantized_view,
            ArrayArg::from_raw_parts(output.clone(), 16),
        );
        kernel_quantized_view::launch_unchecked::<F, R>(
            &client,
            CubeCount::new_single(),
            CubeDim::new_1d(2),
            vector_size_float,
            float_view,
            ArrayArg::from_raw_parts(float_output.clone(), 16),
        );
    }

    let actual = client.read_one_unchecked(output);
    let actual_float = client.read_one_unchecked(float_output);
    let actual = F::from_bytes(&actual);
    let actual_float = F::from_bytes(&actual_float);

    assert_eq!(&actual, &float_data);
    assert_eq!(&actual_float, &float_data);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_quantized_view {
    ($ty: ty) => {
        use super::*;

        #[$crate::tests::test_log::test]
        fn test_quantized_view_per_tensor_int() {
            let client = TestRuntime::client(&Default::default());
            cubecl_std::tests::view::quantized::test_quantized_per_tensor_int::<TestRuntime, $ty>(
                client.clone(),
                1,
            );
            cubecl_std::tests::view::quantized::test_quantized_per_tensor_int::<TestRuntime, $ty>(
                client, 2,
            );
        }

        #[$crate::tests::test_log::test]
        fn test_quantized_view_per_tensor_fp4() {
            let client = TestRuntime::client(&Default::default());
            cubecl_std::tests::view::quantized::test_quantized_per_tensor_fp4::<TestRuntime, $ty>(
                client.clone(),
                1,
            );
            cubecl_std::tests::view::quantized::test_quantized_per_tensor_fp4::<TestRuntime, $ty>(
                client, 2,
            );
        }
    };
}