use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static QDQ_LEGACY_SHADER_SOURCE: &str = include_str!("../shaders/qdq_legacy.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("qdq_q4_0_f32", QDQ_LEGACY_SHADER_SOURCE);
registry.register_source("qdq_q8_0_f32", QDQ_LEGACY_SHADER_SOURCE);
}
pub const QDQ_BLOCK_SIZE: u32 = 32;
pub fn dispatch_qdq_q4_0_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
) -> Result<()> {
validate(input, output, "qdq_q4_0_f32")?;
let n = input.element_count() as u64;
let num_blocks = n / u64::from(QDQ_BLOCK_SIZE);
let pipeline = registry.get_pipeline("qdq_q4_0_f32", device)?;
let shared_mem_bytes: u64 = 64 * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output)],
&[(0, shared_mem_bytes)],
MTLSize::new(num_blocks, 1, 1),
MTLSize::new(u64::from(QDQ_BLOCK_SIZE), 1, 1),
);
Ok(())
}
pub fn dispatch_qdq_q8_0_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
) -> Result<()> {
validate(input, output, "qdq_q8_0_f32")?;
let n = input.element_count() as u64;
let num_blocks = n / u64::from(QDQ_BLOCK_SIZE);
let pipeline = registry.get_pipeline("qdq_q8_0_f32", device)?;
let shared_mem_bytes: u64 = 32 * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output)],
&[(0, shared_mem_bytes)],
MTLSize::new(num_blocks, 1, 1),
MTLSize::new(u64::from(QDQ_BLOCK_SIZE), 1, 1),
);
Ok(())
}
fn validate(input: &MlxBuffer, output: &MlxBuffer, op_name: &str) -> Result<()> {
let n = input.element_count();
if n == 0 {
return Err(MlxError::InvalidArgument(format!(
"{op_name}: input must have at least one element"
)));
}
if n % (QDQ_BLOCK_SIZE as usize) != 0 {
return Err(MlxError::InvalidArgument(format!(
"{op_name}: input element count ({n}) must be divisible by block size ({})",
QDQ_BLOCK_SIZE
)));
}
if output.element_count() != n {
return Err(MlxError::InvalidArgument(format!(
"{op_name}: output element count {} != input element count {}",
output.element_count(),
n
)));
}
if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"{op_name}: only f32 supported; got input={} output={}",
input.dtype(),
output.dtype()
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
fn qdq_q4_0_cpu_oracle(input: &[f32]) -> Vec<f32> {
const QK: usize = 32;
assert!(input.len() % QK == 0);
let mut out = vec![0f32; input.len()];
for blk_i in 0..(input.len() / QK) {
let block = &input[blk_i * QK..(blk_i + 1) * QK];
let mut amax = 0.0f32;
let mut max = 0.0f32;
for &v in block {
let av = v.abs();
if av > amax {
amax = av;
max = v;
}
}
let d = max / -8.0;
let id = if d == 0.0 { 0.0 } else { 1.0 / d };
let d_h = half::f16::from_f32(d).to_f32();
for (j, &v) in block.iter().enumerate() {
let scaled = v * id + 8.5;
let q = (scaled as i32).clamp(0, 15);
out[blk_i * QK + j] = (q - 8) as f32 * d_h;
}
}
out
}
fn qdq_q8_0_cpu_oracle(input: &[f32]) -> Vec<f32> {
const QK: usize = 32;
assert!(input.len() % QK == 0);
let mut out = vec![0f32; input.len()];
for blk_i in 0..(input.len() / QK) {
let block = &input[blk_i * QK..(blk_i + 1) * QK];
let amax = block.iter().fold(0.0f32, |a, &v| a.max(v.abs()));
let d = amax / 127.0;
let id = if d == 0.0 { 0.0 } else { 1.0 / d };
let d_h = half::f16::from_f32(d).to_f32();
for (j, &v) in block.iter().enumerate() {
let q = ((v * id).round() as i32).clamp(-128, 127);
out[blk_i * QK + j] = (q as f32) * d_h;
}
}
out
}
fn run_qdq_q4_0(input: &[f32]) -> Vec<f32> {
let device = MlxDevice::new().expect("MlxDevice::new");
let n_bytes = input.len() * std::mem::size_of::<f32>();
let mut in_buf = device
.alloc_buffer(n_bytes, DType::F32, vec![input.len()])
.expect("alloc input");
let out_buf = device
.alloc_buffer(n_bytes, DType::F32, vec![input.len()])
.expect("alloc output");
{
let slice: &mut [f32] = in_buf.as_mut_slice().expect("input as_mut_slice");
slice.copy_from_slice(input);
}
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("command_encoder");
dispatch_qdq_q4_0_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
)
.expect("dispatch_qdq_q4_0_f32");
encoder.commit_and_wait().expect("commit_and_wait");
out_buf.as_slice::<f32>().expect("output as_slice").to_vec()
}
fn run_qdq_q8_0(input: &[f32]) -> Vec<f32> {
let device = MlxDevice::new().expect("MlxDevice::new");
let n_bytes = input.len() * std::mem::size_of::<f32>();
let mut in_buf = device
.alloc_buffer(n_bytes, DType::F32, vec![input.len()])
.expect("alloc input");
let out_buf = device
.alloc_buffer(n_bytes, DType::F32, vec![input.len()])
.expect("alloc output");
{
let slice: &mut [f32] = in_buf.as_mut_slice().expect("input as_mut_slice");
slice.copy_from_slice(input);
}
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("command_encoder");
dispatch_qdq_q8_0_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
)
.expect("dispatch_qdq_q8_0_f32");
encoder.commit_and_wait().expect("commit_and_wait");
out_buf.as_slice::<f32>().expect("output as_slice").to_vec()
}
fn assert_byte_identical(label: &str, gpu: &[f32], cpu: &[f32]) {
assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
if g.to_bits() != c.to_bits() {
panic!(
"{label}: bit-mismatch at index {i}: gpu={} (0x{:08x}) cpu={} (0x{:08x})",
g,
g.to_bits(),
c,
c.to_bits()
);
}
}
}
#[test]
fn qdq_q4_0_byte_identical_one_block() {
let input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.137).collect();
let gpu = run_qdq_q4_0(&input);
let cpu = qdq_q4_0_cpu_oracle(&input);
assert_byte_identical("q4_0 single-block ramp", &gpu, &cpu);
}
#[test]
fn qdq_q4_0_byte_identical_random_multi_block() {
let mut input = Vec::with_capacity(256);
for blk in 0..8 {
let scale = (blk as f32 + 1.0) * 0.5;
for i in 0..32 {
let v = ((i as f32 * 17.0 + blk as f32 * 31.0).sin()) * scale;
input.push(v);
}
}
let gpu = run_qdq_q4_0(&input);
let cpu = qdq_q4_0_cpu_oracle(&input);
assert_byte_identical("q4_0 multi-block sin", &gpu, &cpu);
}
#[test]
fn qdq_q4_0_zero_block() {
let input = vec![0.0_f32; 32];
let gpu = run_qdq_q4_0(&input);
let cpu = qdq_q4_0_cpu_oracle(&input);
assert_byte_identical("q4_0 all-zero block", &gpu, &cpu);
assert!(gpu.iter().all(|&v| v == 0.0));
}
#[test]
fn qdq_q4_0_single_outlier_block() {
let mut input = vec![0.0_f32; 32];
input[7] = 5.25;
let gpu = run_qdq_q4_0(&input);
let cpu = qdq_q4_0_cpu_oracle(&input);
assert_byte_identical("q4_0 single outlier", &gpu, &cpu);
}
#[test]
fn qdq_q4_0_negative_amax_block() {
let mut input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.05).collect();
input[3] = -2.0; let gpu = run_qdq_q4_0(&input);
let cpu = qdq_q4_0_cpu_oracle(&input);
assert_byte_identical("q4_0 negative amax", &gpu, &cpu);
}
#[test]
fn qdq_q8_0_byte_identical_one_block() {
let input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.137).collect();
let gpu = run_qdq_q8_0(&input);
let cpu = qdq_q8_0_cpu_oracle(&input);
assert_byte_identical("q8_0 single-block ramp", &gpu, &cpu);
}
#[test]
fn qdq_q8_0_byte_identical_random_multi_block() {
let mut input = Vec::with_capacity(256);
for blk in 0..8 {
let scale = (blk as f32 + 1.0) * 0.5;
for i in 0..32 {
let v = ((i as f32 * 17.0 + blk as f32 * 31.0).sin()) * scale;
input.push(v);
}
}
let gpu = run_qdq_q8_0(&input);
let cpu = qdq_q8_0_cpu_oracle(&input);
assert_byte_identical("q8_0 multi-block sin", &gpu, &cpu);
}
#[test]
fn qdq_q8_0_zero_block() {
let input = vec![0.0_f32; 32];
let gpu = run_qdq_q8_0(&input);
let cpu = qdq_q8_0_cpu_oracle(&input);
assert_byte_identical("q8_0 all-zero block", &gpu, &cpu);
assert!(gpu.iter().all(|&v| v == 0.0));
}
#[test]
fn qdq_q8_0_signed_extremes() {
let mut input = vec![0.0_f32; 32];
input[0] = 1.0;
input[1] = 0.5 / 127.0;
input[2] = -0.5 / 127.0;
let gpu = run_qdq_q8_0(&input);
let cpu = qdq_q8_0_cpu_oracle(&input);
assert_byte_identical("q8_0 signed extremes", &gpu, &cpu);
}
#[test]
fn input_size_must_be_block_aligned() {
let device = MlxDevice::new().expect("MlxDevice");
let in_buf = device
.alloc_buffer(33 * 4, DType::F32, vec![33])
.expect("alloc");
let out_buf = device
.alloc_buffer(33 * 4, DType::F32, vec![33])
.expect("alloc");
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("command_encoder");
let err = dispatch_qdq_q4_0_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
)
.expect_err("must reject non-block-aligned input");
let msg = format!("{err}");
assert!(
msg.contains("divisible by block size"),
"wrong error message: {msg}"
);
}
#[test]
fn output_size_must_match_input() {
let device = MlxDevice::new().expect("MlxDevice");
let in_buf = device
.alloc_buffer(32 * 4, DType::F32, vec![32])
.expect("alloc input");
let out_buf = device
.alloc_buffer(64 * 4, DType::F32, vec![64])
.expect("alloc output");
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("command_encoder");
let err = dispatch_qdq_q4_0_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
)
.expect_err("must reject mismatched output size");
let msg = format!("{err}");
assert!(
msg.contains("output element count"),
"wrong error message: {msg}"
);
}
}