use crate::{self as cubecl};
use cubecl::prelude::*;
use cubecl_ir::StorageType;
use cubecl_runtime::TypeUsage;
#[cube(launch)]
pub fn kernel_atomic_add<I: Numeric>(output: &mut Array<Atomic<I>>) {
if UNIT_POS == 0 {
Atomic::add(&output[0], I::from_int(5));
}
}
fn supports_feature<R: Runtime, F: Numeric>(
client: &ComputeClient<R::Server>,
feat: TypeUsage,
) -> bool {
let ty = StorageType::Atomic(F::as_type_native_unchecked().elem_type());
client.properties().type_usage(ty).contains(feat)
}
pub fn test_kernel_atomic_add<R: Runtime, F: Numeric + CubeElement>(
client: ComputeClient<R::Server>,
) {
if !supports_feature::<R, F>(&client, TypeUsage::AtomicAdd) {
println!(
"{} Add not supported - skipped",
Atomic::<F>::as_type_native_unchecked()
);
return;
}
let handle = client.create(F::as_bytes(&[F::from_int(12), F::from_int(1)]));
kernel_atomic_add::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts::<F>(&handle, 2, 1) },
);
let actual = client.read_one(handle);
let actual = F::from_bytes(&actual);
assert_eq!(actual[0], F::from_int(17));
}
#[cube(launch)]
pub fn kernel_atomic_min<I: Numeric>(output: &mut Array<Atomic<I>>) {
if UNIT_POS == 0 {
Atomic::min(&output[0], I::from_int(5));
}
}
pub fn test_kernel_atomic_min<R: Runtime, F: Numeric + CubeElement>(
client: ComputeClient<R::Server>,
) {
if !supports_feature::<R, F>(&client, TypeUsage::AtomicMinMax) {
println!(
"{} Min not supported - skipped",
Atomic::<F>::as_type_native_unchecked()
);
return;
}
let handle = client.create(F::as_bytes(&[F::from_int(12), F::from_int(1)]));
kernel_atomic_min::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts::<F>(&handle, 2, 1) },
);
let actual = client.read_one(handle);
let actual = F::from_bytes(&actual);
assert_eq!(actual[0], F::from_int(5));
}
#[cube(launch)]
pub fn kernel_atomic_max<I: Numeric>(output: &mut Array<Atomic<I>>) {
if UNIT_POS == 0 {
Atomic::max(&output[0], I::from_int(5));
}
}
pub fn test_kernel_atomic_max<R: Runtime, F: Numeric + CubeElement>(
client: ComputeClient<R::Server>,
) {
if !supports_feature::<R, F>(&client, TypeUsage::AtomicMinMax) {
println!(
"{} Max not supported - skipped",
Atomic::<F>::as_type_native_unchecked()
);
return;
}
let handle = client.create(F::as_bytes(&[F::from_int(12), F::from_int(1)]));
kernel_atomic_max::launch::<F, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts::<F>(&handle, 2, 1) },
);
let actual = client.read_one(handle);
let actual = F::from_bytes(&actual);
assert_eq!(actual[0], F::from_int(12));
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_atomic_int {
() => {
use super::*;
#[test]
fn test_atomic_add_int() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_add::<TestRuntime, IntType>(
client,
);
}
#[test]
fn test_atomic_min_int() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_min::<TestRuntime, IntType>(
client,
);
}
#[test]
fn test_atomic_max_int() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_max::<TestRuntime, IntType>(
client,
);
}
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_atomic_float {
() => {
use super::*;
#[test]
fn test_atomic_add_float() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_add::<TestRuntime, FloatType>(
client,
);
}
#[test]
fn test_atomic_min_float() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_min::<TestRuntime, FloatType>(
client,
);
}
#[test]
fn test_atomic_max_float() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::atomic::test_kernel_atomic_max::<TestRuntime, FloatType>(
client,
);
}
};
}