use super::{metal_mmap_trace_enabled, slice_is_in_registered_mmap, st, MetalBackend};
use crate::backend::{Backend, GgufQuantType};
use ferrum_types::{FerrumError, Result};
use metal::MTLResourceOptions;
use std::ffi::c_void;
use std::sync::OnceLock;
static QUANT_GEMM_TIME_US: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_M: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_N: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_K: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
fn debug_per_call_flush() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| std::env::vars().any(|(name, _)| name == "FERRUM_METAL_QUANT_PROFILE"))
}
pub enum MetalQuantStore {
Q4K {
blocks: metal::Buffer,
byte_offset: u64,
n_rows: usize,
n_cols: usize,
n_blocks: usize,
},
Q6K {
blocks: metal::Buffer,
byte_offset: u64,
n_rows: usize,
n_cols: usize,
n_blocks: usize,
},
Fused {
parts: Vec<MetalQuantStore>,
total_rows: usize,
n_cols: usize,
},
Q4KExperts {
blocks: metal::Buffer,
byte_offset: u64,
num_experts: usize,
n_rows: usize, n_cols: usize, },
Q6KExperts {
blocks: metal::Buffer,
byte_offset: u64,
num_experts: usize,
n_rows: usize,
n_cols: usize,
},
}
impl MetalQuantStore {
fn n_rows(&self) -> usize {
match self {
MetalQuantStore::Q4K { n_rows, .. } | MetalQuantStore::Q6K { n_rows, .. } => *n_rows,
MetalQuantStore::Fused { total_rows, .. } => *total_rows,
MetalQuantStore::Q4KExperts { n_rows, .. }
| MetalQuantStore::Q6KExperts { n_rows, .. } => *n_rows,
}
}
}
unsafe impl Send for MetalQuantStore {}
unsafe impl Sync for MetalQuantStore {}
fn buffer_for_quant_bytes(bytes: &[u8]) -> (metal::Buffer, u64) {
const PAGE: usize = 16384;
let trace = metal_mmap_trace_enabled();
if slice_is_in_registered_mmap(bytes) {
let ptr_addr = bytes.as_ptr() as usize;
let aligned_start = ptr_addr & !(PAGE - 1);
let aligned_end = (ptr_addr + bytes.len()).div_ceil(PAGE) * PAGE;
let aligned_len = aligned_end - aligned_start;
let byte_offset = (ptr_addr - aligned_start) as u64;
let buf = st().pipes.device.new_buffer_with_bytes_no_copy(
aligned_start as *const c_void,
aligned_len as u64,
MTLResourceOptions::StorageModeShared,
None,
);
if buf.length() != 0 {
if trace {
eprintln!(
"[mmap] zero-copy: tensor ptr=0x{ptr_addr:x} len={} -> buf @0x{aligned_start:x} len={aligned_len} off={byte_offset}",
bytes.len()
);
}
return (buf, byte_offset);
}
if trace {
eprintln!(
"[mmap] zero-copy refused for tensor ptr=0x{ptr_addr:x} len={} aligned_len={aligned_len} — copying",
bytes.len()
);
}
}
if trace {
eprintln!("[mmap] copy: ptr={:p} len={}", bytes.as_ptr(), bytes.len());
}
let buf = st().pipes.device.new_buffer_with_data(
bytes.as_ptr() as *const c_void,
bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
(buf, 0)
}
pub fn load_q4k_experts(
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<MetalQuantStore> {
const QK_K: usize = 256;
const BLOCK_BYTES: usize = 144;
if n_cols % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_q4k_experts: n_cols {n_cols} not a multiple of {QK_K}"
)));
}
let expected = num_experts * n_rows * (n_cols / QK_K) * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_q4k_experts: bytes {} != expected {expected} ({num_experts}E × {n_rows}R × {n_cols}C)",
bytes.len()
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
})
}
pub fn load_q6k_experts(
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<MetalQuantStore> {
const QK_K: usize = 256;
const BLOCK_BYTES: usize = crate::q6_k_gemv::Q6_K_BLOCK_BYTES;
if n_cols % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_q6k_experts: n_cols {n_cols} not a multiple of {QK_K}"
)));
}
let expected = num_experts * n_rows * (n_cols / QK_K) * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_q6k_experts: bytes {} != expected {expected}",
bytes.len()
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
})
}
pub fn dispatch_gemv_moe_id(
enc: &metal::ComputeCommandEncoderRef,
a: &metal::Buffer,
weights: &MetalQuantStore,
ids: &metal::Buffer,
out: &metal::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
match weights {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_on_encoder(
&st().pipes.device,
enc,
a,
blocks,
*byte_offset,
ids,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q6_k_moe_id_gemv::dispatch_gemv_q6k_moe_id_on_encoder(
&st().pipes.device,
enc,
a,
blocks,
*byte_offset,
ids,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
_ => Err(FerrumError::model(
"dispatch_gemv_moe_id: weights must be Q4KExperts or Q6KExperts variant".to_string(),
)),
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gemv_moe_id_offset(
enc: &metal::ComputeCommandEncoderRef,
a: &metal::Buffer,
a_byte_offset: u64,
weights: &MetalQuantStore,
ids: &metal::Buffer,
ids_byte_offset: u64,
out: &metal::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
match weights {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_offset_on_encoder(
&st().pipes.device,
enc,
a,
a_byte_offset,
blocks,
*byte_offset,
ids,
ids_byte_offset,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q6_k_moe_id_gemv::dispatch_gemv_q6k_moe_id_offset_on_encoder(
&st().pipes.device,
enc,
a,
a_byte_offset,
blocks,
*byte_offset,
ids,
ids_byte_offset,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
_ => Err(FerrumError::model(
"dispatch_gemv_moe_id_offset: weights must be Q4KExperts or Q6KExperts variant"
.to_string(),
)),
}
}
fn dispatch_part_gemm(
enc: &metal::ComputeCommandEncoderRef,
a_buf: &metal::Buffer,
part: &MetalQuantStore,
out_buf: &metal::Buffer,
c_offset_cols: usize,
m: usize,
part_rows: usize,
stride_c: usize,
n_cols: usize,
) -> Result<()> {
if part_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Fused: part rows {part_rows} not divisible by 4"
)));
}
match part {
MetalQuantStore::Q4K {
blocks,
byte_offset,
..
} => {
crate::q4_k_gemm::dispatch_gemm_q4k_part(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
c_offset_cols,
m,
part_rows,
stride_c,
n_cols,
);
}
MetalQuantStore::Q6K {
blocks,
byte_offset,
..
} => {
crate::q6_k_gemm::dispatch_gemm_q6k_part(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
c_offset_cols,
m,
part_rows,
stride_c,
n_cols,
);
}
MetalQuantStore::Fused { .. }
| MetalQuantStore::Q4KExperts { .. }
| MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant Fused: only Q4K/Q6K leaf parts supported here".to_string(),
));
}
}
Ok(())
}
fn dispatch_part_gemv_offset(
enc: &metal::ComputeCommandEncoderRef,
a_buf: &metal::Buffer,
a_offset_bytes: u64,
part: &MetalQuantStore,
out_buf: &metal::Buffer,
c_offset_bytes: u64,
n_cols: usize,
) -> Result<()> {
match part {
MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
..
} => {
if *n_rows % 4 != 0 {
crate::q4_k_gemv::dispatch_gemv_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
*n_rows,
n_cols,
);
if a_offset_bytes != 0 || c_offset_bytes != 0 {
return Err(FerrumError::model(
"gemm_quant Fused: q4k v1 path doesn't support offsets yet".to_string(),
));
}
return Ok(());
}
crate::q4_k_gemv_v2::dispatch_gemv_q4k_v2_offset(
&st().pipes.device,
enc,
a_buf,
a_offset_bytes,
blocks,
*byte_offset,
out_buf,
c_offset_bytes,
*n_rows,
n_cols,
);
}
MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
..
} => {
if *n_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Fused: Q6K part n_rows={n_rows} not divisible by 4"
)));
}
crate::q6_k_gemv::dispatch_gemv_q6k_v2_offset(
&st().pipes.device,
enc,
a_buf,
a_offset_bytes,
blocks,
*byte_offset,
out_buf,
c_offset_bytes,
*n_rows,
n_cols,
);
}
MetalQuantStore::Fused { .. }
| MetalQuantStore::Q4KExperts { .. }
| MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant Fused: only Q4K/Q6K leaf parts supported here".to_string(),
));
}
}
Ok(())
}
fn metal_load_quant_store_helper(
kind: GgufQuantType,
bytes: &[u8],
n_rows: usize,
n_cols: usize,
) -> Result<MetalQuantStore> {
const QK_K: usize = 256;
match kind {
GgufQuantType::Q4K => {
const BLOCK_BYTES: usize = 144;
let total_elems = n_rows * n_cols;
if total_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_quant Q4K: elements {total_elems} not multiple of {QK_K}"
)));
}
let n_blocks = total_elems / QK_K;
let expected = n_blocks * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_quant Q4K: bytes {} != expected {} ({n_blocks} blocks)",
bytes.len(),
expected
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
})
}
GgufQuantType::Q6K => {
const BLOCK_BYTES: usize = crate::q6_k_gemv::Q6_K_BLOCK_BYTES; let total_elems = n_rows * n_cols;
if total_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_quant Q6K: elements {total_elems} not multiple of {QK_K}"
)));
}
let n_blocks = total_elems / QK_K;
let expected = n_blocks * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_quant Q6K: bytes {} != expected {} ({n_blocks} blocks)",
bytes.len(),
expected
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
})
}
other => Err(FerrumError::unsupported(format!(
"Metal load_quant: {other:?} not yet implemented"
))),
}
}
impl crate::backend::BackendQuantGguf for MetalBackend {
fn load_quant(
kind: GgufQuantType,
bytes: &[u8],
n_rows: usize,
n_cols: usize,
) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
let store = metal_load_quant_store_helper(kind, bytes, n_rows, n_cols)?;
Ok(Box::new(crate::quant_linear::metal_gguf::MetalGgufLinear {
store,
in_features: n_cols,
out_features: n_rows,
}))
}
fn load_quant_experts(
kind: GgufQuantType,
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<Box<dyn crate::StackedExpertGgufLinear<Self>>> {
let store = match kind {
GgufQuantType::Q4K => load_q4k_experts(bytes, num_experts, n_rows, n_cols)?,
GgufQuantType::Q6K => load_q6k_experts(bytes, num_experts, n_rows, n_cols)?,
other => {
return Err(FerrumError::unsupported(format!(
"Metal load_quant_experts: {other:?} not implemented (only Q4K / Q6K)"
)));
}
};
Ok(Box::new(
crate::quant_linear::metal_gguf_moe::MetalStackedExpertGgufLinear::new(store)?,
))
}
fn load_quant_fused(
parts: &[(GgufQuantType, &[u8], usize)],
n_cols: usize,
) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
let mut sub_stores: Vec<MetalQuantStore> = Vec::with_capacity(parts.len());
let mut total_rows = 0;
for (kind, bytes, n_rows) in parts {
let store = metal_load_quant_store_helper(*kind, bytes, *n_rows, n_cols)?;
if matches!(store, MetalQuantStore::Fused { .. }) {
return Err(FerrumError::model(
"Metal load_quant_fused: nested Fused not supported".to_string(),
));
}
total_rows += n_rows;
sub_stores.push(store);
}
let fused = MetalQuantStore::Fused {
parts: sub_stores,
total_rows,
n_cols,
};
Ok(Box::new(crate::quant_linear::metal_gguf::MetalGgufLinear {
store: fused,
in_features: n_cols,
out_features: total_rows,
}))
}
}
pub fn metal_gemm_quant_dispatch(
ctx: &mut <MetalBackend as Backend>::Context,
a: &<MetalBackend as Backend>::Buffer,
weight: &MetalQuantStore,
out: &mut <MetalBackend as Backend>::Buffer,
m: usize,
) -> Result<()> {
if let MetalQuantStore::Fused {
parts,
total_rows,
n_cols,
} = weight
{
if m != 1 {
let a_buf = a.expect_f32("gemm_quant a (fused)");
let out_buf = out.expect_f32_mut("gemm_quant out (fused)");
let enc = ctx.compute_encoder();
let mut col_off = 0usize;
for part in parts {
let part_rows = part.n_rows();
dispatch_part_gemm(
enc,
a_buf,
part,
out_buf,
col_off,
m,
part_rows,
*total_rows,
*n_cols,
)?;
col_off += part_rows;
}
return Ok(());
}
let a_buf = a.expect_f32("gemm_quant a (fused m=1)");
let out_buf = out.expect_f32_mut("gemm_quant out (fused m=1)");
let enc = ctx.compute_encoder();
let mut row_off_elems = 0usize;
for part in parts {
let part_rows = part.n_rows();
let c_off = (row_off_elems * 4) as u64;
dispatch_part_gemv_offset(enc, a_buf, 0, part, out_buf, c_off, *n_cols)?;
row_off_elems += part_rows;
}
return Ok(());
}
let (blocks, blocks_off, n_rows, n_cols, n_blocks, is_q6k) = match weight {
MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
} => (blocks, *byte_offset, *n_rows, *n_cols, *n_blocks, false),
MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
} => (blocks, *byte_offset, *n_rows, *n_cols, *n_blocks, true),
MetalQuantStore::Fused { .. } => unreachable!("handled above"),
MetalQuantStore::Q4KExperts { .. } | MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant: ExpertsStacked must be dispatched via gemv_moe_id".to_string(),
));
}
};
let _t0 = if debug_per_call_flush() {
Some(std::time::Instant::now())
} else {
None
};
let a_buf = a.expect_f32("gemm_quant a");
let out_buf = out.expect_f32_mut("gemm_quant out");
if m == 1 {
let enc = ctx.compute_encoder();
if is_q6k {
if n_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Q6K: n_rows={n_rows} not divisible by 4"
)));
}
crate::q6_k_gemv::dispatch_gemv_q6k_v2_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
} else if n_rows % 4 == 0 {
crate::q4_k_gemv_v2::dispatch_gemv_q4k_v2_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
} else {
crate::q4_k_gemv::dispatch_gemv_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
}
} else if is_q6k {
let enc = ctx.compute_encoder();
crate::q6_k_gemm::dispatch_gemm_q6k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
m,
n_rows,
n_cols,
);
} else {
let enc = ctx.compute_encoder();
crate::q4_k_gemm::dispatch_gemm_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
m,
n_rows,
n_cols,
);
let _ = n_blocks; }
if let Some(t0) = _t0 {
ctx.flush();
let elapsed_us = t0.elapsed().as_micros();
QUANT_GEMM_TIME_US.fetch_add(elapsed_us as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_M.store(m as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_N.store(n_rows as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_K.store(n_cols as u64, std::sync::atomic::Ordering::Relaxed);
if QUANT_GEMM_CALLS.load(std::sync::atomic::Ordering::Relaxed) <= 16 {
eprintln!(
"[gemm_quant] m={} n={} k={} took {} us",
m, n_rows, n_cols, elapsed_us
);
}
}
Ok(())
}