use crate::{self as cubecl};
use alloc::vec;
use cubecl::prelude::*;
#[cube(launch_unchecked)]
pub fn kernel_saturating_add<I: Int, N: Size>(
lhs: &Array<Vector<I, N>>,
rhs: &Array<Vector<I, N>>,
output: &mut Array<Vector<I, N>>,
) {
if (UNIT_POS as usize) < output.len() {
output[UNIT_POS as usize] =
Vector::<I, N>::saturating_add(lhs[UNIT_POS as usize], rhs[UNIT_POS as usize]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_saturating_sub<I: Int, N: Size>(
lhs: &Array<Vector<I, N>>,
rhs: &Array<Vector<I, N>>,
output: &mut Array<Vector<I, N>>,
) {
if (UNIT_POS as usize) < output.len() {
output[UNIT_POS as usize] =
Vector::<I, N>::saturating_sub(lhs[UNIT_POS as usize], rhs[UNIT_POS as usize]);
}
}
#[allow(clippy::needless_range_loop)]
pub fn test_saturating_add_unsigned<R: Runtime, I: Int + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
if I::cube_type() == u64::cube_type() {
return;
}
let lhs = vec![
I::new(2),
I::max_value(),
I::max_value() - I::new(10),
I::new(20),
];
let rhs = vec![I::new(10), I::new(1), I::new(9), I::max_value()];
let out = vec![
I::new(12),
I::max_value(),
I::max_value() - I::new(1),
I::max_value(),
];
let lhs_handle = client.create_from_slice(I::as_bytes(&lhs));
let rhs_handle = client.create_from_slice(I::as_bytes(&rhs));
let out_handle = client.empty(4 * size_of::<I>());
unsafe {
kernel_saturating_add::launch_unchecked::<I, R>(
&client,
CubeCount::new_single(),
CubeDim::new_1d(out.len() as u32),
vector_size,
ArrayArg::from_raw_parts(lhs_handle, 4),
ArrayArg::from_raw_parts(rhs_handle, 4),
ArrayArg::from_raw_parts(out_handle.clone(), 4),
)
}
let actual = client.read_one_unchecked(out_handle);
let actual = I::from_bytes(&actual);
assert_eq!(actual, out);
}
#[allow(clippy::needless_range_loop)]
pub fn test_saturating_sub_unsigned<R: Runtime, I: Int + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
if I::cube_type() == u64::cube_type() {
return;
}
let lhs = vec![
I::new(2),
I::new(4),
I::new(10),
I::max_value() - I::new(10),
];
let rhs = vec![I::new(1), I::new(6), I::new(8), I::max_value()];
let out = vec![I::new(1), I::new(0), I::new(2), I::new(0)];
let lhs_handle = client.create_from_slice(I::as_bytes(&lhs));
let rhs_handle = client.create_from_slice(I::as_bytes(&rhs));
let out_handle = client.empty(4 * size_of::<I>());
unsafe {
kernel_saturating_sub::launch_unchecked::<I, R>(
&client,
CubeCount::new_single(),
CubeDim::new_1d(out.len() as u32),
vector_size,
ArrayArg::from_raw_parts(lhs_handle, 4),
ArrayArg::from_raw_parts(rhs_handle, 4),
ArrayArg::from_raw_parts(out_handle.clone(), 4),
)
}
let actual = client.read_one_unchecked(out_handle);
let actual = I::from_bytes(&actual);
assert_eq!(actual, out);
}
#[allow(clippy::needless_range_loop)]
pub fn test_saturating_add_signed<R: Runtime, I: Int + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
let lhs = vec![
I::new(0),
I::new(0),
I::new(0),
I::new(5),
I::new(-5),
I::new(10),
I::new(-10),
I::new(50),
I::new(30),
I::new(10),
I::max_value(),
I::new(1),
I::min_value(),
I::new(-1),
I::max_value() - I::new(1),
I::min_value() + I::new(1),
];
let rhs = vec![
I::new(0),
I::new(5),
I::new(-5),
I::new(0),
I::new(0),
I::new(20),
I::new(-20),
I::new(-30),
I::new(-50),
I::new(-10),
I::new(1),
I::max_value(),
I::new(-1),
I::min_value(),
I::new(1),
I::new(-1),
];
let out = vec![
I::new(0),
I::new(5),
I::new(-5),
I::new(5),
I::new(-5),
I::new(30),
I::new(-30),
I::new(20),
I::new(-20),
I::new(0),
I::max_value(),
I::max_value(),
I::min_value(),
I::min_value(),
I::max_value(),
I::min_value(),
];
let lhs_handle = client.create_from_slice(I::as_bytes(&lhs));
let rhs_handle = client.create_from_slice(I::as_bytes(&rhs));
let out_handle = client.empty(16 * size_of::<I>());
unsafe {
kernel_saturating_add::launch_unchecked::<I, R>(
&client,
CubeCount::new_single(),
CubeDim::new_1d(out.len() as u32),
vector_size,
ArrayArg::from_raw_parts(lhs_handle, 16),
ArrayArg::from_raw_parts(rhs_handle, 16),
ArrayArg::from_raw_parts(out_handle.clone(), 16),
)
}
let actual = client.read_one_unchecked(out_handle);
let actual = I::from_bytes(&actual);
assert_eq!(actual, out);
}
#[allow(clippy::needless_range_loop)]
pub fn test_saturating_sub_signed<R: Runtime, I: Int + CubeElement>(
client: ComputeClient<R>,
vector_size: VectorSize,
) {
let lhs = vec![
I::new(0), I::new(0), I::new(0), I::new(10), I::new(-10), I::new(20), I::new(5), I::new(-5), I::new(-20), I::max_value(), I::max_value(), I::min_value(), I::min_value(), I::max_value() - I::new(1), I::min_value() + I::new(1), I::new(50), ];
let rhs = vec![
I::new(0),
I::new(5),
I::new(-5),
I::new(3),
I::new(-3),
I::new(15),
I::new(10),
I::new(-10),
I::new(-5),
I::new(-1),
I::new(1),
I::new(1),
I::new(-1),
I::new(-1),
I::new(1),
I::new(-30),
];
let out = vec![
I::new(0),
I::new(-5),
I::new(5),
I::new(7),
I::new(-7),
I::new(5),
I::new(-5),
I::new(5),
I::new(-15),
I::max_value(), I::max_value() - I::new(1),
I::min_value(), I::min_value() + I::new(1),
I::max_value(), I::min_value(), I::new(80), ];
let lhs_handle = client.create_from_slice(I::as_bytes(&lhs));
let rhs_handle = client.create_from_slice(I::as_bytes(&rhs));
let out_handle = client.empty(16 * size_of::<I>());
unsafe {
kernel_saturating_sub::launch_unchecked::<I, R>(
&client,
CubeCount::new_single(),
CubeDim::new_1d(out.len() as u32),
vector_size,
ArrayArg::from_raw_parts(lhs_handle, 16),
ArrayArg::from_raw_parts(rhs_handle, 16),
ArrayArg::from_raw_parts(out_handle.clone(), 16),
)
}
let actual = client.read_one_unchecked(out_handle);
let actual = I::from_bytes(&actual);
assert_eq!(actual, out);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_saturating_uint {
() => {
use super::*;
#[$crate::runtime_tests::test_log::test]
fn test_saturating_add_unsigned() {
let client = TestRuntime::client(&Default::default());
let test = cubecl_core::runtime_tests::saturating::test_saturating_add_unsigned::<
TestRuntime,
UintType,
>;
test(client.clone(), 1);
test(client.clone(), 2);
test(client, 4);
}
#[$crate::runtime_tests::test_log::test]
fn test_saturating_sub_unsigned() {
let client = TestRuntime::client(&Default::default());
let test = cubecl_core::runtime_tests::saturating::test_saturating_sub_unsigned::<
TestRuntime,
UintType,
>;
test(client.clone(), 1);
test(client.clone(), 2);
test(client, 4);
}
};
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_saturating_int {
() => {
use super::*;
#[$crate::runtime_tests::test_log::test]
fn test_saturating_add_signed() {
let client = TestRuntime::client(&Default::default());
let test = cubecl_core::runtime_tests::saturating::test_saturating_add_signed::<
TestRuntime,
IntType,
>;
test(client.clone(), 1);
test(client.clone(), 2);
test(client, 4);
}
#[$crate::runtime_tests::test_log::test]
fn test_saturating_sub_signed() {
let client = TestRuntime::client(&Default::default());
let test = cubecl_core::runtime_tests::saturating::test_saturating_sub_signed::<
TestRuntime,
IntType,
>;
test(client.clone(), 1);
test(client.clone(), 2);
test(client, 4);
}
};
}