use crate::prelude::*;
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn gdn_conv1d<F: Float>(
x: &Array<F>,
w: &Array<F>,
out: &mut Array<F>,
dims: &Array<u32>, #[comptime] k: usize,
) {
let gid = ABSOLUTE_POS;
if gid < out.len() {
let conv_dim = dims[0] as usize;
let seq_len = dims[1] as usize;
let pos = gid % seq_len;
let row = gid / seq_len; let ch = row % conv_dim;
let x_base = row * seq_len;
let w_base = ch * k;
let mut acc = F::new(0.0);
for i in 0..k {
if pos + i >= k - 1 {
let src = pos + i - (k - 1);
acc += x[x_base + src] * w[w_base + i];
}
}
out[x_base + pos] = acc / (F::new(1.0) + (-acc).exp());
}
}
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn gdn_gating<F: Float>(
b_in: &Array<F>,
a_in: &Array<F>,
a_log: &Array<F>,
dt_bias: &Array<F>,
beta_out: &mut Array<F>,
g_out: &mut Array<F>,
dims: &Array<u32>, ) {
let idx = ABSOLUTE_POS;
if idx < beta_out.len() {
let num_heads = dims[0] as usize;
let head = idx % num_heads;
let beta = F::new(1.0) / (F::new(1.0) + (-b_in[idx]).exp());
let softplus = (F::new(1.0) + (a_in[idx] + dt_bias[head]).exp()).ln();
let g = -(a_log[head].exp()) * softplus;
beta_out[idx] = beta;
g_out[idx] = g;
}
}
#[kernel(targets(cuda, metal, vulkan, webgpu, cpu), unchecked)]
pub fn gdn_scan<F: Float>(
q: &Array<F>,
k: &Array<F>,
v: &Array<F>,
g: &Array<F>,
beta: &Array<F>,
state: &mut Array<F>,
out: &mut Array<F>,
dims: &Array<u32>, #[comptime] k_dim: usize,
#[comptime] v_dim: usize,
) {
let tid = ABSOLUTE_POS; let bh = dims[0] as usize;
let seq = dims[1] as usize;
if tid < bh * v_dim {
let b = tid / v_dim;
let vi = tid % v_dim;
let qk_bh = b * seq * k_dim;
let v_bh = b * seq * v_dim;
let gb_bh = b * seq;
let state_base = b * k_dim * v_dim;
let mut s = Array::<F>::new(k_dim);
for j in 0..k_dim {
s[j] = state[state_base + j * v_dim + vi];
}
for t in 0..seq {
let decay = g[gb_bh + t].exp();
let beta_t = beta[gb_bh + t];
let v_t = v[v_bh + t * v_dim + vi];
let mut kv_mem = F::new(0.0);
for j in 0..k_dim {
let sj = s[j] * decay;
s[j] = sj;
kv_mem += sj * k[qk_bh + t * k_dim + j];
}
let delta = (v_t - kv_mem) * beta_t;
let mut y_t = F::new(0.0);
for j in 0..k_dim {
let sj = s[j] + k[qk_bh + t * k_dim + j] * delta;
s[j] = sj;
y_t += sj * q[qk_bh + t * k_dim + j];
}
out[v_bh + t * v_dim + vi] = y_t;
}
for j in 0..k_dim {
state[state_base + j * v_dim + vi] = s[j];
}
}
}
pub fn gdn_conv1d_run<R: Runtime>(
client: &ComputeClient<R>,
x: &[f32],
w: &[f32],
batch: usize,
conv_dim: usize,
seq_len: usize,
k: usize,
) -> Vec<f32> {
let total = batch * conv_dim * seq_len;
let xh = client.create_from_slice(f32::as_bytes(x));
let wh = client.create_from_slice(f32::as_bytes(w));
let dh = client.create_from_slice(u32::as_bytes(&[conv_dim as u32, seq_len as u32]));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; total]));
let block = 64u32;
let grid = (total as u32).div_ceil(block);
unsafe {
gdn_conv1d::launch_unchecked::<f32, R>(
client,
Grid::Static(grid, 1, 1),
Block::new_1d(block),
ArrayArg::from_raw_parts(xh.clone(), x.len()),
ArrayArg::from_raw_parts(wh.clone(), w.len()),
ArrayArg::from_raw_parts(oh.clone(), total),
ArrayArg::from_raw_parts(dh.clone(), 2),
k,
);
}
f32::from_bytes(&client.read_one_unchecked(oh)).to_vec()
}
pub fn gdn_gating_run<R: Runtime>(
client: &ComputeClient<R>,
b: &[f32],
a: &[f32],
a_log: &[f32],
dt_bias: &[f32],
num_heads: usize,
) -> (Vec<f32>, Vec<f32>) {
let total = b.len();
let bh = client.create_from_slice(f32::as_bytes(b));
let ah = client.create_from_slice(f32::as_bytes(a));
let alh = client.create_from_slice(f32::as_bytes(a_log));
let dth = client.create_from_slice(f32::as_bytes(dt_bias));
let betah = client.create_from_slice(f32::as_bytes(&vec![0.0f32; total]));
let gh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; total]));
let dh = client.create_from_slice(u32::as_bytes(&[num_heads as u32]));
let block = 64u32;
let grid = (total as u32).div_ceil(block);
unsafe {
gdn_gating::launch_unchecked::<f32, R>(
client,
Grid::Static(grid, 1, 1),
Block::new_1d(block),
ArrayArg::from_raw_parts(bh.clone(), b.len()),
ArrayArg::from_raw_parts(ah.clone(), a.len()),
ArrayArg::from_raw_parts(alh.clone(), a_log.len()),
ArrayArg::from_raw_parts(dth.clone(), dt_bias.len()),
ArrayArg::from_raw_parts(betah.clone(), total),
ArrayArg::from_raw_parts(gh.clone(), total),
ArrayArg::from_raw_parts(dh.clone(), 1),
);
}
let beta = f32::from_bytes(&client.read_one_unchecked(betah)).to_vec();
let g = f32::from_bytes(&client.read_one_unchecked(gh)).to_vec();
(beta, g)
}
#[allow(clippy::too_many_arguments)]
pub fn gdn_scan_run<R: Runtime>(
client: &ComputeClient<R>,
q: &[f32],
k: &[f32],
v: &[f32],
g: &[f32],
beta: &[f32],
state: &[f32],
bh: usize,
seq: usize,
k_dim: usize,
v_dim: usize,
) -> (Vec<f32>, Vec<f32>) {
let out_len = bh * seq * v_dim;
let qh = client.create_from_slice(f32::as_bytes(q));
let kh = client.create_from_slice(f32::as_bytes(k));
let vh = client.create_from_slice(f32::as_bytes(v));
let gh = client.create_from_slice(f32::as_bytes(g));
let betah = client.create_from_slice(f32::as_bytes(beta));
let sh = client.create_from_slice(f32::as_bytes(state));
let oh = client.create_from_slice(f32::as_bytes(&vec![0.0f32; out_len]));
let dh = client.create_from_slice(u32::as_bytes(&[bh as u32, seq as u32]));
let block = 64u32;
let grid = ((bh * v_dim) as u32).div_ceil(block);
unsafe {
gdn_scan::launch_unchecked::<f32, R>(
client,
Grid::Static(grid, 1, 1),
Block::new_1d(block),
ArrayArg::from_raw_parts(qh.clone(), q.len()),
ArrayArg::from_raw_parts(kh.clone(), k.len()),
ArrayArg::from_raw_parts(vh.clone(), v.len()),
ArrayArg::from_raw_parts(gh.clone(), g.len()),
ArrayArg::from_raw_parts(betah.clone(), beta.len()),
ArrayArg::from_raw_parts(sh.clone(), state.len()),
ArrayArg::from_raw_parts(oh.clone(), out_len),
ArrayArg::from_raw_parts(dh.clone(), 2),
k_dim,
v_dim,
);
}
let out = f32::from_bytes(&client.read_one_unchecked(oh)).to_vec();
let new_state = f32::from_bytes(&client.read_one_unchecked(sh)).to_vec();
(out, new_state)
}
pub fn gdn_conv1d_ref(
x: &[f32],
w: &[f32],
batch: usize,
conv_dim: usize,
seq_len: usize,
k: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; batch * conv_dim * seq_len];
for b in 0..batch {
for ch in 0..conv_dim {
let x_base = (b * conv_dim + ch) * seq_len;
let w_base = ch * k;
for pos in 0..seq_len {
let mut acc = 0.0f32;
for i in 0..k {
let src = pos as isize - (k as isize - 1) + i as isize;
if src >= 0 {
acc += x[x_base + src as usize] * w[w_base + i];
}
}
out[x_base + pos] = acc / (1.0 + (-acc).exp()); }
}
}
out
}
pub fn gdn_gating_ref(
b: &[f32],
a: &[f32],
a_log: &[f32],
dt_bias: &[f32],
num_heads: usize,
neg_exp_a_log: bool,
) -> (Vec<f32>, Vec<f32>) {
let n = b.len();
let mut beta = vec![0.0f32; n];
let mut g = vec![0.0f32; n];
for idx in 0..n {
let head = idx % num_heads;
beta[idx] = 1.0 / (1.0 + (-b[idx]).exp());
let softplus = (1.0 + (a[idx] + dt_bias[head]).exp()).ln();
let coeff = if neg_exp_a_log { a_log[head] } else { -(a_log[head].exp()) };
g[idx] = coeff * softplus;
}
(beta, g)
}
#[allow(clippy::too_many_arguments)]
pub fn gdn_scan_ref(
q: &[f32],
k: &[f32],
v: &[f32],
g: &[f32],
beta: &[f32],
state: &mut [f32],
bh: usize,
seq: usize,
k_dim: usize,
v_dim: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; bh * seq * v_dim];
for b in 0..bh {
let qk_bh = b * seq * k_dim;
let v_bh = b * seq * v_dim;
let gb_bh = b * seq;
let state_base = b * k_dim * v_dim;
for vi in 0..v_dim {
let mut s = vec![0.0f32; k_dim];
for j in 0..k_dim {
s[j] = state[state_base + j * v_dim + vi];
}
for t in 0..seq {
let decay = g[gb_bh + t].exp();
let beta_t = beta[gb_bh + t];
let v_t = v[v_bh + t * v_dim + vi];
let mut kv_mem = 0.0f32;
for j in 0..k_dim {
let sj = s[j] * decay;
s[j] = sj;
kv_mem += sj * k[qk_bh + t * k_dim + j];
}
let delta = (v_t - kv_mem) * beta_t;
let mut y_t = 0.0f32;
for j in 0..k_dim {
let sj = s[j] + k[qk_bh + t * k_dim + j] * delta;
s[j] = sj;
y_t += sj * q[qk_bh + t * k_dim + j];
}
out[v_bh + t * v_dim + vi] = y_t;
}
for j in 0..k_dim {
state[state_base + j * v_dim + vi] = s[j];
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn rnd(n: usize, seed: u64) -> Vec<f32> {
let mut s = seed;
(0..n)
.map(|_| {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
(s % 2000) as f32 / 1000.0 - 1.0
})
.collect()
}
fn max_rel(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / x.abs().max(1e-6))
.fold(0.0, f32::max)
}
#[cfg(feature = "cpu")]
#[test]
fn gdn_conv1d_cpu_bit_exact() {
use cubecl::cpu::{CpuDevice, CpuRuntime};
let c = CpuRuntime::client(&CpuDevice::default());
for &(batch, conv_dim, seq_len) in &[(1usize, 512usize, 48usize), (1, 64, 3), (2, 128, 32)] {
let x = rnd(batch * conv_dim * seq_len, 0x1111 + conv_dim as u64);
let w = rnd(conv_dim * 4, 0x2222 + seq_len as u64);
let got = gdn_conv1d_run::<CpuRuntime>(&c, &x, &w, batch, conv_dim, seq_len, 4);
let want = gdn_conv1d_ref(&x, &w, batch, conv_dim, seq_len, 4);
let rel = max_rel(&want, &got);
eprintln!("[gdn_conv1d CPU] b{batch} c{conv_dim} s{seq_len} k4 max_rel={rel:.2e}");
assert!(rel < 2e-3, "gdn_conv1d b{batch} c{conv_dim} s{seq_len} max_rel {rel}");
}
}
#[cfg(feature = "cpu")]
#[test]
fn gdn_gating_cpu_bit_exact() {
use cubecl::cpu::{CpuDevice, CpuRuntime};
let c = CpuRuntime::client(&CpuDevice::default());
for &(num_heads, tokens) in &[(32usize, 40usize), (16, 1), (8, 128)] {
let total = tokens * num_heads;
let b = rnd(total, 0x3333 + num_heads as u64);
let a = rnd(total, 0x4444 + tokens as u64);
let a_log = rnd(num_heads, 0x5555); let dt_bias = rnd(num_heads, 0x6666);
let (gb, gg) = gdn_gating_run::<CpuRuntime>(&c, &b, &a, &a_log, &dt_bias, num_heads);
let (wb, wg) = gdn_gating_ref(&b, &a, &a_log, &dt_bias, num_heads, false);
let (rb, rg) = (max_rel(&wb, &gb), max_rel(&wg, &gg));
eprintln!("[gdn_gating CPU] h{num_heads} tok{tokens} beta_rel={rb:.2e} g_rel={rg:.2e}");
assert!(rb < 2e-3 && rg < 2e-3, "gdn_gating h{num_heads} beta={rb} g={rg}");
}
}
#[cfg(feature = "cpu")]
#[test]
fn gdn_scan_cpu_bit_exact() {
use cubecl::cpu::{CpuDevice, CpuRuntime};
let c = CpuRuntime::client(&CpuDevice::default());
for &(bh, seq, k_dim, v_dim) in &[
(4usize, 32usize, 64usize, 64usize),
(2, 64, 128, 128),
(8, 1, 32, 48),
(3, 40, 48, 32),
] {
let scale = 1.0f32 / (k_dim as f32).sqrt();
let q: Vec<f32> = rnd(bh * seq * k_dim, 0x7001 + k_dim as u64)
.iter()
.map(|x| x * scale)
.collect();
let k = rnd(bh * seq * k_dim, 0x7002 + seq as u64);
let v = rnd(bh * seq * v_dim, 0x7003 + v_dim as u64);
let g: Vec<f32> = rnd(bh * seq, 0x7004).iter().map(|x| x * 0.5 - 0.5).collect();
let beta: Vec<f32> = rnd(bh * seq, 0x7005).iter().map(|x| (x + 1.0) * 0.5).collect();
let state0 = rnd(bh * k_dim * v_dim, 0x7006);
let (got_out, got_state) =
gdn_scan_run::<CpuRuntime>(&c, &q, &k, &v, &g, &beta, &state0, bh, seq, k_dim, v_dim);
let mut ref_state = state0.clone();
let want_out =
gdn_scan_ref(&q, &k, &v, &g, &beta, &mut ref_state, bh, seq, k_dim, v_dim);
let ro = max_rel(&want_out, &got_out);
let rs = max_rel(&ref_state, &got_state);
eprintln!(
"[gdn_scan CPU] bh{bh} s{seq} k{k_dim} v{v_dim} out_rel={ro:.2e} state_rel={rs:.2e}"
);
assert!(ro < 2e-3 && rs < 2e-3, "gdn_scan bh{bh} s{seq} out={ro} state={rs}");
}
}
}