use crate::riir::backend::cpu::cpu_matvec::{project_4bit_cpu, CpuMatvecError};
use crate::riir::moe::moe_router::softmax;
use crate::riir::attn::rms_norm::{rms_norm_per_head_cpu, RmsNormError};
use crate::riir::attn::rope::{apply_rotary_emb_yarn, YarnError};
use crate::riir::snapshot::state::MlaKvCacheGpu;
use crate::riir::variants::{MAX_SEQ_LEN, VARIANT};
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum MlaForwardError {
#[error("called on non-MLA variant (attn_kind = {kind:?})")]
NotMlaVariant {
kind: crate::riir::variants::AttnKind,
},
#[error("hidden buffer length {got} != hidden_dim ({expected})")]
HiddenLen { got: usize, expected: usize },
#[error("output buffer length {got} != hidden_dim ({expected})")]
OutLen { got: usize, expected: usize },
#[error("position {pos} != kv_cache.len {cache_len} (single-step decode)")]
PosMismatch { pos: i32, cache_len: i32 },
#[error("kv_cache.len {len} would exceed MAX_SEQ_LEN={max} after append")]
CacheFull { len: i32, max: usize },
#[error("matvec error in MLA: {0}")]
Matvec(#[from] CpuMatvecError),
#[error("rms-norm error in MLA: {0}")]
Norm(#[from] RmsNormError),
#[error("YaRN RoPE error in MLA: {0}")]
Rope(#[from] YarnError),
#[error("softmax error in MLA: {0}")]
Softmax(#[from] crate::riir::moe::moe_router::MoeRouterError),
}
#[allow(clippy::too_many_arguments)]
pub fn mla_attn_layer_forward_cpu(
wf: &WeightFile,
layer_idx: usize,
pos: i32,
hidden: &[f32],
kv_cache: &mut MlaKvCacheGpu,
yarn_inv_freq: &[f32],
yarn_mscale: f32,
out: &mut [f32],
) -> Result<(), MlaForwardError> {
use crate::riir::variants::AttnKind;
if VARIANT.attn_kind != AttnKind::Mla {
return Err(MlaForwardError::NotMlaVariant {
kind: VARIANT.attn_kind,
});
}
let v = VARIANT;
if hidden.len() != v.hidden_dim {
return Err(MlaForwardError::HiddenLen {
got: hidden.len(),
expected: v.hidden_dim,
});
}
if out.len() != v.hidden_dim {
return Err(MlaForwardError::OutLen {
got: out.len(),
expected: v.hidden_dim,
});
}
if pos != kv_cache.len {
return Err(MlaForwardError::PosMismatch {
pos,
cache_len: kv_cache.len,
});
}
if (kv_cache.len as usize) >= MAX_SEQ_LEN {
return Err(MlaForwardError::CacheFull {
len: kv_cache.len,
max: MAX_SEQ_LEN,
});
}
let hidden_dim = v.hidden_dim;
let num_heads = v.num_attn_heads;
let q_lora_rank = v.q_lora_rank;
let kv_lora_rank = v.kv_lora_rank;
let nope = v.qk_nope_head_dim;
let rope = v.qk_rope_head_dim;
let v_head_dim = v.v_head_dim;
let qk_head_dim = nope + rope;
let kv_b_per_head = nope + v_head_dim;
let q_a_name = format!("model.layers.{layer_idx}.self_attn.q_a_proj");
let q_a_norm =
format!("model.layers.{layer_idx}.self_attn.q_a_layernorm.weight");
let q_b_name = format!("model.layers.{layer_idx}.self_attn.q_b_proj");
let mut q_lat = vec![0.0f32; q_lora_rank];
project_4bit_cpu(wf, &q_a_name, hidden_dim, q_lora_rank, hidden, &mut q_lat)?;
rms_norm_per_head_cpu(wf, &q_a_norm, 1, q_lora_rank, &mut q_lat)?;
let mut q_full = vec![0.0f32; num_heads * qk_head_dim];
project_4bit_cpu(
wf,
&q_b_name,
q_lora_rank,
num_heads * qk_head_dim,
&q_lat,
&mut q_full,
)?;
let mut q_pe = vec![0.0f32; num_heads * rope];
for h in 0..num_heads {
let q_h = &q_full[h * qk_head_dim..(h + 1) * qk_head_dim];
let dst = &mut q_pe[h * rope..(h + 1) * rope];
dst.copy_from_slice(&q_h[nope..nope + rope]);
}
let kv_a_name =
format!("model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa");
let kv_a_norm =
format!("model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight");
let mut kv_pre = vec![0.0f32; kv_lora_rank + rope];
project_4bit_cpu(
wf,
&kv_a_name,
hidden_dim,
kv_lora_rank + rope,
hidden,
&mut kv_pre,
)?;
rms_norm_per_head_cpu(
wf,
&kv_a_norm,
1,
kv_lora_rank,
&mut kv_pre[..kv_lora_rank],
)?;
apply_rotary_emb_yarn(pos, &mut q_pe, rope, yarn_inv_freq, yarn_mscale)?;
apply_rotary_emb_yarn(
pos,
&mut kv_pre[kv_lora_rank..],
rope,
yarn_inv_freq,
yarn_mscale,
)?;
for h in 0..num_heads {
let dst = &mut q_full[h * qk_head_dim + nope..(h + 1) * qk_head_dim];
let src = &q_pe[h * rope..(h + 1) * rope];
dst.copy_from_slice(src);
}
let new_idx = pos as usize;
unsafe {
let l_dst = kv_cache.latent_slice_mut(new_idx, new_idx + 1);
l_dst.copy_from_slice(&kv_pre[..kv_lora_rank]);
let r_dst = kv_cache.rope_k_slice_mut(new_idx, new_idx + 1);
r_dst.copy_from_slice(&kv_pre[kv_lora_rank..]);
}
kv_cache.len = pos + 1;
let cache_len = kv_cache.len as usize;
let latent_cache_view: &[f32] =
unsafe { kv_cache.latent_slice(cache_len) };
let rope_k_cache_view: &[f32] =
unsafe { kv_cache.rope_k_slice(cache_len) };
let kv_b_name = format!("model.layers.{layer_idx}.self_attn.kv_b_proj");
let mut decoded_all = vec![0.0f32; cache_len * num_heads * kv_b_per_head];
for j in 0..cache_len {
let latent_j = &latent_cache_view
[j * kv_lora_rank..(j + 1) * kv_lora_rank];
let dec_j = &mut decoded_all
[j * num_heads * kv_b_per_head..(j + 1) * num_heads * kv_b_per_head];
project_4bit_cpu(
wf,
&kv_b_name,
kv_lora_rank,
num_heads * kv_b_per_head,
latent_j,
dec_j,
)?;
}
let softmax_scale =
(1.0 / (qk_head_dim as f32).sqrt()) * yarn_mscale * yarn_mscale;
let mut head_out = vec![0.0f32; num_heads * v_head_dim];
let mut scores = vec![0.0f32; cache_len];
for h in 0..num_heads {
let q_h = &q_full[h * qk_head_dim..(h + 1) * qk_head_dim];
let q_nope_h = &q_h[..nope];
let q_pe_h = &q_h[nope..nope + rope];
for j in 0..cache_len {
let dec_jh = &decoded_all[(j * num_heads + h) * kv_b_per_head
..(j * num_heads + h + 1) * kv_b_per_head];
let k_nope_jh = &dec_jh[..nope];
let rope_k_j =
&rope_k_cache_view[j * rope..(j + 1) * rope];
let mut s = 0.0f32;
for c in 0..nope {
s = q_nope_h[c].mul_add(k_nope_jh[c], s);
}
for c in 0..rope {
s = q_pe_h[c].mul_add(rope_k_j[c], s);
}
scores[j] = s * softmax_scale;
}
softmax(&mut scores)?;
let head_out_h = &mut head_out[h * v_head_dim..(h + 1) * v_head_dim];
head_out_h.fill(0.0);
for j in 0..cache_len {
let dec_jh = &decoded_all[(j * num_heads + h) * kv_b_per_head
..(j * num_heads + h + 1) * kv_b_per_head];
let v_jh = &dec_jh[nope..nope + v_head_dim];
let w = scores[j];
for c in 0..v_head_dim {
head_out_h[c] = w.mul_add(v_jh[c], head_out_h[c]);
}
}
}
let o_name = format!("model.layers.{layer_idx}.self_attn.o_proj");
project_4bit_cpu(
wf,
&o_name,
num_heads * v_head_dim,
hidden_dim,
&head_out,
out,
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(any(
feature = "model-qwen3-5-a17b",
feature = "model-qwen3-6-35b-a3b",
))]
#[test]
fn rejects_non_mla_variant() {
}
#[cfg(feature = "model-cogito-v2-671b")]
#[test]
#[ignore = "needs Cogito-V2 weights mmap'd from /Volumes/Temp Backup"]
fn mla_layer0_pos0_smoke() {
use crate::riir::attn::rope::{compute_yarn_inv_freq, yarn_get_mscale_full};
use crate::riir::variants::ROPE_THETA;
use std::path::Path;
let bin = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.bin",
);
let manifest = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.json",
);
let wf = WeightFile::open(bin, manifest).expect("open weights");
let v = VARIANT;
let inv_freq = compute_yarn_inv_freq(
v.qk_rope_head_dim,
ROPE_THETA,
v.yarn_factor,
v.yarn_original_max_pos as f32,
v.yarn_beta_fast,
v.yarn_beta_slow,
);
let mscale = yarn_get_mscale_full(
v.yarn_factor,
v.yarn_mscale,
v.yarn_mscale_all_dim,
);
let mut hidden = vec![0.0f32; v.hidden_dim];
hidden[7] = 1.0;
let device = metal::Device::system_default()
.expect("Metal device for MLA KV cache buffers");
let mut cache = MlaKvCacheGpu::new();
cache.ensure_buffers(&device);
let mut out = vec![0.0f32; v.hidden_dim];
mla_attn_layer_forward_cpu(
&wf, 0, 0, &hidden, &mut cache, &inv_freq, mscale, &mut out,
)
.expect("MLA forward should succeed");
assert_eq!(cache.len, 1, "cache should advance to 1");
assert!(
out.iter().all(|v| v.is_finite()),
"out[i] non-finite at first index = {:?}",
out.iter().position(|v| !v.is_finite()),
);
let max_abs = out.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
assert!(
max_abs > 0.0,
"output is all zeros — likely a wiring bug"
);
assert!(
max_abs < 1e6,
"output magnitude {max_abs} suspiciously large"
);
}
}