use metal::NSUInteger;
use super::expert_forward::MoeBuffers;
use super::expert_io::ExpertFiles;
use super::gpu_matvec::{encode_matvec, MatvecPipelines, MatvecSpec};
use super::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use super::layer_weight_cache::LayerWeightCache;
use super::linear_attn_forward::{
bits_of, full_attn_layer_idx_for, post_attention_tail,
read_buffer_to_vec, GpuAttnEncodeArgs, LayerForwardBuffers,
LayerForwardError, OProj,
};
use super::metal::MetalBackend;
use super::mtl_weight_buf::MtlWeightBuf;
use super::rms_norm::rms_norm_per_head_cpu;
use super::rope::apply_rotary_emb;
use super::sdpa::sdpa_cpu;
use super::state::KvCache;
use super::variants::VARIANT;
use super::weight_file::WeightFile;
#[allow(clippy::too_many_arguments)]
pub fn full_attn_layer_forward(
metal: &mut MetalBackend,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_cache: &LayerWeightCache,
buffers: &mut LayerForwardBuffers,
moe: &mut MoeBuffers,
deferred: &mut super::deferred::DeferredRing,
layer_idx: usize,
pos: i32,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut super::PrefetchState,
prefetch_set: usize,
kv_state: &mut KvCache,
gpu_combine: bool,
prev_layer_chained: bool,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let v = VARIANT;
if v.layer_kind(layer_idx) != super::variants::LayerKind::FullAttn {
return Err(LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "full_attn_layer_forward called on linear-attn layer",
});
}
let q_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.q_proj.weight"),
);
let k_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.k_proj.weight"),
);
let v_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.v_proj.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.o_proj.weight"),
);
let attn = layer_cache.attn.full().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "full_attn weights (called on linear-attn layer)",
},
)?;
let q_w = attn.q_proj_w;
let q_s = attn.q_proj_s;
let q_b = attn.q_proj_b;
let k_w = attn.k_proj_w;
let k_s = attn.k_proj_s;
let k_b = attn.k_proj_b;
let v_w = attn.v_proj_w;
let v_s = attn.v_proj_s;
let v_b = attn.v_proj_b;
let o_w = attn.o_proj_w;
let o_s = attn.o_proj_s;
let o_b = attn.o_proj_b;
let q_dim = v.num_attn_heads * v.head_dim; let q_proj_dim = q_dim * 2; let kv_dim = v.num_kv_heads * v.head_dim;
let mv = MatvecPipelines::fetch(metal)?;
let rms_pipes = RmsNormBf16Pipelines::fetch(metal)?;
{
let cmdbuf = metal.queue().new_command_buffer();
if !prev_layer_chained {
encode_rms_norm_bf16_into(
cmdbuf,
&rms_pipes,
&buffers.input,
wf_buf.buffer(),
layer_cache.input_layernorm_w,
&buffers.sum_sq,
&buffers.normed,
v.hidden_dim as u32,
super::variants::RMS_NORM_EPS,
);
}
let specs = [
MatvecSpec {
w_off: q_w,
s_off: q_s,
b_off: q_b,
input: &buffers.normed,
output: &buffers.q_proj_out,
out_dim: q_proj_dim as u32,
in_dim: v.hidden_dim as u32,
bits: q_bits,
},
MatvecSpec {
w_off: k_w,
s_off: k_s,
b_off: k_b,
input: &buffers.normed,
output: &buffers.k_out,
out_dim: kv_dim as u32,
in_dim: v.hidden_dim as u32,
bits: k_bits,
},
MatvecSpec {
w_off: v_w,
s_off: v_s,
b_off: v_b,
input: &buffers.normed,
output: &buffers.v_out,
out_dim: kv_dim as u32,
in_dim: v.hidden_dim as u32,
bits: v_bits,
},
];
for s in &specs {
encode_matvec(cmdbuf, &mv, wf_buf, s);
}
cmdbuf.commit();
cmdbuf.wait_until_completed();
}
let q_proj_host = read_buffer_to_vec(&buffers.q_proj_out, q_proj_dim);
let mut k_host = read_buffer_to_vec(&buffers.k_out, kv_dim);
let v_host = read_buffer_to_vec(&buffers.v_out, kv_dim);
let mut q_host = vec![0.0f32; q_dim];
let mut q_gate_host = vec![0.0f32; q_dim];
for h in 0..v.num_attn_heads {
let src_off = h * (2 * v.head_dim);
let dst_off = h * v.head_dim;
q_host[dst_off..dst_off + v.head_dim].copy_from_slice(
&q_proj_host[src_off..src_off + v.head_dim],
);
q_gate_host[dst_off..dst_off + v.head_dim].copy_from_slice(
&q_proj_host[src_off + v.head_dim..src_off + 2 * v.head_dim],
);
}
let q_norm_name =
format!("model.layers.{layer_idx}.self_attn.q_norm.weight");
rms_norm_per_head_cpu(
wf,
&q_norm_name,
v.num_attn_heads,
v.head_dim,
&mut q_host,
)?;
let k_norm_name =
format!("model.layers.{layer_idx}.self_attn.k_norm.weight");
rms_norm_per_head_cpu(
wf,
&k_norm_name,
v.num_kv_heads,
v.head_dim,
&mut k_host,
)?;
apply_rotary_emb(pos, &mut q_host, &mut k_host)?;
let cache_pos = kv_state.len as usize;
if cache_pos + 1 > super::variants::MAX_SEQ_LEN {
return Err(LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "kv cache overflow",
});
}
let row_start = cache_pos * kv_dim;
let row_end = row_start + kv_dim;
kv_state.k_cache[row_start..row_end].copy_from_slice(&k_host);
kv_state.v_cache[row_start..row_end].copy_from_slice(&v_host);
kv_state.len += 1;
let fa_idx = full_attn_layer_idx_for(layer_idx);
if let Some(fa_idx) = fa_idx {
if cache_pos < super::variants::GPU_KV_SEQ {
unsafe {
let k_dst = buffers.gpu_kv_k[fa_idx].contents() as *mut f32;
let v_dst = buffers.gpu_kv_v[fa_idx].contents() as *mut f32;
std::ptr::copy_nonoverlapping(
k_host.as_ptr(),
k_dst.add(row_start),
kv_dim,
);
std::ptr::copy_nonoverlapping(
v_host.as_ptr(),
v_dst.add(row_start),
kv_dim,
);
}
}
}
let kv_len = kv_state.len;
let gpu_attn_ready = fa_idx.is_some()
&& kv_len >= 32
&& (kv_len as usize) < super::variants::GPU_KV_SEQ;
let gpu_attn_args = if gpu_attn_ready {
let fa_idx = fa_idx.expect("gpu_attn_ready ⇒ Some(fa_idx)");
unsafe {
let q_dst = buffers.gpu_attn_q.contents() as *mut f32;
let g_dst = buffers.gpu_attn_gate.contents() as *mut f32;
std::ptr::copy_nonoverlapping(q_host.as_ptr(), q_dst, q_dim);
std::ptr::copy_nonoverlapping(
q_gate_host.as_ptr(),
g_dst,
q_dim,
);
}
Some(GpuAttnEncodeArgs {
fa_idx,
kv_len: kv_len as u32,
})
} else {
let kv_total = (kv_len as usize) * kv_dim;
let mut attn_out = vec![0.0f32; q_dim];
sdpa_cpu(
kv_len,
&q_host,
&q_gate_host,
&kv_state.k_cache[..kv_total],
&kv_state.v_cache[..kv_total],
&mut attn_out,
)?;
let dst = buffers.batch_out[6].contents() as *mut f32;
unsafe {
std::ptr::copy_nonoverlapping(
attn_out.as_ptr(),
dst,
q_dim,
);
}
debug_assert!(
buffers.batch_out[6].length() as usize
>= q_dim * std::mem::size_of::<f32>(),
"batch_out[6] sized {} bytes, need {} for full-attn o_proj input",
buffers.batch_out[6].length() as NSUInteger,
q_dim * std::mem::size_of::<f32>(),
);
None
};
post_attention_tail(
metal,
wf,
wf_buf,
layer_cache,
buffers,
moe,
deferred,
layer_idx,
k_active,
expert_files,
pool,
prefetch,
prefetch_set,
OProj {
w_off: o_w,
s_off: o_s,
b_off: o_b,
bits: o_bits,
in_dim: q_dim as u32,
},
gpu_combine,
gpu_attn_args,
chain_next_norm_off,
)
}