use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static GATED_DELTA_NET_SHADER_SOURCE: &str =
include_str!("../shaders/gated_delta_net.metal");
pub const MAX_STATE_D: u32 = 128;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("gated_delta_net_f32", GATED_DELTA_NET_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy)]
pub struct GatedDeltaNetParams {
pub d_k: u32,
pub d_v: u32,
pub n_k_heads: u32,
pub n_v_heads: u32,
pub n_tokens: u32,
pub n_seqs: u32,
}
fn validate(
p: &GatedDeltaNetParams,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g: &MlxBuffer,
beta: &MlxBuffer,
state_in: &MlxBuffer,
output: &MlxBuffer,
state_out: &MlxBuffer,
) -> Result<()> {
if p.d_k == 0 || p.d_v == 0 || p.n_k_heads == 0 || p.n_v_heads == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net: dims must all be > 0".into(),
));
}
if p.n_tokens == 0 || p.n_seqs == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net: n_tokens and n_seqs must be > 0".into(),
));
}
if p.n_v_heads % p.n_k_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net: n_v_heads ({}) must be a multiple of n_k_heads ({})",
p.n_v_heads, p.n_k_heads
)));
}
if p.d_k > MAX_STATE_D || p.d_v > MAX_STATE_D {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net: d_k ({}) and d_v ({}) must be <= MAX_STATE_D ({})",
p.d_k, p.d_v, MAX_STATE_D
)));
}
let qk_elems = (p.d_k as usize)
* (p.n_k_heads as usize)
* (p.n_tokens as usize)
* (p.n_seqs as usize);
let v_elems = (p.d_v as usize)
* (p.n_v_heads as usize)
* (p.n_tokens as usize)
* (p.n_seqs as usize);
let scalar_elems = (p.n_v_heads as usize) * (p.n_tokens as usize) * (p.n_seqs as usize);
let state_elems =
(p.d_k as usize) * (p.d_v as usize) * (p.n_v_heads as usize) * (p.n_seqs as usize);
for (name, buf, exp) in [
("q", q, qk_elems),
("k", k, qk_elems),
("v", v, v_elems),
("output", output, v_elems),
("g", g, scalar_elems),
("beta", beta, scalar_elems),
("state_in", state_in, state_elems),
("state_out", state_out, state_elems),
] {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gated_delta_net(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g: &MlxBuffer,
beta: &MlxBuffer,
state_in: &MlxBuffer,
output: &MlxBuffer,
state_out: &MlxBuffer,
params_buf: &MlxBuffer,
p: GatedDeltaNetParams,
) -> Result<()> {
validate(&p, q, k, v, g, beta, state_in, output, state_out)?;
let pipeline = registry.get_pipeline("gated_delta_net_f32", device)?;
let tg = MTLSize::new(p.d_v as u64, 1, 1);
let grid_tgs = MTLSize::new(p.n_v_heads as u64, p.n_seqs as u64, 1);
let shared_floats = 2 * (p.d_k as u64) + 2 * (p.d_v as u64);
let shared_bytes = shared_floats * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, q),
(1, k),
(2, v),
(3, g),
(4, beta),
(5, state_in),
(6, output),
(7, state_out),
(8, params_buf),
],
&[(0, shared_bytes)],
grid_tgs,
tg,
);
Ok(())
}
pub fn build_gated_delta_net_params(
device: &crate::MlxDevice,
p: GatedDeltaNetParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(8 * 4, DType::U32, vec![8])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.d_k;
s[1] = p.d_v;
s[2] = p.n_k_heads;
s[3] = p.n_v_heads;
s[4] = p.n_tokens;
s[5] = p.n_seqs;
s[6] = 0;
s[7] = 0;
}
Ok(buf)
}
pub fn cpu_reference_f32(
q: &[f32],
k: &[f32],
v: &[f32],
g: &[f32],
beta: &[f32],
state_in: &[f32],
p: GatedDeltaNetParams,
) -> (Vec<f32>, Vec<f32>) {
let d_k = p.d_k as usize;
let d_v = p.d_v as usize;
let nh_k = p.n_k_heads as usize;
let nh_v = p.n_v_heads as usize;
let n_t = p.n_tokens as usize;
let n_s = p.n_seqs as usize;
let group_ratio = nh_v / nh_k;
let kq_token_stride = nh_k * d_k;
let kq_seq_stride = n_t * kq_token_stride;
let v_token_stride = nh_v * d_v;
let v_seq_stride = n_t * v_token_stride;
let scalar_seq_stride = n_t * nh_v;
let state_head_stride = d_v * d_k;
let state_seq_stride = nh_v * state_head_stride;
let mut output = vec![0.0f32; n_s * v_seq_stride];
let mut state = state_in.to_vec();
for s in 0..n_s {
for vh in 0..nh_v {
let kh = vh / group_ratio;
for t in 0..n_t {
let kq_base = s * kq_seq_stride + t * kq_token_stride + kh * d_k;
let v_base = s * v_seq_stride + t * v_token_stride + vh * d_v;
let sc_idx = s * scalar_seq_stride + t * nh_v + vh;
let beta_val = beta[sc_idx];
let g_val = g[sc_idx];
let alpha = (-g_val).exp();
let state_base = s * state_seq_stride + vh * state_head_stride;
let mut delta = vec![0.0f32; d_v];
for i in 0..d_v {
let mut sk = 0.0f32;
for j in 0..d_k {
sk += state[state_base + i * d_k + j] * k[kq_base + j];
}
delta[i] = v[v_base + i] - sk;
}
for i in 0..d_v {
let beta_delta = beta_val * delta[i];
for j in 0..d_k {
let idx = state_base + i * d_k + j;
state[idx] = alpha * state[idx] + beta_delta * k[kq_base + j];
}
}
for i in 0..d_v {
let mut acc = 0.0f32;
for j in 0..d_k {
acc += state[state_base + i * d_k + j] * q[kq_base + j];
}
output[v_base + i] = acc;
}
}
}
}
(output, state)
}