use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::gated_delta_net::GatedDeltaNetParams;
pub static GATED_DELTA_NET_DECODE_SHADER_SOURCE: &str =
include_str!("../shaders/gated_delta_net_decode.metal");
pub const MAX_NSG: u32 = 4;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"gated_delta_net_decode_f32_1",
GATED_DELTA_NET_DECODE_SHADER_SOURCE,
);
registry.register_source(
"gated_delta_net_decode_f32_2",
GATED_DELTA_NET_DECODE_SHADER_SOURCE,
);
registry.register_source(
"gated_delta_net_decode_f32_4",
GATED_DELTA_NET_DECODE_SHADER_SOURCE,
);
}
fn validate(
p: &GatedDeltaNetParams,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g: &MlxBuffer,
beta: &MlxBuffer,
state_in: &MlxBuffer,
output: &MlxBuffer,
state_out: &MlxBuffer,
) -> Result<()> {
if p.d_k == 0 || p.d_v == 0 || p.n_k_heads == 0 || p.n_v_heads == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_decode: dims must all be > 0".into(),
));
}
if p.n_tokens == 0 || p.n_seqs == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_decode: n_tokens and n_seqs must be > 0".into(),
));
}
if p.n_v_heads % p.n_k_heads != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: n_v_heads ({}) must be a multiple of n_k_heads ({})",
p.n_v_heads, p.n_k_heads
)));
}
if p.d_k % 32 != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: D_k ({}) must be a multiple of 32 (simdgroup width)",
p.d_k
)));
}
if p.d_k / 32 > MAX_NSG {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: D_k ({}) implies NSG > {} — extend the .metal kernel",
p.d_k, MAX_NSG
)));
}
if p.d_v == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_decode: D_v must be > 0".into(),
));
}
let nsg = p.d_k / 32;
if p.d_v % nsg != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: D_v ({}) must be a multiple of NSG ({}=D_k/32)",
p.d_v, nsg
)));
}
let qk_elems = (p.d_k as usize)
* (p.n_k_heads as usize)
* (p.n_tokens as usize)
* (p.n_seqs as usize);
let v_elems = (p.d_v as usize)
* (p.n_v_heads as usize)
* (p.n_tokens as usize)
* (p.n_seqs as usize);
let scalar_elems = (p.n_v_heads as usize) * (p.n_tokens as usize) * (p.n_seqs as usize);
let state_elems =
(p.d_k as usize) * (p.d_v as usize) * (p.n_v_heads as usize) * (p.n_seqs as usize);
for (name, buf, exp) in [
("q", q, qk_elems),
("k", k, qk_elems),
("v", v, v_elems),
("output", output, v_elems),
("g", g, scalar_elems),
("beta", beta, scalar_elems),
("state_in", state_in, state_elems),
("state_out", state_out, state_elems),
] {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gated_delta_net_decode(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g: &MlxBuffer,
beta: &MlxBuffer,
state_in: &MlxBuffer,
output: &MlxBuffer,
state_out: &MlxBuffer,
params_buf: &MlxBuffer,
p: GatedDeltaNetParams,
) -> Result<()> {
validate(&p, q, k, v, g, beta, state_in, output, state_out)?;
let nsg: u32 = p.d_k / 32;
let kernel_name = match nsg {
1 => "gated_delta_net_decode_f32_1",
2 => "gated_delta_net_decode_f32_2",
4 => "gated_delta_net_decode_f32_4",
other => {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_decode: unsupported NSG={} (D_k={})",
other, p.d_k
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg = MTLSize::new(32, nsg as u64, 1);
let grid_tgs = MTLSize::new((p.d_v / nsg) as u64, p.n_v_heads as u64, p.n_seqs as u64);
encoder.encode_threadgroups(
pipeline,
&[
(0, q),
(1, k),
(2, v),
(3, g),
(4, beta),
(5, state_in),
(6, output),
(7, state_out),
(8, params_buf),
],
grid_tgs,
tg,
);
Ok(())
}