use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static GATED_DELTA_NET_CHUNK_SHADER_SOURCE: &str =
include_str!("../shaders/gated_delta_net_chunk.metal");
pub const MAX_K: u32 = 128;
pub const MAX_V: u32 = 256;
pub const DEFAULT_BV: u32 = 32;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"gated_delta_net_chunk_inter_state_bf16",
GATED_DELTA_NET_CHUNK_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct GatedDeltaNetChunkParams {
pub b: u32,
pub t: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
pub v: u32,
pub bt: u32,
}
impl GatedDeltaNetChunkParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
pub fn group_ratio(&self) -> u32 {
self.h / self.hg
}
}
fn validate(
p: &GatedDeltaNetChunkParams,
k: &MlxBuffer,
w: &MlxBuffer,
u: &MlxBuffer,
g: &MlxBuffer,
h0: &MlxBuffer,
h_out: &MlxBuffer,
v_new: &MlxBuffer,
final_state: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.hg == 0 || p.h == 0 || p.k == 0 || p.v == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_chunk: all dims must be > 0".into(),
));
}
if p.h % p.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: h ({}) must be a multiple of hg ({})",
p.h, p.hg
)));
}
if p.k != MAX_K {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: K ({}) must equal MAX_K = {} exactly. \
Section 2f's simdgroup_matrix MMA K-tile loop is compile-time \
hard-coded at 16 (= K=128/8); K<128 would read OOB. To support \
K=32/64/96/192/256, port FLA's b_h1..b_h4 bank-split.",
p.k, MAX_K
)));
}
if p.v > MAX_V {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: v ({}) must be <= MAX_V ({})",
p.v, MAX_V
)));
}
let bv = DEFAULT_BV;
let shared_bytes = ((bv * p.k) as u64 * 4) + ((bv * p.bt) as u64 * 2);
const M5_MAX_TG_MEM_BYTES: u64 = 32 * 1024;
if shared_bytes > M5_MAX_TG_MEM_BYTES {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: threadgroup memory {} bytes exceeds M5 Max \
cap of {} bytes (bt={}, bv={}, k={})",
shared_bytes, M5_MAX_TG_MEM_BYTES, p.bt, bv, p.k
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk (iter 1): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
if p.bt != 64 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk (iter 1): bt must be 64 (got {})",
p.bt
)));
}
if p.v % DEFAULT_BV != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk (iter 1): v ({}) must be a multiple of BV ({})",
p.v, DEFAULT_BV
)));
}
let nt = p.num_chunks() as usize;
let k_elems = (p.b * p.t * p.hg * p.k) as usize;
let w_elems = (p.b * p.t * p.h * p.k) as usize;
let u_elems = (p.b * p.t * p.h * p.v) as usize;
let g_elems = (p.b * p.t * p.h) as usize;
let h0_elems = (p.b * p.h * p.v * p.k) as usize;
let h_out_elems = p.b as usize * nt * (p.h * p.v * p.k) as usize;
let v_new_elems = u_elems;
let final_elems = h0_elems;
let bf16_inputs: [(&str, &MlxBuffer, usize); 4] = [
("k", k, k_elems),
("w", w, w_elems),
("u", u, u_elems),
("v_new", v_new, v_new_elems),
];
for (name, buf, exp) in bf16_inputs {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: {} must be bf16 (got {})",
name,
buf.dtype()
)));
}
}
let f32_buffers: [(&str, &MlxBuffer, usize); 4] = [
("g", g, g_elems),
("h0", h0, h0_elems),
("final_state", final_state, final_elems),
("h_out_check", h_out, h_out_elems), ];
for (name, buf, exp) in f32_buffers {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
}
if h_out.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: h_out must be bf16 (got {})",
h_out.dtype()
)));
}
for (name, buf) in [("g", g), ("h0", h0), ("final_state", final_state)] {
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_chunk: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gated_delta_net_chunk_inter_state(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
k: &MlxBuffer,
w: &MlxBuffer,
u: &MlxBuffer,
g: &MlxBuffer,
h0: &MlxBuffer,
h_out: &MlxBuffer,
v_new: &MlxBuffer,
final_state: &MlxBuffer,
params_buf: &MlxBuffer,
p: GatedDeltaNetChunkParams,
) -> Result<()> {
validate(&p, k, w, u, g, h0, h_out, v_new, final_state)?;
let pipeline = registry.get_pipeline("gated_delta_net_chunk_inter_state_bf16", device)?;
let nv_tiles = (p.v / DEFAULT_BV) as u64;
let grid_tgs = MTLSize::new(nv_tiles, p.h as u64, p.b as u64);
let tg = MTLSize::new(128, 1, 1);
let bh_bytes: u64 = (DEFAULT_BV * p.k) as u64 * 4;
let bv_stage_bytes: u64 = (DEFAULT_BV * p.bt) as u64 * 2;
let shared_bytes = bh_bytes + bv_stage_bytes;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, k),
(1, w),
(2, u),
(3, g),
(4, h0),
(5, h_out),
(6, v_new),
(7, final_state),
(8, params_buf),
],
&[(0, shared_bytes)],
grid_tgs,
tg,
);
Ok(())
}
pub fn build_gated_delta_net_chunk_params(
device: &crate::MlxDevice,
p: GatedDeltaNetChunkParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(8 * 4, DType::U32, vec![8])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.hg;
s[3] = p.h;
s[4] = p.k;
s[5] = p.v;
s[6] = p.bt;
s[7] = p.num_chunks();
}
Ok(buf)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn dummy_buf(device: &MlxDevice, dtype: DType) -> MlxBuffer {
device
.alloc_buffer(2, dtype, vec![1])
.expect("alloc dummy")
}
#[test]
fn validate_rejects_k_above_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let k_buf = dummy_buf(&device, DType::BF16);
let w_buf = dummy_buf(&device, DType::BF16);
let u_buf = dummy_buf(&device, DType::BF16);
let g_buf = dummy_buf(&device, DType::F32);
let h0_buf = dummy_buf(&device, DType::F32);
let h_out_buf = dummy_buf(&device, DType::BF16);
let v_new_buf = dummy_buf(&device, DType::BF16);
let final_state_buf = dummy_buf(&device, DType::F32);
let p = GatedDeltaNetChunkParams {
b: 1,
t: 128,
hg: 2,
h: 4,
k: 256, v: 128,
bt: 64,
};
let err = validate(
&p,
&k_buf,
&w_buf,
&u_buf,
&g_buf,
&h0_buf,
&h_out_buf,
&v_new_buf,
&final_state_buf,
)
.expect_err("validate must reject K=256");
let msg = err.to_string();
assert!(
msg.contains("256"),
"expected K=256 in error message, got: {msg}"
);
assert!(
msg.contains("MAX_K = 128") || msg.contains("MAX_K=128"),
"expected explicit MAX_K=128 in error, got: {msg}"
);
assert!(
msg.contains("must equal") || msg.contains("hard-coded"),
"expected exact-equality wording in error (post iter-2.5 narrow), got: {msg}"
);
}
}