use super::pipelines::MetalPipelines;
use crate::attention::AttentionParams;
use metal::*;
pub struct MetalLayerWeights {
pub input_ln_w: Buffer,
pub q_proj_w: Buffer,
pub k_proj_w: Buffer,
pub v_proj_w: Buffer,
pub o_proj_w: Buffer,
pub q_norm_w: Buffer,
pub k_norm_w: Buffer,
pub post_ln_w: Buffer,
pub gate_proj_w: Buffer,
pub up_proj_w: Buffer,
pub down_proj_w: Buffer,
pub has_qk_norm: bool,
pub attn_scale: Option<Buffer>,
pub mlp_scale: Option<Buffer>,
}
pub struct MetalTransformerConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub rms_norm_eps: f32,
}
pub struct MetalKvCache {
pub k_buf: Buffer, pub v_buf: Buffer, pub len: usize,
pub max_len: usize,
}
impl MetalKvCache {
pub fn new(pipes: &MetalPipelines, nkv: usize, hd: usize, max_len: usize) -> Self {
let size = nkv * max_len * hd;
Self {
k_buf: pipes.buffer_empty(size),
v_buf: pipes.buffer_empty(size),
len: 0,
max_len,
}
}
pub fn reset(&mut self) {
self.len = 0;
}
}
pub struct LayerScratch {
pub ln_out: Buffer,
pub q_buf: Buffer,
pub k_buf: Buffer,
pub v_buf: Buffer,
pub q_ready: Buffer,
pub k_ready: Buffer,
pub v_ready: Buffer,
pub attn_out: Buffer,
pub attn_flat: Buffer,
pub o_out: Buffer,
pub hidden: Buffer,
pub post_ln: Buffer,
pub gate_buf: Buffer,
pub up_buf: Buffer,
pub silu_out: Buffer,
pub mlp_out: Buffer,
pub output: Buffer,
}
impl LayerScratch {
pub fn new(
pipes: &MetalPipelines,
tokens: usize,
h: usize,
im: usize,
nh: usize,
nkv: usize,
hd: usize,
) -> Self {
Self {
ln_out: pipes.buffer_empty(tokens * h),
q_buf: pipes.buffer_empty(tokens * nh * hd),
k_buf: pipes.buffer_empty(tokens * nkv * hd),
v_buf: pipes.buffer_empty(tokens * nkv * hd),
q_ready: pipes.buffer_empty(nh * tokens * hd),
k_ready: pipes.buffer_empty(nkv * tokens * hd),
v_ready: pipes.buffer_empty(nkv * tokens * hd),
attn_out: pipes.buffer_empty(nh * tokens * hd),
attn_flat: pipes.buffer_empty(tokens * nh * hd),
o_out: pipes.buffer_empty(tokens * h),
hidden: pipes.buffer_empty(tokens * h),
post_ln: pipes.buffer_empty(tokens * h),
gate_buf: pipes.buffer_empty(tokens * im),
up_buf: pipes.buffer_empty(tokens * im),
silu_out: pipes.buffer_empty(tokens * im),
mlp_out: pipes.buffer_empty(tokens * h),
output: pipes.buffer_empty(tokens * h),
}
}
}
pub fn metal_layer_forward_v2(
cmd: &CommandBufferRef,
pipes: &MetalPipelines,
input: &Buffer,
tokens: usize,
w: &MetalLayerWeights,
cfg: &MetalTransformerConfig,
kv_cache: &mut MetalKvCache,
pos_offset: usize,
cos_buf: &Buffer,
sin_buf: &Buffer,
s: &LayerScratch,
) {
let h = cfg.hidden_size;
let nh = cfg.num_heads;
let nkv = cfg.num_kv_heads;
let hd = cfg.head_dim;
let im = cfg.intermediate_size;
{
let enc = cmd.new_compute_command_encoder();
pipes.rms_norm_enc(
enc,
input,
&w.input_ln_w,
&s.ln_out,
tokens,
h,
cfg.rms_norm_eps,
);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.gemm_v2(enc, &s.ln_out, &w.q_proj_w, &s.q_buf, tokens, nh * hd, h);
pipes.gemm_v2(enc, &s.ln_out, &w.k_proj_w, &s.k_buf, tokens, nkv * hd, h);
pipes.gemm_v2(enc, &s.ln_out, &w.v_proj_w, &s.v_buf, tokens, nkv * hd, h);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
let qk_mode: i32 = if w.has_qk_norm { 1 } else { 2 };
pipes.qk_norm_rope(
enc,
&s.q_buf,
&w.q_norm_w,
cos_buf,
sin_buf,
&s.q_ready,
tokens,
nh,
hd,
pos_offset,
cfg.rms_norm_eps,
qk_mode,
);
pipes.qk_norm_rope(
enc,
&s.k_buf,
&w.k_norm_w,
cos_buf,
sin_buf,
&s.k_ready,
tokens,
nkv,
hd,
pos_offset,
cfg.rms_norm_eps,
qk_mode,
);
pipes.qk_norm_rope(
enc,
&s.v_buf,
&w.k_norm_w,
cos_buf,
sin_buf,
&s.v_ready,
tokens,
nkv,
hd,
pos_offset,
cfg.rms_norm_eps,
0,
); enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.kv_cache_append(
enc,
&s.k_ready,
&kv_cache.k_buf,
nkv,
hd,
kv_cache.len,
tokens,
kv_cache.max_len,
);
pipes.kv_cache_append(
enc,
&s.v_ready,
&kv_cache.v_buf,
nkv,
hd,
kv_cache.len,
tokens,
kv_cache.max_len,
);
enc.end_encoding();
}
let kv_len = kv_cache.len + tokens;
kv_cache.len = kv_len;
{
let params = AttentionParams {
batch: 1,
num_heads: nh,
num_kv_heads: nkv,
q_len: tokens,
kv_len,
head_dim: hd,
causal: tokens > 1,
pos_offset,
sliding_window: 0,
};
pipes.flash_attn_v2(
cmd,
&s.q_ready,
&kv_cache.k_buf,
&kv_cache.v_buf,
&s.attn_out,
¶ms,
kv_cache.max_len,
);
}
{
let enc = cmd.new_compute_command_encoder();
pipes.transpose_out(enc, &s.attn_out, &s.attn_flat, tokens, nh, hd);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.gemm_v2(enc, &s.attn_flat, &w.o_proj_w, &s.o_out, tokens, h, nh * hd);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.fused_residual_norm_enc(
enc,
input,
&s.o_out,
w.attn_scale.as_ref(),
&w.post_ln_w,
&s.hidden,
&s.post_ln,
tokens,
h,
cfg.rms_norm_eps,
h,
);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.gemm_v2(enc, &s.post_ln, &w.gate_proj_w, &s.gate_buf, tokens, im, h);
pipes.gemm_v2(enc, &s.post_ln, &w.up_proj_w, &s.up_buf, tokens, im, h);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.silu_mul_enc(enc, &s.gate_buf, &s.up_buf, &s.silu_out, tokens * im);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
pipes.gemm_v2(enc, &s.silu_out, &w.down_proj_w, &s.mlp_out, tokens, h, im);
enc.end_encoding();
}
{
let enc = cmd.new_compute_command_encoder();
if let Some(ref scale) = w.mlp_scale {
pipes.fused_scale_add_enc(enc, &s.hidden, &s.mlp_out, scale, &s.output, tokens * h, h);
} else {
pipes.add_enc(enc, &s.hidden, &s.mlp_out, &s.output, tokens * h);
}
enc.end_encoding();
}
}