use super::{AttnConfig, Backend, SrcDtype};
use ferrum_attention::metal::pipelines::MetalPipelines;
use ferrum_attention::AttentionParams;
use ferrum_types::{FerrumError, Result};
use half::{bf16, f16};
use metal::{Device, MTLResourceOptions};
use std::ffi::c_void;
use std::sync::{Arc, Mutex, OnceLock};
struct MetalMmapEntry {
base_addr: usize,
len: usize,
_keeper: Arc<dyn std::any::Any + Send + Sync>,
}
struct MetalState {
pipes: MetalPipelines,
mmaps: Mutex<Vec<MetalMmapEntry>>,
}
static METAL_STATE: OnceLock<MetalState> = OnceLock::new();
fn st() -> &'static MetalState {
METAL_STATE.get_or_init(|| MetalState {
pipes: MetalPipelines::new(&Device::system_default().unwrap()),
mmaps: Mutex::new(Vec::new()),
})
}
pub fn register_gguf_mmap(
slice: &[u8],
keeper: Arc<dyn std::any::Any + Send + Sync>,
) -> Result<()> {
const PAGE: usize = 16384;
let base_addr = slice.as_ptr() as usize;
if !base_addr.is_multiple_of(PAGE) {
return Err(FerrumError::model(format!(
"register_gguf_mmap: base pointer 0x{base_addr:x} not page-aligned (need {PAGE})"
)));
}
let trace = std::env::var("FERRUM_MMAP_TRACE").is_ok();
if trace {
eprintln!(
"[mmap] register file at 0x{base_addr:x} len={} ({:.2} GB)",
slice.len(),
slice.len() as f64 / 1e9
);
}
let mut guard = st()
.mmaps
.lock()
.map_err(|e| FerrumError::model(format!("register_gguf_mmap: registry poisoned: {e}")))?;
if guard
.iter()
.any(|e| e.base_addr == base_addr && e.len == slice.len())
{
return Ok(());
}
guard.push(MetalMmapEntry {
base_addr,
len: slice.len(),
_keeper: keeper,
});
Ok(())
}
#[inline(never)]
fn slice_is_in_registered_mmap(bytes: &[u8]) -> bool {
let ptr = bytes.as_ptr() as usize;
let len = bytes.len();
let end = match ptr.checked_add(len) {
Some(e) => e,
None => return false,
};
let guard = match st().mmaps.lock() {
Ok(g) => g,
Err(_) => return false,
};
for entry in guard.iter() {
let entry_end = match entry.base_addr.checked_add(entry.len) {
Some(e) => e,
None => continue,
};
if ptr >= entry.base_addr && end <= entry_end {
return true;
}
}
false
}
pub fn maybe_begin_frame_capture() -> bool {
use metal::{CaptureDescriptor, CaptureManager, MTLCaptureDestination};
let Ok(out_path) = std::env::var("FERRUM_METAL_CAPTURE") else {
return false;
};
if std::env::var("MTL_CAPTURE_ENABLED").is_err() {
eprintln!(
"[capture] FERRUM_METAL_CAPTURE set but MTL_CAPTURE_ENABLED is not — Metal will reject. Re-launch with MTL_CAPTURE_ENABLED=1."
);
return false;
}
let mgr = CaptureManager::shared();
if !mgr.supports_destination(MTLCaptureDestination::GpuTraceDocument) {
eprintln!("[capture] device does not support GpuTraceDocument");
return false;
}
let desc = CaptureDescriptor::new();
desc.set_capture_device(&st().pipes.device);
desc.set_destination(MTLCaptureDestination::GpuTraceDocument);
desc.set_output_url(&out_path);
match mgr.start_capture(&desc) {
Ok(()) => {
eprintln!("[capture] started → {out_path}");
true
}
Err(e) => {
eprintln!("[capture] start_capture failed: {e}");
false
}
}
}
pub fn end_frame_capture() {
metal::CaptureManager::shared().stop_capture();
eprintln!("[capture] stopped — open .gputrace in Xcode");
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Dtype {
F32,
F16,
}
impl Dtype {
pub const fn bytes_per_elem(self) -> usize {
match self {
Dtype::F32 => 4,
Dtype::F16 => 2,
}
}
}
pub struct MetalBuf {
raw: metal::Buffer,
dtype: Dtype,
n: usize,
}
impl MetalBuf {
pub fn raw(&self) -> &metal::Buffer {
&self.raw
}
pub fn raw_mut(&mut self) -> &mut metal::Buffer {
&mut self.raw
}
pub fn dtype(&self) -> Dtype {
self.dtype
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
pub fn is_f16(&self) -> bool {
matches!(self.dtype, Dtype::F16)
}
#[inline]
fn expect_f32<'a>(&'a self, what: &str) -> &'a metal::Buffer {
debug_assert!(
matches!(self.dtype, Dtype::F32),
"{what}: expected F32 buffer, got {:?}",
self.dtype
);
&self.raw
}
#[inline]
fn expect_f32_mut<'a>(&'a mut self, what: &str) -> &'a mut metal::Buffer {
debug_assert!(
matches!(self.dtype, Dtype::F32),
"{what}: expected F32 buffer, got {:?}",
self.dtype
);
&mut self.raw
}
}
unsafe impl Send for MetalBuf {}
unsafe impl Sync for MetalBuf {}
pub struct MetalContext {
cmd: Option<&'static metal::CommandBufferRef>,
encoder: Option<&'static metal::ComputeCommandEncoderRef>,
}
impl MetalContext {
fn cmd(&mut self) -> &'static metal::CommandBufferRef {
match self.cmd {
Some(c) => c,
None => {
let c = st().pipes.queue.new_command_buffer();
let c_static: &'static metal::CommandBufferRef =
unsafe { std::mem::transmute::<&metal::CommandBufferRef, _>(c) };
self.cmd = Some(c_static);
c_static
}
}
}
fn compute_encoder(&mut self) -> &'static metal::ComputeCommandEncoderRef {
if let Some(enc) = self.encoder {
return enc;
}
let cmd = self.cmd();
let enc = cmd.new_compute_command_encoder();
let enc_static: &'static metal::ComputeCommandEncoderRef =
unsafe { std::mem::transmute::<&metal::ComputeCommandEncoderRef, _>(enc) };
self.encoder = Some(enc_static);
enc_static
}
fn compute_encoder_end(&mut self) {
if let Some(enc) = self.encoder.take() {
enc.end_encoding();
}
}
fn flush(&mut self) {
self.compute_encoder_end();
if let Some(cmd) = self.cmd.take() {
cmd.commit();
cmd.wait_until_completed();
}
}
}
static QUANT_GEMM_TIME_US: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_M: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_N: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
static QUANT_GEMM_LAST_K: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
fn debug_per_call_flush() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| std::env::var("FERRUM_METAL_QUANT_PROFILE").is_ok())
}
fn prefer_f16_weights() -> bool {
static FLAG: OnceLock<bool> = OnceLock::new();
*FLAG.get_or_init(|| {
std::env::var("FERRUM_METAL_DTYPE")
.map(|v| v.eq_ignore_ascii_case("f16"))
.unwrap_or(false)
})
}
const F16_MIN_ELEMS: usize = 1_048_576;
fn alloc_f32_raw(n: usize) -> metal::Buffer {
st().pipes.buffer_empty(n)
}
fn buffer_from_f32_slice(data: &[f32]) -> metal::Buffer {
st().pipes.buffer_from_data(data)
}
fn buffer_from_f16_bytes(bytes: &[u8]) -> metal::Buffer {
debug_assert_eq!(bytes.len() % 2, 0);
st().pipes.device.new_buffer_with_data(
bytes.as_ptr() as *const c_void,
bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
)
}
fn buffer_f16_from_f32(data: &[f32]) -> metal::Buffer {
let n = data.len();
let mut f16_bytes = vec![0u8; n * 2];
for i in 0..n {
let h = f16::from_f32(data[i]).to_le_bytes();
f16_bytes[i * 2] = h[0];
f16_bytes[i * 2 + 1] = h[1];
}
buffer_from_f16_bytes(&f16_bytes)
}
pub struct MetalBackend;
pub enum MetalQuantStore {
Q4K {
blocks: metal::Buffer,
byte_offset: u64,
n_rows: usize,
n_cols: usize,
n_blocks: usize,
},
Q6K {
blocks: metal::Buffer,
byte_offset: u64,
n_rows: usize,
n_cols: usize,
n_blocks: usize,
},
Fused {
parts: Vec<MetalQuantStore>,
total_rows: usize,
n_cols: usize,
},
Q4KExperts {
blocks: metal::Buffer,
byte_offset: u64,
num_experts: usize,
n_rows: usize, n_cols: usize, },
Q6KExperts {
blocks: metal::Buffer,
byte_offset: u64,
num_experts: usize,
n_rows: usize,
n_cols: usize,
},
}
impl MetalQuantStore {
fn n_rows(&self) -> usize {
match self {
MetalQuantStore::Q4K { n_rows, .. } | MetalQuantStore::Q6K { n_rows, .. } => *n_rows,
MetalQuantStore::Fused { total_rows, .. } => *total_rows,
MetalQuantStore::Q4KExperts { n_rows, .. }
| MetalQuantStore::Q6KExperts { n_rows, .. } => *n_rows,
}
}
}
fn buffer_for_quant_bytes(bytes: &[u8]) -> (metal::Buffer, u64) {
const PAGE: usize = 16384;
let trace = std::env::var("FERRUM_MMAP_TRACE").is_ok();
if slice_is_in_registered_mmap(bytes) {
let ptr_addr = bytes.as_ptr() as usize;
let aligned_start = ptr_addr & !(PAGE - 1);
let aligned_end = (ptr_addr + bytes.len()).div_ceil(PAGE) * PAGE;
let aligned_len = aligned_end - aligned_start;
let byte_offset = (ptr_addr - aligned_start) as u64;
let buf = st().pipes.device.new_buffer_with_bytes_no_copy(
aligned_start as *const c_void,
aligned_len as u64,
MTLResourceOptions::StorageModeShared,
None,
);
if buf.length() != 0 {
if trace {
eprintln!(
"[mmap] zero-copy: tensor ptr=0x{ptr_addr:x} len={} -> buf @0x{aligned_start:x} len={aligned_len} off={byte_offset}",
bytes.len()
);
}
return (buf, byte_offset);
}
if trace {
eprintln!(
"[mmap] zero-copy refused for tensor ptr=0x{ptr_addr:x} len={} aligned_len={aligned_len} — copying",
bytes.len()
);
}
}
if trace {
eprintln!("[mmap] copy: ptr={:p} len={}", bytes.as_ptr(), bytes.len());
}
let buf = st().pipes.device.new_buffer_with_data(
bytes.as_ptr() as *const c_void,
bytes.len() as u64,
MTLResourceOptions::StorageModeShared,
);
(buf, 0)
}
pub fn load_q4k_experts(
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<MetalQuantStore> {
const QK_K: usize = 256;
const BLOCK_BYTES: usize = 144;
if n_cols % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_q4k_experts: n_cols {n_cols} not a multiple of {QK_K}"
)));
}
let expected = num_experts * n_rows * (n_cols / QK_K) * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_q4k_experts: bytes {} != expected {expected} ({num_experts}E × {n_rows}R × {n_cols}C)",
bytes.len()
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
})
}
pub fn load_q6k_experts(
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<MetalQuantStore> {
const QK_K: usize = 256;
const BLOCK_BYTES: usize = crate::q6_k_gemv::Q6_K_BLOCK_BYTES;
if n_cols % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_q6k_experts: n_cols {n_cols} not a multiple of {QK_K}"
)));
}
let expected = num_experts * n_rows * (n_cols / QK_K) * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_q6k_experts: bytes {} != expected {expected}",
bytes.len()
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
})
}
pub fn dispatch_gemv_moe_id(
enc: &metal::ComputeCommandEncoderRef,
a: &metal::Buffer,
weights: &MetalQuantStore,
ids: &metal::Buffer,
out: &metal::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
match weights {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_on_encoder(
&st().pipes.device,
enc,
a,
blocks,
*byte_offset,
ids,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q6_k_moe_id_gemv::dispatch_gemv_q6k_moe_id_on_encoder(
&st().pipes.device,
enc,
a,
blocks,
*byte_offset,
ids,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
_ => Err(FerrumError::model(
"dispatch_gemv_moe_id: weights must be Q4KExperts or Q6KExperts variant".to_string(),
)),
}
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gemv_moe_id_offset(
enc: &metal::ComputeCommandEncoderRef,
a: &metal::Buffer,
a_byte_offset: u64,
weights: &MetalQuantStore,
ids: &metal::Buffer,
ids_byte_offset: u64,
out: &metal::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
match weights {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q4_k_moe_id_gemv::dispatch_gemv_q4k_moe_id_offset_on_encoder(
&st().pipes.device,
enc,
a,
a_byte_offset,
blocks,
*byte_offset,
ids,
ids_byte_offset,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q6_k_moe_id_gemv::dispatch_gemv_q6k_moe_id_offset_on_encoder(
&st().pipes.device,
enc,
a,
a_byte_offset,
blocks,
*byte_offset,
ids,
ids_byte_offset,
out,
*n_rows,
*n_cols,
n_selected,
src1_stride,
);
Ok(())
}
_ => Err(FerrumError::model(
"dispatch_gemv_moe_id_offset: weights must be Q4KExperts or Q6KExperts variant"
.to_string(),
)),
}
}
unsafe impl Send for MetalQuantStore {}
unsafe impl Sync for MetalQuantStore {}
fn dispatch_part_gemm(
enc: &metal::ComputeCommandEncoderRef,
a_buf: &metal::Buffer,
part: &MetalQuantStore,
out_buf: &metal::Buffer,
c_offset_cols: usize,
m: usize,
part_rows: usize,
stride_c: usize,
n_cols: usize,
) -> Result<()> {
if part_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Fused: part rows {part_rows} not divisible by 4"
)));
}
match part {
MetalQuantStore::Q4K {
blocks,
byte_offset,
..
} => {
crate::q4_k_gemm::dispatch_gemm_q4k_part(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
c_offset_cols,
m,
part_rows,
stride_c,
n_cols,
);
}
MetalQuantStore::Q6K {
blocks,
byte_offset,
..
} => {
crate::q6_k_gemm::dispatch_gemm_q6k_part(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
c_offset_cols,
m,
part_rows,
stride_c,
n_cols,
);
}
MetalQuantStore::Fused { .. }
| MetalQuantStore::Q4KExperts { .. }
| MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant Fused: only Q4K/Q6K leaf parts supported here".to_string(),
));
}
}
Ok(())
}
fn dispatch_part_gemv_offset(
enc: &metal::ComputeCommandEncoderRef,
a_buf: &metal::Buffer,
a_offset_bytes: u64,
part: &MetalQuantStore,
out_buf: &metal::Buffer,
c_offset_bytes: u64,
n_cols: usize,
) -> Result<()> {
match part {
MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
..
} => {
if *n_rows % 4 != 0 {
crate::q4_k_gemv::dispatch_gemv_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
out_buf,
*n_rows,
n_cols,
);
if a_offset_bytes != 0 || c_offset_bytes != 0 {
return Err(FerrumError::model(
"gemm_quant Fused: q4k v1 path doesn't support offsets yet".to_string(),
));
}
return Ok(());
}
crate::q4_k_gemv_v2::dispatch_gemv_q4k_v2_offset(
&st().pipes.device,
enc,
a_buf,
a_offset_bytes,
blocks,
*byte_offset,
out_buf,
c_offset_bytes,
*n_rows,
n_cols,
);
}
MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
..
} => {
if *n_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Fused: Q6K part n_rows={n_rows} not divisible by 4"
)));
}
crate::q6_k_gemv::dispatch_gemv_q6k_v2_offset(
&st().pipes.device,
enc,
a_buf,
a_offset_bytes,
blocks,
*byte_offset,
out_buf,
c_offset_bytes,
*n_rows,
n_cols,
);
}
MetalQuantStore::Fused { .. }
| MetalQuantStore::Q4KExperts { .. }
| MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant Fused: only Q4K/Q6K leaf parts supported here".to_string(),
));
}
}
Ok(())
}
impl Backend for MetalBackend {
type Buffer = MetalBuf;
type Context = MetalContext;
type GptqStore = (); type QuantStore = MetalQuantStore;
fn new_context() -> Self::Context {
MetalContext {
cmd: None,
encoder: None,
}
}
fn sync(ctx: &mut Self::Context) {
ctx.flush();
}
fn load_quant(
kind: super::GgufQuantType,
bytes: &[u8],
n_rows: usize,
n_cols: usize,
) -> Result<Self::QuantStore> {
use super::GgufQuantType;
const QK_K: usize = 256;
match kind {
GgufQuantType::Q4K => {
const BLOCK_BYTES: usize = 144;
let total_elems = n_rows * n_cols;
if total_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_quant Q4K: elements {total_elems} not multiple of {QK_K}"
)));
}
let n_blocks = total_elems / QK_K;
let expected = n_blocks * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_quant Q4K: bytes {} != expected {} ({n_blocks} blocks)",
bytes.len(),
expected
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
})
}
GgufQuantType::Q6K => {
const BLOCK_BYTES: usize = crate::q6_k_gemv::Q6_K_BLOCK_BYTES; let total_elems = n_rows * n_cols;
if total_elems % QK_K != 0 {
return Err(FerrumError::model(format!(
"load_quant Q6K: elements {total_elems} not multiple of {QK_K}"
)));
}
let n_blocks = total_elems / QK_K;
let expected = n_blocks * BLOCK_BYTES;
if bytes.len() != expected {
return Err(FerrumError::model(format!(
"load_quant Q6K: bytes {} != expected {} ({n_blocks} blocks)",
bytes.len(),
expected
)));
}
let (blocks, byte_offset) = buffer_for_quant_bytes(bytes);
Ok(MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
})
}
other => Err(FerrumError::unsupported(format!(
"Metal load_quant: {other:?} not yet implemented"
))),
}
}
fn load_quant_experts(
kind: super::GgufQuantType,
bytes: &[u8],
num_experts: usize,
n_rows: usize,
n_cols: usize,
) -> Result<Self::QuantStore> {
use super::GgufQuantType;
match kind {
GgufQuantType::Q4K => load_q4k_experts(bytes, num_experts, n_rows, n_cols),
GgufQuantType::Q6K => load_q6k_experts(bytes, num_experts, n_rows, n_cols),
other => Err(FerrumError::unsupported(format!(
"Metal load_quant_experts: {other:?} not implemented (only Q4K / Q6K)"
))),
}
}
fn gemv_quant_moe_id(
ctx: &mut Self::Context,
a: &Self::Buffer,
weight: &Self::QuantStore,
ids: &Self::Buffer,
out: &mut Self::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
let a_buf = a.expect_f32("gemv_quant_moe_id a");
let ids_buf = &ids.raw;
let out_buf = out.expect_f32_mut("gemv_quant_moe_id out");
let enc = ctx.compute_encoder();
dispatch_gemv_moe_id(
enc,
a_buf,
weight,
ids_buf,
out_buf,
n_selected,
src1_stride,
)
}
fn supports_batched_moe_gemv() -> bool {
true
}
fn supports_batched_moe_gate_up_silu() -> bool {
true
}
fn gemv_quant_moe_id_gate_up_silu_batched(
ctx: &mut Self::Context,
a: &Self::Buffer,
gate_w: &Self::QuantStore,
up_w: &Self::QuantStore,
ids: &Self::Buffer,
silu_out: &mut Self::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let (gate_blocks, gate_byte_offset, gate_n_rows, gate_n_cols) = match gate_w {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => (blocks, *byte_offset, *n_rows, *n_cols),
_ => {
return Err(FerrumError::model(
"gemv_quant_moe_id_gate_up_silu_batched: gate_w must be Q4KExperts".to_string(),
));
}
};
let (up_blocks, up_byte_offset, up_n_rows, up_n_cols) = match up_w {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => (blocks, *byte_offset, *n_rows, *n_cols),
_ => {
return Err(FerrumError::model(
"gemv_quant_moe_id_gate_up_silu_batched: up_w must be Q4KExperts".to_string(),
));
}
};
if gate_n_rows != up_n_rows || gate_n_cols != up_n_cols {
return Err(FerrumError::model(format!(
"gemv_quant_moe_id_gate_up_silu_batched: gate/up shape mismatch — \
gate=({gate_n_rows}, {gate_n_cols}) up=({up_n_rows}, {up_n_cols})"
)));
}
let a_buf = a.expect_f32("gemv_quant_moe_id_gate_up_silu_batched a");
let ids_buf = &ids.raw;
let out_buf = silu_out.expect_f32_mut("gemv_quant_moe_id_gate_up_silu_batched silu_out");
let enc = ctx.compute_encoder();
crate::q4_k_moe_id_gate_up_silu_batched::dispatch_gemv_q4k_moe_id_gate_up_silu_batched_on_encoder(
&st().pipes.device,
enc,
a_buf,
gate_blocks,
gate_byte_offset,
up_blocks,
up_byte_offset,
ids_buf,
out_buf,
gate_n_rows,
gate_n_cols,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
);
Ok(())
}
fn gemv_quant_moe_id_batched(
ctx: &mut Self::Context,
a: &Self::Buffer,
weight: &Self::QuantStore,
ids: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
top_k: usize,
src1_outer_stride: usize,
src1_inner_stride: usize,
) -> Result<()> {
let a_buf = a.expect_f32("gemv_quant_moe_id_batched a");
let ids_buf = &ids.raw;
let out_buf = out.expect_f32_mut("gemv_quant_moe_id_batched out");
let enc = ctx.compute_encoder();
match weight {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q4_k_moe_id_gemv_batched::dispatch_gemv_q4k_moe_id_batched_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
ids_buf,
out_buf,
*n_rows,
*n_cols,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => {
crate::q6_k_moe_id_gemv_batched::dispatch_gemv_q6k_moe_id_batched_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
*byte_offset,
ids_buf,
out_buf,
*n_rows,
*n_cols,
m,
top_k,
src1_outer_stride,
src1_inner_stride,
);
Ok(())
}
_ => Err(FerrumError::model(
"gemv_quant_moe_id_batched: weight must be Q4KExperts or Q6KExperts".to_string(),
)),
}
}
fn gemv_quant_moe_id_offset(
ctx: &mut Self::Context,
a: &Self::Buffer,
a_offset: usize,
weight: &Self::QuantStore,
ids: &Self::Buffer,
ids_offset: usize,
out: &mut Self::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
let a_buf = a.expect_f32("gemv_quant_moe_id_offset a");
let ids_buf = &ids.raw;
let out_buf = out.expect_f32_mut("gemv_quant_moe_id_offset out");
let enc = ctx.compute_encoder();
let a_byte_offset = (a_offset * std::mem::size_of::<f32>()) as u64;
let ids_byte_offset = (ids_offset * std::mem::size_of::<i32>()) as u64;
dispatch_gemv_moe_id_offset(
enc,
a_buf,
a_byte_offset,
weight,
ids_buf,
ids_byte_offset,
out_buf,
n_selected,
src1_stride,
)
}
fn from_slice_i32(data: &[i32]) -> Self::Buffer {
let bytes = data.len() * std::mem::size_of::<i32>();
let raw = st().pipes.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
bytes as u64,
MTLResourceOptions::StorageModeShared,
);
MetalBuf {
raw,
dtype: Dtype::F32,
n: data.len(),
}
}
fn gemm_quant_moe_id(
ctx: &mut Self::Context,
a: &Self::Buffer,
weight: &Self::QuantStore,
ids: &Self::Buffer,
tpe: &Self::Buffer,
out: &mut Self::Buffer,
ne11: usize,
top_k: usize,
max_per_expert: usize,
batch: usize,
) -> Result<()> {
let a_buf = a.expect_f32("gemm_quant_moe_id a");
let ids_buf = &ids.raw;
let tpe_buf = &tpe.raw;
let out_buf = out.expect_f32_mut("gemm_quant_moe_id out");
let enc = ctx.compute_encoder();
match weight {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
} => {
crate::q4_k_moe_id_gemm::dispatch_gemm_q4k_moe_id_on_encoder(
&st().pipes.device,
enc,
blocks,
*byte_offset,
a_buf,
ids_buf,
tpe_buf,
out_buf,
*num_experts,
*n_rows,
*n_cols,
ne11,
top_k,
max_per_expert,
batch,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
} => {
crate::q6_k_moe_id_gemm::dispatch_gemm_q6k_moe_id_on_encoder(
&st().pipes.device,
enc,
blocks,
*byte_offset,
a_buf,
ids_buf,
tpe_buf,
out_buf,
*num_experts,
*n_rows,
*n_cols,
ne11,
top_k,
max_per_expert,
batch,
);
Ok(())
}
_ => Err(FerrumError::model(
"gemm_quant_moe_id: weight must be Q4KExperts or Q6KExperts".to_string(),
)),
}
}
fn gemm_quant_moe_id_indirect(
ctx: &mut Self::Context,
a: &Self::Buffer,
weight: &Self::QuantStore,
ids: &Self::Buffer,
tpe: &Self::Buffer,
out: &mut Self::Buffer,
args_buf: &Self::Buffer,
ne11: usize,
top_k: usize,
max_per_expert: usize,
batch: usize,
) -> Result<()> {
let a_buf = a.expect_f32("gemm_quant_moe_id_indirect a");
let ids_buf = &ids.raw;
let tpe_buf = &tpe.raw;
let out_buf = out.expect_f32_mut("gemm_quant_moe_id_indirect out");
let args = &args_buf.raw;
let enc = ctx.compute_encoder();
match weight {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
} => {
crate::q4_k_moe_id_gemm::dispatch_gemm_q4k_moe_id_indirect_on_encoder(
&st().pipes.device,
enc,
blocks,
*byte_offset,
a_buf,
ids_buf,
tpe_buf,
out_buf,
args,
*num_experts,
*n_rows,
*n_cols,
ne11,
top_k,
max_per_expert,
batch,
);
Ok(())
}
MetalQuantStore::Q6KExperts {
blocks,
byte_offset,
num_experts,
n_rows,
n_cols,
} => {
crate::q6_k_moe_id_gemm::dispatch_gemm_q6k_moe_id_indirect_on_encoder(
&st().pipes.device,
enc,
blocks,
*byte_offset,
a_buf,
ids_buf,
tpe_buf,
out_buf,
args,
*num_experts,
*n_rows,
*n_cols,
ne11,
top_k,
max_per_expert,
batch,
);
Ok(())
}
_ => Err(FerrumError::model(
"gemm_quant_moe_id_indirect: weight must be Q4KExperts or Q6KExperts".to_string(),
)),
}
}
fn route_topk_softmax(
ctx: &mut Self::Context,
logits: &Self::Buffer,
out_ids: &mut Self::Buffer,
out_weights: &mut Self::Buffer,
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
) -> Result<()> {
let logits_buf = logits.expect_f32("route_topk_softmax logits");
let ids_buf = &out_ids.raw;
let weights_buf = out_weights.expect_f32_mut("route_topk_softmax out_weights");
let enc = ctx.compute_encoder();
crate::moe_router::dispatch_route_topk_softmax(
&st().pipes.device,
enc,
logits_buf,
ids_buf,
weights_buf,
batch,
num_experts,
top_k,
norm_topk_prob,
);
Ok(())
}
fn silu_mul_batched(
ctx: &mut Self::Context,
gate: &Self::Buffer,
up: &Self::Buffer,
out: &mut Self::Buffer,
total_pairs: usize,
ffn: usize,
) -> Result<()> {
let gate_buf = gate.expect_f32("silu_mul_batched gate");
let up_buf = up.expect_f32("silu_mul_batched up");
let out_buf = out.expect_f32_mut("silu_mul_batched out");
let enc = ctx.compute_encoder();
crate::moe_post_ops_batched::dispatch_silu_mul_batched(
&st().pipes.device,
enc,
gate_buf,
up_buf,
out_buf,
total_pairs,
ffn,
);
Ok(())
}
fn compute_ids_tpe_gpu(
ctx: &mut Self::Context,
selected_ids: &Self::Buffer,
tpe: &mut Self::Buffer,
ids: &mut Self::Buffer,
gate_up_args: &mut Self::Buffer,
down_args: &mut Self::Buffer,
batch: usize,
num_experts: usize,
top_k: usize,
m_gate_up: usize,
m_down: usize,
) -> Result<()> {
let sel_buf = &selected_ids.raw;
let tpe_buf = &tpe.raw;
let ids_buf = &ids.raw;
let gate_up_args_buf = &gate_up_args.raw;
let down_args_buf = &down_args.raw;
let enc = ctx.compute_encoder();
crate::moe_router::dispatch_compute_ids_tpe(
&st().pipes.device,
enc,
sel_buf,
tpe_buf,
ids_buf,
gate_up_args_buf,
down_args_buf,
batch,
num_experts,
top_k,
m_gate_up,
m_down,
);
Ok(())
}
fn weighted_sum_residual_stacked(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
residual: &mut Self::Buffer,
n_slots: usize,
hidden: usize,
) -> Result<()> {
let slots_buf = slots.expect_f32("weighted_sum_residual_stacked slots");
let weights_buf = weights.expect_f32("weighted_sum_residual_stacked weights");
let residual_buf = residual.expect_f32_mut("weighted_sum_residual_stacked residual");
let enc = ctx.compute_encoder();
crate::moe_post_ops::dispatch_weighted_sum_residual_stacked(
&st().pipes.device,
enc,
slots_buf,
weights_buf,
residual_buf,
n_slots,
hidden,
);
Ok(())
}
fn weighted_sum_residual_norm_stacked(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
residual: &mut Self::Buffer,
next_norm_w: &Self::Buffer,
normed_out: &mut Self::Buffer,
n_slots: usize,
hidden: usize,
eps: f32,
) -> Result<()> {
let slots_buf = slots.expect_f32("weighted_sum_residual_norm_stacked slots");
let weights_buf = weights.expect_f32("weighted_sum_residual_norm_stacked weights");
let residual_buf = residual.expect_f32_mut("weighted_sum_residual_norm_stacked residual");
let nw_buf = next_norm_w.expect_f32("weighted_sum_residual_norm_stacked next_norm_w");
let normed_buf = normed_out.expect_f32_mut("weighted_sum_residual_norm_stacked normed_out");
let enc = ctx.compute_encoder();
crate::moe_post_ops::dispatch_weighted_sum_residual_norm_stacked(
&st().pipes.device,
enc,
slots_buf,
weights_buf,
residual_buf,
nw_buf,
normed_buf,
n_slots,
hidden,
eps,
);
Ok(())
}
fn weighted_sum_batched(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
top_k: usize,
hidden: usize,
) -> Result<()> {
let slots_buf = slots.expect_f32("weighted_sum_batched slots");
let weights_buf = weights.expect_f32("weighted_sum_batched weights");
let out_buf = out.expect_f32_mut("weighted_sum_batched out");
let enc = ctx.compute_encoder();
crate::moe_post_ops_batched::dispatch_weighted_sum_batched(
&st().pipes.device,
enc,
slots_buf,
weights_buf,
out_buf,
batch,
top_k,
hidden,
);
Ok(())
}
fn weighted_sum_batched_offset(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
weights_offset: usize,
out: &mut Self::Buffer,
out_offset: usize,
batch: usize,
top_k: usize,
hidden: usize,
) -> Result<()> {
let slots_buf = slots.expect_f32("weighted_sum_batched_offset slots");
let weights_buf = weights.expect_f32("weighted_sum_batched_offset weights");
let out_buf = out.expect_f32_mut("weighted_sum_batched_offset out");
let enc = ctx.compute_encoder();
let weights_byte_offset = (weights_offset * std::mem::size_of::<f32>()) as u64;
let out_byte_offset = (out_offset * std::mem::size_of::<f32>()) as u64;
crate::moe_post_ops_batched::dispatch_weighted_sum_batched_offset(
&st().pipes.device,
enc,
slots_buf,
0,
weights_buf,
weights_byte_offset,
out_buf,
out_byte_offset,
batch,
top_k,
hidden,
);
Ok(())
}
fn silu_mul_stacked(
ctx: &mut Self::Context,
gate: &Self::Buffer,
up: &Self::Buffer,
out: &mut Self::Buffer,
n_slots: usize,
ffn: usize,
) -> Result<()> {
let gate_buf = gate.expect_f32("silu_mul_stacked gate");
let up_buf = up.expect_f32("silu_mul_stacked up");
let out_buf = out.expect_f32_mut("silu_mul_stacked out");
let enc = ctx.compute_encoder();
crate::moe_post_ops::dispatch_silu_mul_stacked(
&st().pipes.device,
enc,
gate_buf,
up_buf,
out_buf,
n_slots,
ffn,
);
Ok(())
}
fn supports_fused_moe_gate_up_silu() -> bool {
true
}
fn gemv_quant_moe_id_gate_up_silu(
ctx: &mut Self::Context,
a: &Self::Buffer,
gate_w: &Self::QuantStore,
up_w: &Self::QuantStore,
ids: &Self::Buffer,
silu_out: &mut Self::Buffer,
n_selected: usize,
) -> Result<()> {
let (gate_blocks, gate_byte_offset, gate_n_rows, gate_n_cols) = match gate_w {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => (blocks, *byte_offset, *n_rows, *n_cols),
_ => {
return Err(FerrumError::model(
"gemv_quant_moe_id_gate_up_silu: gate_w must be Q4KExperts".to_string(),
));
}
};
let (up_blocks, up_byte_offset, up_n_rows, up_n_cols) = match up_w {
MetalQuantStore::Q4KExperts {
blocks,
byte_offset,
n_rows,
n_cols,
..
} => (blocks, *byte_offset, *n_rows, *n_cols),
_ => {
return Err(FerrumError::model(
"gemv_quant_moe_id_gate_up_silu: up_w must be Q4KExperts".to_string(),
));
}
};
if gate_n_rows != up_n_rows || gate_n_cols != up_n_cols {
return Err(FerrumError::model(format!(
"gemv_quant_moe_id_gate_up_silu: gate/up shape mismatch — \
gate=({gate_n_rows}, {gate_n_cols}) up=({up_n_rows}, {up_n_cols})"
)));
}
let a_buf = a.expect_f32("gemv_quant_moe_id_gate_up_silu a");
let ids_buf = &ids.raw;
let out_buf = silu_out.expect_f32_mut("gemv_quant_moe_id_gate_up_silu silu_out");
let enc = ctx.compute_encoder();
crate::q4_k_moe_id_gate_up_silu::dispatch_gemv_q4k_moe_id_gate_up_silu_on_encoder(
&st().pipes.device,
enc,
a_buf,
gate_blocks,
gate_byte_offset,
up_blocks,
up_byte_offset,
ids_buf,
out_buf,
gate_n_rows,
gate_n_cols,
n_selected,
);
Ok(())
}
fn weighted_sum_stacked(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
out: &mut Self::Buffer,
n_slots: usize,
hidden: usize,
) -> Result<()> {
let slots_buf = slots.expect_f32("weighted_sum_stacked slots");
let weights_buf = weights.expect_f32("weighted_sum_stacked weights");
let out_buf = out.expect_f32_mut("weighted_sum_stacked out");
let enc = ctx.compute_encoder();
crate::moe_post_ops::dispatch_weighted_sum_stacked(
&st().pipes.device,
enc,
slots_buf,
weights_buf,
out_buf,
n_slots,
hidden,
);
Ok(())
}
fn write_i32_into(buf: &mut Self::Buffer, data: &[i32]) {
let dst = buf.raw.contents() as *mut i32;
let n = data.len().min(buf.n);
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
}
}
fn write_f32_into(buf: &mut Self::Buffer, data: &[f32]) {
let dst = buf.raw.contents() as *mut f32;
let n = data.len().min(buf.n);
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), dst, n);
}
}
fn load_quant_fused(
parts: &[(super::GgufQuantType, &[u8], usize)],
n_cols: usize,
) -> Result<Self::QuantStore> {
let mut sub_stores: Vec<MetalQuantStore> = Vec::with_capacity(parts.len());
let mut total_rows = 0;
for (kind, bytes, n_rows) in parts {
let store = Self::load_quant(*kind, bytes, *n_rows, n_cols)?;
if matches!(store, MetalQuantStore::Fused { .. }) {
return Err(FerrumError::model(
"Metal load_quant_fused: nested Fused not supported".to_string(),
));
}
total_rows += n_rows;
sub_stores.push(store);
}
Ok(MetalQuantStore::Fused {
parts: sub_stores,
total_rows,
n_cols,
})
}
fn gemm_quant(
ctx: &mut Self::Context,
a: &Self::Buffer,
weight: &Self::QuantStore,
out: &mut Self::Buffer,
m: usize,
) -> Result<()> {
if let MetalQuantStore::Fused {
parts,
total_rows,
n_cols,
} = weight
{
if m != 1 {
let a_buf = a.expect_f32("gemm_quant a (fused)");
let out_buf = out.expect_f32_mut("gemm_quant out (fused)");
let enc = ctx.compute_encoder();
let mut col_off = 0usize;
for part in parts {
let part_rows = part.n_rows();
dispatch_part_gemm(
enc,
a_buf,
part,
out_buf,
col_off,
m,
part_rows,
*total_rows,
*n_cols,
)?;
col_off += part_rows;
}
return Ok(());
}
let a_buf = a.expect_f32("gemm_quant a (fused m=1)");
let out_buf = out.expect_f32_mut("gemm_quant out (fused m=1)");
let enc = ctx.compute_encoder();
let mut row_off_elems = 0usize;
for part in parts {
let part_rows = part.n_rows();
let c_off = (row_off_elems * 4) as u64;
dispatch_part_gemv_offset(enc, a_buf, 0, part, out_buf, c_off, *n_cols)?;
row_off_elems += part_rows;
}
return Ok(());
}
let (blocks, blocks_off, n_rows, n_cols, n_blocks, is_q6k) = match weight {
MetalQuantStore::Q4K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
} => (blocks, *byte_offset, *n_rows, *n_cols, *n_blocks, false),
MetalQuantStore::Q6K {
blocks,
byte_offset,
n_rows,
n_cols,
n_blocks,
} => (blocks, *byte_offset, *n_rows, *n_cols, *n_blocks, true),
MetalQuantStore::Fused { .. } => unreachable!("handled above"),
MetalQuantStore::Q4KExperts { .. } | MetalQuantStore::Q6KExperts { .. } => {
return Err(FerrumError::model(
"gemm_quant: ExpertsStacked must be dispatched via gemv_moe_id".to_string(),
));
}
};
let _t0 = if debug_per_call_flush() {
Some(std::time::Instant::now())
} else {
None
};
let a_buf = a.expect_f32("gemm_quant a");
let out_buf = out.expect_f32_mut("gemm_quant out");
if m == 1 {
let enc = ctx.compute_encoder();
if is_q6k {
if n_rows % 4 != 0 {
return Err(FerrumError::model(format!(
"gemm_quant Q6K: n_rows={n_rows} not divisible by 4"
)));
}
crate::q6_k_gemv::dispatch_gemv_q6k_v2_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
} else if n_rows % 4 == 0 {
crate::q4_k_gemv_v2::dispatch_gemv_q4k_v2_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
} else {
crate::q4_k_gemv::dispatch_gemv_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
n_rows,
n_cols,
);
}
} else if is_q6k {
let enc = ctx.compute_encoder();
crate::q6_k_gemm::dispatch_gemm_q6k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
m,
n_rows,
n_cols,
);
} else {
let enc = ctx.compute_encoder();
crate::q4_k_gemm::dispatch_gemm_q4k_on_encoder(
&st().pipes.device,
enc,
a_buf,
blocks,
blocks_off,
out_buf,
m,
n_rows,
n_cols,
);
let _ = n_blocks; }
if let Some(t0) = _t0 {
ctx.flush();
let elapsed_us = t0.elapsed().as_micros();
QUANT_GEMM_TIME_US.fetch_add(elapsed_us as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_M.store(m as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_N.store(n_rows as u64, std::sync::atomic::Ordering::Relaxed);
QUANT_GEMM_LAST_K.store(n_cols as u64, std::sync::atomic::Ordering::Relaxed);
if QUANT_GEMM_CALLS.load(std::sync::atomic::Ordering::Relaxed) <= 16 {
eprintln!(
"[gemm_quant] m={} n={} k={} took {} us",
m, n_rows, n_cols, elapsed_us
);
}
}
Ok(())
}
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
) {
let a_buf = a.expect_f32("gemm a");
let out_buf = out.expect_f32_mut("gemm out");
let enc = ctx.compute_encoder();
match b.dtype {
Dtype::F16 => {
if m == 1 {
st().pipes.gemv_enc_f16w(enc, a_buf, &b.raw, out_buf, n, k);
} else {
st().pipes
.gemm_v2_f16w(enc, a_buf, &b.raw, out_buf, m, n, k);
}
}
Dtype::F32 => {
if m == 1 {
st().pipes.gemv_enc(enc, a_buf, &b.raw, out_buf, n, k);
} else {
st().pipes.gemm_v2(enc, a_buf, &b.raw, out_buf, m, n, k);
}
}
}
}
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let x = x.expect_f32("rms_norm x");
let w = w.expect_f32("rms_norm w");
let out = out.expect_f32_mut("rms_norm out");
let enc = ctx.compute_encoder();
st().pipes.rms_norm_enc(enc, x, w, out, tokens, dim, eps);
}
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let residual = residual.expect_f32_mut("fused_add_rms_norm residual");
let x = x.expect_f32("fused_add_rms_norm x");
let w = w.expect_f32("fused_add_rms_norm w");
let out = out.expect_f32_mut("fused_add_rms_norm out");
let enc = ctx.compute_encoder();
st().pipes.fused_residual_norm_enc(
enc, residual, x, None, w, residual, out, tokens, dim, eps, 0,
);
}
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
) {
let p = AttentionParams {
batch,
num_heads: cfg.num_heads,
num_kv_heads: cfg.num_kv_heads,
q_len,
kv_len,
head_dim: cfg.head_dim,
causal: cfg.causal,
pos_offset,
sliding_window: cfg.sliding_window,
};
let q = q.expect_f32("flash_attention q");
let k = k.expect_f32("flash_attention k");
let v = v.expect_f32("flash_attention v");
let out = out.expect_f32_mut("flash_attention out");
ctx.compute_encoder_end();
let cmd = ctx.cmd();
st().pipes
.flash_attn_v2(cmd, q, k, v, out, &p, cfg.kv_seq_stride);
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
) {
let src = src.expect_f32("copy_slice src");
let dst = dst.expect_f32_mut("copy_slice dst");
ctx.compute_encoder_end();
let cmd = ctx.cmd();
let blit = cmd.new_blit_command_encoder();
blit.copy_from_buffer(
src,
(src_offset * 4) as u64,
dst,
(dst_offset * 4) as u64,
(len * 4) as u64,
);
blit.end_encoding();
}
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
) {
let out = out.expect_f32_mut("embedding_lookup out");
ctx.flush();
unsafe {
let o = std::slice::from_raw_parts_mut(out.contents() as *mut f32, ids.len() * dim);
match table.dtype {
Dtype::F32 => {
let t = std::slice::from_raw_parts(
table.raw.contents() as *const f32,
table.raw.length() as usize / 4,
);
for (i, &id) in ids.iter().enumerate() {
let s = id as usize * dim;
o[i * dim..(i + 1) * dim].copy_from_slice(&t[s..s + dim]);
}
}
Dtype::F16 => {
let t = std::slice::from_raw_parts(
table.raw.contents() as *const f16,
table.raw.length() as usize / 2,
);
for (i, &id) in ids.iter().enumerate() {
let s = id as usize * dim;
for j in 0..dim {
o[i * dim + j] = t[s + j].to_f32();
}
}
}
}
}
}
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
) {
let qkv = qkv.expect_f32("split_qkv qkv");
let q = q.expect_f32_mut("split_qkv q");
let k = k.expect_f32_mut("split_qkv k");
let v = v.expect_f32_mut("split_qkv v");
let enc = ctx.compute_encoder();
st().pipes
.split_qkv_enc(enc, qkv, q, k, v, tokens, q_dim, kv_dim);
}
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gu: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
) {
let gu = gu.expect_f32("fused_silu_mul_split gate_up");
let out = out.expect_f32_mut("fused_silu_mul_split out");
let enc = ctx.compute_encoder();
st().pipes.silu_mul_split_enc(enc, gu, out, tokens, im);
}
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
) {
let input = input.expect_f32("qk_norm_rope input");
let norm_w = norm_w.expect_f32("qk_norm_rope norm_w");
let cos = cos.expect_f32("qk_norm_rope cos");
let sin = sin.expect_f32("qk_norm_rope sin");
let output = output.expect_f32_mut("qk_norm_rope output");
let enc = ctx.compute_encoder();
st().pipes.qk_norm_rope(
enc, input, norm_w, cos, sin, output, tokens, heads, head_dim, pos_offset, eps, mode,
);
}
fn split_qkv_norm_rope(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
k_out: &mut Self::Buffer,
v_out: &mut Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
) -> Result<()> {
let qkv = qkv.expect_f32("split_qkv_norm_rope qkv");
let q_norm_w = q_norm_w.expect_f32("split_qkv_norm_rope q_norm_w");
let k_norm_w = k_norm_w.expect_f32("split_qkv_norm_rope k_norm_w");
let cos = cos.expect_f32("split_qkv_norm_rope cos");
let sin = sin.expect_f32("split_qkv_norm_rope sin");
let q_out = q_out.expect_f32_mut("split_qkv_norm_rope q_out");
let k_out = k_out.expect_f32_mut("split_qkv_norm_rope k_out");
let v_out = v_out.expect_f32_mut("split_qkv_norm_rope v_out");
let enc = ctx.compute_encoder();
st().pipes.split_qkv_norm_rope(
enc, qkv, q_norm_w, k_norm_w, cos, sin, q_out, k_out, v_out, tokens, q_heads, kv_heads,
head_dim, pos_offset, eps, qk_mode,
);
Ok(())
}
fn split_qkv_norm_rope_into_cache(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
cache_capacity: usize,
) -> Result<()> {
let qkv = qkv.expect_f32("split_qkv_norm_rope_kvc qkv");
let q_norm_w = q_norm_w.expect_f32("split_qkv_norm_rope_kvc q_norm_w");
let k_norm_w = k_norm_w.expect_f32("split_qkv_norm_rope_kvc k_norm_w");
let cos = cos.expect_f32("split_qkv_norm_rope_kvc cos");
let sin = sin.expect_f32("split_qkv_norm_rope_kvc sin");
let q_out = q_out.expect_f32_mut("split_qkv_norm_rope_kvc q_out");
let cache_k = cache_k.expect_f32_mut("split_qkv_norm_rope_kvc cache_k");
let cache_v = cache_v.expect_f32_mut("split_qkv_norm_rope_kvc cache_v");
let enc = ctx.compute_encoder();
st().pipes.split_qkv_norm_rope_into_cache(
enc,
qkv,
q_norm_w,
k_norm_w,
cos,
sin,
q_out,
cache_k,
cache_v,
tokens,
q_heads,
kv_heads,
head_dim,
pos_offset,
eps,
qk_mode,
cache_len,
cache_capacity,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
qkv_byte_offset: u64,
q_norm_w: &Self::Buffer,
k_norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
q_out: &mut Self::Buffer,
q_out_byte_offset: u64,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
block_table: &Self::Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
) -> Result<()> {
let qkv = qkv.expect_f32("split_qkv_norm_rope_paged qkv");
let q_norm_w = q_norm_w.expect_f32("split_qkv_norm_rope_paged q_norm_w");
let k_norm_w = k_norm_w.expect_f32("split_qkv_norm_rope_paged k_norm_w");
let cos = cos.expect_f32("split_qkv_norm_rope_paged cos");
let sin = sin.expect_f32("split_qkv_norm_rope_paged sin");
let q_out = q_out.expect_f32_mut("split_qkv_norm_rope_paged q_out");
let cache_k = cache_k.expect_f32_mut("split_qkv_norm_rope_paged cache_k");
let cache_v = cache_v.expect_f32_mut("split_qkv_norm_rope_paged cache_v");
let bt = &block_table.raw;
let enc = ctx.compute_encoder();
st().pipes.split_qkv_norm_rope_into_paged_cache(
enc,
qkv,
qkv_byte_offset,
q_norm_w,
k_norm_w,
cos,
sin,
q_out,
q_out_byte_offset,
cache_k,
cache_v,
bt,
tokens,
q_heads,
kv_heads,
head_dim,
pos_offset,
eps,
qk_mode,
cache_len,
block_size,
max_num_blocks_per_seq,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k_pool: &Self::Buffer,
v_pool: &Self::Buffer,
out: &mut Self::Buffer,
block_tables: &Self::Buffer,
context_lens: &Self::Buffer,
num_seqs: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
q_len: usize,
) -> Result<()> {
let q = q.expect_f32("paged_decode_attention q");
let k_pool = k_pool.expect_f32("paged_decode_attention k_pool");
let v_pool = v_pool.expect_f32("paged_decode_attention v_pool");
let out = out.expect_f32_mut("paged_decode_attention out");
let bt = &block_tables.raw;
let cl = &context_lens.raw;
let enc = ctx.compute_encoder();
let q_layout = if q_len == 1 {
ferrum_attention::metal::pipelines::PagedAttnQLayout::TokenMajor
} else {
ferrum_attention::metal::pipelines::PagedAttnQLayout::HeadMajor
};
st().pipes.paged_decode_attention_on_encoder(
enc,
q,
k_pool,
v_pool,
out,
bt,
cl,
&ferrum_attention::metal::pipelines::PagedAttnDispatchParams {
num_seqs,
num_heads,
num_kv_heads,
head_dim,
block_size,
max_num_blocks_per_seq,
q_len,
q_layout,
},
);
Ok(())
}
fn alloc_u32(n: usize) -> Self::Buffer {
let bytes = (n * std::mem::size_of::<u32>()) as u64;
let raw = st()
.pipes
.device
.new_buffer(bytes, MTLResourceOptions::StorageModeShared);
MetalBuf {
raw,
dtype: Dtype::F32, n,
}
}
fn write_u32(_ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[u32]) {
debug_assert!(data.len() <= dst.n, "write_u32: src too long");
unsafe {
let ptr = dst.raw.contents() as *mut u32;
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
}
}
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
) {
debug_assert!(cache_len + new_tokens <= cache_capacity);
let cache_k = cache_k.expect_f32_mut("kv_cache_append cache_k");
let cache_v = cache_v.expect_f32_mut("kv_cache_append cache_v");
let new_k_head_major = new_k_head_major.expect_f32("kv_cache_append new_k");
let new_v_head_major = new_v_head_major.expect_f32("kv_cache_append new_v");
let enc = ctx.compute_encoder();
st().pipes.kv_cache_append(
enc,
new_k_head_major,
cache_k,
nkv,
hd,
cache_len,
new_tokens,
cache_capacity,
);
st().pipes.kv_cache_append(
enc,
new_v_head_major,
cache_v,
nkv,
hd,
cache_len,
new_tokens,
cache_capacity,
);
}
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
) {
let src = src.expect_f32("transpose_head_to_token src");
let dst = dst.expect_f32_mut("transpose_head_to_token dst");
let enc = ctx.compute_encoder();
st().pipes.transpose_out(enc, src, dst, tokens, heads, dim);
}
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
) {
let data = data.expect_f32_mut("add_bias data");
let bias = bias.expect_f32("add_bias bias");
let enc = ctx.compute_encoder();
st().pipes.add_bias_enc(enc, data, bias, rows, cols);
}
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let x = x.expect_f32("layer_norm x");
let gamma = gamma.expect_f32("layer_norm gamma");
let beta = beta.expect_f32("layer_norm beta");
let out = out.expect_f32_mut("layer_norm out");
let enc = ctx.compute_encoder();
st().pipes
.layer_norm_enc(enc, x, gamma, beta, out, tokens, dim, eps);
}
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
let x = x.expect_f32("gelu x");
let out = out.expect_f32_mut("gelu out");
let enc = ctx.compute_encoder();
st().pipes.gelu_enc(enc, x, out, len);
}
fn add_inplace(ctx: &mut Self::Context, r: &mut Self::Buffer, x: &Self::Buffer, len: usize) {
let r = r.expect_f32_mut("add_inplace r");
let x = x.expect_f32("add_inplace x");
let enc = ctx.compute_encoder();
st().pipes.add_enc(enc, r, x, r, len);
}
fn scaled_add_inplace(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) {
let dst_buf = dst.expect_f32_mut("scaled_add_inplace dst");
let src_buf = src.expect_f32("scaled_add_inplace src");
let enc = ctx.compute_encoder();
st().pipes
.scaled_add_inplace_enc(enc, dst_buf, src_buf, scale, len);
}
fn alloc(len: usize) -> Self::Buffer {
MetalBuf {
raw: alloc_f32_raw(len),
dtype: Dtype::F32,
n: len,
}
}
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
match buf.dtype {
Dtype::F32 => {
let ptr = buf.raw.contents() as *const f32;
unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
}
Dtype::F16 => {
let ptr = buf.raw.contents() as *const f16;
let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
slice.iter().map(|h| h.to_f32()).collect()
}
}
}
fn from_slice(data: &[f32]) -> Self::Buffer {
MetalBuf {
raw: buffer_from_f32_slice(data),
dtype: Dtype::F32,
n: data.len(),
}
}
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
let n = raw.len() / src_dtype.bytes_per_elem();
let want_f16 = prefer_f16_weights() && n >= F16_MIN_ELEMS;
if !want_f16 {
let data = src_dtype.to_f32_vec(raw);
return MetalBuf {
raw: buffer_from_f32_slice(&data),
dtype: Dtype::F32,
n,
};
}
match src_dtype {
SrcDtype::F16 => MetalBuf {
raw: buffer_from_f16_bytes(raw),
dtype: Dtype::F16,
n,
},
SrcDtype::BF16 => {
let mut f16_bytes = vec![0u8; n * 2];
for i in 0..n {
let v = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
let h = f16::from_f32(v).to_le_bytes();
f16_bytes[i * 2] = h[0];
f16_bytes[i * 2 + 1] = h[1];
}
MetalBuf {
raw: buffer_from_f16_bytes(&f16_bytes),
dtype: Dtype::F16,
n,
}
}
SrcDtype::F32 => {
let data = src_dtype.to_f32_vec(raw);
MetalBuf {
raw: buffer_f16_from_f32(&data),
dtype: Dtype::F16,
n,
}
}
}
}
}