use metal::*;
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MetalAttentionRuntimeConfig {
force_legacy: bool,
allow_decode_widen: bool,
}
fn metal_attention_runtime_config() -> &'static MetalAttentionRuntimeConfig {
static CONFIG: OnceLock<MetalAttentionRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(|| {
let mut config = MetalAttentionRuntimeConfig {
force_legacy: false,
allow_decode_widen: true,
};
for (name, value) in std::env::vars() {
match name.as_str() {
"FERRUM_FA_LEGACY" => config.force_legacy = value == "1",
"FERRUM_FA_DECODE" => config.allow_decode_widen = value != "0",
_ => {}
}
}
config
})
}
pub struct MetalPipelines {
pub device: Device,
pub queue: CommandQueue,
pipelines: HashMap<&'static str, ComputePipelineState>,
}
impl MetalPipelines {
pub fn new(device: &Device) -> Self {
let queue = device.new_command_queue();
let opts = CompileOptions::new();
let fa_src = include_str!("shaders/flash_attn.metal");
let ops_src = include_str!("shaders/transformer_ops.metal");
let gemm_src = include_str!("shaders/gemm_f32.metal");
let gemm_f16w_src = include_str!("shaders/gemm_f16w.metal");
let nr_src = include_str!("shaders/norm_rope.metal");
let sm_src = include_str!("shaders/softmax.metal");
let fa_lib = device
.new_library_with_source(fa_src, &opts)
.expect("failed to compile flash_attn.metal");
let ops_lib = device
.new_library_with_source(ops_src, &opts)
.expect("failed to compile transformer_ops.metal");
let gemm_lib = device
.new_library_with_source(gemm_src, &opts)
.expect("failed to compile gemm_f32.metal");
let gemm_f16w_lib = device
.new_library_with_source(gemm_f16w_src, &opts)
.expect("failed to compile gemm_f16w.metal");
let nr_lib = device
.new_library_with_source(nr_src, &opts)
.expect("failed to compile norm_rope.metal");
let sm_lib = device
.new_library_with_source(sm_src, &opts)
.expect("failed to compile softmax.metal");
let mut pipelines = HashMap::new();
for (lib, names) in [
(
&fa_lib,
&[
"flash_attn_f32",
"flash_attn_q_tiled_f32",
"flash_attn_decode_f32",
"flash_attn_decode_paged_f32",
][..],
),
(
&ops_lib,
&[
"rms_norm_f32",
"silu_mul_f32",
"add_f32",
"scaled_add_inplace_f32",
"mul_scale_f32",
"fused_scale_add_f32",
"fused_residual_norm_f32",
"gemm_f32",
"argmax_f32",
"argmax_rows_f32",
"embedding_lookup_f32",
"split_qkv_f32",
"silu_mul_split_f32",
"gemv_f32",
"layer_norm_f32",
"gelu_f32",
"add_bias_f32",
][..],
),
(&gemm_lib, &["gemm_f32_v2"][..]),
(&gemm_f16w_lib, &["gemm_f32a_f16w_v2", "gemv_f32a_f16w"][..]),
(
&nr_lib,
&[
"qk_norm_rope_transpose_f32",
"transpose_out_f32",
"kv_cache_append_f32",
"split_qkv_norm_rope_f32",
"split_qkv_norm_rope_kvc_f32",
"split_qkv_norm_rope_paged_kvc_f32",
][..],
),
(
&sm_lib,
&["softmax_last_dim_f32", "softmax_last_dim_f32_out"][..],
),
] {
for name in names {
let func = lib
.get_function(name, None)
.unwrap_or_else(|e| panic!("kernel {name} not found: {e}"));
let pso = device
.new_compute_pipeline_state_with_function(&func)
.unwrap_or_else(|e| panic!("pipeline {name} failed: {e}"));
pipelines.insert(*name, pso);
}
}
MetalPipelines {
device: device.clone(),
queue,
pipelines,
}
}
pub fn pipeline(&self, name: &str) -> &ComputePipelineState {
self.pipelines
.get(name)
.unwrap_or_else(|| panic!("pipeline {name} not found"))
}
pub fn buffer_from_data(&self, data: &[f32]) -> Buffer {
self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
(data.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
)
}
pub fn buffer_empty(&self, num_floats: usize) -> Buffer {
self.device.new_buffer(
(num_floats * 4) as u64,
MTLResourceOptions::StorageModeShared,
)
}
pub fn read_buffer(buf: &Buffer, len: usize) -> Vec<f32> {
let ptr = buf.contents() as *const f32;
unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
}
pub fn read_buffer_u32(buf: &Buffer, len: usize) -> Vec<u32> {
let ptr = buf.contents() as *const u32;
unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
}
pub fn gemm(
&self,
_cmd: &CommandBufferRef,
a: &Buffer,
b: &Buffer,
c: &Buffer,
m: usize,
n: usize,
k: usize,
) {
extern "C" {
fn cblas_sgemm(
order: i32,
ta: i32,
tb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
}
unsafe {
cblas_sgemm(
101,
111,
112, m as i32,
n as i32,
k as i32,
1.0,
a.contents() as *const f32,
k as i32,
b.contents() as *const f32,
k as i32,
0.0,
c.contents() as *mut f32,
n as i32,
);
}
}
pub fn rms_norm_enc(
&self,
enc: &ComputeCommandEncoderRef,
input: &Buffer,
weight: &Buffer,
output: &Buffer,
rows: usize,
dim: usize,
eps: f32,
) {
#[repr(C)]
struct P {
dim: i32,
eps: f32,
}
let params = P {
dim: dim as i32,
eps,
};
enc.set_compute_pipeline_state(self.pipeline("rms_norm_f32"));
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 8, ¶ms as *const _ as *const c_void as *const _);
enc.set_threadgroup_memory_length(0, 128);
let grid = MTLSize::new(rows as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn silu_mul_enc(
&self,
enc: &ComputeCommandEncoderRef,
gate: &Buffer,
up: &Buffer,
output: &Buffer,
n: usize,
) {
#[repr(C)]
struct P {
n: i32,
}
let params = P { n: n as i32 };
enc.set_compute_pipeline_state(self.pipeline("silu_mul_f32"));
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn add_enc(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b: &Buffer,
output: &Buffer,
n: usize,
) {
#[repr(C)]
struct P {
n: i32,
}
let params = P { n: n as i32 };
enc.set_compute_pipeline_state(self.pipeline("add_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn scaled_add_inplace_enc(
&self,
enc: &ComputeCommandEncoderRef,
dst: &Buffer,
src: &Buffer,
scale: f32,
n: usize,
) {
#[repr(C)]
struct P {
n: i32,
scale: f32,
}
let params = P { n: n as i32, scale };
enc.set_compute_pipeline_state(self.pipeline("scaled_add_inplace_f32"));
enc.set_buffer(0, Some(dst), 0);
enc.set_buffer(1, Some(src), 0);
enc.set_bytes(2, 8, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn rms_norm(
&self,
cmd: &CommandBufferRef,
input: &Buffer,
weight: &Buffer,
output: &Buffer,
rows: usize,
dim: usize,
eps: f32,
) {
#[repr(C)]
struct P {
dim: i32,
eps: f32,
}
let params = P {
dim: dim as i32,
eps,
};
let enc = cmd.new_compute_command_encoder();
enc.set_compute_pipeline_state(self.pipeline("rms_norm_f32"));
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 8, ¶ms as *const _ as *const c_void as *const _);
enc.set_threadgroup_memory_length(0, 128); let grid = MTLSize::new(rows as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1); enc.dispatch_thread_groups(grid, tg);
enc.end_encoding();
}
pub fn silu_mul(
&self,
cmd: &CommandBufferRef,
gate: &Buffer,
up: &Buffer,
output: &Buffer,
n: usize,
) {
#[repr(C)]
struct P {
n: i32,
}
let params = P { n: n as i32 };
let enc = cmd.new_compute_command_encoder();
enc.set_compute_pipeline_state(self.pipeline("silu_mul_f32"));
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
enc.end_encoding();
}
pub fn add(&self, cmd: &CommandBufferRef, a: &Buffer, b: &Buffer, output: &Buffer, n: usize) {
#[repr(C)]
struct P {
n: i32,
}
let params = P { n: n as i32 };
let enc = cmd.new_compute_command_encoder();
enc.set_compute_pipeline_state(self.pipeline("add_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
enc.end_encoding();
}
pub fn gemv_enc(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b: &Buffer,
c: &Buffer,
n: usize,
k: usize,
) {
#[repr(C)]
struct P {
m: i32,
n: i32,
k: i32,
}
let params = P {
m: 1,
n: n as i32,
k: k as i32,
};
enc.set_compute_pipeline_state(self.pipeline("gemv_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(c), 0);
enc.set_bytes(3, 12, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn gemm_v2(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b: &Buffer,
c: &Buffer,
m: usize,
n: usize,
k: usize,
) {
#[repr(C)]
struct P {
m: i32,
n: i32,
k: i32,
}
let params = P {
m: m as i32,
n: n as i32,
k: k as i32,
};
enc.set_compute_pipeline_state(self.pipeline("gemm_f32_v2"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(c), 0);
enc.set_bytes(3, 12, ¶ms as *const _ as *const c_void as *const _);
enc.set_threadgroup_memory_length(0, 12288);
let grid = MTLSize::new(n.div_ceil(32) as u64, m.div_ceil(64) as u64, 1);
let tg = MTLSize::new(128, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn gemm_v2_f16w(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b_f16: &Buffer,
c: &Buffer,
m: usize,
n: usize,
k: usize,
) {
#[repr(C)]
struct P {
m: i32,
n: i32,
k: i32,
}
let params = P {
m: m as i32,
n: n as i32,
k: k as i32,
};
enc.set_compute_pipeline_state(self.pipeline("gemm_f32a_f16w_v2"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b_f16), 0);
enc.set_buffer(2, Some(c), 0);
enc.set_bytes(3, 12, ¶ms as *const _ as *const c_void as *const _);
enc.set_threadgroup_memory_length(0, 12288);
let grid = MTLSize::new(n.div_ceil(32) as u64, m.div_ceil(64) as u64, 1);
let tg = MTLSize::new(128, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn gemv_enc_f16w(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b_f16: &Buffer,
c: &Buffer,
n: usize,
k: usize,
) {
#[repr(C)]
struct P {
m: i32,
n: i32,
k: i32,
}
let params = P {
m: 1,
n: n as i32,
k: k as i32,
};
enc.set_compute_pipeline_state(self.pipeline("gemv_f32a_f16w"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b_f16), 0);
enc.set_buffer(2, Some(c), 0);
enc.set_bytes(3, 12, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn qk_norm_rope(
&self,
enc: &ComputeCommandEncoderRef,
input: &Buffer,
weight: &Buffer,
cos: &Buffer,
sin: &Buffer,
output: &Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
norm_mode: i32,
)
{
#[repr(C)]
struct P {
tokens: i32,
heads: i32,
head_dim: i32,
half_dim: i32,
pos_offset: i32,
eps: f32,
apply_norm: i32,
}
let params = P {
tokens: tokens as i32,
heads: heads as i32,
head_dim: head_dim as i32,
half_dim: (head_dim / 2) as i32,
pos_offset: pos_offset as i32,
eps,
apply_norm: norm_mode,
};
enc.set_compute_pipeline_state(self.pipeline("qk_norm_rope_transpose_f32"));
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight), 0);
enc.set_buffer(2, Some(cos), 0);
enc.set_buffer(3, Some(sin), 0);
enc.set_buffer(4, Some(output), 0);
enc.set_bytes(
5,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(tokens as u64, heads as u64, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[allow(clippy::too_many_arguments)]
pub fn split_qkv_norm_rope(
&self,
enc: &ComputeCommandEncoderRef,
qkv: &Buffer,
q_norm_w: &Buffer,
k_norm_w: &Buffer,
cos: &Buffer,
sin: &Buffer,
q_out: &Buffer,
k_out: &Buffer,
v_out: &Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
) {
#[repr(C)]
struct P {
tokens: i32,
q_heads: i32,
kv_heads: i32,
head_dim: i32,
half_dim: i32,
pos_offset: i32,
eps: f32,
qk_mode: i32,
}
let params = P {
tokens: tokens as i32,
q_heads: q_heads as i32,
kv_heads: kv_heads as i32,
head_dim: head_dim as i32,
half_dim: (head_dim / 2) as i32,
pos_offset: pos_offset as i32,
eps,
qk_mode,
};
enc.set_compute_pipeline_state(self.pipeline("split_qkv_norm_rope_f32"));
enc.set_buffer(0, Some(qkv), 0);
enc.set_buffer(1, Some(q_norm_w), 0);
enc.set_buffer(2, Some(k_norm_w), 0);
enc.set_buffer(3, Some(cos), 0);
enc.set_buffer(4, Some(sin), 0);
enc.set_buffer(5, Some(q_out), 0);
enc.set_buffer(6, Some(k_out), 0);
enc.set_buffer(7, Some(v_out), 0);
enc.set_bytes(
8,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(tokens as u64, (q_heads + 2 * kv_heads) as u64, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[allow(clippy::too_many_arguments)]
pub fn split_qkv_norm_rope_into_cache(
&self,
enc: &ComputeCommandEncoderRef,
qkv: &Buffer,
q_norm_w: &Buffer,
k_norm_w: &Buffer,
cos: &Buffer,
sin: &Buffer,
q_out: &Buffer,
cache_k: &Buffer,
cache_v: &Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
cache_capacity: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
q_heads: i32,
kv_heads: i32,
head_dim: i32,
half_dim: i32,
pos_offset: i32,
eps: f32,
qk_mode: i32,
cache_len: i32,
cache_capacity: i32,
}
let params = P {
tokens: tokens as i32,
q_heads: q_heads as i32,
kv_heads: kv_heads as i32,
head_dim: head_dim as i32,
half_dim: (head_dim / 2) as i32,
pos_offset: pos_offset as i32,
eps,
qk_mode,
cache_len: cache_len as i32,
cache_capacity: cache_capacity as i32,
};
enc.set_compute_pipeline_state(self.pipeline("split_qkv_norm_rope_kvc_f32"));
enc.set_buffer(0, Some(qkv), 0);
enc.set_buffer(1, Some(q_norm_w), 0);
enc.set_buffer(2, Some(k_norm_w), 0);
enc.set_buffer(3, Some(cos), 0);
enc.set_buffer(4, Some(sin), 0);
enc.set_buffer(5, Some(q_out), 0);
enc.set_buffer(6, Some(cache_k), 0);
enc.set_buffer(7, Some(cache_v), 0);
enc.set_bytes(
8,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(tokens as u64, (q_heads + 2 * kv_heads) as u64, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub fn split_qkv_norm_rope_into_paged_cache(
&self,
enc: &ComputeCommandEncoderRef,
qkv: &Buffer,
qkv_byte_offset: u64,
q_norm_w: &Buffer,
k_norm_w: &Buffer,
cos: &Buffer,
sin: &Buffer,
q_out: &Buffer,
q_out_byte_offset: u64,
cache_k: &Buffer,
cache_v: &Buffer,
block_table: &Buffer,
tokens: usize,
q_heads: usize,
kv_heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
qk_mode: i32,
cache_len: usize,
block_size: usize,
max_num_blocks_per_seq: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
q_heads: i32,
kv_heads: i32,
head_dim: i32,
half_dim: i32,
pos_offset: i32,
eps: f32,
qk_mode: i32,
cache_len: i32,
block_size: i32,
max_num_blocks_per_seq: i32,
}
let params = P {
tokens: tokens as i32,
q_heads: q_heads as i32,
kv_heads: kv_heads as i32,
head_dim: head_dim as i32,
half_dim: (head_dim / 2) as i32,
pos_offset: pos_offset as i32,
eps,
qk_mode,
cache_len: cache_len as i32,
block_size: block_size as i32,
max_num_blocks_per_seq: max_num_blocks_per_seq as i32,
};
enc.set_compute_pipeline_state(self.pipeline("split_qkv_norm_rope_paged_kvc_f32"));
enc.set_buffer(0, Some(qkv), qkv_byte_offset);
enc.set_buffer(1, Some(q_norm_w), 0);
enc.set_buffer(2, Some(k_norm_w), 0);
enc.set_buffer(3, Some(cos), 0);
enc.set_buffer(4, Some(sin), 0);
enc.set_buffer(5, Some(q_out), q_out_byte_offset);
enc.set_buffer(6, Some(cache_k), 0);
enc.set_buffer(7, Some(cache_v), 0);
enc.set_buffer(8, Some(block_table), 0);
enc.set_bytes(
9,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(tokens as u64, (q_heads + 2 * kv_heads) as u64, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn transpose_out(
&self,
enc: &ComputeCommandEncoderRef,
input: &Buffer,
output: &Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
heads: i32,
head_dim: i32,
}
let params = P {
tokens: tokens as i32,
heads: heads as i32,
head_dim: head_dim as i32,
};
enc.set_compute_pipeline_state(self.pipeline("transpose_out_f32"));
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(output), 0);
enc.set_bytes(2, 12, ¶ms as *const _ as *const c_void as *const _);
let n = tokens * heads * head_dim;
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn kv_cache_append(
&self,
enc: &ComputeCommandEncoderRef,
new_data: &Buffer,
cache: &Buffer,
heads: usize,
head_dim: usize,
old_len: usize,
new_len: usize,
max_len: usize,
) {
#[repr(C)]
struct P {
heads: i32,
head_dim: i32,
old_len: i32,
new_len: i32,
max_len: i32,
}
let params = P {
heads: heads as i32,
head_dim: head_dim as i32,
old_len: old_len as i32,
new_len: new_len as i32,
max_len: max_len as i32,
};
enc.set_compute_pipeline_state(self.pipeline("kv_cache_append_f32"));
enc.set_buffer(0, Some(new_data), 0);
enc.set_buffer(1, Some(cache), 0);
enc.set_bytes(2, 20, ¶ms as *const _ as *const c_void as *const _);
let n = heads * new_len * head_dim;
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn softmax_last_dim_inplace(
&self,
enc: &ComputeCommandEncoderRef,
data: &Buffer,
rows: usize,
cols: usize,
) {
#[repr(C)]
struct P {
rows: i32,
cols: i32,
}
let params = P {
rows: rows as i32,
cols: cols as i32,
};
enc.set_compute_pipeline_state(self.pipeline("softmax_last_dim_f32"));
enc.set_buffer(0, Some(data), 0);
enc.set_bytes(1, 8, ¶ms as *const _ as *const c_void as *const _);
enc.set_threadgroup_memory_length(0, 128);
let grid = MTLSize::new(rows as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn mul_scale_enc(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
scale: &Buffer,
output: &Buffer,
n: usize,
scale_len: usize,
) {
#[repr(C)]
struct P {
n: i32,
scale_len: i32,
}
let params = P {
n: n as i32,
scale_len: scale_len as i32,
};
enc.set_compute_pipeline_state(self.pipeline("mul_scale_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(scale), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 8, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn fused_scale_add_enc(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b: &Buffer,
scale: &Buffer,
output: &Buffer,
n: usize,
scale_len: usize,
) {
#[repr(C)]
struct P {
n: i32,
scale_len: i32,
}
let params = P {
n: n as i32,
scale_len: scale_len as i32,
};
enc.set_compute_pipeline_state(self.pipeline("fused_scale_add_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(scale), 0);
enc.set_buffer(3, Some(output), 0);
enc.set_bytes(4, 8, ¶ms as *const _ as *const c_void as *const _);
let grid = MTLSize::new(n.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn fused_residual_norm_enc(
&self,
enc: &ComputeCommandEncoderRef,
a: &Buffer,
b: &Buffer,
scale: Option<&Buffer>,
weight: &Buffer,
out_res: &Buffer,
out_norm: &Buffer,
tokens: usize,
dim: usize,
eps: f32,
scale_len: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
dim: i32,
eps: f32,
has_scale: i32,
scale_len: i32,
}
let dummy_buf = self
.device
.new_buffer(4, MTLResourceOptions::StorageModeShared);
let params = P {
tokens: tokens as i32,
dim: dim as i32,
eps,
has_scale: if scale.is_some() { 1 } else { 0 },
scale_len: scale_len as i32,
};
enc.set_compute_pipeline_state(self.pipeline("fused_residual_norm_f32"));
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(scale.unwrap_or(&dummy_buf)), 0);
enc.set_buffer(3, Some(weight), 0);
enc.set_buffer(4, Some(out_res), 0);
enc.set_buffer(5, Some(out_norm), 0);
enc.set_bytes(
6,
std::mem::size_of::<P>() as u64,
¶ms as *const _ as *const c_void as *const _,
);
enc.set_threadgroup_memory_length(0, 128);
let grid = MTLSize::new(tokens as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn flash_attn_v2(
&self,
cmd: &CommandBufferRef,
q: &Buffer,
k: &Buffer,
v: &Buffer,
o: &Buffer,
params: &crate::attention::AttentionParams,
kv_seq_stride: usize,
) {
#[repr(C)]
struct P {
batch: i32,
num_heads: i32,
num_kv_heads: i32,
q_len: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
pos_offset: i32,
kv_seq_stride: i32,
sliding_window: i32,
}
let p = P {
batch: params.batch as i32,
num_heads: params.num_heads as i32,
num_kv_heads: params.num_kv_heads as i32,
q_len: params.q_len as i32,
kv_len: params.kv_len as i32,
head_dim: params.head_dim as i32,
scale: 1.0 / (params.head_dim as f32).sqrt(),
causal: if params.causal { 1 } else { 0 },
pos_offset: params.pos_offset as i32,
kv_seq_stride: kv_seq_stride as i32,
sliding_window: params.sliding_window as i32,
};
const Q_TILE_R: usize = 8;
let runtime_config = metal_attention_runtime_config();
let force_legacy = runtime_config.force_legacy;
let allow_decode_widen = runtime_config.allow_decode_widen;
let use_q_tiled = !force_legacy
&& params.head_dim == 128
&& params.sliding_window == 0
&& params.q_len >= Q_TILE_R;
let use_decode_widen = !force_legacy
&& allow_decode_widen
&& params.head_dim == 128
&& params.sliding_window == 0
&& params.q_len == 1;
let enc = cmd.new_compute_command_encoder();
if use_q_tiled {
enc.set_compute_pipeline_state(self.pipeline("flash_attn_q_tiled_f32"));
enc.set_buffer(0, Some(q), 0);
enc.set_buffer(1, Some(k), 0);
enc.set_buffer(2, Some(v), 0);
enc.set_buffer(3, Some(o), 0);
enc.set_bytes(
4,
std::mem::size_of::<P>() as u64,
&p as *const _ as *const c_void as *const _,
);
let q_tiles = params.q_len.div_ceil(Q_TILE_R) as u64;
let grid = MTLSize::new(q_tiles, params.num_heads as u64, params.batch as u64);
let tg = MTLSize::new(128, 1, 1); enc.dispatch_thread_groups(grid, tg);
} else if use_decode_widen {
enc.set_compute_pipeline_state(self.pipeline("flash_attn_decode_f32"));
enc.set_buffer(0, Some(q), 0);
enc.set_buffer(1, Some(k), 0);
enc.set_buffer(2, Some(v), 0);
enc.set_buffer(3, Some(o), 0);
enc.set_bytes(
4,
std::mem::size_of::<P>() as u64,
&p as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(1, params.num_heads as u64, params.batch as u64);
let tg = MTLSize::new(32, 32, 1);
enc.dispatch_thread_groups(grid, tg);
} else {
enc.set_compute_pipeline_state(self.pipeline("flash_attn_f32"));
enc.set_buffer(0, Some(q), 0);
enc.set_buffer(1, Some(k), 0);
enc.set_buffer(2, Some(v), 0);
enc.set_buffer(3, Some(o), 0);
enc.set_bytes(
4,
std::mem::size_of::<P>() as u64,
&p as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(
params.q_len as u64,
params.num_heads as u64,
params.batch as u64,
);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
enc.end_encoding();
}
pub fn flash_attn(
&self,
cmd: &CommandBufferRef,
q: &Buffer,
k: &Buffer,
v: &Buffer,
o: &Buffer,
params: &crate::attention::AttentionParams,
) {
self.flash_attn_v2(cmd, q, k, v, o, params, 0);
}
pub fn split_qkv_enc(
&self,
enc: &ComputeCommandEncoderRef,
qkv: &Buffer,
q: &Buffer,
k: &Buffer,
v: &Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
q_dim: i32,
kv_dim: i32,
}
let p = P {
tokens: tokens as i32,
q_dim: q_dim as i32,
kv_dim: kv_dim as i32,
};
enc.set_compute_pipeline_state(self.pipeline("split_qkv_f32"));
enc.set_buffer(0, Some(qkv), 0);
enc.set_buffer(1, Some(q), 0);
enc.set_buffer(2, Some(k), 0);
enc.set_buffer(3, Some(v), 0);
enc.set_bytes(4, 12, &p as *const _ as *const c_void as *const _);
let total = tokens * (q_dim + 2 * kv_dim);
let grid = MTLSize::new(total.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn layer_norm_enc(
&self,
enc: &ComputeCommandEncoderRef,
x: &Buffer,
gamma: &Buffer,
beta: &Buffer,
out: &Buffer,
tokens: usize,
dim: usize,
eps: f32,
) {
#[repr(C)]
struct P {
dim: i32,
eps: f32,
}
let p = P {
dim: dim as i32,
eps,
};
enc.set_compute_pipeline_state(self.pipeline("layer_norm_f32"));
enc.set_buffer(0, Some(x), 0);
enc.set_buffer(1, Some(gamma), 0);
enc.set_buffer(2, Some(beta), 0);
enc.set_buffer(3, Some(out), 0);
enc.set_bytes(4, 8, &p as *const _ as *const c_void as *const _);
let grid = MTLSize::new(tokens as u64, 1, 1);
let tg = MTLSize::new(32, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn gelu_enc(&self, enc: &ComputeCommandEncoderRef, x: &Buffer, out: &Buffer, len: usize) {
#[repr(C)]
struct P {
n: i32,
}
let p = P { n: len as i32 };
enc.set_compute_pipeline_state(self.pipeline("gelu_f32"));
enc.set_buffer(0, Some(x), 0);
enc.set_buffer(1, Some(out), 0);
enc.set_bytes(2, 4, &p as *const _ as *const c_void as *const _);
let grid = MTLSize::new(len.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn add_bias_enc(
&self,
enc: &ComputeCommandEncoderRef,
data: &Buffer,
bias: &Buffer,
rows: usize,
cols: usize,
) {
#[repr(C)]
struct P {
rows: i32,
cols: i32,
}
let p = P {
rows: rows as i32,
cols: cols as i32,
};
enc.set_compute_pipeline_state(self.pipeline("add_bias_f32"));
enc.set_buffer(0, Some(data), 0);
enc.set_buffer(1, Some(bias), 0);
enc.set_bytes(2, 8, &p as *const _ as *const c_void as *const _);
let total = rows * cols;
let grid = MTLSize::new(total.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
pub fn silu_mul_split_enc(
&self,
enc: &ComputeCommandEncoderRef,
gate_up: &Buffer,
out: &Buffer,
tokens: usize,
im: usize,
) {
#[repr(C)]
struct P {
tokens: i32,
im: i32,
}
let p = P {
tokens: tokens as i32,
im: im as i32,
};
enc.set_compute_pipeline_state(self.pipeline("silu_mul_split_f32"));
enc.set_buffer(0, Some(gate_up), 0);
enc.set_buffer(1, Some(out), 0);
enc.set_bytes(2, 8, &p as *const _ as *const c_void as *const _);
let total = tokens * im;
let grid = MTLSize::new(total.div_ceil(256) as u64, 1, 1);
let tg = MTLSize::new(256, 1, 1);
enc.dispatch_thread_groups(grid, tg);
}
#[allow(clippy::too_many_arguments)]
pub fn paged_decode_attention_on_encoder(
&self,
enc: &metal::ComputeCommandEncoderRef,
q: &Buffer,
k_cache: &Buffer,
v_cache: &Buffer,
o: &Buffer,
block_tables: &Buffer,
context_lens: &Buffer,
params: &PagedAttnDispatchParams,
) {
debug_assert_eq!(
params.head_dim, 128,
"paged_decode_attention currently only supports head_dim=128"
);
debug_assert!(
params.num_heads % params.num_kv_heads == 0,
"GQA: num_heads must be divisible by num_kv_heads"
);
debug_assert!(params.q_len >= 1, "q_len must be ≥ 1");
let (q_head_stride, o_head_stride) = match params.q_layout {
PagedAttnQLayout::TokenMajor => (params.head_dim as i32, params.head_dim as i32),
PagedAttnQLayout::HeadMajor => {
let s = (params.q_len * params.head_dim) as i32;
(s, s)
}
};
#[repr(C)]
struct P {
num_heads: i32,
num_kv_heads: i32,
head_dim: i32,
scale: f32,
block_size: i32,
max_num_blocks_per_seq: i32,
kv_block_stride: i32,
kv_head_stride: i32,
q_len: i32,
q_head_stride: i32,
o_head_stride: i32,
}
let kv_head_stride = (params.block_size * params.head_dim) as i32;
let kv_block_stride = (params.num_kv_heads as i32) * kv_head_stride;
let p = P {
num_heads: params.num_heads as i32,
num_kv_heads: params.num_kv_heads as i32,
head_dim: params.head_dim as i32,
scale: 1.0 / (params.head_dim as f32).sqrt(),
block_size: params.block_size as i32,
max_num_blocks_per_seq: params.max_num_blocks_per_seq as i32,
kv_block_stride,
kv_head_stride,
q_len: params.q_len as i32,
q_head_stride,
o_head_stride,
};
enc.set_compute_pipeline_state(self.pipeline("flash_attn_decode_paged_f32"));
enc.set_buffer(0, Some(q), 0);
enc.set_buffer(1, Some(k_cache), 0);
enc.set_buffer(2, Some(v_cache), 0);
enc.set_buffer(3, Some(o), 0);
enc.set_buffer(4, Some(block_tables), 0);
enc.set_buffer(5, Some(context_lens), 0);
enc.set_bytes(
6,
std::mem::size_of::<P>() as u64,
&p as *const _ as *const c_void as *const _,
);
let grid = MTLSize::new(
params.q_len as u64,
params.num_heads as u64,
params.num_seqs as u64,
);
let tg = MTLSize::new(32, 32, 1);
enc.dispatch_thread_groups(grid, tg);
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PagedAttnQLayout {
TokenMajor,
HeadMajor,
}
#[derive(Clone, Copy, Debug)]
pub struct PagedAttnDispatchParams {
pub num_seqs: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub block_size: usize,
pub max_num_blocks_per_seq: usize,
pub q_len: usize,
pub q_layout: PagedAttnQLayout,
}