use cubecl::prelude::*;
#[cube(launch_unchecked)]
pub fn kernel_add<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] + b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_sub<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] - b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_mul<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] * b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_div<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] / b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_relu<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let v = x[ABSOLUTE_POS];
out[ABSOLUTE_POS] = F::max(v, F::new(0.0));
}
}
#[cube(launch_unchecked)]
pub fn kernel_neg<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::new(0.0) - x[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_abs<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::abs(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_exp<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::exp(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_ln<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::ln(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sqrt<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::sqrt(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sin<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::sin(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_cos<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::cos(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_tanh<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::tanh(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sigmoid<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let neg_x = F::new(0.0) - x[ABSOLUTE_POS];
out[ABSOLUTE_POS] = F::new(1.0) / (F::new(1.0) + F::exp(neg_x));
}
}
#[cube(launch_unchecked)]
pub fn kernel_matmul_naive<F: Float>(
a: &Array<F>,
b: &Array<F>,
out: &mut Array<F>,
m: u32,
k: u32,
n: u32,
) {
let m_u = m as usize;
let k_u = k as usize;
let n_u = n as usize;
let total = m_u * n_u;
if ABSOLUTE_POS < total {
let row = ABSOLUTE_POS / n_u;
let col = ABSOLUTE_POS % n_u;
let mut acc = F::new(0.0);
for i in 0..k_u {
acc += a[row * k_u + i] * b[i * n_u + col];
}
out[ABSOLUTE_POS] = acc;
}
}
fn elementwise_launch_dims(n: u32) -> (CubeCount, CubeDim) {
let units_per_cube: u32 = 256;
let num_cubes = n.div_ceil(units_per_cube).max(1);
(
CubeCount::Static(num_cubes, 1, 1),
CubeDim::new_1d(units_per_cube),
)
}
fn run_unary<R, L>(client: &ComputeClient<R>, x: &[f32], launcher: L) -> Vec<f32>
where
R: Runtime,
L: FnOnce(&ComputeClient<R>, CubeCount, CubeDim, ArrayArg<R>, ArrayArg<R>),
{
let n = x.len();
let size_bytes = n * std::mem::size_of::<f32>();
let x_handle = client.create_from_slice(f32::as_bytes(x));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(n as u32);
let in_arg = unsafe { ArrayArg::from_raw_parts::<f32>(&x_handle, n, 1) };
let out_arg = unsafe { ArrayArg::from_raw_parts::<f32>(&out_handle, n, 1) };
launcher(client, count, dim, in_arg, out_arg);
let bytes = client.read_one(out_handle);
f32::from_bytes(&bytes)[..n].to_vec()
}
fn run_binary<R, L>(client: &ComputeClient<R>, a: &[f32], b: &[f32], launcher: L) -> Vec<f32>
where
R: Runtime,
L: FnOnce(&ComputeClient<R>, CubeCount, CubeDim, ArrayArg<R>, ArrayArg<R>, ArrayArg<R>),
{
let n = a.len();
debug_assert_eq!(n, b.len());
let size_bytes = n * std::mem::size_of::<f32>();
let a_handle = client.create_from_slice(f32::as_bytes(a));
let b_handle = client.create_from_slice(f32::as_bytes(b));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(n as u32);
let a_arg = unsafe { ArrayArg::from_raw_parts::<f32>(&a_handle, n, 1) };
let b_arg = unsafe { ArrayArg::from_raw_parts::<f32>(&b_handle, n, 1) };
let out_arg = unsafe { ArrayArg::from_raw_parts::<f32>(&out_handle, n, 1) };
launcher(client, count, dim, a_arg, b_arg, out_arg);
let bytes = client.read_one(out_handle);
f32::from_bytes(&bytes)[..n].to_vec()
}
macro_rules! define_unary_runner {
($run_fn:ident, $kernel:ident) => {
#[doc = concat!("Upload `x`, run `", stringify!($kernel), "`, read back the result.")]
pub fn $run_fn<R: Runtime>(client: &ComputeClient<R>, x: &[f32]) -> Vec<f32> {
run_unary::<R, _>(client, x, |client, count, dim, input, output| unsafe {
$kernel::launch_unchecked::<f32, R>(client, count, dim, input, output)
.expect(concat!("cubecl ", stringify!($kernel), " launch failed"));
})
}
};
}
macro_rules! define_binary_runner {
($run_fn:ident, $kernel:ident) => {
#[doc = concat!("Upload `a` and `b`, run `", stringify!($kernel), "`, read back the result.")]
pub fn $run_fn<R: Runtime>(
client: &ComputeClient<R>,
a: &[f32],
b: &[f32],
) -> Vec<f32> {
run_binary::<R, _>(client, a, b, |client, count, dim, a, b, out| unsafe {
$kernel::launch_unchecked::<f32, R>(client, count, dim, a, b, out)
.expect(concat!("cubecl ", stringify!($kernel), " launch failed"));
})
}
};
}
define_binary_runner!(run_add, kernel_add);
define_binary_runner!(run_sub, kernel_sub);
define_binary_runner!(run_mul, kernel_mul);
define_binary_runner!(run_div, kernel_div);
define_unary_runner!(run_relu, kernel_relu);
define_unary_runner!(run_neg, kernel_neg);
define_unary_runner!(run_abs, kernel_abs);
define_unary_runner!(run_exp, kernel_exp);
define_unary_runner!(run_ln, kernel_ln);
define_unary_runner!(run_sqrt, kernel_sqrt);
define_unary_runner!(run_sin, kernel_sin);
define_unary_runner!(run_cos, kernel_cos);
define_unary_runner!(run_tanh, kernel_tanh);
define_unary_runner!(run_sigmoid, kernel_sigmoid);
pub fn run_matmul<R: Runtime>(
client: &ComputeClient<R>,
a: &[f32],
b: &[f32],
m: usize,
k: usize,
n: usize,
) -> Vec<f32> {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
let out_len = m * n;
let size_bytes = out_len * std::mem::size_of::<f32>();
let a_handle = client.create_from_slice(f32::as_bytes(a));
let b_handle = client.create_from_slice(f32::as_bytes(b));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(out_len as u32);
unsafe {
kernel_matmul_naive::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts::<f32>(&a_handle, a.len(), 1),
ArrayArg::from_raw_parts::<f32>(&b_handle, b.len(), 1),
ArrayArg::from_raw_parts::<f32>(&out_handle, out_len, 1),
ScalarArg::new(m as u32),
ScalarArg::new(k as u32),
ScalarArg::new(n as u32),
)
.expect("cubecl matmul kernel launch failed");
}
let bytes = client.read_one(out_handle);
f32::from_bytes(&bytes)[..out_len].to_vec()
}