use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use crate::MlxDevice;
use crate::ops::chunk_gated_delta_rule_tri_solve_invert::{
build_chunk_tri_solve_invert_params, dispatch_chunk_tri_solve_invert,
ChunkTriSolveInvertParams,
};
use crate::ops::gated_delta_net_chunk::{
build_gated_delta_net_chunk_params, dispatch_gated_delta_net_chunk_inter_state,
GatedDeltaNetChunkParams,
};
use crate::ops::gated_delta_net_chunk_o::{
build_gated_delta_net_chunk_o_params, dispatch_gated_delta_net_chunk_o,
GatedDeltaNetChunkOParams,
};
use crate::ops::gated_delta_net_kkt::{
build_gated_delta_net_kkt_params, dispatch_gated_delta_net_kkt,
GatedDeltaNetKktParams,
};
use crate::ops::gated_delta_net_recompute_wu::{
build_gated_delta_net_recompute_wu_params, dispatch_gated_delta_net_recompute_wu,
GatedDeltaNetRecomputeWuParams,
};
use crate::ops::l2_norm::dispatch_l2_norm;
pub static CHUNK_LOCAL_CUMSUM_G_SHADER_SOURCE: &str =
include_str!("../shaders/chunk_local_cumsum_g.metal");
pub const MAX_K: u32 = 128;
pub const MAX_V: u32 = 256;
pub const FIXED_BT: u32 = 64;
pub const L2_NORM_EPS: f32 = 1.0e-6;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"chunk_local_cumsum_g_f32",
CHUNK_LOCAL_CUMSUM_G_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkGatedDeltaRuleParams {
pub b: u32,
pub t: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
pub v: u32,
pub bt: u32,
pub scale: f32,
pub use_qk_l2norm: bool,
}
impl ChunkGatedDeltaRuleParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
}
#[allow(clippy::too_many_arguments)]
fn validate(
p: &ChunkGatedDeltaRuleParams,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g_log_decay: &MlxBuffer,
beta: &MlxBuffer,
h0: &MlxBuffer,
o: &MlxBuffer,
final_state: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.hg == 0 || p.h == 0 || p.k == 0 || p.v == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"chunk_gated_delta_rule_fwd: all dims must be > 0".into(),
));
}
if p.h % p.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: H ({}) must be a multiple of Hg ({})",
p.h, p.hg
)));
}
if p.k != MAX_K {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: K ({}) must equal MAX_K = {} exactly. \
Sub-kernels inter_state and chunk_o have compile-time-fixed 16 \
K-tiles in their simdgroup_matrix MMA loops; runtime K bounds \
defeat MMA scheduling (3.15× regression measured). To support \
other K values, port FLA's b_h1..b_h4 bank-split.",
p.k, MAX_K
)));
}
if p.v > MAX_V {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: V ({}) exceeds chunk-pipeline cap \
(MAX_V = {})",
p.v, MAX_V
)));
}
if p.bt != FIXED_BT {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd (iter 4): bt must be {} (got {})",
FIXED_BT, p.bt
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd (iter 4): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
if !p.scale.is_finite() {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: scale must be finite (got {})",
p.scale
)));
}
let q_elems = (p.b * p.t * p.hg * p.k) as usize;
let k_elems = (p.b * p.t * p.hg * p.k) as usize;
let v_elems = (p.b * p.t * p.h * p.v) as usize;
let g_elems = (p.b * p.t * p.h) as usize;
let beta_elems = (p.b * p.t * p.h) as usize;
let h0_elems = (p.b * p.h * p.v * p.k) as usize;
let o_elems = (p.b * p.t * p.h * p.v) as usize;
let final_state_elems = (p.b * p.h * p.v * p.k) as usize;
let bf16_inputs: [(&str, &MlxBuffer, usize); 4] = [
("q", q, q_elems),
("k", k, k_elems),
("v", v, v_elems),
("o", o, o_elems),
];
for (name, buf, exp) in bf16_inputs {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} must be bf16 (got {})",
name,
buf.dtype()
)));
}
}
let f32_inputs: [(&str, &MlxBuffer, usize); 4] = [
("g_log_decay", g_log_decay, g_elems),
("beta", beta, beta_elems),
("h0", h0, h0_elems),
("final_state", final_state, final_state_elems),
];
for (name, buf, exp) in f32_inputs {
if buf.element_count() != exp {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} element count {} != expected {}",
name,
buf.element_count(),
exp
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"chunk_gated_delta_rule_fwd: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
fn build_chunk_local_cumsum_g_params(
device: &MlxDevice,
p: &ChunkGatedDeltaRuleParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(5 * 4, DType::U32, vec![5])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.h;
s[3] = p.bt;
s[4] = p.num_chunks();
}
Ok(buf)
}
fn build_l2_norm_params(device: &MlxDevice, eps: f32, dim: u32) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(2 * 4, DType::F32, vec![2])?;
{
let s = buf.as_mut_slice::<f32>()?;
s[0] = eps;
s[1] = dim as f32;
}
Ok(buf)
}
fn dispatch_chunk_local_cumsum_g(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
g_in: &MlxBuffer,
g_out: &MlxBuffer,
params_buf: &MlxBuffer,
p: &ChunkGatedDeltaRuleParams,
) -> Result<()> {
let pipeline = registry.get_pipeline("chunk_local_cumsum_g_f32", device)?;
let nt = p.num_chunks() as u64;
let grid_tgs = MTLSize::new(1, p.h as u64, (p.b as u64) * nt);
let tg = MTLSize::new(p.bt as u64, 1, 1);
encoder.encode_threadgroups(
pipeline,
&[(0, g_in), (1, g_out), (2, params_buf)],
grid_tgs,
tg,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_chunk_gated_delta_rule_fwd(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
g_log_decay: &MlxBuffer,
beta: &MlxBuffer,
h0: &MlxBuffer,
o: &MlxBuffer,
final_state: &MlxBuffer,
p: ChunkGatedDeltaRuleParams,
) -> Result<()> {
validate(&p, q, k, v, g_log_decay, beta, h0, o, final_state)?;
let metal_device = device.metal_device();
let nt = p.num_chunks();
let q_qk_elems = (p.b * p.t * p.hg * p.k) as usize;
let q_normed_buf;
let k_normed_buf;
let q_for_pipeline: &MlxBuffer;
let k_for_pipeline: &MlxBuffer;
if p.use_qk_l2norm {
q_normed_buf =
device.alloc_buffer(q_qk_elems * 2, DType::BF16, vec![q_qk_elems])?;
k_normed_buf =
device.alloc_buffer(q_qk_elems * 2, DType::BF16, vec![q_qk_elems])?;
let l2_params = build_l2_norm_params(device, L2_NORM_EPS, p.k)?;
let rows = p.b * p.t * p.hg;
dispatch_l2_norm(
encoder, registry, metal_device, q, &q_normed_buf, &l2_params, rows, p.k,
)?;
encoder.memory_barrier();
dispatch_l2_norm(
encoder, registry, metal_device, k, &k_normed_buf, &l2_params, rows, p.k,
)?;
encoder.memory_barrier();
q_for_pipeline = &q_normed_buf;
k_for_pipeline = &k_normed_buf;
} else {
q_for_pipeline = q;
k_for_pipeline = k;
}
let g_elems = (p.b * p.t * p.h) as usize;
let g_cumsum_buf =
device.alloc_buffer(g_elems * 4, DType::F32, vec![g_elems])?;
let cumsum_params = build_chunk_local_cumsum_g_params(device, &p)?;
dispatch_chunk_local_cumsum_g(
encoder,
registry,
metal_device,
g_log_decay,
&g_cumsum_buf,
&cumsum_params,
&p,
)?;
encoder.memory_barrier();
let a_elems = (p.b * p.t * p.h * p.bt) as usize;
let a_strict_buf =
device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let kkt_params_value = GatedDeltaNetKktParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
bt: p.bt,
};
let kkt_params = build_gated_delta_net_kkt_params(device, kkt_params_value)?;
dispatch_gated_delta_net_kkt(
encoder,
registry,
metal_device,
k_for_pipeline,
beta,
&g_cumsum_buf,
&a_strict_buf,
&kkt_params,
kkt_params_value,
)?;
encoder.memory_barrier();
let a_inv_buf =
device.alloc_buffer(a_elems * 4, DType::F32, vec![a_elems])?;
let invert_params_value = ChunkTriSolveInvertParams {
b: p.b,
t: p.t,
h: p.h,
bt: p.bt,
};
let invert_params = build_chunk_tri_solve_invert_params(device, invert_params_value)?;
dispatch_chunk_tri_solve_invert(
encoder,
registry,
metal_device,
&a_strict_buf,
&a_inv_buf,
&invert_params,
invert_params_value,
)?;
encoder.memory_barrier();
let w_elems = (p.b * p.t * p.h * p.k) as usize;
let u_elems = (p.b * p.t * p.h * p.v) as usize;
let w_buf = device.alloc_buffer(w_elems * 2, DType::BF16, vec![w_elems])?;
let u_buf = device.alloc_buffer(u_elems * 2, DType::BF16, vec![u_elems])?;
let recompute_wu_params_value = GatedDeltaNetRecomputeWuParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
let recompute_wu_params =
build_gated_delta_net_recompute_wu_params(device, recompute_wu_params_value)?;
dispatch_gated_delta_net_recompute_wu(
encoder,
registry,
metal_device,
k_for_pipeline,
v,
beta,
&g_cumsum_buf,
&a_inv_buf,
&w_buf,
&u_buf,
&recompute_wu_params,
recompute_wu_params_value,
)?;
encoder.memory_barrier();
let h_elems = (p.b * nt * p.h * p.v * p.k) as usize;
let v_new_elems = (p.b * p.t * p.h * p.v) as usize;
let h_buf = device.alloc_buffer(h_elems * 2, DType::BF16, vec![h_elems])?;
let v_new_buf =
device.alloc_buffer(v_new_elems * 2, DType::BF16, vec![v_new_elems])?;
let chunk_params_value = GatedDeltaNetChunkParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
};
let chunk_params = build_gated_delta_net_chunk_params(device, chunk_params_value)?;
dispatch_gated_delta_net_chunk_inter_state(
encoder,
registry,
metal_device,
k_for_pipeline,
&w_buf,
&u_buf,
&g_cumsum_buf,
h0,
&h_buf,
&v_new_buf,
final_state,
&chunk_params,
chunk_params_value,
)?;
encoder.memory_barrier();
let chunk_o_params_value = GatedDeltaNetChunkOParams {
b: p.b,
t: p.t,
hg: p.hg,
h: p.h,
k: p.k,
v: p.v,
bt: p.bt,
scale: p.scale,
};
let chunk_o_params =
build_gated_delta_net_chunk_o_params(device, chunk_o_params_value)?;
dispatch_gated_delta_net_chunk_o(
encoder,
registry,
metal_device,
q_for_pipeline,
k_for_pipeline,
&v_new_buf,
&h_buf,
&g_cumsum_buf,
o,
&chunk_o_params,
chunk_o_params_value,
)?;
Ok(())
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn dummy_buf(device: &MlxDevice, dtype: DType) -> MlxBuffer {
device.alloc_buffer(2, dtype, vec![1]).expect("alloc dummy")
}
#[test]
fn validate_rejects_k_above_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let q_buf = dummy_buf(&device, DType::BF16);
let k_buf = dummy_buf(&device, DType::BF16);
let v_buf = dummy_buf(&device, DType::BF16);
let g_buf = dummy_buf(&device, DType::F32);
let beta_buf = dummy_buf(&device, DType::F32);
let h0_buf = dummy_buf(&device, DType::F32);
let o_buf = dummy_buf(&device, DType::BF16);
let final_state_buf = dummy_buf(&device, DType::F32);
let p = ChunkGatedDeltaRuleParams {
b: 1,
t: 128,
hg: 2,
h: 4,
k: 256, v: 128,
bt: 64,
scale: (128f32).powf(-0.5),
use_qk_l2norm: true,
};
let err = validate(
&p,
&q_buf,
&k_buf,
&v_buf,
&g_buf,
&beta_buf,
&h0_buf,
&o_buf,
&final_state_buf,
)
.expect_err("validate must reject K=256");
let msg = err.to_string();
assert!(
msg.contains("256"),
"expected K=256 in error message, got: {msg}"
);
assert!(
msg.contains("MAX_K = 128") || msg.contains("MAX_K=128"),
"expected explicit MAX_K=128 in error (orchestrator inherits sub-kernel \
K==128-exact constraint per Wave 5b.2 iter 2.5), got: {msg}"
);
assert!(
msg.contains("must equal") || msg.contains("hard-coded"),
"expected exact-equality wording in error, got: {msg}"
);
}
}