use cubecl::prelude::*;
use cubecl::wgpu::WgpuRuntime;
use cubecl::server::Handle;
use burn::backend::wgpu::CubeTensor;
use burn_backend::{DType, Shape};
use burn_cubecl::kernel::into_contiguous;
fn elem_size(dtype: DType) -> usize {
match dtype {
DType::F64 | DType::I64 | DType::U64 => 8,
DType::F32 | DType::Flex32 | DType::I32 | DType::U32 => 4,
DType::F16 | DType::BF16 | DType::I16 | DType::U16 => 2,
DType::I8 | DType::U8 | DType::Bool => 1,
DType::QFloat(_) => 4,
}
}
fn contiguous_strides(dims: &[usize]) -> Vec<usize> {
let n = dims.len();
let mut s = vec![1usize; n];
for i in (0..n.saturating_sub(1)).rev() {
s[i] = s[i + 1] * dims[i + 1];
}
s
}
fn empty_cube(src: &CubeTensor<WgpuRuntime>, shape: Shape) -> CubeTensor<WgpuRuntime> {
let n_bytes = shape.dims.iter().product::<usize>() * elem_size(src.dtype);
let handle: Handle = src.client.empty(n_bytes);
CubeTensor {
client: src.client.clone(),
device: src.device.clone(),
handle,
strides: contiguous_strides(&shape.dims),
shape,
dtype: src.dtype,
qparams: None,
}
}
#[cube(launch)]
fn scalenorm_kernel<F: Float>(
x: &Tensor<F>,
g: &Tensor<F>, out: &mut Tensor<F>,
d: u32, scale: f32, ) {
let row = CUBE_POS_X; let lane = UNIT_POS_X;
let n_per_lane = d / PLANE_DIM;
let base = row * d + lane;
let mut sq = F::new(0.0);
for i in 0u32..n_per_lane {
let v = x[(base + i * PLANE_DIM) as usize];
sq += v * v;
}
let row_sq = plane_sum(sq);
let factor = g[0usize]
* F::cast_from(scale)
* F::powf(F::max(row_sq, F::new(1e-12)), F::new(-0.5));
for i in 0u32..n_per_lane {
let idx = (base + i * PLANE_DIM) as usize;
out[idx] = x[idx] * factor;
}
}
pub fn launch_scalenorm(
x: CubeTensor<WgpuRuntime>,
g: CubeTensor<WgpuRuntime>,
scale: f32,
) -> CubeTensor<WgpuRuntime> {
let x = into_contiguous(x);
let g = into_contiguous(g);
let out = empty_cube(&x, x.shape.clone());
let d = *x.shape.dims.last().unwrap() as u32;
let rows = x.shape.dims[..x.shape.num_dims() - 1].iter().product::<usize>() as u32;
let cube_dim = CubeDim { x: 64, y: 1, z: 1 };
let cube_count = CubeCount::Static(rows, 1, 1);
match x.dtype {
DType::F32 | DType::Flex32 => unsafe {
scalenorm_kernel::launch::<f32, WgpuRuntime>(
&x.client, cube_count, cube_dim,
TensorArg::from_raw_parts::<f32>(&x.handle, &x.strides, &x.shape.dims, 1),
TensorArg::from_raw_parts::<f32>(&g.handle, &g.strides, &g.shape.dims, 1),
TensorArg::from_raw_parts::<f32>(&out.handle, &out.strides, &out.shape.dims, 1),
ScalarArg::new(d),
ScalarArg::new(scale),
).expect("scalenorm f32 launch");
},
DType::F16 => unsafe {
scalenorm_kernel::launch::<half::f16, WgpuRuntime>(
&x.client, cube_count, cube_dim,
TensorArg::from_raw_parts::<half::f16>(&x.handle, &x.strides, &x.shape.dims, 1),
TensorArg::from_raw_parts::<half::f16>(&g.handle, &g.strides, &g.shape.dims, 1),
TensorArg::from_raw_parts::<half::f16>(&out.handle, &out.strides, &out.shape.dims, 1),
ScalarArg::new(d),
ScalarArg::new(scale),
).expect("scalenorm f16 launch");
},
dt => panic!("scalenorm: unsupported dtype {dt:?}"),
}
out
}
#[cube(launch)]
fn rope_kernel<F: Float>(
x: &Tensor<F>,
cos: &Tensor<F>, sin: &Tensor<F>, out: &mut Tensor<F>,
d: usize, half: usize, n_seq: usize, total: usize, ) {
let pos = ABSOLUTE_POS;
if pos >= total { terminate!(); }
let d_idx = pos % d;
let n_idx = (pos / d) % n_seq;
let result: F = if d_idx < half {
let cs = n_idx * half + d_idx;
x[pos] * cos[cs] - x[pos + half] * sin[cs]
} else {
let cs = n_idx * half + d_idx - half;
x[pos - half] * sin[cs] + x[pos] * cos[cs]
};
out[pos] = result;
}
pub fn launch_rope(
x: CubeTensor<WgpuRuntime>,
cos: CubeTensor<WgpuRuntime>,
sin: CubeTensor<WgpuRuntime>,
) -> CubeTensor<WgpuRuntime> {
let x = into_contiguous(x);
let cos = into_contiguous(cos);
let sin = into_contiguous(sin);
let out = empty_cube(&x, x.shape.clone());
let dims = &x.shape.dims;
let rank = dims.len();
let d = dims[rank - 1];
let half = d / 2;
let n_seq = dims[rank - 2];
let total = dims.iter().product::<usize>();
let cube_dim_x: u32 = 256;
let cube_count_x = ((total as u32) + cube_dim_x - 1) / cube_dim_x;
let cube_dim = CubeDim { x: cube_dim_x, y: 1, z: 1 };
let cube_count = CubeCount::Static(cube_count_x, 1, 1);
match x.dtype {
DType::F32 | DType::Flex32 => unsafe {
rope_kernel::launch::<f32, WgpuRuntime>(
&x.client, cube_count, cube_dim,
TensorArg::from_raw_parts::<f32>(&x.handle, &x.strides, &x.shape.dims, 1),
TensorArg::from_raw_parts::<f32>(&cos.handle, &cos.strides, &cos.shape.dims, 1),
TensorArg::from_raw_parts::<f32>(&sin.handle, &sin.strides, &sin.shape.dims, 1),
TensorArg::from_raw_parts::<f32>(&out.handle, &out.strides, &out.shape.dims, 1),
ScalarArg::new(d),
ScalarArg::new(half),
ScalarArg::new(n_seq),
ScalarArg::new(total),
).expect("rope f32 launch");
},
DType::F16 => unsafe {
rope_kernel::launch::<half::f16, WgpuRuntime>(
&x.client, cube_count, cube_dim,
TensorArg::from_raw_parts::<half::f16>(&x.handle, &x.strides, &x.shape.dims, 1),
TensorArg::from_raw_parts::<half::f16>(&cos.handle, &cos.strides, &cos.shape.dims, 1),
TensorArg::from_raw_parts::<half::f16>(&sin.handle, &sin.strides, &sin.shape.dims, 1),
TensorArg::from_raw_parts::<half::f16>(&out.handle, &out.strides, &out.shape.dims, 1),
ScalarArg::new(d),
ScalarArg::new(half),
ScalarArg::new(n_seq),
ScalarArg::new(total),
).expect("rope f16 launch");
},
dt => panic!("rope: unsupported dtype {dt:?}"),
}
out
}