use crate::riir::io::embedding::bf16_to_f32;
use crate::riir::variants::{GROUP_SIZE, VARIANT};
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum LmHeadError {
#[error("LM head tensor '{name}' missing from manifest")]
MissingTensor { name: &'static str },
#[error(
"LM head tensor '{name}' has unexpected shape {shape:?} (expected {expected:?})"
)]
ShapeMismatch {
name: &'static str,
shape: Vec<usize>,
expected: Vec<usize>,
},
#[error("input length {got} != HIDDEN_DIM {expected}")]
InputLen { got: usize, expected: usize },
#[error("output length {got} != VOCAB_SIZE {expected}")]
OutputLen { got: usize, expected: usize },
}
const WEIGHT_NAME: &str = "lm_head.weight";
const SCALES_NAME: &str = "lm_head.scales";
const BIASES_NAME: &str = "lm_head.biases";
pub fn lm_head_cpu(
wf: &WeightFile,
x: &[f32],
out: &mut [f32],
) -> Result<(), LmHeadError> {
let hidden_dim = VARIANT.hidden_dim;
let vocab_size = VARIANT.vocab_size;
if x.len() != hidden_dim {
return Err(LmHeadError::InputLen {
got: x.len(),
expected: hidden_dim,
});
}
if out.len() != vocab_size {
return Err(LmHeadError::OutputLen {
got: out.len(),
expected: vocab_size,
});
}
let num_groups = hidden_dim / GROUP_SIZE;
let packed_cols = hidden_dim / 8;
let packed_per_group = GROUP_SIZE / 8;
let w_bytes = tensor_or_missing(wf, WEIGHT_NAME)?;
let s_bytes = tensor_or_missing(wf, SCALES_NAME)?;
let b_bytes = tensor_or_missing(wf, BIASES_NAME)?;
expect_shape(wf, WEIGHT_NAME, &[vocab_size, packed_cols])?;
expect_shape(wf, SCALES_NAME, &[vocab_size, num_groups])?;
expect_shape(wf, BIASES_NAME, &[vocab_size, num_groups])?;
for row in 0..vocab_size {
let w_row_off = row * packed_cols * 4;
let s_row_off = row * num_groups * 2;
let b_row_off = row * num_groups * 2;
let w_row = &w_bytes[w_row_off..w_row_off + packed_cols * 4];
let s_row = &s_bytes[s_row_off..s_row_off + num_groups * 2];
let b_row = &b_bytes[b_row_off..b_row_off + num_groups * 2];
let mut acc: f32 = 0.0;
for g in 0..num_groups {
let scale = bf16_to_f32(read_u16_le(s_row, g));
let bias = bf16_to_f32(read_u16_le(b_row, g));
let base_x = g * GROUP_SIZE;
for p in 0..packed_per_group {
let packed = read_u32_le(w_row, g * packed_per_group + p);
let x_base = base_x + p * 8;
for n in 0..8 {
let val = (packed >> (n * 4)) & 0xF;
let t = (val as f32).mul_add(scale, bias);
acc = t.mul_add(x[x_base + n], acc);
}
}
}
out[row] = acc;
}
Ok(())
}
fn tensor_or_missing<'a>(
wf: &'a WeightFile,
name: &'static str,
) -> Result<&'a [u8], LmHeadError> {
wf.tensor_bytes(name)
.ok_or(LmHeadError::MissingTensor { name })
}
fn expect_shape(
wf: &WeightFile,
name: &'static str,
expected: &[usize],
) -> Result<(), LmHeadError> {
let info = wf
.tensor_info(name)
.ok_or(LmHeadError::MissingTensor { name })?;
if info.shape != expected {
return Err(LmHeadError::ShapeMismatch {
name,
shape: info.shape.clone(),
expected: expected.to_vec(),
});
}
Ok(())
}
fn read_u32_le(buf: &[u8], idx: usize) -> u32 {
let off = idx * 4;
u32::from_le_bytes([buf[off], buf[off + 1], buf[off + 2], buf[off + 3]])
}
fn read_u16_le(buf: &[u8], idx: usize) -> u16 {
let off = idx * 2;
u16::from_le_bytes([buf[off], buf[off + 1]])
}