use std::fmt::Display;
use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_runtime::server::Handle;
#[track_caller]
pub(crate) fn assert_equals_approx<
R: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: &ComputeClient<R::Server>,
output: Handle,
expected: &[F],
epsilon: f32,
) {
let actual = client.read_one(output);
let actual = F::from_bytes(&actual);
let epsilon = (epsilon / f32::EPSILON * F::EPSILON.to_f32().unwrap()).max(epsilon);
for (i, (a, e)) in actual[0..expected.len()]
.iter()
.zip(expected.iter())
.enumerate()
{
let allowed_error = F::new((epsilon * e.to_f32().unwrap().abs()).max(epsilon));
assert!(
(*a - *e).abs() < allowed_error || (a.is_nan() && e.is_nan()),
"Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}
index: {}
actual: {:?}
expected: {:?}",
a,
e,
(*a - *e).abs(),
epsilon,
i,
actual,
expected
);
}
}
macro_rules! test_binary_impl {
(
$test_name:ident,
$float_type:ident,
$binary_func:expr,
[$({
input_vectorization: $input_vectorization:expr,
out_vectorization: $out_vectorization:expr,
lhs: $lhs:expr,
rhs: $rhs:expr,
expected: $expected:expr
}),*]) => {
pub fn $test_name<R: Runtime, $float_type: Float + num_traits::Float + CubeElement + Display>(client: ComputeClient<R::Server>) {
#[cube(launch_unchecked, fast_math = FastMath::AllowTransform | FastMath::UnsignedZero)]
fn test_function<$float_type: Float>(lhs: &Array<$float_type>, rhs: &Array<$float_type>, output: &mut Array<$float_type>) {
if ABSOLUTE_POS < rhs.len() {
output[ABSOLUTE_POS] = $binary_func(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
}
$(
{
let lhs = $lhs;
let rhs = $rhs;
let output_handle = client.empty($expected.len() * core::mem::size_of::<$float_type>());
let lhs_handle = client.create($float_type::as_bytes(lhs));
let rhs_handle = client.create($float_type::as_bytes(rhs));
unsafe {
test_function::launch_unchecked::<$float_type, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new((lhs.len() / $input_vectorization as usize) as u32, 1, 1),
ArrayArg::from_raw_parts::<$float_type>(&lhs_handle, lhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<$float_type>(&rhs_handle, rhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<$float_type>(&output_handle, $expected.len(), $out_vectorization),
)
};
assert_equals_approx::<R, F>(&client, output_handle, $expected, 0.001);
}
)*
}
};
}
test_binary_impl!(
test_dot,
F,
F::dot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 1., -3.1, -2.4, 15.1],
rhs: as_type![F: -1., 23.1, -1.4, 5.1],
expected: as_type![F: -1.0, -71.61, 3.36, 77.01]
},
{
input_vectorization: 2,
out_vectorization: 1,
lhs: as_type![F: 1., -3.1, -2.4, 15.1],
rhs: as_type![F: -1., 23.1, -1.4, 5.1],
expected: as_type![F: -72.61, 80.37]
},
{
input_vectorization: 4,
out_vectorization: 1,
lhs: as_type![F: 1., -3.1, -2.4, 15.1],
rhs: as_type![F: -1., 23.1, -1.4, 5.1],
expected: as_type![F: 7.76]
},
{
input_vectorization: 4,
out_vectorization: 1,
lhs: as_type![F: 1., -3.1, -2.4, 15.1, -1., 23.1, -1.4, 5.1],
rhs: as_type![F: -1., 23.1, -1.4, 5.1, 1., -3.1, -2.4, 15.1],
expected: as_type![F: 7.76, 7.76]
}
]
);
test_binary_impl!(
test_powf,
F,
F::powf,
[
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 2., -3., 2., 81.],
rhs: as_type![F: 3., 2., -1., 0.5],
expected: as_type![F: 8., 9., 0.5, 9.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: 2., -3., 2., 81.],
rhs: as_type![F: 3., 2., -1., 0.5],
expected: as_type![F: 8., 9., 0.5, 9.]
}
]
);
#[cube(launch_unchecked)]
fn test_powi_kernel<F: Float>(
lhs: &Array<Line<F>>,
rhs: &Array<Line<i32>>,
output: &mut Array<Line<F>>,
) {
if ABSOLUTE_POS < rhs.len() {
output[ABSOLUTE_POS] = Powi::powi(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
}
macro_rules! test_powi_impl {
(
$test_name:ident,
$float_type:ident,
[$({
input_vectorization: $input_vectorization:expr,
out_vectorization: $out_vectorization:expr,
lhs: $lhs:expr,
rhs: $rhs:expr,
expected: $expected:expr
}),*]) => {
pub fn $test_name<R: Runtime, $float_type: Float + num_traits::Float + CubeElement + Display>(client: ComputeClient<R::Server>) {
$(
{
let lhs = $lhs;
let rhs = $rhs;
let output_handle = client.empty($expected.len() * core::mem::size_of::<$float_type>());
let lhs_handle = client.create($float_type::as_bytes(lhs));
let rhs_handle = client.create(i32::as_bytes(rhs));
unsafe {
test_powi_kernel::launch_unchecked::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new((lhs.len() / $input_vectorization as usize) as u32, 1, 1),
ArrayArg::from_raw_parts::<$float_type>(&lhs_handle, lhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<i32>(&rhs_handle, rhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<$float_type>(&output_handle, $expected.len(), $out_vectorization),
)
};
assert_equals_approx::<R, F>(&client, output_handle, $expected, 0.001);
}
)*
}
};
}
test_powi_impl!(
test_powi,
F,
[
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 2., -3., 2., 81.],
rhs: as_type![i32: 3, 2, -1, 1],
expected: as_type![F: 8., 9., 0.5, 81.]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: 2., -3., 2., 81.],
rhs: as_type![i32: 3, 2, -1, 1],
expected: as_type![F: 8., 9., 0.5, 81.]
}
]
);
#[cube(launch_unchecked)]
fn test_mulhi_kernel(
lhs: &Array<Line<u32>>,
rhs: &Array<Line<u32>>,
output: &mut Array<Line<u32>>,
) {
if ABSOLUTE_POS < rhs.len() {
output[ABSOLUTE_POS] = MulHi::mul_hi(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
}
macro_rules! test_mulhi_impl {
(
$test_name:ident,
[$({
input_vectorization: $input_vectorization:expr,
out_vectorization: $out_vectorization:expr,
lhs: $lhs:expr,
rhs: $rhs:expr,
expected: $expected:expr
}),*]) => {
pub fn $test_name<R: Runtime>(client: ComputeClient<R::Server>) {
$(
{
let lhs = $lhs;
let rhs = $rhs;
let output_handle = client.empty($expected.len() * core::mem::size_of::<u32>());
let lhs_handle = client.create(u32::as_bytes(lhs));
let rhs_handle = client.create(u32::as_bytes(rhs));
unsafe {
test_mulhi_kernel::launch_unchecked::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new((lhs.len() / $input_vectorization as usize) as u32, 1, 1),
ArrayArg::from_raw_parts::<u32>(&lhs_handle, lhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<u32>(&rhs_handle, rhs.len(), $input_vectorization),
ArrayArg::from_raw_parts::<u32>(&output_handle, $expected.len(), $out_vectorization),
)
};
let actual = client.read_one(output_handle);
let actual = u32::from_bytes(&actual);
let expected: &[u32] = $expected;
assert_eq!(actual, expected);
}
)*
}
};
}
test_mulhi_impl!(
test_mulhi,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: &[1, 2, 3, 4],
rhs: &[5, 6, 7, 8],
expected: &[0, 0, 0, 0]
},
{
input_vectorization: 1,
out_vectorization: 1,
lhs: &[0xFFFFFFFF, 0x80000000, 0x55555555, 0x10000000],
rhs: &[0x10000, 2, 4, 0x20000000],
expected: &[0x0000FFFF, 1, 1, 0x2000000]
},
{
input_vectorization: 1,
out_vectorization: 1,
lhs: &[0xFFFFFFFF, 0xFFFFFFFF, 0x80000000, 0x10000],
rhs: &[0xFFFFFFFF, 2, 0x80000000, 0x10000],
expected: &[0xFFFFFFFEu32, 1, 0x40000000, 1]
}
]
);
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_binary {
() => {
mod binary {
use super::*;
macro_rules! add_test {
($test_name:ident) => {
#[test]
fn $test_name() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::binary::$test_name::<TestRuntime, FloatType>(
client,
);
}
};
}
add_test!(test_dot);
add_test!(test_powf);
add_test!(test_powi);
}
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_binary_untyped {
() => {
mod binary_untyped {
use super::*;
macro_rules! add_test {
($test_name:ident) => {
#[test]
fn $test_name() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::binary::$test_name::<TestRuntime>(client);
}
};
($test_name:ident, $ty:ty) => {
#[test]
fn $test_name() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::binary::$test_name::<TestRuntime, $ty>(client);
}
};
}
add_test!(test_mulhi);
}
};
}