use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static QMM_AFFINE_SHADER_SOURCE: &str = include_str!("../shaders/qmm_affine.metal");
pub static QMM_AFFINE_TILED_SHADER_SOURCE: &str =
include_str!("../shaders/qmm_affine_tiled.metal");
pub static QMM_AFFINE_SIMD_SHADER_SOURCE: &str =
include_str!("../shaders/qmm_affine_simd.metal");
pub static QMM_AFFINE_SIMD4_SHADER_SOURCE: &str =
include_str!("../shaders/qmm_affine_simd4.metal");
pub static QMM_AFFINE_SIMD4_GS64_SHADER_SOURCE: &str =
include_str!("../shaders/qmm_affine_simd4_gs64.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("qmm_affine_t_f32", QMM_AFFINE_SHADER_SOURCE);
registry.register_source(
"qmm_affine_t_f32_tiled",
QMM_AFFINE_TILED_SHADER_SOURCE,
);
registry.register_source(
"qmm_affine_t_f32_simd",
QMM_AFFINE_SIMD_SHADER_SOURCE,
);
registry.register_source(
"qmm_affine_t_f32_simd4",
QMM_AFFINE_SIMD4_SHADER_SOURCE,
);
registry.register_source(
"qmm_affine_t_f32_simd4_gs64",
QMM_AFFINE_SIMD4_GS64_SHADER_SOURCE,
);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qmm_affine_t_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
q_int: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
y: &MlxBuffer,
meta: &MlxBuffer,
m: u32,
n: u32,
k: u32,
group_size: u32,
) -> Result<()> {
const OP: &str = "qmm_affine_t_f32";
if m == 0 || n == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: M, N, K must all be > 0; got ({m}, {n}, {k})"
)));
}
if !(2..=1024).contains(&group_size) || !group_size.is_power_of_two() {
return Err(MlxError::InvalidArgument(format!(
"{OP}: group_size must be a power of two in [2, 1024]; got {group_size}"
)));
}
if k % group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: K ({k}) must be divisible by group_size ({group_size})"
)));
}
if x.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: x dtype {} not f32",
x.dtype()
)));
}
if q_int.dtype() != DType::U8 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int dtype {} not u8",
q_int.dtype()
)));
}
if scales.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: scales dtype {} not f32",
scales.dtype()
)));
}
if biases.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: biases dtype {} not f32",
biases.dtype()
)));
}
if y.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: y dtype {} not f32",
y.dtype()
)));
}
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let gs_us = group_size as usize;
if x.element_count() != m_us * k_us {
return Err(MlxError::InvalidArgument(format!(
"{OP}: x element_count {} != M*K = {}",
x.element_count(),
m_us * k_us
)));
}
if q_int.element_count() != n_us * k_us {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int element_count {} != N*K = {}",
q_int.element_count(),
n_us * k_us
)));
}
let n_groups = n_us * (k_us / gs_us);
if scales.element_count() != n_groups {
return Err(MlxError::InvalidArgument(format!(
"{OP}: scales element_count {} != N * K/group_size = {}",
scales.element_count(),
n_groups
)));
}
if biases.element_count() != n_groups {
return Err(MlxError::InvalidArgument(format!(
"{OP}: biases element_count {} != N * K/group_size = {}",
biases.element_count(),
n_groups
)));
}
if y.element_count() != m_us * n_us {
return Err(MlxError::InvalidArgument(format!(
"{OP}: y element_count {} != M*N = {}",
y.element_count(),
m_us * n_us
)));
}
if meta.byte_len() < 16 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: meta must be >= 16 bytes ([M,N,K,group_size] u32); got {}",
meta.byte_len()
)));
}
let pipeline = registry.get_pipeline(OP, device)?;
let tg_x: u64 = std::cmp::min(16, m as u64);
let tg_y: u64 = std::cmp::min(16, n as u64);
let tg_count_x = (m as u64).div_ceil(tg_x);
let tg_count_y = (n as u64).div_ceil(tg_y);
encoder.encode_threadgroups(
pipeline,
&[(0, x), (1, q_int), (2, scales), (3, biases), (4, y), (5, meta)],
MTLSize::new(tg_count_x, tg_count_y, 1),
MTLSize::new(tg_x, tg_y, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qmm_affine_t_f32_tiled(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
q_int: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
y: &MlxBuffer,
meta: &MlxBuffer,
m: u32,
n: u32,
k: u32,
group_size: u32,
) -> Result<()> {
const OP: &str = "qmm_affine_t_f32_tiled";
const TILED_BK: u32 = 32;
if group_size != TILED_BK {
return Err(MlxError::InvalidArgument(format!(
"{OP}: group_size must equal {TILED_BK} (kernel BK is hard-coded); got {group_size}"
)));
}
if m == 0 || n == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: M, N, K must all be > 0; got ({m}, {n}, {k})"
)));
}
if k % group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: K ({k}) must be divisible by group_size ({group_size})"
)));
}
if x.dtype() != DType::F32 || scales.dtype() != DType::F32
|| biases.dtype() != DType::F32 || y.dtype() != DType::F32
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: x/scales/biases/y must be f32"
)));
}
if q_int.dtype() != DType::U8 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int dtype {} not u8",
q_int.dtype()
)));
}
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let gs_us = group_size as usize;
if x.element_count() != m_us * k_us
|| q_int.element_count() != n_us * k_us
|| scales.element_count() != n_us * (k_us / gs_us)
|| biases.element_count() != n_us * (k_us / gs_us)
|| y.element_count() != m_us * n_us
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: buffer element_count mismatch"
)));
}
if meta.byte_len() < 16 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: meta < 16 bytes"
)));
}
let pipeline = registry.get_pipeline(OP, device)?;
const BM: u64 = 16;
const BN: u64 = 16;
let tg_count_x = (m as u64).div_ceil(BM);
let tg_count_y = (n as u64).div_ceil(BN);
const SHMEM_BYTES: u64 = 2688;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, x), (1, q_int), (2, scales), (3, biases), (4, y), (5, meta)],
&[(0, SHMEM_BYTES)],
MTLSize::new(tg_count_x, tg_count_y, 1),
MTLSize::new(BM, BN, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qmm_affine_t_f32_simd(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
q_int: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
y: &MlxBuffer,
meta: &MlxBuffer,
m: u32,
n: u32,
k: u32,
group_size: u32,
) -> Result<()> {
const OP: &str = "qmm_affine_t_f32_simd";
const SIMD_BK: u32 = 32;
if group_size != SIMD_BK {
return Err(MlxError::InvalidArgument(format!(
"{OP}: group_size must equal {SIMD_BK} (kernel BK is hard-coded); got {group_size}"
)));
}
if m == 0 || n == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: M, N, K must all be > 0; got ({m}, {n}, {k})"
)));
}
if k % group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: K ({k}) must be divisible by group_size ({group_size})"
)));
}
if x.dtype() != DType::F32 || scales.dtype() != DType::F32
|| biases.dtype() != DType::F32 || y.dtype() != DType::F32
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: x/scales/biases/y must be f32"
)));
}
if q_int.dtype() != DType::U8 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int dtype {} not u8",
q_int.dtype()
)));
}
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let gs_us = group_size as usize;
if x.element_count() != m_us * k_us
|| q_int.element_count() != n_us * k_us
|| scales.element_count() != n_us * (k_us / gs_us)
|| biases.element_count() != n_us * (k_us / gs_us)
|| y.element_count() != m_us * n_us
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: buffer element_count mismatch"
)));
}
if meta.byte_len() < 16 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: meta < 16 bytes"
)));
}
let pipeline = registry.get_pipeline(OP, device)?;
const BM: u64 = 8;
const BN: u64 = 8;
let tg_count_x = (m as u64).div_ceil(BM);
let tg_count_y = (n as u64).div_ceil(BN);
const SHMEM_BYTES: u64 = 2048;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, x), (1, q_int), (2, scales), (3, biases), (4, y), (5, meta)],
&[(0, SHMEM_BYTES)],
MTLSize::new(tg_count_x, tg_count_y, 1),
MTLSize::new(32, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qmm_affine_t_f32_simd4(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
q_int: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
y: &MlxBuffer,
meta: &MlxBuffer,
m: u32,
n: u32,
k: u32,
group_size: u32,
) -> Result<()> {
const OP: &str = "qmm_affine_t_f32_simd4";
const SIMD_BK: u32 = 32;
if group_size != SIMD_BK {
return Err(MlxError::InvalidArgument(format!(
"{OP}: group_size must equal {SIMD_BK} (kernel BK is hard-coded); got {group_size}"
)));
}
if m == 0 || n == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: M, N, K must all be > 0; got ({m}, {n}, {k})"
)));
}
if k % group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: K ({k}) must be divisible by group_size ({group_size})"
)));
}
if x.dtype() != DType::F32 || scales.dtype() != DType::F32
|| biases.dtype() != DType::F32 || y.dtype() != DType::F32
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: x/scales/biases/y must be f32"
)));
}
if q_int.dtype() != DType::U8 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int dtype {} not u8",
q_int.dtype()
)));
}
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let gs_us = group_size as usize;
if x.element_count() != m_us * k_us
|| q_int.element_count() != n_us * k_us
|| scales.element_count() != n_us * (k_us / gs_us)
|| biases.element_count() != n_us * (k_us / gs_us)
|| y.element_count() != m_us * n_us
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: buffer element_count mismatch"
)));
}
if meta.byte_len() < 16 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: meta < 16 bytes"
)));
}
let pipeline = registry.get_pipeline(OP, device)?;
const BM: u64 = 32;
const BN: u64 = 32;
let tg_count_x = (m as u64).div_ceil(BM);
let tg_count_y = (n as u64).div_ceil(BN);
const SHMEM_BYTES: u64 = 8192;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, x), (1, q_int), (2, scales), (3, biases), (4, y), (5, meta)],
&[(0, SHMEM_BYTES)],
MTLSize::new(tg_count_x, tg_count_y, 1),
MTLSize::new(128, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qmm_affine_t_f32_simd4_gs64(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
q_int: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
y: &MlxBuffer,
meta: &MlxBuffer,
m: u32,
n: u32,
k: u32,
group_size: u32,
) -> Result<()> {
const OP: &str = "qmm_affine_t_f32_simd4_gs64";
const SIMD_BK: u32 = 64;
if group_size != SIMD_BK {
return Err(MlxError::InvalidArgument(format!(
"{OP}: group_size must equal {SIMD_BK} (kernel BK is hard-coded); got {group_size}"
)));
}
if m == 0 || n == 0 || k == 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: M, N, K must all be > 0; got ({m}, {n}, {k})"
)));
}
if k % group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: K ({k}) must be divisible by group_size ({group_size})"
)));
}
if x.dtype() != DType::F32 || scales.dtype() != DType::F32
|| biases.dtype() != DType::F32 || y.dtype() != DType::F32
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: x/scales/biases/y must be f32"
)));
}
if q_int.dtype() != DType::U8 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: q_int dtype {} not u8",
q_int.dtype()
)));
}
let m_us = m as usize;
let n_us = n as usize;
let k_us = k as usize;
let gs_us = group_size as usize;
if x.element_count() != m_us * k_us
|| q_int.element_count() != n_us * k_us
|| scales.element_count() != n_us * (k_us / gs_us)
|| biases.element_count() != n_us * (k_us / gs_us)
|| y.element_count() != m_us * n_us
{
return Err(MlxError::InvalidArgument(format!(
"{OP}: buffer element_count mismatch"
)));
}
if meta.byte_len() < 16 {
return Err(MlxError::InvalidArgument(format!(
"{OP}: meta < 16 bytes"
)));
}
let pipeline = registry.get_pipeline(OP, device)?;
const BM: u64 = 32;
const BN: u64 = 32;
let tg_count_x = (m as u64).div_ceil(BM);
let tg_count_y = (n as u64).div_ceil(BN);
const SHMEM_BYTES: u64 = 16384;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, x), (1, q_int), (2, scales), (3, biases), (4, y), (5, meta)],
&[(0, SHMEM_BYTES)],
MTLSize::new(tg_count_x, tg_count_y, 1),
MTLSize::new(128, 1, 1),
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
fn qmm_affine_t_cpu(
x: &[f32],
q_int: &[u8],
scales: &[f32],
biases: &[f32],
m: usize,
n: usize,
k: usize,
group_size: usize,
) -> Vec<f32> {
let groups_per_row = k / group_size;
let mut y = vec![0.0f32; m * n];
for r in 0..m {
for col in 0..n {
let mut acc = 0.0f64;
for g in 0..groups_per_row {
let s = scales[col * groups_per_row + g] as f64;
let b = biases[col * groups_per_row + g] as f64;
for i in 0..group_size {
let kk = g * group_size + i;
let q = q_int[col * k + kk] as f64;
let w_dq = q * s + b;
acc += (x[r * k + kk] as f64) * w_dq;
}
}
y[r * n + col] = acc as f32;
}
}
y
}
fn alloc_f32(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
device
.alloc_buffer(n * 4, DType::F32, shape)
.expect("alloc f32")
}
fn alloc_u8(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
device.alloc_buffer(n, DType::U8, shape).expect("alloc u8")
}
fn make_meta(device: &MlxDevice, m: u32, n: u32, k: u32, gs: u32) -> MlxBuffer {
let mut buf = device.alloc_buffer(16, DType::U32, vec![4]).unwrap();
let dst = buf.as_mut_slice::<u32>().unwrap();
dst.copy_from_slice(&[m, n, k, gs]);
buf
}
#[test]
fn qmm_affine_t_matches_cpu_oracle_4bit_g32() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 8usize;
let n = 16usize;
let k = 64usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k))
.map(|i| ((i as f32) * 0.013 - 0.4).sin() * 0.6)
.collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 11 + 3) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.05 + (i as f32) * 0.003)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.2 + (i as f32) * 0.011)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf,
&q_buf,
&s_buf,
&b_buf,
&y_buf,
&meta,
m as u32,
n as u32,
k as u32,
gs as u32,
)
.unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"y[{i}]: gpu={} cpu={}",
gpu[i],
cpu[i]
);
}
}
#[test]
fn qmm_affine_t_handles_unaligned_m_n() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 7usize;
let n = 13usize;
let k = 64usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| (i as f32) * 0.011 - 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|_| 0.07)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|_| -0.1)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf,
&q_buf,
&s_buf,
&b_buf,
&y_buf,
&meta,
m as u32,
n as u32,
k as u32,
gs as u32,
)
.unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"unaligned y[{i}]: gpu={} cpu={}",
gpu[i],
cpu[i]
);
}
}
#[test]
fn qmm_affine_t_equals_qdq_then_matmul_composition() {
use crate::ops::qdq_affine::dispatch_qdq_affine_forward_f32;
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 5usize;
let n = 9usize;
let k = 96usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| ((i as f32) * 0.017).cos() * 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 13 + 5) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.04 + (i as f32) * 0.005)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.05 + (i as f32) * 0.013)
.collect();
let mut x_buf_a = alloc_f32(&device, m * k, vec![m, k]);
x_buf_a.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf_a = alloc_u8(&device, n * k, vec![n, k]);
q_buf_a.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf_a = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf_a.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf_a = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf_a.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_a = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf_a,
&q_buf_a,
&s_buf_a,
&b_buf_a,
&y_a,
&meta,
m as u32,
n as u32,
k as u32,
gs as u32,
)
.unwrap();
encoder.commit_and_wait().unwrap();
let n_total = n * k;
let mut q_buf_b = alloc_u8(&device, n_total, vec![n_total]);
q_buf_b
.as_mut_slice::<u8>()
.unwrap()
.copy_from_slice(&q_int);
let mut s_buf_b = alloc_f32(&device, n * groups_per_row, vec![n * groups_per_row]);
s_buf_b.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf_b = alloc_f32(&device, n * groups_per_row, vec![n * groups_per_row]);
b_buf_b.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let w_dq = alloc_f32(&device, n_total, vec![n_total]);
let mut fwd_meta = device.alloc_buffer(8, DType::U32, vec![2]).unwrap();
fwd_meta
.as_mut_slice::<u32>()
.unwrap()
.copy_from_slice(&[n_total as u32, gs as u32]);
let mut encoder = device.command_encoder().unwrap();
dispatch_qdq_affine_forward_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&q_buf_b,
&s_buf_b,
&b_buf_b,
&w_dq,
&fwd_meta,
gs as u32,
)
.unwrap();
encoder.commit_and_wait().unwrap();
let w_dq_host = w_dq.as_slice::<f32>().unwrap();
let mut y_b = vec![0.0f32; m * n];
for r in 0..m {
for col in 0..n {
let mut acc = 0.0f64;
for kk in 0..k {
acc += (x[r * k + kk] as f64) * (w_dq_host[col * k + kk] as f64);
}
y_b[r * n + col] = acc as f32;
}
}
let y_a_host = y_a.as_slice::<f32>().unwrap();
for i in 0..(m * n) {
assert!(
(y_a_host[i] - y_b[i]).abs() < 1e-3 * y_b[i].abs().max(1.0),
"fused vs composed at i={i}: fused={} composed={}",
y_a_host[i],
y_b[i]
);
}
}
#[test]
fn rejects_k_not_divisible_by_group_size() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 32, vec![1, 32]);
let q_buf = alloc_u8(&device, 32, vec![1, 32]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 32, 5);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf,
&q_buf,
&s_buf,
&b_buf,
&y_buf,
&meta,
1,
1,
32,
5, );
assert!(res.is_err());
}
#[test]
fn qmm_affine_tiled_matches_per_element_kernel() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 32usize;
let n = 64usize;
let k = 128usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k))
.map(|i| ((i as f32) * 0.013 - 0.4).sin() * 0.6)
.collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 11 + 3) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.05 + (i as f32) * 0.003)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.2 + (i as f32) * 0.011)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_pe = alloc_f32(&device, m * n, vec![m, n]);
let y_tl = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_pe, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
dispatch_qmm_affine_t_f32_tiled(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_tl, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let pe = y_pe.as_slice::<f32>().unwrap();
let tl = y_tl.as_slice::<f32>().unwrap();
for i in 0..(m * n) {
assert!(
(pe[i] - tl[i]).abs() < 1e-4 * pe[i].abs().max(1.0),
"tile vs per-elem at i={i}: pe={} tiled={}",
pe[i], tl[i]
);
}
}
#[test]
fn qmm_affine_tiled_handles_unaligned_m_n() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 23usize;
let n = 47usize;
let k = 64usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| (i as f32) * 0.011 - 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row)).map(|i| 0.07 + i as f32 * 0.001).collect();
let biases: Vec<f32> = (0..(n * groups_per_row)).map(|i| -0.1 + i as f32 * 0.002).collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32_tiled(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"tiled unaligned y[{i}]: gpu={} cpu={}",
gpu[i], cpu[i]
);
}
}
#[test]
fn qmm_affine_tiled_rejects_non_32_group_size() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 64, vec![1, 64]);
let q_buf = alloc_u8(&device, 64, vec![1, 64]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 64, 64);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32_tiled(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
1, 1, 64, 64,
);
assert!(res.is_err(), "tiled must reject group_size != 32");
}
#[test]
fn rejects_dtype_mismatch() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 32, vec![1, 32]);
let wrong_q = alloc_f32(&device, 32, vec![1, 32]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 32, 32);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf,
&wrong_q,
&s_buf,
&b_buf,
&y_buf,
&meta,
1,
1,
32,
32,
);
assert!(res.is_err());
}
#[test]
fn qmm_affine_simd_matches_per_element_kernel() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 32usize;
let n = 64usize;
let k = 128usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k))
.map(|i| ((i as f32) * 0.013 - 0.4).sin() * 0.6)
.collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 11 + 3) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.05 + (i as f32) * 0.003)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.2 + (i as f32) * 0.011)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_pe = alloc_f32(&device, m * n, vec![m, n]);
let y_simd = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_pe, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
dispatch_qmm_affine_t_f32_simd(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_simd, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let pe = y_pe.as_slice::<f32>().unwrap();
let sm = y_simd.as_slice::<f32>().unwrap();
for i in 0..(m * n) {
assert!(
(pe[i] - sm[i]).abs() < 1e-4 * pe[i].abs().max(1.0),
"simd vs per-elem at i={i}: pe={} simd={}",
pe[i], sm[i]
);
}
}
#[test]
fn qmm_affine_simd_handles_unaligned_m_n() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 13usize;
let n = 21usize;
let k = 64usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| (i as f32) * 0.011 - 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row)).map(|i| 0.07 + i as f32 * 0.001).collect();
let biases: Vec<f32> = (0..(n * groups_per_row)).map(|i| -0.1 + i as f32 * 0.002).collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32_simd(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"simd unaligned y[{i}]: gpu={} cpu={}",
gpu[i], cpu[i]
);
}
}
#[test]
fn qmm_affine_simd4_matches_per_element_kernel() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 32usize;
let n = 64usize;
let k = 128usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k))
.map(|i| ((i as f32) * 0.013 - 0.4).sin() * 0.6)
.collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 11 + 3) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.05 + (i as f32) * 0.003)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.2 + (i as f32) * 0.011)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_pe = alloc_f32(&device, m * n, vec![m, n]);
let y_simd4 = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_pe, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
dispatch_qmm_affine_t_f32_simd4(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_simd4, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let pe = y_pe.as_slice::<f32>().unwrap();
let sm = y_simd4.as_slice::<f32>().unwrap();
for i in 0..(m * n) {
assert!(
(pe[i] - sm[i]).abs() < 1e-4 * pe[i].abs().max(1.0),
"simd4 vs per-elem at i={i}: pe={} simd4={}",
pe[i], sm[i]
);
}
}
#[test]
fn qmm_affine_simd4_handles_unaligned_m_n() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 23usize;
let n = 47usize;
let k = 64usize;
let gs = 32usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| (i as f32) * 0.011 - 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row)).map(|i| 0.07 + i as f32 * 0.001).collect();
let biases: Vec<f32> = (0..(n * groups_per_row)).map(|i| -0.1 + i as f32 * 0.002).collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32_simd4(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"simd4 unaligned y[{i}]: gpu={} cpu={}",
gpu[i], cpu[i]
);
}
}
#[test]
fn qmm_affine_simd4_rejects_non_32_group_size() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 64, vec![1, 64]);
let q_buf = alloc_u8(&device, 64, vec![1, 64]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 64, 64);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32_simd4(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
1, 1, 64, 64,
);
assert!(res.is_err());
}
#[test]
fn qmm_affine_simd4_gs64_matches_per_element_kernel() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 32usize;
let n = 64usize;
let k = 128usize; let gs = 64usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k))
.map(|i| ((i as f32) * 0.013 - 0.4).sin() * 0.6)
.collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 11 + 3) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row))
.map(|i| 0.05 + (i as f32) * 0.003)
.collect();
let biases: Vec<f32> = (0..(n * groups_per_row))
.map(|i| -0.2 + (i as f32) * 0.011)
.collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_pe = alloc_f32(&device, m * n, vec![m, n]);
let y_simd4 = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_pe, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
dispatch_qmm_affine_t_f32_simd4_gs64(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_simd4, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let pe = y_pe.as_slice::<f32>().unwrap();
let sm = y_simd4.as_slice::<f32>().unwrap();
for i in 0..(m * n) {
assert!(
(pe[i] - sm[i]).abs() < 1e-4 * pe[i].abs().max(1.0),
"simd4_gs64 vs per-elem at i={i}: pe={} simd4_gs64={}",
pe[i], sm[i]
);
}
}
#[test]
fn qmm_affine_simd4_gs64_handles_unaligned_m_n() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let m = 23usize;
let n = 47usize;
let k = 128usize;
let gs = 64usize;
let groups_per_row = k / gs;
let x: Vec<f32> = (0..(m * k)).map(|i| (i as f32) * 0.011 - 0.5).collect();
let q_int: Vec<u8> = (0..(n * k)).map(|i| ((i * 7) % 16) as u8).collect();
let scales: Vec<f32> = (0..(n * groups_per_row)).map(|i| 0.07 + i as f32 * 0.001).collect();
let biases: Vec<f32> = (0..(n * groups_per_row)).map(|i| -0.1 + i as f32 * 0.002).collect();
let mut x_buf = alloc_f32(&device, m * k, vec![m, k]);
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
let mut q_buf = alloc_u8(&device, n * k, vec![n, k]);
q_buf.as_mut_slice::<u8>().unwrap().copy_from_slice(&q_int);
let mut s_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
s_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&scales);
let mut b_buf = alloc_f32(&device, n * groups_per_row, vec![n, groups_per_row]);
b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&biases);
let y_buf = alloc_f32(&device, m * n, vec![m, n]);
let meta = make_meta(&device, m as u32, n as u32, k as u32, gs as u32);
let mut encoder = device.command_encoder().unwrap();
dispatch_qmm_affine_t_f32_simd4_gs64(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
m as u32, n as u32, k as u32, gs as u32,
).unwrap();
encoder.commit_and_wait().unwrap();
let gpu = y_buf.as_slice::<f32>().unwrap();
let cpu = qmm_affine_t_cpu(&x, &q_int, &scales, &biases, m, n, k, gs);
for i in 0..(m * n) {
assert!(
(gpu[i] - cpu[i]).abs() < 1e-3 * cpu[i].abs().max(1.0),
"simd4_gs64 unaligned y[{i}]: gpu={} cpu={}",
gpu[i], cpu[i]
);
}
}
#[test]
fn qmm_affine_simd4_gs64_rejects_non_64_group_size() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 32, vec![1, 32]);
let q_buf = alloc_u8(&device, 32, vec![1, 32]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 32, 32);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32_simd4_gs64(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
1, 1, 32, 32,
);
assert!(res.is_err());
}
#[test]
fn qmm_affine_simd_rejects_non_32_group_size() {
let device = MlxDevice::new().expect("device");
let mut registry = KernelRegistry::new();
let x_buf = alloc_f32(&device, 64, vec![1, 64]);
let q_buf = alloc_u8(&device, 64, vec![1, 64]);
let s_buf = alloc_f32(&device, 1, vec![1]);
let b_buf = alloc_f32(&device, 1, vec![1]);
let y_buf = alloc_f32(&device, 1, vec![1, 1]);
let meta = make_meta(&device, 1, 1, 64, 64);
let mut encoder = device.command_encoder().unwrap();
let res = dispatch_qmm_affine_t_f32_simd(
&mut encoder, &mut registry, device.metal_device(),
&x_buf, &q_buf, &s_buf, &b_buf, &y_buf, &meta,
1, 1, 64, 64,
);
assert!(res.is_err());
}
}