use alloc::{vec, vec::Vec};
use std::println;
use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_common::{e2m1x2, e2m3, e3m2, e4m3, e5m2, ue8m0};
use cubecl_ir::features::TypeUsage;
#[cube(launch_unchecked)]
pub fn kernel_fp8<F: Float, N: Size>(
input: &mut Array<Vector<F, N>>,
out: &mut Array<Vector<u8, N>>,
) {
if ABSOLUTE_POS == 0 {
let value = input[0];
out[0] = Vector::reinterpret(Vector::<e4m3, N>::cast_from(value));
out[1] = Vector::reinterpret(Vector::<e5m2, N>::cast_from(value));
input[0] = Vector::cast_from(Vector::<e4m3, N>::reinterpret(out[0]));
}
}
#[cube(launch_unchecked)]
pub fn kernel_fp6<F: Float, N: Size>(
input: &mut Array<Vector<F, N>>,
out: &mut Array<Vector<u8, N>>,
) {
if ABSOLUTE_POS == 0 {
let value = input[0];
out[0] = Vector::reinterpret(Vector::<e2m3, N>::cast_from(value));
out[1] = Vector::reinterpret(Vector::<e3m2, N>::cast_from(value));
input[0] = Vector::cast_from(Vector::<e2m3, N>::reinterpret(out[0]));
}
}
#[cube(launch_unchecked)]
pub fn kernel_fp4<F: Float, N: Size, N2: Size>(
input: &mut Array<Vector<F, N>>,
out: &mut Array<Vector<u8, N2>>,
) {
if ABSOLUTE_POS == 0 {
let value = input[0];
out[0] = Vector::reinterpret(Vector::<e2m1x2, N2>::cast_from(value));
input[0] = Vector::cast_from(Vector::<e2m1x2, N2>::reinterpret(out[0]));
}
}
#[cube(launch_unchecked)]
pub fn kernel_scale<N: Size>(input: &mut Array<Vector<f32, N>>, out: &mut Array<Vector<ue8m0, N>>) {
if ABSOLUTE_POS == 0 {
let value = input[0];
out[0] = Vector::<ue8m0, N>::cast_from(value);
input[0] = Vector::cast_from(out[0]);
}
}
#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp8<R: Runtime, F: Float + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
if !e4m3::supported_uses(&client).contains(TypeUsage::Conversion) {
println!("Unsupported, skipping");
return;
}
let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
let num_out = vector_size;
let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
let handle2 = client.empty(2 * num_out * size_of::<u8>());
unsafe {
kernel_fp8::launch_unchecked::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
vector_size,
ArrayArg::from_raw_parts(handle1.clone(), num_out),
ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
)
};
let actual = client.read_one_unchecked(handle2);
let actual = u8::from_bytes(&actual);
let expect_0: Vec<u8> = vec![0b1_1000_000, 0b0_0111_110, 0b0_0101_101, 0b0_0111_010];
let expect_1: Vec<u8> = vec![0b1_10000_00, 0b0_01111_11, 0b0_01101_10, 0b0_01111_01];
let mut expected = expect_0[..num_out].to_vec();
expected.extend(expect_1[..num_out].iter().copied());
let actual_2 = client.read_one_unchecked(handle1);
let actual_2 = F::from_bytes(&actual_2);
println!("actual_2: {actual_2:?}");
let expected_data = as_type![F: -2.0, 1.75, 0.40625, 1.25];
assert_eq!(actual, &expected);
assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}
#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp6<R: Runtime, F: Float + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
if !e2m3::supported_uses(&client).contains(TypeUsage::Conversion) {
println!("Unsupported, skipping");
return;
}
let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
let num_out = vector_size;
let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
let handle2 = client.empty(2 * num_out * size_of::<u8>());
unsafe {
kernel_fp6::launch_unchecked::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
vector_size,
ArrayArg::from_raw_parts(handle1.clone(), num_out),
ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
)
};
let actual = client.read_one_unchecked(handle2);
let actual = u8::from_bytes(&actual);
let expect_0: Vec<u8> = vec![0b1_10_000, 0b0_01_110, 0b0_00_011, 0b0_01_010];
let expect_1: Vec<u8> = vec![0b1_100_00, 0b0_011_11, 0b0_001_10, 0b0_011_01];
let mut expected = expect_0[..num_out].to_vec();
expected.extend(expect_1[..num_out].iter().copied());
let actual_2 = client.read_one_unchecked(handle1);
let actual_2 = F::from_bytes(&actual_2);
println!("actual_2: {actual_2:?}");
let expected_data = as_type![F: -2.0, 1.75, 0.375, 1.25];
assert_eq!(actual, &expected);
assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}
#[allow(clippy::unusual_byte_groupings, reason = "Split by float components")]
pub fn test_fp4<R: Runtime, F: Float + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
if !e2m1x2::supported_uses(&client).contains(TypeUsage::Conversion) {
println!("Unsupported, skipping");
return;
}
let data = as_type![F: -2.1, 1.8, 0.4, 1.2];
let num_out = vector_size;
let handle1 = client.create_from_slice(F::as_bytes(&data[..num_out]));
let handle2 = client.empty(num_out / 2 * size_of::<u8>());
unsafe {
kernel_fp4::launch_unchecked::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
vector_size,
vector_size / 2,
ArrayArg::from_raw_parts(handle1.clone(), num_out),
ArrayArg::from_raw_parts(handle2.clone(), 2 * num_out),
)
};
let actual = client.read_one_unchecked(handle2);
let actual = u8::from_bytes(&actual);
let expect_0: Vec<u8> = vec![0b0_10_0__1_10_0, 0b0_01_0__0_00_1];
let expected = expect_0[..num_out / 2].to_vec();
let actual_2 = client.read_one_unchecked(handle1);
let actual_2 = F::from_bytes(&actual_2);
println!("actual_2: {actual_2:?}");
let expected_data = as_type![F: -2.0, 2.0, 0.5, 1.0];
assert_eq!(actual, &expected);
assert_eq!(&actual_2[..num_out], &expected_data[..num_out]);
}
pub fn test_scale<R: Runtime>(client: ComputeClient<R>, vector_size: VectorSize) {
if !ue8m0::supported_uses(&client).contains(TypeUsage::Conversion) {
println!("Unsupported, skipping");
return;
}
let data = [2.0, 1024.0, 57312.0, f32::from_bits(0x7F000000)];
let num_out = vector_size;
let handle1 = client.create_from_slice(f32::as_bytes(&data[..num_out]));
let handle2 = client.empty(num_out * size_of::<u8>());
unsafe {
kernel_scale::launch_unchecked(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(1),
vector_size,
ArrayArg::from_raw_parts(handle1.clone(), num_out),
ArrayArg::from_raw_parts(handle2.clone(), num_out),
)
};
let actual = client.read_one_unchecked(handle2);
let actual = u8::from_bytes(&actual);
let expect: Vec<u8> = vec![0b1000_0000, 0b1000_1001, 0b1000_1111, 0b1111_1110];
let actual_2 = client.read_one_unchecked(handle1);
let actual_2 = f32::from_bytes(&actual_2);
println!("actual_2: {actual_2:?}");
assert_eq!(actual, &expect[..num_out]);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_minifloat {
() => {
use super::*;
#[$crate::runtime_tests::test_log::test]
fn test_fp8() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
client.clone(),
1,
);
cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
client.clone(),
2,
);
cubecl_core::runtime_tests::minifloat::test_fp8::<TestRuntime, FloatType>(
client.clone(),
4,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_fp6() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
client.clone(),
1,
);
cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
client.clone(),
2,
);
cubecl_core::runtime_tests::minifloat::test_fp6::<TestRuntime, FloatType>(
client.clone(),
4,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_fp4() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::minifloat::test_fp4::<TestRuntime, FloatType>(
client.clone(),
2,
);
cubecl_core::runtime_tests::minifloat::test_fp4::<TestRuntime, FloatType>(
client.clone(),
4,
);
}
#[$crate::runtime_tests::test_log::test]
fn test_scale() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 1);
cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 2);
cubecl_core::runtime_tests::minifloat::test_scale::<TestRuntime>(client.clone(), 4);
}
};
}