#![cfg(all(target_os = "macos", feature = "metal"))]
use std::ffi::c_void;
use std::sync::OnceLock;
use metal::{
Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, MTLSize,
};
const SHADER_SRC: &str = include_str!("q4_k_moe_id_gemv_batched.metal");
const KERNEL_NAME: &str = "gemv_q4kw_moe_id_batched_f32";
static PIPELINE: OnceLock<ComputePipelineState> = OnceLock::new();
fn pipeline(device: &Device) -> &'static ComputePipelineState {
PIPELINE.get_or_init(|| {
let lib = device
.new_library_with_source(SHADER_SRC, &CompileOptions::new())
.expect("compile q4_k_moe_id_gemv_batched.metal");
let function = lib
.get_function(KERNEL_NAME, None)
.expect("find gemv_q4kw_moe_id_batched_f32 function");
device
.new_compute_pipeline_state_with_function(&function)
.expect("build gemv_q4kw_moe_id_batched_f32 pipeline")
})
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gemv_q4k_moe_id_batched_on_encoder(
device: &Device,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
weights_stacked: &Buffer,
weights_byte_offset: u64,
ids: &Buffer,
out: &Buffer,
n: usize,
k: usize,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) {
debug_assert!(k % 256 == 0, "K must be a multiple of 256 (got {k})");
debug_assert!(n % 4 == 0, "N must be a multiple of 4 (got {n})");
debug_assert!(top_k > 0 && m > 0);
let nb01_bytes = (k / 256) * 144;
let nb02_bytes = n * nb01_bytes;
let n_pairs = m * top_k;
#[repr(C)]
struct P {
n: i32,
k: i32,
nb01: i32,
nb02: i32,
top_k: i32,
n_pairs: i32,
src1_outer_stride: i32,
src1_inner_stride: i32,
}
let params = P {
n: n as i32,
k: k as i32,
nb01: nb01_bytes as i32,
nb02: nb02_bytes as i32,
top_k: top_k as i32,
n_pairs: n_pairs as i32,
src1_outer_stride: src1_outer_stride as i32,
src1_inner_stride: src1_inner_stride as i32,
};
let pipe = pipeline(device);
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(weights_stacked), weights_byte_offset);
enc.set_buffer(1, Some(a), 0);
enc.set_buffer(2, Some(ids), 0);
enc.set_buffer(3, Some(out), 0);
enc.set_bytes(
4,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void,
);
const TILE_ROWS: u64 = 4;
let grid = MTLSize::new((n as u64).div_ceil(TILE_ROWS), 1, n_pairs as u64);
let tg = MTLSize::new(32, 2, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::quantized::{GgmlDType, QTensor};
use candle_core::{Device as CDevice, Tensor};
use metal::MTLResourceOptions;
#[test]
fn batched_matches_per_token_q4k_moe_gemv_gate_up() {
const NUM_EXPERTS: usize = 4;
const N: usize = 64;
const K: usize = 256;
const M: usize = 3;
const TOP_K: usize = 2;
let cpu = CDevice::Cpu;
let mut weights_bytes = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 313) % 251) as f32) * 0.013).sin() * 0.4 + 0.05)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
weights_bytes.extend_from_slice(&qt.data().unwrap());
}
const QK_K: usize = 256;
const BLOCK_BYTES: usize = 144;
assert_eq!(
weights_bytes.len(),
NUM_EXPERTS * N * (K / QK_K) * BLOCK_BYTES
);
let act: Vec<f32> = (0..M * K)
.map(|i| ((i as f32) * 0.0021).cos() * 0.7)
.collect();
let ids: Vec<i32> = vec![1, 3, 0, 2, 3, 1];
assert_eq!(ids.len(), M * TOP_K);
let Some(device) = metal::Device::system_default() else {
eprintln!("no Metal device — skipping");
return;
};
let queue = device.new_command_queue();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let w_buf = device.new_buffer_with_data(
weights_bytes.as_ptr() as *const _,
weights_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (M * TOP_K * N * 4) as u64;
let dst_per_token = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let dst_batched = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let cmd1 = queue.new_command_buffer();
let enc1 = cmd1.new_compute_command_encoder();
for t in 0..M {
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_offset_on_encoder(
&device,
enc1,
&a_buf,
(t * K * std::mem::size_of::<f32>()) as u64,
&w_buf,
0,
&ids_buf,
(t * TOP_K * std::mem::size_of::<i32>()) as u64,
&dst_per_token,
N,
K,
TOP_K,
0, );
}
enc1.end_encoding();
cmd1.commit();
cmd1.wait_until_completed();
let scratch = device.new_buffer(
(TOP_K * N * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
for t in 0..M {
let cmd_t = queue.new_command_buffer();
let enc_t = cmd_t.new_compute_command_encoder();
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_offset_on_encoder(
&device,
enc_t,
&a_buf,
(t * K * std::mem::size_of::<f32>()) as u64,
&w_buf,
0,
&ids_buf,
(t * TOP_K * std::mem::size_of::<i32>()) as u64,
&scratch,
N,
K,
TOP_K,
0,
);
enc_t.end_encoding();
let blit = cmd_t.new_blit_command_encoder();
blit.copy_from_buffer(
&scratch,
0,
&dst_per_token,
(t * TOP_K * N * 4) as u64,
(TOP_K * N * 4) as u64,
);
blit.end_encoding();
cmd_t.commit();
cmd_t.wait_until_completed();
}
let cmd2 = queue.new_command_buffer();
let enc2 = cmd2.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_batched_on_encoder(
&device,
enc2,
&a_buf,
&w_buf,
0,
&ids_buf,
&dst_batched,
N,
K,
M,
TOP_K,
K, 0, );
enc2.end_encoding();
cmd2.commit();
cmd2.wait_until_completed();
let len = M * TOP_K * N;
let per_token: &[f32] =
unsafe { std::slice::from_raw_parts(dst_per_token.contents() as *const f32, len) };
let batched: &[f32] =
unsafe { std::slice::from_raw_parts(dst_batched.contents() as *const f32, len) };
let mut max_abs = 0f32;
let mut mismatches = 0usize;
for (i, (&a, &b)) in per_token.iter().zip(batched.iter()).enumerate() {
let diff = (a - b).abs();
if diff > max_abs {
max_abs = diff;
}
let denom = a.abs().max(b.abs()).max(1e-3);
let rel = diff / denom;
if rel > 1e-5 && diff > 1e-5 {
mismatches += 1;
if mismatches < 5 {
eprintln!("[{i}] per_token={a} batched={b} diff={diff} rel={rel}");
}
}
}
eprintln!("max_abs={max_abs:.6} mismatches={mismatches}/{len}");
assert!(
mismatches == 0,
"batched output diverges from per-token — max_abs={max_abs:.6} \
({mismatches}/{len} mismatches)"
);
}
#[test]
fn batched_matches_per_token_q4k_moe_gemv_down() {
const NUM_EXPERTS: usize = 4;
const N: usize = 64;
const K: usize = 256;
const M: usize = 3;
const TOP_K: usize = 2;
let cpu = CDevice::Cpu;
let mut weights_bytes = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 251) % 199) as f32) * 0.011).cos() * 0.3 - 0.1)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
weights_bytes.extend_from_slice(&qt.data().unwrap());
}
let act: Vec<f32> = (0..M * TOP_K * K)
.map(|i| ((i as f32) * 0.0017).sin() * 0.5)
.collect();
let ids: Vec<i32> = vec![1, 3, 0, 2, 3, 1];
let Some(device) = metal::Device::system_default() else {
eprintln!("no Metal device — skipping");
return;
};
let queue = device.new_command_queue();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let w_buf = device.new_buffer_with_data(
weights_bytes.as_ptr() as *const _,
weights_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (M * TOP_K * N * 4) as u64;
let dst_per_token = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let dst_batched = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
let scratch = device.new_buffer(
(TOP_K * N * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
for t in 0..M {
let cmd_t = queue.new_command_buffer();
let enc_t = cmd_t.new_compute_command_encoder();
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_offset_on_encoder(
&device,
enc_t,
&a_buf,
(t * TOP_K * K * std::mem::size_of::<f32>()) as u64,
&w_buf,
0,
&ids_buf,
(t * TOP_K * std::mem::size_of::<i32>()) as u64,
&scratch,
N,
K,
TOP_K,
K, );
enc_t.end_encoding();
let blit = cmd_t.new_blit_command_encoder();
blit.copy_from_buffer(
&scratch,
0,
&dst_per_token,
(t * TOP_K * N * 4) as u64,
(TOP_K * N * 4) as u64,
);
blit.end_encoding();
cmd_t.commit();
cmd_t.wait_until_completed();
}
let cmd2 = queue.new_command_buffer();
let enc2 = cmd2.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_batched_on_encoder(
&device,
enc2,
&a_buf,
&w_buf,
0,
&ids_buf,
&dst_batched,
N,
K,
M,
TOP_K,
TOP_K * K, K, );
enc2.end_encoding();
cmd2.commit();
cmd2.wait_until_completed();
let len = M * TOP_K * N;
let per_token: &[f32] =
unsafe { std::slice::from_raw_parts(dst_per_token.contents() as *const f32, len) };
let batched: &[f32] =
unsafe { std::slice::from_raw_parts(dst_batched.contents() as *const f32, len) };
let mut max_abs = 0f32;
let mut mismatches = 0usize;
for (i, (&a, &b)) in per_token.iter().zip(batched.iter()).enumerate() {
let diff = (a - b).abs();
if diff > max_abs {
max_abs = diff;
}
let denom = a.abs().max(b.abs()).max(1e-3);
let rel = diff / denom;
if rel > 1e-5 && diff > 1e-5 {
mismatches += 1;
if mismatches < 5 {
eprintln!("[{i}] per_token={a} batched={b} diff={diff} rel={rel}");
}
}
}
eprintln!("max_abs={max_abs:.6} mismatches={mismatches}/{len}");
assert!(
mismatches == 0,
"batched(down) diverges from per-token — max_abs={max_abs:.6}"
);
}
#[test]
#[ignore = "microbench — only run manually with --release"]
fn moe_gemv_scaling_microbench() {
const NUM_EXPERTS: usize = 8;
const N: usize = 768;
const K: usize = 2048;
const TOP_K: usize = 8;
const M_VALUES: &[usize] = &[1, 2, 4, 8, 16];
const ITERS: usize = 500;
let cpu = CDevice::Cpu;
eprintln!(
"preparing synthetic Q4_K stack: {} experts × {} × {} ...",
NUM_EXPERTS, N, K
);
let mut weights_bytes = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 313) % 251) as f32) * 0.013).sin() * 0.4 + 0.05)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
weights_bytes.extend_from_slice(&qt.data().unwrap());
}
eprintln!(" stack size: {} MB", weights_bytes.len() / (1024 * 1024));
let Some(device) = metal::Device::system_default() else {
eprintln!("no Metal device — skipping");
return;
};
let queue = device.new_command_queue();
let w_buf = device.new_buffer_with_data(
weights_bytes.as_ptr() as *const _,
weights_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let max_m = *M_VALUES.iter().max().unwrap();
let act: Vec<f32> = (0..max_m * K)
.map(|i| ((i as f32) * 0.0021).cos() * 0.7)
.collect();
let ids: Vec<i32> = (0..max_m * TOP_K)
.map(|i| (i as i32) % NUM_EXPERTS as i32)
.collect();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (max_m * TOP_K * N * 4) as u64;
let out_buf = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
eprintln!();
eprintln!(
"{:<8} {:<10} {:<14} {:<14} {:<14} {:<14}",
"m", "n_pairs", "ms_per_iter", "us_per_pair", "us_per_pair_layer*", "scaling_vs_m1"
);
eprintln!(
"{:<8} {:<10} {:<14} {:<14} {:<14} {:<14}",
"-", "-", "-", "-", "-", "-"
);
let mut m1_us_per_pair: Option<f64> = None;
for &m in M_VALUES {
let n_pairs = m * TOP_K;
for _ in 0..3 {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_batched_on_encoder(
&device, enc, &a_buf, &w_buf, 0, &ids_buf, &out_buf, N, K, m, TOP_K, K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
let t0 = std::time::Instant::now();
for _ in 0..ITERS {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_batched_on_encoder(
&device, enc, &a_buf, &w_buf, 0, &ids_buf, &out_buf, N, K, m, TOP_K, K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
let elapsed_us = t0.elapsed().as_micros() as f64;
let us_per_iter = elapsed_us / ITERS as f64;
let us_per_pair = us_per_iter / n_pairs as f64;
let scaling = if let Some(m1) = m1_us_per_pair {
us_per_pair / m1
} else {
m1_us_per_pair = Some(us_per_pair);
1.0
};
eprintln!(
"{:<8} {:<10} {:<14.3} {:<14.3} {:<14.3} {:<14.3}",
m,
n_pairs,
us_per_iter / 1000.0,
us_per_pair,
us_per_pair,
scaling
);
}
eprintln!();
eprintln!("scaling_vs_m1 = us_per_pair(m) / us_per_pair(m=1)");
eprintln!(" = 1.0 → kernel scales perfectly with m (each pair takes same time as m=1)");
eprintln!(" < 1.0 → kernel benefits from batching (cache reuse, occupancy)");
eprintln!(" > 1.0 → kernel REGRESSES with m (cache thrashing, register spill)");
}
#[test]
#[ignore = "manually-run capture — needs MTL_CAPTURE_ENABLED=1"]
fn moe_gemv_capture_one_iter() {
use crate::q4_k_moe_id_gate_up_silu_batched::dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder;
use metal::{CaptureDescriptor, MTLCaptureDestination};
const NUM_EXPERTS: usize = 8;
const N: usize = 768; const K: usize = 2048; const TOP_K: usize = 8;
const M: usize = 16;
if std::env::var("MTL_CAPTURE_ENABLED").as_deref() != Ok("1") {
eprintln!("MTL_CAPTURE_ENABLED=1 not set — capture would silently no-op");
return;
}
let cpu = CDevice::Cpu;
eprintln!("preparing 2 synthetic Q4_K stacks (gate + up) ...");
let pack_stack = |seed: f32| -> Vec<u8> {
let mut buf = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 313) % 251) as f32) * 0.013).sin() * 0.4 + seed)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
buf.extend_from_slice(&qt.data().unwrap());
}
buf
};
let gate_bytes = pack_stack(0.05);
let up_bytes = pack_stack(-0.07);
let Some(device) = metal::Device::system_default() else {
eprintln!("no Metal device");
return;
};
let queue = device.new_command_queue();
let gate_buf = device.new_buffer_with_data(
gate_bytes.as_ptr() as *const _,
gate_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let up_buf = device.new_buffer_with_data(
up_bytes.as_ptr() as *const _,
up_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let act: Vec<f32> = (0..M * K)
.map(|i| ((i as f32) * 0.0021).cos() * 0.7)
.collect();
let ids: Vec<i32> = (0..M * TOP_K)
.map(|i| (i as i32) % NUM_EXPERTS as i32)
.collect();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (M * TOP_K * N * 4) as u64;
let out_buf = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
for _ in 0..3 {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder(
&device, enc, &a_buf, &gate_buf, 0, &up_buf, 0, &ids_buf, &out_buf, N, K, M, TOP_K,
K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
let trace_path = std::path::PathBuf::from("/tmp/ferrum_moe_gemv_m16.gputrace");
let _ = std::fs::remove_dir_all(&trace_path);
let descriptor = CaptureDescriptor::new();
descriptor.set_capture_command_queue(&queue);
descriptor.set_destination(MTLCaptureDestination::GpuTraceDocument);
descriptor.set_output_url(&trace_path);
let manager = metal::CaptureManager::shared();
manager.start_capture(&descriptor).expect("start_capture");
for _ in 0..5 {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder(
&device, enc, &a_buf, &gate_buf, 0, &up_buf, 0, &ids_buf, &out_buf, N, K, M, TOP_K,
K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
manager.stop_capture();
eprintln!();
eprintln!("===========================================================");
eprintln!(" GPU trace saved: {}", trace_path.display());
eprintln!();
eprintln!(" open with: open '{}'", trace_path.display());
eprintln!("===========================================================");
}
#[test]
#[ignore = "microbench — only run manually with --release"]
fn moe_gemv_fused_scaling_microbench() {
use crate::q4_k_moe_id_gate_up_silu_batched::dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder;
const NUM_EXPERTS: usize = 8;
const N: usize = 768;
const K: usize = 2048;
const TOP_K: usize = 8;
const M_VALUES: &[usize] = &[1, 2, 4, 8, 16];
const ITERS: usize = 500;
let cpu = CDevice::Cpu;
eprintln!("preparing 2 synthetic Q4_K stacks (gate + up) ...");
let pack_stack = |seed: f32| -> Vec<u8> {
let mut buf = Vec::new();
for e in 0..NUM_EXPERTS {
let raw: Vec<f32> = (0..N * K)
.map(|i| ((((i + e * 313) % 251) as f32) * 0.013).sin() * 0.4 + seed)
.collect();
let t = Tensor::from_vec(raw, (N, K), &cpu).unwrap();
let qt = QTensor::quantize(&t, GgmlDType::Q4K).unwrap();
buf.extend_from_slice(&qt.data().unwrap());
}
buf
};
let gate_bytes = pack_stack(0.05);
let up_bytes = pack_stack(-0.07);
eprintln!(" each stack: {} MB", gate_bytes.len() / (1024 * 1024));
let Some(device) = metal::Device::system_default() else {
return;
};
let queue = device.new_command_queue();
let gate_buf = device.new_buffer_with_data(
gate_bytes.as_ptr() as *const _,
gate_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let up_buf = device.new_buffer_with_data(
up_bytes.as_ptr() as *const _,
up_bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
let max_m = *M_VALUES.iter().max().unwrap();
let act: Vec<f32> = (0..max_m * K)
.map(|i| ((i as f32) * 0.0021).cos() * 0.7)
.collect();
let ids: Vec<i32> = (0..max_m * TOP_K)
.map(|i| (i as i32) % NUM_EXPERTS as i32)
.collect();
let a_buf = device.new_buffer_with_data(
act.as_ptr() as *const _,
(act.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let ids_buf = device.new_buffer_with_data(
ids.as_ptr() as *const _,
(ids.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
let out_size = (max_m * TOP_K * N * 4) as u64;
let out_buf = device.new_buffer(out_size, MTLResourceOptions::StorageModeShared);
eprintln!();
eprintln!(
"{:<8} {:<10} {:<14} {:<14} {:<14}",
"m", "n_pairs", "ms_per_iter", "us_per_pair", "scaling_vs_m1"
);
let mut m1_us_per_pair: Option<f64> = None;
for &m in M_VALUES {
let n_pairs = m * TOP_K;
for _ in 0..3 {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder(
&device, enc, &a_buf, &gate_buf, 0, &up_buf, 0, &ids_buf, &out_buf, N, K, m,
TOP_K, K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
let t0 = std::time::Instant::now();
for _ in 0..ITERS {
let cmd = queue.new_command_buffer();
let enc = cmd.new_compute_command_encoder();
dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder(
&device, enc, &a_buf, &gate_buf, 0, &up_buf, 0, &ids_buf, &out_buf, N, K, m,
TOP_K, K, 0,
);
enc.end_encoding();
cmd.commit();
cmd.wait_until_completed();
}
let elapsed_us = t0.elapsed().as_micros() as f64;
let us_per_iter = elapsed_us / ITERS as f64;
let us_per_pair = us_per_iter / n_pairs as f64;
let scaling = if let Some(m1) = m1_us_per_pair {
us_per_pair / m1
} else {
m1_us_per_pair = Some(us_per_pair);
1.0
};
eprintln!(
"{:<8} {:<10} {:<14.3} {:<14.3} {:<14.3}",
m,
n_pairs,
us_per_iter / 1000.0,
us_per_pair,
scaling
);
}
}
}