#![cfg(all(target_os = "macos", feature = "metal"))]
use std::ffi::c_void;
use std::sync::OnceLock;
use metal::{
Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLSize,
};
const SHADER_SRC: &str = include_str!("q4_k_moe_id_gate_up_silu.metal");
const KERNEL_NAME: &str = "gemv_q4kw_moe_id_gate_up_silu_f32";
static PIPELINE: OnceLock<ComputePipelineState> = OnceLock::new();
fn pipeline(device: &Device) -> &'static ComputePipelineState {
PIPELINE.get_or_init(|| {
let lib = device
.new_library_with_source(SHADER_SRC, &CompileOptions::new())
.expect("compile q4_k_moe_id_gate_up_silu.metal");
let function = lib
.get_function(KERNEL_NAME, None)
.expect("find gemv_q4kw_moe_id_gate_up_silu_f32 function");
device
.new_compute_pipeline_state_with_function(&function)
.expect("build gemv_q4kw_moe_id_gate_up_silu_f32 pipeline")
})
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gemv_q4k_moe_id_gate_up_silu_on_encoder(
device: &Device,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
gate_w_stacked: &Buffer,
gate_w_byte_offset: u64,
up_w_stacked: &Buffer,
up_w_byte_offset: u64,
ids: &Buffer,
silu_out: &Buffer,
n: usize,
k: usize,
n_selected: usize,
) {
debug_assert!(k % 256 == 0, "K must be a multiple of 256 (got {k})");
debug_assert!(n % 4 == 0, "N must be a multiple of 4 (got {n})");
let nb01_bytes = (k / 256) * 144;
let nb02_bytes = n * nb01_bytes;
#[repr(C)]
struct P {
n: i32,
k: i32,
nb01: i32,
nb02: i32,
n_selected: i32,
}
let params = P {
n: n as i32,
k: k as i32,
nb01: nb01_bytes as i32,
nb02: nb02_bytes as i32,
n_selected: n_selected as i32,
};
let pipe = pipeline(device);
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(gate_w_stacked), gate_w_byte_offset);
enc.set_buffer(1, Some(up_w_stacked), up_w_byte_offset);
enc.set_buffer(2, Some(a), 0);
enc.set_buffer(3, Some(ids), 0);
enc.set_buffer(4, Some(silu_out), 0);
enc.set_bytes(
5,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void,
);
const TILE_ROWS: u64 = 4;
let grid = MTLSize::new((n as u64).div_ceil(TILE_ROWS), 1, n_selected as u64);
let tg = MTLSize::new(32, 2, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::quantized::{GgmlDType, QTensor};
use candle_core::{Device as CDevice, Tensor};
use metal::MTLResourceOptions;
#[test]
fn fused_matches_unfused_q4k_moe_gate_up_silu() {
const NUM_EXPERTS: usize = 4;
const N: usize = 64; const K: usize = 256; const TOP_K: usize = 2;
let cpu = CDevice::Cpu;
let pack_stack = |seed: f32| -> Vec<u8> {
let mut buf = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 313) % 251) as f32) * 0.013).sin() * 0.4 + seed)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
buf.extend_from_slice(&qt.data().unwrap());
}
buf
};
let gate_bytes = pack_stack(0.05);
let up_bytes = pack_stack(-0.07);
const QK_K: usize = 256;
const BLOCK_BYTES: usize = 144;
let expected = NUM_EXPERTS * N * (K / QK_K) * BLOCK_BYTES;
assert_eq!(gate_bytes.len(), expected);
assert_eq!(up_bytes.len(), expected);
let act: Vec<f32> = (0..K).map(|i| ((i as f32) * 0.0021).cos() * 0.7).collect();
let ids: Vec<i32> = vec![1, 3];
let Some(device) = metal::Device::system_default() else {
eprintln!("no Metal device — skipping");
return;
};
let queue = device.new_command_queue();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let gate_buf = device.new_buffer_with_data(
gate_bytes.as_ptr() as *const _,
gate_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let up_buf = device.new_buffer_with_data(
up_bytes.as_ptr() as *const _,
up_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (TOP_K * N * 4) as u64;
let gate_out = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let up_out = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let silu_unfused = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let silu_fused = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let cmd1 = queue.new_command_buffer();
let enc1 = cmd1.new_compute_command_encoder();
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_on_encoder(
&device, enc1, &a_buf, &gate_buf, 0, &ids_buf, &gate_out, N, K, TOP_K, 0,
);
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_on_encoder(
&device, enc1, &a_buf, &up_buf, 0, &ids_buf, &up_out, N, K, TOP_K, 0,
);
crate::moe_post_ops::dispatch_silu_mul_stacked(
&device,
enc1,
&gate_out,
&up_out,
&silu_unfused,
TOP_K,
N,
);
enc1.end_encoding();
cmd1.commit();
cmd1.wait_until_completed();
let cmd2 = queue.new_command_buffer();
let enc2 = cmd2.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_gate_up_silu_on_encoder(
&device,
enc2,
&a_buf,
&gate_buf,
0,
&up_buf,
0,
&ids_buf,
&silu_fused,
N,
K,
TOP_K,
);
enc2.end_encoding();
cmd2.commit();
cmd2.wait_until_completed();
let len = TOP_K * N;
let unfused: &[f32] =
unsafe { std::slice::from_raw_parts(silu_unfused.contents() as *const f32, len) };
let fused: &[f32] =
unsafe { std::slice::from_raw_parts(silu_fused.contents() as *const f32, len) };
let mut max_abs = 0f32;
let mut mismatches = 0usize;
for (i, (&u, &f)) in unfused.iter().zip(fused.iter()).enumerate() {
let diff = (u - f).abs();
if diff > max_abs {
max_abs = diff;
}
let denom = u.abs().max(f.abs()).max(1e-3);
let rel = diff / denom;
if rel > 1e-5 && diff > 1e-5 {
mismatches += 1;
if mismatches < 5 {
eprintln!("[{i}] unfused={u} fused={f} diff={diff} rel={rel}");
}
}
}
eprintln!("max_abs={max_abs:.6} mismatches={mismatches}/{len}");
assert!(
mismatches == 0,
"fused output diverges from unfused — max_abs={max_abs:.6} \
({mismatches}/{len} mismatches)"
);
}
}