use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static GATED_DELTA_NET_CHUNK_O_SHADER_SOURCE: &str =
include_str!("../shaders/gated_delta_net_chunk_o.metal");
pub const MAX_K: u32 = 128;
pub const MAX_V: u32 = 256;
pub const DEFAULT_BK: u32 = 32;
pub const DEFAULT_BV: u32 = 32;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"gated_delta_net_chunk_o_bf16",
GATED_DELTA_NET_CHUNK_O_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct GatedDeltaNetChunkOParams {
pub b: u32,
pub t: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
pub v: u32,
pub bt: u32,
pub scale: f32,
}
impl GatedDeltaNetChunkOParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
pub fn num_v_tiles(&self) -> u32 {
self.v / DEFAULT_BV
}
}
#[allow(clippy::too_many_arguments)]
fn validate(
p: &GatedDeltaNetChunkOParams,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
h: &MlxBuffer,
g: &MlxBuffer,
o: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.hg == 0 || p.h == 0 || p.k == 0 || p.v == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_chunk_o: all dims must be > 0".into(),
));
}
if p.h % p.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: h ({}) must be a multiple of hg ({})",
p.h, p.hg
)));
}
if p.k != MAX_K {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: K ({}) must equal MAX_K = {} exactly. \
The simdgroup_matrix MMA K-tile loop is compile-time hard-coded \
at 16 (= K=128/8) in the shader; runtime loop bounds defeat the \
MMA scheduler (3.15× regression measured 2026-04-27). Future \
iters will lift this via FLA's b_h1..b_h4 bank-split.",
p.k, MAX_K
)));
}
if p.v > MAX_V {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: v ({}) must be <= MAX_V ({})",
p.v, MAX_V
)));
}
if p.bt != 64 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o (iter 3): bt must be 64 (got {})",
p.bt
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o (iter 3): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
if p.k % DEFAULT_BK != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o (iter 3): k ({}) must be a multiple of BK ({})",
p.k, DEFAULT_BK
)));
}
if p.v % DEFAULT_BV != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o (iter 3): v ({}) must be a multiple of BV ({})",
p.v, DEFAULT_BV
)));
}
if !p.scale.is_finite() {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: scale must be finite (got {})",
p.scale
)));
}
let shared_bytes: u64 = ((p.bt * p.bt) as u64) * 2;
const M5_MAX_TG_MEM_BYTES: u64 = 32 * 1024;
if shared_bytes > M5_MAX_TG_MEM_BYTES {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: threadgroup memory {} bytes exceeds M5 Max \
cap of {} bytes (bt={})",
shared_bytes, M5_MAX_TG_MEM_BYTES, p.bt
)));
}
let nt = p.num_chunks();
let q_elems = (p.b * p.t * p.hg * p.k) as usize;
let k_elems = (p.b * p.t * p.hg * p.k) as usize;
let v_elems = (p.b * p.t * p.h * p.v) as usize;
let h_elems = (p.b * nt * p.h * p.v * p.k) as usize;
let g_elems = (p.b * p.t * p.h) as usize;
let o_elems = (p.b * p.t * p.h * p.v) as usize;
if q.element_count() != q_elems || q.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: q must be bf16[{}] (got {} {})",
q_elems,
q.element_count(),
q.dtype()
)));
}
if k.element_count() != k_elems || k.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: k must be bf16[{}] (got {} {})",
k_elems,
k.element_count(),
k.dtype()
)));
}
if v.element_count() != v_elems || v.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: v must be bf16[{}] (got {} {})",
v_elems,
v.element_count(),
v.dtype()
)));
}
if h.element_count() != h_elems || h.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: h must be bf16[{}] (got {} {})",
h_elems,
h.element_count(),
h.dtype()
)));
}
if g.element_count() != g_elems || g.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: g must be f32[{}] (got {} {})",
g_elems,
g.element_count(),
g.dtype()
)));
}
if o.element_count() != o_elems || o.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk_o: o must be bf16[{}] (got {} {})",
o_elems,
o.element_count(),
o.dtype()
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gated_delta_net_chunk_o(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
h: &MlxBuffer,
g: &MlxBuffer,
o: &MlxBuffer,
params_buf: &MlxBuffer,
p: GatedDeltaNetChunkOParams,
) -> Result<()> {
validate(&p, q, k, v, h, g, o)?;
let pipeline = registry.get_pipeline("gated_delta_net_chunk_o_bf16", device)?;
let nv = p.num_v_tiles() as u64;
let nt = p.num_chunks() as u64;
let bh = (p.b * p.h) as u64;
let grid_tgs = MTLSize::new(nv, nt, bh);
let tg = MTLSize::new(256, 1, 1);
let ba_stage_bytes: u64 = (p.bt as u64) * (p.bt as u64) * 2;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, q),
(1, k),
(2, v),
(3, h),
(4, g),
(5, o),
(6, params_buf),
],
&[(0, ba_stage_bytes)],
grid_tgs,
tg,
);
Ok(())
}
pub fn build_gated_delta_net_chunk_o_params(
device: &crate::MlxDevice,
p: GatedDeltaNetChunkOParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(11 * 4, DType::U32, vec![11])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.hg;
s[3] = p.h;
s[4] = p.k;
s[5] = p.v;
s[6] = p.bt;
s[7] = p.num_chunks();
s[8] = DEFAULT_BK;
s[9] = DEFAULT_BV;
s[10] = p.scale.to_bits();
}
Ok(buf)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn dummy_buf(device: &MlxDevice, dtype: DType) -> MlxBuffer {
device
.alloc_buffer(2, dtype, vec![1])
.expect("alloc dummy")
}
#[test]
fn validate_rejects_k_above_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let q_buf = dummy_buf(&device, DType::BF16);
let k_buf = dummy_buf(&device, DType::BF16);
let v_buf = dummy_buf(&device, DType::BF16);
let h_buf = dummy_buf(&device, DType::BF16);
let g_buf = dummy_buf(&device, DType::F32);
let o_buf = dummy_buf(&device, DType::BF16);
let p = GatedDeltaNetChunkOParams {
b: 1,
t: 128,
hg: 2,
h: 4,
k: 256, v: 128,
bt: 64,
scale: (128f32).powf(-0.5),
};
let err = validate(&p, &q_buf, &k_buf, &v_buf, &h_buf, &g_buf, &o_buf)
.expect_err("validate must reject K=256");
let msg = err.to_string();
assert!(
msg.contains("256"),
"expected K=256 in error message, got: {msg}"
);
assert!(
msg.contains("compile-time") || msg.contains("MMA"),
"expected MMA-compile-time-bound context in error, got: {msg}"
);
assert!(
msg.contains("MAX_K = 128") || msg.contains("MAX_K=128"),
"expected explicit MAX_K cap in error, got: {msg}"
);
}
}