use cubecl::prelude::*;
#[cube(launch_unchecked)]
pub fn kernel_add<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] + b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_sub<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] - b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_mul<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] * b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_div<F: Float>(a: &Array<F>, b: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = a[ABSOLUTE_POS] / b[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_relu<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let v = x[ABSOLUTE_POS];
out[ABSOLUTE_POS] = F::max(v, F::new(0.0));
}
}
#[cube(launch_unchecked)]
pub fn kernel_neg<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::new(0.0) - x[ABSOLUTE_POS];
}
}
#[cube(launch_unchecked)]
pub fn kernel_abs<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::abs(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_exp<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::exp(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_ln<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::ln(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sqrt<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::sqrt(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sin<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::sin(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_cos<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::cos(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_tanh<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
out[ABSOLUTE_POS] = F::tanh(x[ABSOLUTE_POS]);
}
}
#[cube(launch_unchecked)]
pub fn kernel_sigmoid<F: Float>(x: &Array<F>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let neg_x = F::new(0.0) - x[ABSOLUTE_POS];
out[ABSOLUTE_POS] = F::new(1.0) / (F::new(1.0) + F::exp(neg_x));
}
}
#[cube(launch_unchecked)]
pub fn kernel_chebyshev_t<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let two_x = F::new(2.0) * xv;
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = xv;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for _ in 2..=n_u {
let next = two_x * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_chebyshev_u<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let two_x = F::new(2.0) * xv;
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = two_x;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for _ in 2..=n_u {
let next = two_x * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_chebyshev_v<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let two_x = F::new(2.0) * xv;
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = two_x - F::new(1.0);
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for _ in 2..=n_u {
let next = two_x * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_chebyshev_w<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let two_x = F::new(2.0) * xv;
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = two_x + F::new(1.0);
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for _ in 2..=n_u {
let next = two_x * prev1 - prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_hermite_h<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let two_x = F::new(2.0) * xv;
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = two_x;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for k in 1..n_u {
let kf = F::cast_from(k as u32);
let next = two_x * prev1 - F::new(2.0) * kf * prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_hermite_he<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = xv;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for k in 1..n_u {
let kf = F::cast_from(k as u32);
let next = xv * prev1 - kf * prev2;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_laguerre_l<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = F::new(1.0) - xv;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for k in 1..n_u {
let kf = F::cast_from(k as u32);
let two_k_plus_one = F::new(2.0) * kf + F::new(1.0);
let denom = kf + F::new(1.0);
let next = ((two_k_plus_one - xv) * prev1 - kf * prev2) / denom;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_legendre_p<F: Float>(x: &Array<F>, out: &mut Array<F>, n: u32) {
if ABSOLUTE_POS < out.len() {
let xv = x[ABSOLUTE_POS];
let n_u = n as usize;
let mut prev2 = F::new(1.0);
let mut prev1 = xv;
if n_u == 0 {
out[ABSOLUTE_POS] = prev2;
} else if n_u == 1 {
out[ABSOLUTE_POS] = prev1;
} else {
for k in 1..n_u {
let kf = F::cast_from(k as u32);
let two_k_plus_one = F::new(2.0) * kf + F::new(1.0);
let denom = kf + F::new(1.0);
let next = (two_k_plus_one * xv * prev1 - kf * prev2) / denom;
prev2 = prev1;
prev1 = next;
}
out[ABSOLUTE_POS] = prev1;
}
}
}
#[cube(launch_unchecked)]
pub fn kernel_matmul_naive<F: Float>(
a: &Array<F>,
b: &Array<F>,
out: &mut Array<F>,
m: u32,
k: u32,
n: u32,
) {
let m_u = m as usize;
let k_u = k as usize;
let n_u = n as usize;
let total = m_u * n_u;
if ABSOLUTE_POS < total {
let row = ABSOLUTE_POS / n_u;
let col = ABSOLUTE_POS % n_u;
let mut acc = F::new(0.0);
for i in 0..k_u {
acc += a[row * k_u + i] * b[i * n_u + col];
}
out[ABSOLUTE_POS] = acc;
}
}
fn elementwise_launch_dims(n: u32) -> (CubeCount, CubeDim) {
let units_per_cube: u32 = 256;
let num_cubes = n.div_ceil(units_per_cube).max(1);
(
CubeCount::Static(num_cubes, 1, 1),
CubeDim::new_1d(units_per_cube),
)
}
fn run_unary<R, L>(client: &ComputeClient<R>, x: &[f32], launcher: L) -> Vec<f32>
where
R: Runtime,
L: FnOnce(&ComputeClient<R>, CubeCount, CubeDim, ArrayArg<R>, ArrayArg<R>),
{
let n = x.len();
let size_bytes = std::mem::size_of_val(x);
let x_handle = client.create_from_slice(f32::as_bytes(x));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(n as u32);
let in_arg = unsafe { ArrayArg::from_raw_parts(x_handle, n) };
let out_arg = unsafe { ArrayArg::from_raw_parts(out_handle.clone(), n) };
launcher(client, count, dim, in_arg, out_arg);
let bytes = client.read_one(out_handle).expect("cubecl read_one failed");
f32::from_bytes(&bytes)[..n].to_vec()
}
fn run_binary<R, L>(client: &ComputeClient<R>, a: &[f32], b: &[f32], launcher: L) -> Vec<f32>
where
R: Runtime,
L: FnOnce(&ComputeClient<R>, CubeCount, CubeDim, ArrayArg<R>, ArrayArg<R>, ArrayArg<R>),
{
let n = a.len();
debug_assert_eq!(n, b.len());
let size_bytes = std::mem::size_of_val(a);
let a_handle = client.create_from_slice(f32::as_bytes(a));
let b_handle = client.create_from_slice(f32::as_bytes(b));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(n as u32);
let a_arg = unsafe { ArrayArg::from_raw_parts(a_handle, n) };
let b_arg = unsafe { ArrayArg::from_raw_parts(b_handle, n) };
let out_arg = unsafe { ArrayArg::from_raw_parts(out_handle.clone(), n) };
launcher(client, count, dim, a_arg, b_arg, out_arg);
let bytes = client.read_one(out_handle).expect("cubecl read_one failed");
f32::from_bytes(&bytes)[..n].to_vec()
}
macro_rules! define_unary_runner {
($run_fn:ident, $kernel:ident) => {
#[doc = concat!("Upload `x`, run `", stringify!($kernel), "`, read back the result.")]
pub fn $run_fn<R: Runtime>(client: &ComputeClient<R>, x: &[f32]) -> Vec<f32> {
run_unary::<R, _>(client, x, |client, count, dim, input, output| unsafe {
$kernel::launch_unchecked::<f32, R>(client, count, dim, input, output);
})
}
};
}
macro_rules! define_binary_runner {
($run_fn:ident, $kernel:ident) => {
#[doc = concat!("Upload `a` and `b`, run `", stringify!($kernel), "`, read back the result.")]
pub fn $run_fn<R: Runtime>(client: &ComputeClient<R>, a: &[f32], b: &[f32]) -> Vec<f32> {
run_binary::<R, _>(client, a, b, |client, count, dim, a, b, out| unsafe {
$kernel::launch_unchecked::<f32, R>(client, count, dim, a, b, out);
})
}
};
}
define_binary_runner!(run_add, kernel_add);
define_binary_runner!(run_sub, kernel_sub);
define_binary_runner!(run_mul, kernel_mul);
define_binary_runner!(run_div, kernel_div);
define_unary_runner!(run_relu, kernel_relu);
define_unary_runner!(run_neg, kernel_neg);
define_unary_runner!(run_abs, kernel_abs);
define_unary_runner!(run_exp, kernel_exp);
define_unary_runner!(run_ln, kernel_ln);
define_unary_runner!(run_sqrt, kernel_sqrt);
define_unary_runner!(run_sin, kernel_sin);
define_unary_runner!(run_cos, kernel_cos);
define_unary_runner!(run_tanh, kernel_tanh);
define_unary_runner!(run_sigmoid, kernel_sigmoid);
fn run_unary_with_n<R, L>(client: &ComputeClient<R>, x: &[f32], n: u32, launcher: L) -> Vec<f32>
where
R: Runtime,
L: FnOnce(&ComputeClient<R>, CubeCount, CubeDim, ArrayArg<R>, ArrayArg<R>, u32),
{
let count_elems = x.len();
let size_bytes = std::mem::size_of_val(x);
let x_handle = client.create_from_slice(f32::as_bytes(x));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(count_elems as u32);
let in_arg = unsafe { ArrayArg::from_raw_parts(x_handle, count_elems) };
let out_arg = unsafe { ArrayArg::from_raw_parts(out_handle.clone(), count_elems) };
launcher(client, count, dim, in_arg, out_arg, n);
let bytes = client.read_one(out_handle).expect("cubecl read_one failed");
f32::from_bytes(&bytes)[..count_elems].to_vec()
}
macro_rules! define_unary_with_n_runner {
($run_fn:ident, $kernel:ident) => {
#[doc = concat!("Upload `x`, run `", stringify!($kernel), "` with degree `n`, read back the result.")]
pub fn $run_fn<R: Runtime>(client: &ComputeClient<R>, x: &[f32], n: u32) -> Vec<f32> {
run_unary_with_n::<R, _>(
client,
x,
n,
|client, count, dim, input, output, n_val| unsafe {
$kernel::launch_unchecked::<f32, R>(
client,
count,
dim,
input,
output,
n_val,
);
},
)
}
};
}
define_unary_with_n_runner!(run_chebyshev_t, kernel_chebyshev_t);
define_unary_with_n_runner!(run_chebyshev_u, kernel_chebyshev_u);
define_unary_with_n_runner!(run_chebyshev_v, kernel_chebyshev_v);
define_unary_with_n_runner!(run_chebyshev_w, kernel_chebyshev_w);
define_unary_with_n_runner!(run_hermite_h, kernel_hermite_h);
define_unary_with_n_runner!(run_hermite_he, kernel_hermite_he);
define_unary_with_n_runner!(run_laguerre_l, kernel_laguerre_l);
define_unary_with_n_runner!(run_legendre_p, kernel_legendre_p);
pub fn run_matmul<R: Runtime>(
client: &ComputeClient<R>,
a: &[f32],
b: &[f32],
m: usize,
k: usize,
n: usize,
) -> Vec<f32> {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
let out_len = m * n;
let size_bytes = out_len * std::mem::size_of::<f32>();
let a_handle = client.create_from_slice(f32::as_bytes(a));
let b_handle = client.create_from_slice(f32::as_bytes(b));
let out_handle = client.empty(size_bytes);
let (count, dim) = elementwise_launch_dims(out_len as u32);
unsafe {
kernel_matmul_naive::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(a_handle, a.len()),
ArrayArg::from_raw_parts(b_handle, b.len()),
ArrayArg::from_raw_parts(out_handle.clone(), out_len),
m as u32,
k as u32,
n as u32,
);
}
let bytes = client.read_one(out_handle).expect("cubecl read_one failed");
f32::from_bytes(&bytes)[..out_len].to_vec()
}