use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="ssm",
subop="conv1d_causal_step",
class=GenericEmpty,
tol=0.0,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn conv1d_causal_step<T>(
x: Tensor<T>,
w: Tensor<T>,
b: Tensor<T>,
mut state: Tensor<T>,
mut y: Tensor<T>,
#[constexpr] n_channels: u32,
#[constexpr] kernel_size: u32,
) {
let d = program_id::<0>();
let x_d = load(x[d]).cast::<f32>();
let b_d = load(b[d]).cast::<f32>();
let w_last = load(w[(kernel_size - 1u32) * n_channels + d]).cast::<f32>();
let mut acc = b_d + w_last * x_d;
let conv_taps = select(kernel_size > 1u32, kernel_size - 1u32, 0u32);
for k in range(0u32, conv_taps, 1u32) {
let s_kd = load(state[k * n_channels + d]).cast::<f32>();
let w_kd = load(w[k * n_channels + d]).cast::<f32>();
acc = acc + w_kd * s_kd;
}
store(y[d], acc.cast::<T>());
let shift_taps = select(kernel_size > 2u32, kernel_size - 2u32, 0u32);
for k in range(0u32, shift_taps, 1u32) {
let next = load(state[(k + 1u32) * n_channels + d]);
store(state[k * n_channels + d], next);
}
store(state[(kernel_size - 2u32) * n_channels + d], load(x[d]));
}
#[bench_kernel(
op="ssm",
subop="step",
class=GenericEmpty,
tol=0.0,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn ssm_step<T>(
x: Tensor<T>,
a: Tensor<T>,
b: Tensor<T>,
c: Tensor<T>,
dt: Tensor<T>,
mut h: Tensor<f32>,
mut y: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] state_dim: u32,
) {
let idx = program_id::<0>();
let h_id = idx / head_dim;
let d = idx - h_id * head_dim;
let dt_val = load(dt[h_id]).cast::<f32>();
let a_val = load(a[h_id]).cast::<f32>();
let decay = exp(a_val * dt_val);
let x_d = load(x[h_id * head_dim + d]).cast::<f32>();
let mut y_d = 0.0f32;
let h_base = h_id * state_dim * head_dim;
for n in range(0u32, state_dim, 1u32) {
let h_idx = h_base + n * head_dim + d;
let h_old = load(h[h_idx]);
let b_n = load(b[n]).cast::<f32>();
let new_h = decay * h_old + dt_val * b_n * x_d;
store(h[h_idx], new_h);
let c_n = load(c[n]).cast::<f32>();
y_d = y_d + c_n * new_h;
}
store(y[h_id * head_dim + d], y_d.cast::<T>());
}
#[bench_kernel(
op="ssm",
subop="step_a2d",
class=GenericEmpty,
tol=0.0,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn ssm_step_a2d<T>(
x: Tensor<T>,
a_log: Tensor<T>,
b: Tensor<T>,
c: Tensor<T>,
dt: Tensor<T>,
mut h: Tensor<f32>,
mut y: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] state_dim: u32,
) {
let idx = program_id::<0>();
let h_id = idx / head_dim;
let d = idx - h_id * head_dim;
let dt_val = load(dt[h_id]).cast::<f32>();
let x_d = load(x[h_id * head_dim + d]).cast::<f32>();
let a_log_base = idx * state_dim;
let mut y_d = 0.0f32;
let h_base = h_id * state_dim * head_dim;
for n in range(0u32, state_dim, 1u32) {
let a_val = 0.0f32 - exp(load(a_log[a_log_base + n]).cast::<f32>());
let decay = exp(a_val * dt_val);
let h_idx = h_base + n * head_dim + d;
let h_old = load(h[h_idx]);
let b_n = load(b[n]).cast::<f32>();
let new_h = decay * h_old + dt_val * b_n * x_d;
store(h[h_idx], new_h);
let c_n = load(c[n]).cast::<f32>();
y_d = y_d + c_n * new_h;
}
store(y[h_id * head_dim + d], y_d.cast::<T>());
}
#[bench_kernel(
op="ssm",
subop="mt_step",
class=GenericEmpty,
tol=0.0,
kernel_mode=Reduction,
)]
#[kernel]
pub fn mt_ssm_step<T>(
x: Tensor<T>, a_log: Tensor<T>, b_mat: Tensor<T>, c_mat: Tensor<T>, d_skip: Tensor<T>, dt: Tensor<T>, state_in: Tensor<T>, mut state_out: Tensor<T>, mut out: Tensor<T>, #[constexpr] dh: u32,
#[constexpr] ds: u32,
#[constexpr] n_heads: u32,
#[constexpr] heads_per_group: u32,
) {
let d_idx = tgid_x;
let n = tgid_y;
let ds_idx = tid;
let h_idx = n - (n / n_heads) * n_heads;
let g_idx = n / heads_per_group;
let dt_val = load(dt[n]).cast::<f32>();
let a_val = 0.0f32 - exp(load(a_log[h_idx]).cast::<f32>());
let da = exp(a_val * dt_val);
let x_val = load(x[n * dh + d_idx]).cast::<f32>();
let n_per_t = ds / 32u32;
let bc_base = g_idx * ds;
let state_base = n * dh * ds + d_idx * ds;
let mut acc = 0.0f32;
for i in range(0u32, n_per_t, 1u32) {
let s_idx = n_per_t * ds_idx + i;
let idx = state_base + s_idx;
let db_by_x = x_val * dt_val * load(b_mat[bc_base + s_idx]).cast::<f32>();
let new_state = da * load(state_in[idx]).cast::<f32>() + db_by_x;
store(state_out[idx], new_state.cast::<T>());
acc = acc + new_state * load(c_mat[bc_base + s_idx]).cast::<f32>();
}
let total = simd_sum(acc);
if ds_idx == 0u32 {
let d_val = load(d_skip[h_idx]).cast::<f32>();
store(out[n * dh + d_idx], (total + x_val * d_val).cast::<T>());
}
}