use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static GATED_DELTA_NET_KKT_SHADER_SOURCE: &str =
include_str!("../shaders/gated_delta_net_kkt.metal");
pub const MAX_K: u32 = 192;
pub const DEFAULT_BK: u32 = 64;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"gated_delta_net_kkt_bf16",
GATED_DELTA_NET_KKT_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct GatedDeltaNetKktParams {
pub b: u32,
pub t: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
pub bt: u32,
}
impl GatedDeltaNetKktParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
}
fn validate(
p: &GatedDeltaNetKktParams,
k: &MlxBuffer,
beta: &MlxBuffer,
g: &MlxBuffer,
a: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.hg == 0 || p.h == 0 || p.k == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"gated_delta_net_kkt: all dims must be > 0".into(),
));
}
if p.h % p.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: h ({}) must be a multiple of hg ({})",
p.h, p.hg
)));
}
if p.k > MAX_K {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: K ({}) exceeds iter-2 32 KB threadgroup memory \
budget (MAX_K = {}); iter-3 will autotune past this",
p.k, MAX_K
)));
}
if p.bt != 64 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt (iter 2): bt must be 64 (got {})",
p.bt
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt (iter 2): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
if p.k % DEFAULT_BK != 0 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt (iter 2): k ({}) must be a multiple of BK ({})",
p.k, DEFAULT_BK
)));
}
let shared_bytes: u64 = ((p.bt * DEFAULT_BK) as u64) * 2 + ((p.bt * p.bt) as u64) * 4;
const M5_MAX_TG_MEM_BYTES: u64 = 32 * 1024;
if shared_bytes > M5_MAX_TG_MEM_BYTES {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: threadgroup memory {} bytes exceeds M5 Max \
cap of {} bytes (bt={}, bk={}, k={})",
shared_bytes, M5_MAX_TG_MEM_BYTES, p.bt, DEFAULT_BK, p.k
)));
}
let k_elems = (p.b * p.t * p.hg * p.k) as usize;
let beta_elems = (p.b * p.t * p.h) as usize;
let g_elems = (p.b * p.t * p.h) as usize;
let a_elems = (p.b * p.t * p.h * p.bt) as usize;
if k.element_count() != k_elems || k.dtype() != DType::BF16 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: k must be bf16[{}] (got {} {})",
k_elems,
k.element_count(),
k.dtype()
)));
}
if beta.element_count() != beta_elems || beta.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: beta must be f32[{}] (got {} {})",
beta_elems,
beta.element_count(),
beta.dtype()
)));
}
if g.element_count() != g_elems || g.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: g must be f32[{}] (got {} {})",
g_elems,
g.element_count(),
g.dtype()
)));
}
if a.element_count() != a_elems || a.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"gated_delta_net_kkt: A must be f32[{}] (got {} {})",
a_elems,
a.element_count(),
a.dtype()
)));
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gated_delta_net_kkt(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
k: &MlxBuffer,
beta: &MlxBuffer,
g: &MlxBuffer,
a: &MlxBuffer,
params_buf: &MlxBuffer,
p: GatedDeltaNetKktParams,
) -> Result<()> {
validate(&p, k, beta, g, a)?;
let pipeline = registry.get_pipeline("gated_delta_net_kkt_bf16", device)?;
let grid_tgs = MTLSize::new(p.num_chunks() as u64, p.h as u64, p.b as u64);
let tg = MTLSize::new(256, 1, 1);
let bk_stage_bytes: u64 = (p.bt as u64) * (DEFAULT_BK as u64) * 2;
let ba_acc_bytes: u64 = (p.bt as u64) * (p.bt as u64) * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, k), (1, beta), (2, g), (3, a), (4, params_buf)],
&[(0, bk_stage_bytes), (1, ba_acc_bytes)],
grid_tgs,
tg,
);
Ok(())
}
pub fn build_gated_delta_net_kkt_params(
device: &crate::MlxDevice,
p: GatedDeltaNetKktParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(8 * 4, DType::U32, vec![8])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.hg;
s[3] = p.h;
s[4] = p.k;
s[5] = p.bt;
s[6] = p.num_chunks();
s[7] = DEFAULT_BK;
}
Ok(buf)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::MlxDevice;
fn dummy_buf(device: &MlxDevice, dtype: DType) -> MlxBuffer {
device
.alloc_buffer(2, dtype, vec![1])
.expect("alloc dummy")
}
#[test]
fn validate_rejects_k_above_max() {
let device = MlxDevice::new().expect("MlxDevice::new");
let k_buf = dummy_buf(&device, DType::BF16);
let beta_buf = dummy_buf(&device, DType::F32);
let g_buf = dummy_buf(&device, DType::F32);
let a_buf = dummy_buf(&device, DType::F32);
let p = GatedDeltaNetKktParams {
b: 1,
t: 128,
hg: 2,
h: 4,
k: 256, bt: 64,
};
let err = validate(&p, &k_buf, &beta_buf, &g_buf, &a_buf)
.expect_err("validate must reject K=256");
let msg = err.to_string();
assert!(
msg.contains("256"),
"expected K=256 in error message, got: {msg}"
);
assert!(
msg.contains("32 KB") || msg.contains("threadgroup"),
"expected threadgroup-memory-budget context in error, got: {msg}"
);
assert!(
msg.contains("MAX_K = 192") || msg.contains("MAX_K=192"),
"expected explicit MAX_K cap in error, got: {msg}"
);
}
}