use super::{AttnConfig, Backend};
use ferrum_attention::metal::pipelines::MetalPipelines;
use ferrum_attention::AttentionParams;
use metal::Device;
use std::sync::OnceLock;
struct MetalState {
pipes: MetalPipelines,
}
static METAL_STATE: OnceLock<MetalState> = OnceLock::new();
fn st() -> &'static MetalState {
METAL_STATE.get_or_init(|| MetalState {
pipes: MetalPipelines::new(&Device::system_default().unwrap()),
})
}
pub struct MetalBackend;
#[cfg(target_os = "macos")]
extern "C" {
fn cblas_sgemm(
order: i32,
ta: i32,
tb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
}
pub struct MetalContext {
cmd: Option<&'static metal::CommandBufferRef>,
}
impl MetalContext {
fn cmd(&mut self) -> &'static metal::CommandBufferRef {
match self.cmd {
Some(c) => c,
None => {
let c = st().pipes.queue.new_command_buffer();
let c_static: &'static metal::CommandBufferRef =
unsafe { std::mem::transmute::<&metal::CommandBufferRef, _>(c) };
self.cmd = Some(c_static);
c_static
}
}
}
fn flush(&mut self) {
if let Some(cmd) = self.cmd.take() {
cmd.commit();
cmd.wait_until_completed();
}
}
}
fn run(f: impl FnOnce(&metal::CommandBufferRef)) {
let cmd = st().pipes.queue.new_command_buffer();
f(cmd);
cmd.commit();
cmd.wait_until_completed();
}
impl Backend for MetalBackend {
type Buffer = metal::Buffer;
type Context = MetalContext;
type GptqStore = ();
fn new_context() -> Self::Context {
MetalContext { cmd: None }
}
fn sync(ctx: &mut Self::Context) {
ctx.flush();
}
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
) {
if m == 1 {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.gemv_enc(enc, a, b, out, n, k);
enc.end_encoding();
} else {
ctx.flush();
unsafe {
cblas_sgemm(
101,
111,
112,
m as i32,
n as i32,
k as i32,
1.0,
a.contents() as *const f32,
k as i32,
b.contents() as *const f32,
k as i32,
0.0,
out.contents() as *mut f32,
n as i32,
);
}
}
}
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.rms_norm_enc(enc, x, w, out, tokens, dim, eps);
enc.end_encoding();
}
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.fused_residual_norm_enc(
enc, residual, x, None, w, residual, out, tokens, dim, eps, 0,
);
enc.end_encoding();
}
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
) {
let p = AttentionParams {
batch,
num_heads: cfg.num_heads,
num_kv_heads: cfg.num_kv_heads,
q_len,
kv_len,
head_dim: cfg.head_dim,
causal: cfg.causal,
pos_offset,
sliding_window: cfg.sliding_window,
};
let cmd = ctx.cmd();
st().pipes
.flash_attn_v2(cmd, q, k, v, out, &p, cfg.kv_seq_stride);
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
) {
let cmd = ctx.cmd();
let blit = cmd.new_blit_command_encoder();
blit.copy_from_buffer(
src,
(src_offset * 4) as u64,
dst,
(dst_offset * 4) as u64,
(len * 4) as u64,
);
blit.end_encoding();
}
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
) {
ctx.flush();
unsafe {
let t = std::slice::from_raw_parts(
table.contents() as *const f32,
table.length() as usize / 4,
);
let o = std::slice::from_raw_parts_mut(out.contents() as *mut f32, ids.len() * dim);
for (i, &id) in ids.iter().enumerate() {
let s = id as usize * dim;
o[i * dim..(i + 1) * dim].copy_from_slice(&t[s..s + dim]);
}
}
}
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes
.split_qkv_enc(enc, qkv, q, k, v, tokens, q_dim, kv_dim);
enc.end_encoding();
}
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gu: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.silu_mul_split_enc(enc, gu, out, tokens, im);
enc.end_encoding();
}
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.qk_norm_rope(
enc, input, norm_w, cos, sin, output, tokens, heads, head_dim, pos_offset, eps, mode,
);
enc.end_encoding();
}
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
) {
debug_assert!(cache_len + new_tokens <= cache_capacity);
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.kv_cache_append(
enc,
new_k_head_major,
cache_k,
nkv,
hd,
cache_len,
new_tokens,
cache_capacity,
);
st().pipes.kv_cache_append(
enc,
new_v_head_major,
cache_v,
nkv,
hd,
cache_len,
new_tokens,
cache_capacity,
);
enc.end_encoding();
}
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.transpose_out(enc, src, dst, tokens, heads, dim);
enc.end_encoding();
}
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.add_bias_enc(enc, data, bias, rows, cols);
enc.end_encoding();
}
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes
.layer_norm_enc(enc, x, gamma, beta, out, tokens, dim, eps);
enc.end_encoding();
}
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.gelu_enc(enc, x, out, len);
enc.end_encoding();
}
fn add_inplace(ctx: &mut Self::Context, r: &mut Self::Buffer, x: &Self::Buffer, len: usize) {
let cmd = ctx.cmd();
let enc = cmd.new_compute_command_encoder();
st().pipes.add_enc(enc, r, x, r, len);
enc.end_encoding();
}
fn alloc(len: usize) -> Self::Buffer {
st().pipes.buffer_empty(len)
}
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
MetalPipelines::read_buffer(buf, len)
}
fn from_slice(data: &[f32]) -> Self::Buffer {
st().pipes.buffer_from_data(data)
}
}