use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::MlxDevice;
use crate::ops::chunk_gated_delta_rule_tri_solve_invert::{
build_chunk_tri_solve_invert_params, dispatch_chunk_tri_solve_invert,
ChunkTriSolveInvertParams,
};
use crate::ops::gated_delta_net_chunk::{
build_gated_delta_net_chunk_params, dispatch_gated_delta_net_chunk_inter_state,
GatedDeltaNetChunkParams,
};
use crate::ops::gated_delta_net_chunk_o::{
build_gated_delta_net_chunk_o_params, dispatch_gated_delta_net_chunk_o,
GatedDeltaNetChunkOParams,
};
use crate::ops::gated_delta_net_kkt::{
build_gated_delta_net_kkt_params, dispatch_gated_delta_net_kkt,
GatedDeltaNetKktParams,
};
use crate::ops::gated_delta_net_recompute_wu::{
build_gated_delta_net_recompute_wu_params, dispatch_gated_delta_net_recompute_wu,
GatedDeltaNetRecomputeWuParams,
};
use crate::ops::l2_norm::dispatch_l2_norm;
pub static CHUNK_LOCAL_CUMSUM_G_SHADER_SOURCE: &str =
include_str!("../shaders/chunk_local_cumsum_g.metal");
pub const MAX_K: u32 = 128;
pub const MAX_V: u32 = 256;
pub const FIXED_BT: u32 = 64;
pub const L2_NORM_EPS: f32 = 1.0e-6;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"chunk_local_cumsum_g_f32",
CHUNK_LOCAL_CUMSUM_G_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkGatedDeltaRuleParams {
pub b: u32,
pub t: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
pub v: u32,
pub bt: u32,
pub scale: f32,
pub use_qk_l2norm: bool,
}
impl ChunkGatedDeltaRuleParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
}
#[allow(clippy::too_many_arguments)]
fn validate(
p: &ChunkGatedDeltaRuleParams,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g_log_decay: &MlxBuffer,
beta: &MlxBuffer,
h0: &MlxBuffer,
o: &MlxBuffer,
final_state: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.hg == 0 || p.h == 0 || p.k == 0 || p.v == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"chunk_gated_delta_rule_fwd: all dims must be > 0".into(),
));
}
if p.h % p.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: H ({}) must be a multiple of Hg ({})",
p.h, p.hg
)));
}
if p.k != MAX_K {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: K ({}) must equal MAX_K = {} exactly. \
Sub-kernels inter_state and chunk_o have compile-time-fixed 16 \
K-tiles in their simdgroup_matrix MMA loops; runtime K bounds \
defeat MMA scheduling (3.15× regression measured). To support \
other K values, port FLA's b_h1..b_h4 bank-split.",
p.k, MAX_K
)));
}
if p.v > MAX_V {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: V ({}) exceeds chunk-pipeline cap \
(MAX_V = {})",
p.v, MAX_V
)));
}
if p.bt != FIXED_BT {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd (iter 4): bt must be {} (got {})",
FIXED_BT, p.bt
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd (iter 4): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
if !p.scale.is_finite() {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: scale must be finite (got {})",
p.scale
)));
}
let q_elems = (p.b * p.t * p.hg * p.k) as usize;
let k_elems = (p.b * p.t * p.hg * p.k) as usize;
let v_elems = (p.b * p.t * p.h * p.v) as usize;
let g_elems = (p.b * p.t * p.h) as usize;
let beta_elems = (p.b * p.t * p.h) as usize;
let h0_elems = (p.b * p.h * p.v * p.k) as usize;
let o_elems = (p.b * p.t * p.h * p.v) as usize;
let final_state_elems = (p.b * p.h * p.v * p.k) as usize;
let bf16_inputs: [(&str, &MlxBuffer, usize); 4] = [
("q", q, q_elems),
("k", k, k_elems),
("v", v, v_elems),
("o", o, o_elems),
];
for (name, buf, exp) in bf16_inputs {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} must be bf16 (got {})",
name,
buf.dtype()
)));
}
}
let f32_inputs: [(&str, &MlxBuffer, usize); 4] = [
("g_log_decay", g_log_decay, g_elems),
("beta", beta, beta_elems),
("h0", h0, h0_elems),
("final_state", final_state, final_state_elems),
];
for (name, buf, exp) in f32_inputs {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
fn build_chunk_local_cumsum_g_params(
device: &MlxDevice,
p: &ChunkGatedDeltaRuleParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(5 * 4, DType::U32, vec![5])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.h;
s[3] = p.bt;
s[4] = p.num_chunks();
}
Ok(buf)
}
fn build_l2_norm_params(device: &MlxDevice, eps: f32, dim: u32) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(2 * 4, DType::F32, vec![2])?;
{
let s = buf.as_mut_slice::<f32>()?;
s[0] = eps;
s[1] = dim as f32;
}
Ok(buf)
}
fn dispatch_chunk_local_cumsum_g(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
g_in: &MlxBuffer,
g_out: &MlxBuffer,
params_buf: &MlxBuffer,
p: &ChunkGatedDeltaRuleParams,
) -> Result<()> {
let pipeline = registry.get_pipeline("chunk_local_cumsum_g_f32", device)?;
let nt = p.num_chunks() as u64;
let grid_tgs = MTLSize::new(1, p.h as u64, (p.b as u64) * nt);
let tg = MTLSize::new(p.bt as u64, 1, 1);
encoder.encode_threadgroups(
pipeline,
&[(0, g_in), (1, g_out), (2, params_buf)],
grid_tgs,
tg,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_chunk_gated_delta_rule_fwd(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g_log_decay: &MlxBuffer,
beta: &MlxBuffer,
h0: &MlxBuffer,
o: &MlxBuffer,
final_state: &MlxBuffer,
p: ChunkGatedDeltaRuleParams,
) -> Result<()> {
validate(&p, q, k, v, g_log_decay, beta, h0, o, final_state)?;
let metal_device = device.metal_device();
let nt = p.num_chunks();
let q_qk_elems = (p.b * p.t * p.hg * p.k) as usize;
let q_normed_buf;
let k_normed_buf;
let q_for_pipeline: &MlxBuffer;
let k_for_pipeline: &MlxBuffer;
if p.use_qk_l2norm {
q_normed_buf =
device.alloc_buffer(q_qk_elems * 2, DType::BF16, vec![q_qk_elems])?;
k_normed_buf =
device.alloc_buffer(q_qk_elems * 2, DType::BF16, vec![q_qk_elems])?;
let l2_params = build_l2_norm_params(device, L2_NORM_EPS, p.k)?;
let rows = p.b * p.t * p.hg;
dispatch_l2_norm(
encoder, registry, metal_device, q, &q_normed_buf, &l2_params, rows, p.k,
)?;
encoder.memory_barrier();
dispatch_l2_norm(
encoder, registry, metal_device, k, &k_normed_buf, &l2_params, rows, p.k,
)?;
encoder.memory_barrier();
q_for_pipeline = &q_normed_buf;
k_for_pipeline = &k_normed_buf;
} else {
q_for_pipeline = q;
k_for_pipeline = k;
}
let g_elems = (p.b * p.t * p.h) as usize;
let g_cumsum_buf =
device.alloc_buffer(g_elems * 4, DType::F32, vec![g_elems])?;
let cumsum_params = build_chunk_local_cumsum_g_params(device, &p)?;
dispatch_chunk_local_cumsum_g(
encoder,
registry,
metal_device,
g_log_decay,
&g_cumsum_buf,
&cumsum_params,
&p,
)?;
encoder.memory_barrier();
let a_elems = (p.b * p.t * p.h * p.bt) as usize;
let a_strict_buf =
device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let kkt_params_value = GatedDeltaNetKktParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
bt: p.bt,
};
let kkt_params = build_gated_delta_net_kkt_params(device, kkt_params_value)?;
dispatch_gated_delta_net_kkt(
encoder,
registry,
metal_device,
k_for_pipeline,
beta,
&g_cumsum_buf,
&a_strict_buf,
&kkt_params,
kkt_params_value,
)?;
encoder.memory_barrier();
let a_inv_buf =
device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let invert_params_value = ChunkTriSolveInvertParams {
b: p.b,
t: p.t,
h: p.h,
bt: p.bt,
};
let invert_params = build_chunk_tri_solve_invert_params(device, invert_params_value)?;
dispatch_chunk_tri_solve_invert(
encoder,
registry,
metal_device,
&a_strict_buf,
&a_inv_buf,
&invert_params,
invert_params_value,
)?;
encoder.memory_barrier();
let w_elems = (p.b * p.t * p.h * p.k) as usize;
let u_elems = (p.b * p.t * p.h * p.v) as usize;
let w_buf = device.alloc_buffer(w_elems * 2, DType::BF16, vec![w_elems])?;
let u_buf = device.alloc_buffer(u_elems * 2, DType::BF16, vec![u_elems])?;
let recompute_wu_params_value = GatedDeltaNetRecomputeWuParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
let recompute_wu_params =
build_gated_delta_net_recompute_wu_params(device, recompute_wu_params_value)?;
dispatch_gated_delta_net_recompute_wu(
encoder,
registry,
metal_device,
k_for_pipeline,
v,
beta,
&g_cumsum_buf,
&a_inv_buf,
&w_buf,
&u_buf,
&recompute_wu_params,
recompute_wu_params_value,
)?;
encoder.memory_barrier();
let h_elems = (p.b * nt * p.h * p.v * p.k) as usize;
let v_new_elems = (p.b * p.t * p.h * p.v) as usize;
let h_buf = device.alloc_buffer(h_elems * 2, DType::BF16, vec![h_elems])?;
let v_new_buf =
device.alloc_buffer(v_new_elems * 2, DType::BF16, vec![v_new_elems])?;
let chunk_params_value = GatedDeltaNetChunkParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
let chunk_params = build_gated_delta_net_chunk_params(device, chunk_params_value)?;
dispatch_gated_delta_net_chunk_inter_state(
encoder,
registry,
metal_device,
k_for_pipeline,
&w_buf,
&u_buf,
&g_cumsum_buf,
h0,
&h_buf,
&v_new_buf,
final_state,
&chunk_params,
chunk_params_value,
)?;
encoder.memory_barrier();
let chunk_o_params_value = GatedDeltaNetChunkOParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
scale: p.scale,
};
let chunk_o_params =
build_gated_delta_net_chunk_o_params(device, chunk_o_params_value)?;
dispatch_gated_delta_net_chunk_o(
encoder,
registry,
metal_device,
q_for_pipeline,
k_for_pipeline,
&v_new_buf,
&h_buf,
&g_cumsum_buf,
o,
&chunk_o_params,
chunk_o_params_value,
)?;
Ok(())
}
pub struct ChunkInternalArena {
pub g_cumsum_buf: MlxBuffer,
pub a_strict_buf: MlxBuffer,
pub a_inv_buf: MlxBuffer,
pub w_buf: MlxBuffer,
pub u_buf: MlxBuffer,
pub h_buf: MlxBuffer,
pub v_new_buf: MlxBuffer,
pub cumsum_params_buf: MlxBuffer,
pub kkt_params_buf: MlxBuffer,
pub invert_params_buf: MlxBuffer,
pub recompute_wu_params_buf: MlxBuffer,
pub chunk_params_buf: MlxBuffer,
pub chunk_o_params_buf: MlxBuffer,
pub b_capacity: u32,
pub t_capacity: u32,
pub hg_capacity: u32,
pub h_capacity: u32,
pub k_capacity: u32,
pub v_capacity: u32,
pub bt_capacity: u32,
}
impl ChunkInternalArena {
pub fn new(
device: &MlxDevice,
b: u32,
t: u32,
hg: u32,
h: u32,
k: u32,
v: u32,
bt: u32,
) -> Result<Self> {
if b == 0 || t == 0 || hg == 0 || h == 0 || k == 0 || v == 0 || bt == 0 {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: zero dim b={} t={} hg={} h={} k={} v={} bt={}",
b, t, hg, h, k, v, bt,
)));
}
if h % hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: H ({}) must be a multiple of Hg ({})",
h, hg
)));
}
if k != MAX_K {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: K ({}) must equal MAX_K = {} \
(sub-kernel hard constraint)",
k, MAX_K
)));
}
if v > MAX_V {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: V ({}) > MAX_V ({})",
v, MAX_V
)));
}
if bt != FIXED_BT {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: bt ({}) must equal FIXED_BT ({})",
bt, FIXED_BT
)));
}
if t % bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::new: t ({}) must be a multiple of bt ({})",
t, bt
)));
}
let nt = t.div_ceil(bt);
let g_elems = (b * t * h) as usize;
let a_elems = (b * t * h * bt) as usize;
let w_elems = (b * t * h * k) as usize;
let u_elems = (b * t * h * v) as usize;
let h_elems = (b * nt * h * v * k) as usize;
let v_new_elems = (b * t * h * v) as usize;
let g_cumsum_buf = device.alloc_buffer(g_elems * 4, DType::F32, vec![g_elems])?;
let a_strict_buf = device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let a_inv_buf = device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let w_buf = device.alloc_buffer(w_elems * 2, DType::BF16, vec![w_elems])?;
let u_buf = device.alloc_buffer(u_elems * 2, DType::BF16, vec![u_elems])?;
let h_buf = device.alloc_buffer(h_elems * 2, DType::BF16, vec![h_elems])?;
let v_new_buf = device.alloc_buffer(v_new_elems * 2, DType::BF16, vec![v_new_elems])?;
let cumsum_params_buf = device.alloc_buffer(5 * 4, DType::U32, vec![5])?;
let kkt_params_buf =
build_gated_delta_net_kkt_params(device, GatedDeltaNetKktParams {
b,
t,
hg,
h,
k,
bt,
})?;
let invert_params_buf =
build_chunk_tri_solve_invert_params(device, ChunkTriSolveInvertParams {
b,
t,
h,
bt,
})?;
let recompute_wu_params_buf = build_gated_delta_net_recompute_wu_params(
device,
GatedDeltaNetRecomputeWuParams {
b,
t,
hg,
h,
k,
v,
bt,
},
)?;
let chunk_params_buf = build_gated_delta_net_chunk_params(
device,
GatedDeltaNetChunkParams {
b,
t,
hg,
h,
k,
v,
bt,
},
)?;
let chunk_o_params_buf = build_gated_delta_net_chunk_o_params(
device,
GatedDeltaNetChunkOParams {
b,
t,
hg,
h,
k,
v,
bt,
scale: 1.0,
},
)?;
let mut cumsum_params_buf = cumsum_params_buf;
{
let s = cumsum_params_buf.as_mut_slice::<u32>()?;
s[0] = b;
s[1] = t;
s[2] = h;
s[3] = bt;
s[4] = nt;
}
Ok(Self {
g_cumsum_buf,
a_strict_buf,
a_inv_buf,
w_buf,
u_buf,
h_buf,
v_new_buf,
cumsum_params_buf,
kkt_params_buf,
invert_params_buf,
recompute_wu_params_buf,
chunk_params_buf,
chunk_o_params_buf,
b_capacity: b,
t_capacity: t,
hg_capacity: hg,
h_capacity: h,
k_capacity: k,
v_capacity: v,
bt_capacity: bt,
})
}
pub fn validate_fits(
&self,
b: u32,
t: u32,
hg: u32,
h: u32,
k: u32,
v: u32,
bt: u32,
) -> Result<()> {
if b != self.b_capacity
|| t != self.t_capacity
|| hg != self.hg_capacity
|| h != self.h_capacity
|| k != self.k_capacity
|| v != self.v_capacity
|| bt != self.bt_capacity
{
return Err(MlxError::InvalidArgument(format!(
"ChunkInternalArena::validate_fits: shape mismatch — \
capacity (b={}, t={}, hg={}, h={}, k={}, v={}, bt={}) \
vs call (b={}, t={}, hg={}, h={}, k={}, v={}, bt={})",
self.b_capacity,
self.t_capacity,
self.hg_capacity,
self.h_capacity,
self.k_capacity,
self.v_capacity,
self.bt_capacity,
b,
t,
hg,
h,
k,
v,
bt,
)));
}
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_chunk_gated_delta_rule_fwd_with_arena(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g_log_decay: &MlxBuffer,
beta: &MlxBuffer,
h0: &MlxBuffer,
o: &MlxBuffer,
final_state: &MlxBuffer,
arena: &mut ChunkInternalArena,
p: ChunkGatedDeltaRuleParams,
) -> Result<()> {
if p.use_qk_l2norm {
return Err(MlxError::InvalidArgument(
"dispatch_chunk_gated_delta_rule_fwd_with_arena: use_qk_l2norm=true is \
reserved for the non-arena variant. Pre-apply l2-norm in the wrapper \
and pass use_qk_l2norm=false."
.into(),
));
}
validate(&p, q, k, v, g_log_decay, beta, h0, o, final_state)?;
arena.validate_fits(p.b, p.t, p.hg, p.h, p.k, p.v, p.bt)?;
let metal_device = device.metal_device();
let q_for_pipeline: &MlxBuffer = q;
let k_for_pipeline: &MlxBuffer = k;
dispatch_chunk_local_cumsum_g(
encoder,
registry,
metal_device,
g_log_decay,
&arena.g_cumsum_buf,
&arena.cumsum_params_buf,
&p,
)?;
encoder.memory_barrier();
let kkt_params_value = GatedDeltaNetKktParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
bt: p.bt,
};
dispatch_gated_delta_net_kkt(
encoder,
registry,
metal_device,
k_for_pipeline,
beta,
&arena.g_cumsum_buf,
&arena.a_strict_buf,
&arena.kkt_params_buf,
kkt_params_value,
)?;
encoder.memory_barrier();
let invert_params_value = ChunkTriSolveInvertParams {
b: p.b,
t: p.t,
h: p.h,
bt: p.bt,
};
dispatch_chunk_tri_solve_invert(
encoder,
registry,
metal_device,
&arena.a_strict_buf,
&arena.a_inv_buf,
&arena.invert_params_buf,
invert_params_value,
)?;
encoder.memory_barrier();
let recompute_wu_params_value = GatedDeltaNetRecomputeWuParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
dispatch_gated_delta_net_recompute_wu(
encoder,
registry,
metal_device,
k_for_pipeline,
v,
beta,
&arena.g_cumsum_buf,
&arena.a_inv_buf,
&arena.w_buf,
&arena.u_buf,
&arena.recompute_wu_params_buf,
recompute_wu_params_value,
)?;
encoder.memory_barrier();
let chunk_params_value = GatedDeltaNetChunkParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
dispatch_gated_delta_net_chunk_inter_state(
encoder,
registry,
metal_device,
k_for_pipeline,
&arena.w_buf,
&arena.u_buf,
&arena.g_cumsum_buf,
h0,
&arena.h_buf,
&arena.v_new_buf,
final_state,
&arena.chunk_params_buf,
chunk_params_value,
)?;
encoder.memory_barrier();
let chunk_o_params_value = GatedDeltaNetChunkOParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
scale: p.scale,
};
debug_assert!(
arena.chunk_o_params_buf.byte_len() >= 11 * 4,
"chunk_o_params_buf too small: {}",
arena.chunk_o_params_buf.byte_len()
);
{
let s = arena.chunk_o_params_buf.as_mut_slice::<u32>()?;
s[0] = chunk_o_params_value.b;
s[1] = chunk_o_params_value.t;
s[2] = chunk_o_params_value.hg;
s[3] = chunk_o_params_value.h;
s[4] = chunk_o_params_value.k;
s[5] = chunk_o_params_value.v;
s[6] = chunk_o_params_value.bt;
s[7] = chunk_o_params_value.num_chunks();
s[8] = crate::ops::gated_delta_net_chunk_o::DEFAULT_BK;
s[9] = crate::ops::gated_delta_net_chunk_o::DEFAULT_BV;
s[10] = chunk_o_params_value.scale.to_bits();
}
dispatch_gated_delta_net_chunk_o(
encoder,
registry,
metal_device,
q_for_pipeline,
k_for_pipeline,
&arena.v_new_buf,
&arena.h_buf,
&arena.g_cumsum_buf,
o,
&arena.chunk_o_params_buf,
chunk_o_params_value,
)?;
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn dummy_buf(device: &MlxDevice, dtype: DType) -> MlxBuffer {
device.alloc_buffer(2, dtype, vec![1]).expect("alloc dummy")
}
#[test]
fn validate_rejects_k_above_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let q_buf = dummy_buf(&device, DType::BF16);
let k_buf = dummy_buf(&device, DType::BF16);
let v_buf = dummy_buf(&device, DType::BF16);
let g_buf = dummy_buf(&device, DType::F32);
let beta_buf = dummy_buf(&device, DType::F32);
let h0_buf = dummy_buf(&device, DType::F32);
let o_buf = dummy_buf(&device, DType::BF16);
let final_state_buf = dummy_buf(&device, DType::F32);
let p = ChunkGatedDeltaRuleParams {
b: 1,
t: 128,
hg: 2,
h: 4,
k: 256, v: 128,
bt: 64,
scale: (128f32).powf(-0.5),
use_qk_l2norm: true,
};
let err = validate(
&p,
&q_buf,
&k_buf,
&v_buf,
&g_buf,
&beta_buf,
&h0_buf,
&o_buf,
&final_state_buf,
)
.expect_err("validate must reject K=256");
let msg = err.to_string();
assert!(
msg.contains("256"),
"expected K=256 in error message, got: {msg}"
);
assert!(
msg.contains("MAX_K = 128") || msg.contains("MAX_K=128"),
"expected explicit MAX_K=128 in error (orchestrator inherits sub-kernel \
K==128-exact constraint per Wave 5b.2 iter 2.5), got: {msg}"
);
assert!(
msg.contains("must equal") || msg.contains("hard-coded"),
"expected exact-equality wording in error, got: {msg}"
);
}
#[test]
fn arena_new_apex_pp4096_shape_succeeds() {
let device = MlxDevice::new().expect("MlxDevice::new");
let arena = ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 64)
.expect("arena alloc apex pp4096");
assert_eq!(arena.b_capacity, 1);
assert_eq!(arena.t_capacity, 4096);
assert_eq!(arena.hg_capacity, 32);
assert_eq!(arena.h_capacity, 32);
assert_eq!(arena.k_capacity, 128);
assert_eq!(arena.v_capacity, 128);
assert_eq!(arena.bt_capacity, 64);
assert_eq!(arena.h_buf.byte_len(), 1 * 64 * 32 * 128 * 128 * 2);
}
#[test]
fn arena_new_rejects_zero_dim() {
let device = MlxDevice::new().expect("MlxDevice::new");
assert!(ChunkInternalArena::new(&device, 0, 4096, 32, 32, 128, 128, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 0, 32, 32, 128, 128, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 4096, 0, 32, 128, 128, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 4096, 32, 0, 128, 128, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 4096, 32, 32, 0, 128, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 0, 64).is_err());
assert!(ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 0).is_err());
}
#[test]
fn arena_new_enforces_k_eq_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let err = match ChunkInternalArena::new(&device, 1, 4096, 32, 32, 64, 128, 64) {
Err(e) => e,
Ok(_) => panic!("arena must reject K!=128"),
};
let msg = err.to_string();
assert!(msg.contains("MAX_K"), "got: {msg}");
}
#[test]
fn arena_new_enforces_bt_eq_fixed() {
let device = MlxDevice::new().expect("MlxDevice::new");
let err = match ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 32) {
Err(e) => e,
Ok(_) => panic!("arena must reject bt!=FIXED_BT"),
};
let msg = err.to_string();
assert!(msg.contains("FIXED_BT"), "got: {msg}");
}
#[test]
fn arena_new_enforces_t_multiple_of_bt() {
let device = MlxDevice::new().expect("MlxDevice::new");
let err = match ChunkInternalArena::new(&device, 1, 4123, 32, 32, 128, 128, 64) {
Err(e) => e,
Ok(_) => panic!("arena must reject t%bt!=0"),
};
let msg = err.to_string();
assert!(msg.contains("multiple"), "got: {msg}");
}
#[test]
fn arena_new_enforces_h_multiple_of_hg() {
let device = MlxDevice::new().expect("MlxDevice::new");
let err = match ChunkInternalArena::new(&device, 1, 4096, 3, 32, 128, 128, 64) {
Err(e) => e,
Ok(_) => panic!("arena must reject H%Hg!=0"),
};
let msg = err.to_string();
assert!(msg.contains("multiple"), "got: {msg}");
}
#[test]
fn arena_validate_fits_accepts_exact_match() {
let device = MlxDevice::new().expect("MlxDevice::new");
let arena = ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 64)
.expect("arena alloc");
arena.validate_fits(1, 4096, 32, 32, 128, 128, 64).expect("exact match accepts");
}
#[test]
fn arena_validate_fits_rejects_drift() {
let device = MlxDevice::new().expect("MlxDevice::new");
let arena = ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 64)
.expect("arena alloc");
assert!(arena.validate_fits(2, 4096, 32, 32, 128, 128, 64).is_err());
assert!(arena.validate_fits(1, 2048, 32, 32, 128, 128, 64).is_err());
assert!(arena.validate_fits(1, 4096, 16, 32, 128, 128, 64).is_err());
assert!(arena.validate_fits(1, 4096, 32, 16, 128, 128, 64).is_err());
assert!(arena.validate_fits(1, 4096, 32, 32, 64, 128, 64).is_err());
assert!(arena.validate_fits(1, 4096, 32, 32, 128, 64, 64).is_err());
assert!(arena.validate_fits(1, 4096, 32, 32, 128, 128, 32).is_err());
}
#[test]
fn arena_dispatch_rejects_use_qk_l2norm_true() {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut arena = ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 64)
.expect("arena alloc");
let mut registry = crate::KernelRegistry::default();
let mut enc = device.command_encoder().expect("encoder");
let q = dummy_buf(&device, DType::BF16);
let k = dummy_buf(&device, DType::BF16);
let v = dummy_buf(&device, DType::BF16);
let g = dummy_buf(&device, DType::F32);
let beta = dummy_buf(&device, DType::F32);
let h0 = dummy_buf(&device, DType::F32);
let o = dummy_buf(&device, DType::BF16);
let final_state = dummy_buf(&device, DType::F32);
let p = ChunkGatedDeltaRuleParams {
b: 1,
t: 4096,
hg: 32,
h: 32,
k: 128,
v: 128,
bt: 64,
scale: 1.0,
use_qk_l2norm: true, };
let err = dispatch_chunk_gated_delta_rule_fwd_with_arena(
&mut enc,
&mut registry,
&device,
&q,
&k,
&v,
&g,
&beta,
&h0,
&o,
&final_state,
&mut arena,
p,
)
.expect_err("must reject use_qk_l2norm=true");
let msg = err.to_string();
assert!(msg.contains("use_qk_l2norm"), "got: {msg}");
assert!(msg.contains("non-arena"), "got: {msg}");
}
#[test]
fn arena_chunk_o_params_layout_matches_helper() {
let device = MlxDevice::new().expect("MlxDevice::new");
let arena = ChunkInternalArena::new(&device, 1, 4096, 32, 32, 128, 128, 64)
.expect("arena alloc");
assert!(
arena.chunk_o_params_buf.byte_len() >= 11 * 4,
"chunk_o_params_buf byte_len = {}, expected >= 44",
arena.chunk_o_params_buf.byte_len()
);
}
}