#![cfg(all(target_os = "macos", feature = "metal"))]
use std::sync::OnceLock;
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, MTLSize,
};
pub const Q4_K_BLOCK_BYTES: usize = 144;
pub const Q4_K_BLOCK_ELEMENTS: usize = 256;
const SHADER_SRC: &str = include_str!("q4_k_dequant.metal");
const KERNEL_NAME: &str = "dequantize_q4_k_f16";
static PIPELINE: OnceLock<ComputePipelineState> = OnceLock::new();
fn pipeline(device: &Device) -> &'static ComputePipelineState {
PIPELINE.get_or_init(|| {
let lib = device
.new_library_with_source(SHADER_SRC, &CompileOptions::new())
.expect("compile q4_k_dequant.metal");
let function = lib
.get_function(KERNEL_NAME, None)
.expect("find dequantize_q4_k_f16 function in library");
device
.new_compute_pipeline_state_with_function(&function)
.expect("build compute pipeline")
})
}
pub fn encode_dequant_q4_k_to_f16(
device: &Device,
cmd: &CommandBufferRef,
blocks_buf: &Buffer,
out_buf: &Buffer,
n_blocks: usize,
) {
if n_blocks == 0 {
return;
}
let pipe = pipeline(device);
let enc = cmd.new_compute_command_encoder();
encode_dispatch(enc, pipe, blocks_buf, out_buf, n_blocks);
enc.end_encoding();
}
pub fn dispatch_dequant_q4_k_on_encoder(
device: &Device,
enc: &ComputeCommandEncoderRef,
blocks_buf: &Buffer,
out_buf: &Buffer,
n_blocks: usize,
) {
if n_blocks == 0 {
return;
}
let pipe = pipeline(device);
encode_dispatch(enc, pipe, blocks_buf, out_buf, n_blocks);
}
fn encode_dispatch(
enc: &ComputeCommandEncoderRef,
pipe: &ComputePipelineState,
blocks_buf: &Buffer,
out_buf: &Buffer,
n_blocks: usize,
) {
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(blocks_buf), 0);
enc.set_buffer(1, Some(out_buf), 0);
let threads_per_group = pipe.thread_execution_width().min(64) as u64;
let total_threads = n_blocks as u64;
let tg = MTLSize::new(threads_per_group, 1, 1);
let grid = MTLSize::new(total_threads, 1, 1);
enc.dispatch_threads(grid, tg);
}
pub fn dequant_q4_k_to_f16_blocking(
device: &Device,
queue: &metal::CommandQueue,
blocks_buf: &Buffer,
out_buf: &Buffer,
n_blocks: usize,
) {
let cmd = queue.new_command_buffer();
encode_dequant_q4_k_to_f16(device, cmd, blocks_buf, out_buf, n_blocks);
cmd.commit();
cmd.wait_until_completed();
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::quantized::{GgmlDType, QTensor};
use candle_core::{Device as CandleDevice, Tensor};
use half::f16;
use metal::MTLResourceOptions;
fn close_enough(a: f32, b: f32, rel: f32, abs: f32) -> bool {
let diff = (a - b).abs();
diff <= abs || diff <= rel * b.abs().max(a.abs())
}
#[test]
fn metal_q4_k_dequant_matches_candle_cpu_reference() {
let n_blocks: usize = 4;
let n_elem: usize = n_blocks * Q4_K_BLOCK_ELEMENTS; let raw: Vec<f32> = (0..n_elem)
.map(|i| ((i as f32 * 0.0173).sin() + (i as f32 * 0.0091).cos()) * 0.5)
.collect();
let cpu = CandleDevice::Cpu;
let t = Tensor::from_vec(raw.clone(), n_elem, &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
let dense_ref = qt.dequantize(&cpu).unwrap();
let ref_f32: Vec<f32> = dense_ref.flatten_all().unwrap().to_vec1::<f32>().unwrap();
assert_eq!(ref_f32.len(), n_elem);
let bytes = qt.data().expect("read QTensor bytes");
assert_eq!(
bytes.len(),
n_blocks * Q4_K_BLOCK_BYTES,
"expected {n_blocks} super-blocks × 144 bytes"
);
let Some(device) = Device::system_default() else {
eprintln!("no Metal device available — skipping");
return;
};
let queue = device.new_command_queue();
let blocks_buf = device.new_buffer_with_data(
bytes.as_ptr() as *const _,
bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let out_buf = device.new_buffer(
(n_elem * std::mem::size_of::<f16>()) as u64,
MTLResourceOptions::StorageModeShared,
);
dequant_q4_k_to_f16_blocking(&device, &queue, &blocks_buf, &out_buf, n_blocks);
let out_ptr = out_buf.contents() as *const f16;
let our_f16: Vec<f16> = unsafe { std::slice::from_raw_parts(out_ptr, n_elem) }.to_vec();
let our_f32: Vec<f32> = our_f16.iter().map(|h| h.to_f32()).collect();
let mut max_abs = 0.0_f32;
let mut max_rel = 0.0_f32;
let mut mismatches = 0;
for (i, (&our, &refv)) in our_f32.iter().zip(ref_f32.iter()).enumerate() {
let diff = (our - refv).abs();
if diff > max_abs {
max_abs = diff;
}
let denom = refv.abs().max(our.abs()).max(1e-6);
let rel = diff / denom;
if rel > max_rel {
max_rel = rel;
}
if !close_enough(our, refv, 1e-2, 1e-3) {
mismatches += 1;
if mismatches < 5 {
eprintln!("[{i}] our={our} ref={refv} diff={diff}");
}
}
}
eprintln!(
"max_abs_diff={max_abs:.6} max_rel_diff={max_rel:.6} mismatches={mismatches}/{n_elem}"
);
assert!(
mismatches == 0,
"{mismatches}/{n_elem} elements outside fp16 tolerance"
);
}
#[test]
fn metal_q4_k_dequant_handles_thousands_of_blocks() {
let n_blocks: usize = 4096; let n_elem: usize = n_blocks * Q4_K_BLOCK_ELEMENTS;
let raw: Vec<f32> = (0..n_elem)
.map(|i| (i as f32 * 1.7e-4).sin() * 0.7)
.collect();
let cpu = CandleDevice::Cpu;
let t = Tensor::from_vec(raw, n_elem, &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
let dense_ref = qt.dequantize(&cpu).unwrap();
let ref_f32: Vec<f32> = dense_ref.flatten_all().unwrap().to_vec1::<f32>().unwrap();
let bytes = qt.data().unwrap();
let Some(device) = Device::system_default() else {
return;
};
let queue = device.new_command_queue();
let blocks_buf = device.new_buffer_with_data(
bytes.as_ptr() as *const _,
bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let out_buf = device.new_buffer(
(n_elem * std::mem::size_of::<f16>()) as u64,
MTLResourceOptions::StorageModeShared,
);
dequant_q4_k_to_f16_blocking(&device, &queue, &blocks_buf, &out_buf, n_blocks);
let out_ptr = out_buf.contents() as *const f16;
let our_f16: &[f16] = unsafe { std::slice::from_raw_parts(out_ptr, n_elem) };
let mut mismatches = 0;
for (i, h) in our_f16.iter().enumerate() {
let our = h.to_f32();
let r = ref_f32[i];
if (our - r).abs() > 1e-2 && (our - r).abs() / r.abs().max(1e-6) > 1e-2 {
mismatches += 1;
}
}
assert_eq!(
mismatches, 0,
"{mismatches}/{n_elem} elements diverged from candle CPU reference"
);
}
}