use metal::{
Buffer, Device, MTLResourceOptions, MTLSize, NSUInteger,
};
use crate::riir::backend::gpu::gpu_matvec::{
encode_matvec, MatvecPipelines, MatvecSpec,
};
use crate::riir::attn::gpu_mla::{
encode_mla_kv_cache_append, encode_mla_out_per_head_4bit,
encode_mla_q_prime_4bit, encode_mla_sdpa_folded,
encode_mla_split_q_kv, GpuMlaError, MlaPipelines,
};
use crate::riir::backend::gpu::gpu_norm::{
encode_rms_norm_bf16_into, RmsNormBf16Pipelines,
};
use crate::riir::attn::gpu_rope::encode_yarn_rope_apply;
use crate::riir::backend::gpu::metal::{MetalContext, MetalError};
use crate::riir::io::mtl_weight_buf::MtlWeightBuf;
use crate::riir::snapshot::state::MlaKvCacheGpu;
use crate::riir::variants::{RMS_NORM_EPS, VARIANT};
use crate::riir::io::weight_file::WeightFile;
pub struct MlaForwardBuffers {
pub q_lat: Buffer, pub q_full: Buffer, pub q_nope: Buffer, pub q_pe: Buffer, pub kv_pre: Buffer, pub kv_lat: Buffer, pub k_pe: Buffer, pub q_prime: Buffer, pub v_combine: Buffer, pub out_per_head: Buffer, pub pre_norm: Buffer, pub out: Buffer, pub q_a_sum_sq: Buffer, pub kv_lat_sum_sq: Buffer, }
impl MlaForwardBuffers {
pub fn new(device: &Device) -> Self {
let v = VARIANT;
let f32_buf = |n: usize| {
let b = device.new_buffer(
(n * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
unsafe {
std::ptr::write_bytes(
b.contents() as *mut u8,
0,
n * std::mem::size_of::<f32>(),
);
}
b
};
let qk_head_dim = v.qk_nope_head_dim + v.qk_rope_head_dim;
Self {
q_lat: f32_buf(v.q_lora_rank),
q_full: f32_buf(v.num_attn_heads * qk_head_dim),
q_nope: f32_buf(v.num_attn_heads * v.qk_nope_head_dim),
q_pe: f32_buf(v.num_attn_heads * v.qk_rope_head_dim),
kv_pre: f32_buf(v.kv_lora_rank + v.qk_rope_head_dim),
kv_lat: f32_buf(v.kv_lora_rank),
k_pe: f32_buf(v.qk_rope_head_dim),
q_prime: f32_buf(v.num_attn_heads * v.kv_lora_rank),
v_combine: f32_buf(v.num_attn_heads * v.kv_lora_rank),
out_per_head: f32_buf(v.num_attn_heads * v.v_head_dim),
pre_norm: f32_buf(v.hidden_dim),
out: f32_buf(v.hidden_dim),
q_a_sum_sq: f32_buf(1),
kv_lat_sum_sq: f32_buf(1),
}
}
}
pub struct MlaYarnTables {
pub inv_freq: Buffer,
pub mscale: f32,
}
impl MlaYarnTables {
pub fn new(device: &Device) -> Self {
use crate::riir::attn::rope::{compute_yarn_inv_freq, yarn_get_mscale_full};
use crate::riir::variants::ROPE_THETA;
let v = VARIANT;
let inv_freq = compute_yarn_inv_freq(
v.qk_rope_head_dim,
ROPE_THETA,
v.yarn_factor,
v.yarn_original_max_pos as f32,
v.yarn_beta_fast,
v.yarn_beta_slow,
);
let mscale = yarn_get_mscale_full(
v.yarn_factor,
v.yarn_mscale,
v.yarn_mscale_all_dim,
);
let buf = device.new_buffer_with_data(
inv_freq.as_ptr().cast(),
(inv_freq.len() * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
Self {
inv_freq: buf,
mscale,
}
}
}
pub struct MlaForwardPipelines {
pub mla: MlaPipelines,
pub matvec: MatvecPipelines,
pub norms: RmsNormBf16Pipelines,
pub yarn_rope: metal::ComputePipelineState,
}
impl MlaForwardPipelines {
pub fn new(metal: &mut MetalContext) -> Result<Self, MetalError> {
Ok(Self {
mla: MlaPipelines::fetch(metal)?,
matvec: MatvecPipelines::fetch(metal)?,
norms: RmsNormBf16Pipelines::fetch(metal)?,
yarn_rope: metal.pipeline("yarn_rope_apply")?.clone(),
})
}
}
#[derive(Debug, thiserror::Error)]
pub enum MlaForwardGpuError {
#[error("MLA only valid on MLA variants (this build's attn_kind is {kind:?})")]
NotMlaVariant { kind: crate::riir::variants::AttnKind },
#[error("kv_cache.len {len} would exceed MAX_SEQ_LEN={max} after append")]
CacheFull { len: i32, max: usize },
#[error("pos {pos} != kv_cache.len {cache_len} (single-step decode)")]
PosMismatch { pos: i32, cache_len: i32 },
#[error("kv_cache buffers not allocated (call ensure_buffers first)")]
CacheNotReady,
#[error("Metal weight tensor: {name}")]
MissingTensor { name: String },
#[error("Metal: {0}")]
Metal(#[from] MetalError),
#[error("MLA dispatch: {0}")]
Mla(#[from] GpuMlaError),
}
#[allow(clippy::too_many_arguments)]
pub fn mla_attn_layer_forward_gpu(
metal: &mut MetalContext,
pipes: &MlaForwardPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
yarn: &MlaYarnTables,
bufs: &mut MlaForwardBuffers,
kv_cache: &mut MlaKvCacheGpu,
layer_idx: usize,
pos: i32,
) -> Result<(), MlaForwardGpuError> {
use crate::riir::variants::AttnKind;
if VARIANT.attn_kind != AttnKind::Mla {
return Err(MlaForwardGpuError::NotMlaVariant {
kind: VARIANT.attn_kind,
});
}
if pos != kv_cache.len {
return Err(MlaForwardGpuError::PosMismatch {
pos,
cache_len: kv_cache.len,
});
}
if (kv_cache.len as usize) >= crate::riir::variants::MAX_SEQ_LEN {
return Err(MlaForwardGpuError::CacheFull {
len: kv_cache.len,
max: crate::riir::variants::MAX_SEQ_LEN,
});
}
let latent_buf =
kv_cache.latent_cache.as_ref().ok_or(MlaForwardGpuError::CacheNotReady)?;
let rope_k_buf =
kv_cache.rope_k_cache.as_ref().ok_or(MlaForwardGpuError::CacheNotReady)?;
let v = VARIANT;
let hidden_dim = v.hidden_dim as u32;
let q_lora_rank = v.q_lora_rank as u32;
let kv_lora_rank = v.kv_lora_rank as u32;
let nope = v.qk_nope_head_dim as u32;
let rope_dim = v.qk_rope_head_dim as u32;
let qk_head_dim = nope + rope_dim;
let v_head_dim = v.v_head_dim as u32;
let num_heads = v.num_attn_heads as u32;
let kv_b_per_head = nope + v_head_dim;
let resolve_proj = |name: &str| -> Result<(u64, u64, u64), MlaForwardGpuError> {
let w = format!("{name}.weight");
let s = format!("{name}.scales");
let b = format!("{name}.biases");
let w_off = wf_buf
.tensor_offset(wf, &w)
.map_err(|_| MlaForwardGpuError::MissingTensor { name: w.clone() })?
.ok_or(MlaForwardGpuError::MissingTensor { name: w })?;
let s_off = wf_buf
.tensor_offset(wf, &s)
.map_err(|_| MlaForwardGpuError::MissingTensor { name: s.clone() })?
.ok_or(MlaForwardGpuError::MissingTensor { name: s })?;
let b_off = wf_buf
.tensor_offset(wf, &b)
.map_err(|_| MlaForwardGpuError::MissingTensor { name: b.clone() })?
.ok_or(MlaForwardGpuError::MissingTensor { name: b })?;
Ok((w_off, s_off, b_off))
};
let resolve_norm = |name: &str| -> Result<u64, MlaForwardGpuError> {
let n = format!("{name}.weight");
wf_buf
.tensor_offset(wf, &n)
.map_err(|_| MlaForwardGpuError::MissingTensor { name: n.clone() })?
.ok_or(MlaForwardGpuError::MissingTensor { name: n })
};
let layer_prefix = format!("model.layers.{layer_idx}.self_attn");
let q_a_off = resolve_proj(&format!("{layer_prefix}.q_a_proj"))?;
let q_a_norm_off = resolve_norm(&format!("{layer_prefix}.q_a_layernorm"))?;
let q_b_off = resolve_proj(&format!("{layer_prefix}.q_b_proj"))?;
let kv_a_off = resolve_proj(&format!("{layer_prefix}.kv_a_proj_with_mqa"))?;
let kv_a_norm_off = resolve_norm(&format!("{layer_prefix}.kv_a_layernorm"))?;
let kv_b_off = resolve_proj(&format!("{layer_prefix}.kv_b_proj"))?;
let o_off = resolve_proj(&format!("{layer_prefix}.o_proj"))?;
let pipe_qprime = pipes.mla.q_prime.clone();
let pipe_sdpa = pipes.mla.sdpa.clone();
let pipe_outhead = pipes.mla.out_per_head.clone();
let pipe_split = pipes.mla.split_q_kv.clone();
let pipe_cache_append = pipes.mla.cache_append.clone();
let pipe_yarn = pipes.yarn_rope.clone();
let queue = metal.queue();
let cmdbuf = queue.new_command_buffer();
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: q_a_off.0,
s_off: q_a_off.1,
b_off: q_a_off.2,
input: &bufs.pre_norm,
output: &bufs.q_lat,
out_dim: q_lora_rank,
in_dim: hidden_dim,
bits: 4,
},
);
encode_rms_norm_bf16_into(
cmdbuf,
&pipes.norms,
&bufs.q_lat,
wf_buf.buffer(),
q_a_norm_off,
&bufs.q_a_sum_sq,
&bufs.q_lat,
q_lora_rank,
RMS_NORM_EPS,
);
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: q_b_off.0,
s_off: q_b_off.1,
b_off: q_b_off.2,
input: &bufs.q_lat,
output: &bufs.q_full,
out_dim: num_heads * qk_head_dim,
in_dim: q_lora_rank,
bits: 4,
},
);
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: kv_a_off.0,
s_off: kv_a_off.1,
b_off: kv_a_off.2,
input: &bufs.pre_norm,
output: &bufs.kv_pre,
out_dim: kv_lora_rank + rope_dim,
in_dim: hidden_dim,
bits: 4,
},
);
encode_mla_split_q_kv(
cmdbuf,
&pipe_split,
&bufs.q_full,
&bufs.kv_pre,
&bufs.q_nope,
&bufs.q_pe,
&bufs.kv_lat,
&bufs.k_pe,
num_heads,
nope,
rope_dim,
kv_lora_rank,
);
encode_rms_norm_bf16_into(
cmdbuf,
&pipes.norms,
&bufs.kv_lat,
wf_buf.buffer(),
kv_a_norm_off,
&bufs.kv_lat_sum_sq,
&bufs.kv_lat,
kv_lora_rank,
RMS_NORM_EPS,
);
encode_yarn_rope_apply(
cmdbuf,
&pipe_yarn,
&bufs.q_pe,
&yarn.inv_freq,
num_heads,
rope_dim,
pos,
yarn.mscale,
)
.map_err(|_| MlaForwardGpuError::Metal(MetalError::NoDevice))?;
encode_yarn_rope_apply(
cmdbuf,
&pipe_yarn,
&bufs.k_pe,
&yarn.inv_freq,
1, rope_dim,
pos,
yarn.mscale,
)
.map_err(|_| MlaForwardGpuError::Metal(MetalError::NoDevice))?;
encode_mla_kv_cache_append(
cmdbuf,
&pipe_cache_append,
&bufs.kv_lat,
&bufs.k_pe,
latent_buf,
rope_k_buf,
kv_lora_rank,
rope_dim,
pos,
);
let cache_len = (pos + 1) as u32;
encode_mla_q_prime_4bit(
cmdbuf,
&pipe_qprime,
wf_buf.buffer(),
kv_b_off.0,
wf_buf.buffer(),
kv_b_off.1,
wf_buf.buffer(),
kv_b_off.2,
&bufs.q_nope,
&bufs.q_prime,
num_heads,
nope,
kv_lora_rank,
kv_b_per_head,
64, );
let softmax_scale =
(1.0 / (qk_head_dim as f32).sqrt()) * yarn.mscale * yarn.mscale;
encode_mla_sdpa_folded(
cmdbuf,
&pipe_sdpa,
&bufs.q_prime,
&bufs.q_pe,
latent_buf,
rope_k_buf,
&bufs.v_combine,
num_heads,
kv_lora_rank,
rope_dim,
cache_len,
softmax_scale,
)?;
encode_mla_out_per_head_4bit(
cmdbuf,
&pipe_outhead,
wf_buf.buffer(),
kv_b_off.0,
wf_buf.buffer(),
kv_b_off.1,
wf_buf.buffer(),
kv_b_off.2,
&bufs.v_combine,
&bufs.out_per_head,
num_heads,
nope,
kv_lora_rank,
v_head_dim,
kv_b_per_head,
64,
);
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: o_off.0,
s_off: o_off.1,
b_off: o_off.2,
input: &bufs.out_per_head,
output: &bufs.out,
out_dim: hidden_dim,
in_dim: num_heads * v_head_dim,
bits: 4,
},
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
kv_cache.len = pos + 1;
Ok(())
}
#[allow(dead_code)]
fn dispatch_1d(
enc: &metal::ComputeCommandEncoderRef,
threadgroups: u64,
threads: u64,
) {
enc.dispatch_thread_groups(
MTLSize::new(threadgroups, 1, 1),
MTLSize::new(threads, 1, 1),
);
}