use rayon::prelude::*;
use crate::riir::io::embedding::bf16_to_f32;
use crate::riir::variants::GROUP_SIZE;
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum CpuMatvecError {
#[error("in_dim {in_dim} is not a multiple of GROUP_SIZE={group_size}")]
InDimNotMultiple { in_dim: usize, group_size: usize },
#[error(
"slice length mismatch on '{field}': got {got} elements, expected {expected}"
)]
SliceLen {
field: &'static str,
got: usize,
expected: usize,
},
#[error("missing tensor '{name}'")]
MissingTensor { name: String },
#[error(
"tensor '{name}' has unexpected shape {shape:?} (expected {expected:?})"
)]
ShapeMismatch {
name: String,
shape: Vec<usize>,
expected: Vec<usize>,
},
#[error("tensor '{name}' not aligned to {align}-byte boundary")]
Misaligned { name: String, align: usize },
}
pub fn dequant_matvec_4bit_cpu(
packed: &[u32],
scales: &[u16],
biases: &[u16],
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
if in_dim % GROUP_SIZE != 0 {
return Err(CpuMatvecError::InDimNotMultiple {
in_dim,
group_size: GROUP_SIZE,
});
}
let in_packed = in_dim / 8;
let in_groups = in_dim / GROUP_SIZE;
let packed_per_group = GROUP_SIZE / 8;
let expect_packed = out_dim * in_packed;
let expect_scales = out_dim * in_groups;
if packed.len() != expect_packed {
return Err(CpuMatvecError::SliceLen {
field: "packed",
got: packed.len(),
expected: expect_packed,
});
}
if scales.len() != expect_scales {
return Err(CpuMatvecError::SliceLen {
field: "scales",
got: scales.len(),
expected: expect_scales,
});
}
if biases.len() != expect_scales {
return Err(CpuMatvecError::SliceLen {
field: "biases",
got: biases.len(),
expected: expect_scales,
});
}
if x.len() != in_dim {
return Err(CpuMatvecError::SliceLen {
field: "x",
got: x.len(),
expected: in_dim,
});
}
if out.len() != out_dim {
return Err(CpuMatvecError::SliceLen {
field: "out",
got: out.len(),
expected: out_dim,
});
}
out.par_iter_mut().enumerate().for_each(|(r, out_r)| {
let packed_row = &packed[r * in_packed..(r + 1) * in_packed];
let scale_row = &scales[r * in_groups..(r + 1) * in_groups];
let bias_row = &biases[r * in_groups..(r + 1) * in_groups];
let mut acc = 0.0f32;
for g in 0..in_groups {
let scale = bf16_to_f32(scale_row[g]);
let bias = bf16_to_f32(bias_row[g]);
let group_packed = &packed_row
[g * packed_per_group..(g + 1) * packed_per_group];
let x_group = &x[g * GROUP_SIZE..(g + 1) * GROUP_SIZE];
for p in 0..packed_per_group {
let word = group_packed[p];
let x_chunk = &x_group[p * 8..p * 8 + 8];
for n in 0..8 {
let nibble = ((word >> (n * 4)) & 0xF) as f32;
let w = nibble.mul_add(scale, bias);
acc = w.mul_add(x_chunk[n], acc);
}
}
}
*out_r = acc;
});
Ok(())
}
pub fn dequant_matvec_8bit_v3_cpu(
packed: &[u32],
scales: &[u16],
biases: &[u16],
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
if in_dim % GROUP_SIZE != 0 {
return Err(CpuMatvecError::InDimNotMultiple {
in_dim,
group_size: GROUP_SIZE,
});
}
let in_packed = in_dim / 4;
let in_groups = in_dim / GROUP_SIZE;
let packed_per_group = GROUP_SIZE / 4;
let expect_packed = out_dim * in_packed;
let expect_scales = out_dim * in_groups;
if packed.len() != expect_packed {
return Err(CpuMatvecError::SliceLen {
field: "packed",
got: packed.len(),
expected: expect_packed,
});
}
if scales.len() != expect_scales {
return Err(CpuMatvecError::SliceLen {
field: "scales",
got: scales.len(),
expected: expect_scales,
});
}
if biases.len() != expect_scales {
return Err(CpuMatvecError::SliceLen {
field: "biases",
got: biases.len(),
expected: expect_scales,
});
}
if x.len() != in_dim {
return Err(CpuMatvecError::SliceLen {
field: "x",
got: x.len(),
expected: in_dim,
});
}
if out.len() != out_dim {
return Err(CpuMatvecError::SliceLen {
field: "out",
got: out.len(),
expected: out_dim,
});
}
out.par_iter_mut().enumerate().for_each(|(r, out_r)| {
let packed_row = &packed[r * in_packed..(r + 1) * in_packed];
let scale_row = &scales[r * in_groups..(r + 1) * in_groups];
let bias_row = &biases[r * in_groups..(r + 1) * in_groups];
let mut acc = 0.0f32;
for g in 0..in_groups {
let scale = bf16_to_f32(scale_row[g]);
let bias = bf16_to_f32(bias_row[g]);
let group_packed = &packed_row
[g * packed_per_group..(g + 1) * packed_per_group];
let x_group = &x[g * GROUP_SIZE..(g + 1) * GROUP_SIZE];
for p in 0..packed_per_group {
let word = group_packed[p];
let x_chunk = &x_group[p * 4..p * 4 + 4];
for n in 0..4 {
let byte = ((word >> (n * 8)) & 0xFF) as f32;
let w = byte.mul_add(scale, bias);
acc = w.mul_add(x_chunk[n], acc);
}
}
}
*out_r = acc;
});
Ok(())
}
pub fn dequant_matvec_4bit_bytes_cpu(
weight_bytes: &[u8],
scales_bytes: &[u8],
biases_bytes: &[u8],
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
let packed = bytes_as_u32("packed", weight_bytes)?;
let scales = bytes_as_u16("scales", scales_bytes)?;
let biases = bytes_as_u16("biases", biases_bytes)?;
dequant_matvec_4bit_cpu(packed, scales, biases, in_dim, out_dim, x, out)
}
pub fn bf16_matvec_cpu(
weight: &[u16],
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
let expect_w = in_dim * out_dim;
if weight.len() != expect_w {
return Err(CpuMatvecError::SliceLen {
field: "weight",
got: weight.len(),
expected: expect_w,
});
}
if x.len() != in_dim {
return Err(CpuMatvecError::SliceLen {
field: "x",
got: x.len(),
expected: in_dim,
});
}
if out.len() != out_dim {
return Err(CpuMatvecError::SliceLen {
field: "out",
got: out.len(),
expected: out_dim,
});
}
out.par_iter_mut().enumerate().for_each(|(r, out_r)| {
let row = &weight[r * in_dim..(r + 1) * in_dim];
let mut acc = 0.0f32;
for (c, &w_bits) in row.iter().enumerate() {
let w = bf16_to_f32(w_bits);
acc = w.mul_add(x[c], acc);
}
*out_r = acc;
});
Ok(())
}
pub fn project_bf16_cpu(
wf: &WeightFile,
name: &str,
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
let bytes = wf
.tensor_bytes(name)
.ok_or_else(|| CpuMatvecError::MissingTensor {
name: name.to_string(),
})?;
expect_shape(wf, name, &[out_dim, in_dim])?;
let weight = bytes_as_u16(name, bytes)?;
bf16_matvec_cpu(weight, in_dim, out_dim, x, out)
}
pub fn project_4bit_cpu(
wf: &WeightFile,
name: &str,
in_dim: usize,
out_dim: usize,
x: &[f32],
out: &mut [f32],
) -> Result<(), CpuMatvecError> {
let w_name = format!("{name}.weight");
let s_name = format!("{name}.scales");
let b_name = format!("{name}.biases");
let w_bytes = wf
.tensor_bytes(&w_name)
.ok_or_else(|| CpuMatvecError::MissingTensor {
name: w_name.clone(),
})?;
let s_bytes = wf
.tensor_bytes(&s_name)
.ok_or_else(|| CpuMatvecError::MissingTensor {
name: s_name.clone(),
})?;
let b_bytes = wf
.tensor_bytes(&b_name)
.ok_or_else(|| CpuMatvecError::MissingTensor {
name: b_name.clone(),
})?;
let in_packed = in_dim / 8;
let in_groups = in_dim / GROUP_SIZE;
let expected_w = vec![out_dim, in_packed];
let expected_s = vec![out_dim, in_groups];
expect_shape(wf, &w_name, &expected_w)?;
expect_shape(wf, &s_name, &expected_s)?;
expect_shape(wf, &b_name, &expected_s)?;
let packed = bytes_as_u32(&w_name, w_bytes)?;
let scales = bytes_as_u16(&s_name, s_bytes)?;
let biases = bytes_as_u16(&b_name, b_bytes)?;
dequant_matvec_4bit_cpu(packed, scales, biases, in_dim, out_dim, x, out)
}
fn expect_shape(
wf: &WeightFile,
name: &str,
expected: &[usize],
) -> Result<(), CpuMatvecError> {
let info = wf
.tensor_info(name)
.ok_or_else(|| CpuMatvecError::MissingTensor {
name: name.to_string(),
})?;
if info.shape != expected {
return Err(CpuMatvecError::ShapeMismatch {
name: name.to_string(),
shape: info.shape.clone(),
expected: expected.to_vec(),
});
}
Ok(())
}
fn bytes_as_u32<'a>(
name: &str,
bytes: &'a [u8],
) -> Result<&'a [u32], CpuMatvecError> {
static_assertions::assert_eq_size!(u32, [u8; 4]);
static_assertions::const_assert_eq!(std::mem::align_of::<u32>(), 4);
let (head, body, tail) = unsafe { bytes.align_to::<u32>() };
if !head.is_empty() || !tail.is_empty() {
return Err(CpuMatvecError::Misaligned {
name: name.to_string(),
align: 4,
});
}
Ok(body)
}
fn bytes_as_u16<'a>(
name: &str,
bytes: &'a [u8],
) -> Result<&'a [u16], CpuMatvecError> {
static_assertions::assert_eq_size!(u16, [u8; 2]);
static_assertions::const_assert_eq!(std::mem::align_of::<u16>(), 2);
let (head, body, tail) = unsafe { bytes.align_to::<u16>() };
if !head.is_empty() || !tail.is_empty() {
return Err(CpuMatvecError::Misaligned {
name: name.to_string(),
align: 2,
});
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dequant_matvec_all_ones() {
let in_dim = 64;
let out_dim = 1;
let packed = vec![0x1111_1111u32; 8];
let scales = vec![0x3F80u16; 1];
let biases = vec![0x0000u16; 1];
let x = vec![1.0f32; in_dim];
let mut out = vec![0.0f32; out_dim];
dequant_matvec_4bit_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap();
assert_eq!(out[0], 64.0);
}
#[test]
fn dequant_matvec_8bit_all_ones() {
let in_dim = 64;
let out_dim = 1;
let packed = vec![0x0101_0101u32; in_dim / 4];
let scales = vec![0x3F80u16; 1];
let biases = vec![0x0000u16; 1];
let x = vec![1.0f32; in_dim];
let mut out = vec![0.0f32; out_dim];
dequant_matvec_8bit_v3_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap();
assert_eq!(out[0], 64.0);
}
#[test]
fn zero_input_yields_zero_output() {
let in_dim = 64;
let out_dim = 2;
let packed = vec![0xFFFF_FFFFu32; 8 * out_dim];
let scales = vec![0x3F80u16; out_dim];
let biases = vec![0x3F80u16; out_dim];
let x = vec![0.0f32; in_dim];
let mut out = vec![999.0f32; out_dim];
dequant_matvec_4bit_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap();
assert_eq!(out, vec![0.0; out_dim]);
}
#[test]
fn bias_only_path() {
let in_dim = 64;
let out_dim = 1;
let packed = vec![0u32; 8];
let scales = vec![0x40A0u16; 1];
let biases = vec![0x4000u16; 1];
let x = vec![1.0f32; in_dim];
let mut out = vec![0.0f32; out_dim];
dequant_matvec_4bit_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap();
assert_eq!(out[0], 128.0);
}
#[test]
fn slice_length_mismatch_errors() {
let in_dim = 64;
let out_dim = 1;
let packed = vec![0u32; 7]; let scales = vec![0x3F80u16; 1];
let biases = vec![0x0000u16; 1];
let x = vec![1.0f32; in_dim];
let mut out = vec![0.0f32; out_dim];
let err = dequant_matvec_4bit_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap_err();
match err {
CpuMatvecError::SliceLen { field, got, expected } => {
assert_eq!(field, "packed");
assert_eq!(got, 7);
assert_eq!(expected, 8);
}
_ => panic!("wrong error variant: {err:?}"),
}
}
#[test]
fn bf16_matvec_identity() {
let in_dim = 4;
let out_dim = 4;
let mut weight = vec![0x0000u16; in_dim * out_dim];
for r in 0..out_dim {
weight[r * in_dim + r] = 0x3F80;
}
let x = vec![1.0, 2.0, 3.0, 4.0];
let mut out = vec![0.0f32; out_dim];
bf16_matvec_cpu(&weight, in_dim, out_dim, &x, &mut out).unwrap();
assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0]);
}
#[cfg(feature = "model-cogito-v2-671b")]
#[test]
#[ignore = "needs Cogito-V2 weights mmap'd from /Volumes/Temp Backup"]
fn q_a_proj_smoke_against_real_weights() {
use std::path::Path;
let bin = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.bin",
);
let manifest = Path::new(
"/Volumes/Temp Backup/models/blallama/cogito-v2-671b/artifacts/model_weights.json",
);
let wf = WeightFile::open(bin, manifest).expect("open weights");
let in_dim = 7168;
let out_dim = 1536;
let mut x = vec![0.0f32; in_dim];
x[3] = 1.0;
let mut out = vec![0.0f32; out_dim];
project_4bit_cpu(
&wf,
"model.layers.0.self_attn.q_a_proj",
in_dim,
out_dim,
&x,
&mut out,
)
.unwrap();
assert!(
out.iter().all(|v| v.is_finite()),
"non-finite values in q_a_proj output: \
first nonfinite at index {:?}",
out.iter().position(|v| !v.is_finite()),
);
let max_abs =
out.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
assert!(
max_abs > 1e-4 && max_abs < 1e3,
"q_a_proj column-3 magnitude {max_abs} outside sane range — \
possible wrong layout or wrong tensor",
);
}
#[test]
fn in_dim_alignment_check() {
let in_dim = 65; let out_dim = 1;
let packed = vec![0u32; 8];
let scales = vec![0x3F80u16; 1];
let biases = vec![0x0000u16; 1];
let x = vec![1.0f32; in_dim];
let mut out = vec![0.0f32; out_dim];
let err = dequant_matvec_4bit_cpu(
&packed, &scales, &biases, in_dim, out_dim, &x, &mut out,
)
.unwrap_err();
assert!(matches!(err, CpuMatvecError::InDimNotMultiple { .. }));
}
}