use super::{i4_packed_words, I4_LANES_PER_WORD};
pub fn pack_i4x8_cpu(values: &[i32]) -> Vec<u32> {
let mut out = Vec::new();
try_pack_i4x8_cpu_into(values, &mut out).unwrap_or_else(|error| panic!("{error}"));
out
}
pub fn pack_i4x8_cpu_into(values: &[i32], out: &mut Vec<u32>) {
try_pack_i4x8_cpu_into(values, out).unwrap_or_else(|error| panic!("{error}"));
}
pub fn try_pack_i4x8_cpu_into(values: &[i32], out: &mut Vec<u32>) -> Result<(), String> {
let lane_count = u32::try_from(values.len()).map_err(|_| {
format!(
"pack_i4x8 CPU oracle received {} lanes, exceeding u32 lane count. Fix: shard quantized activations before parity evaluation.",
values.len()
)
})?;
let word_count = i4_packed_words(lane_count) as usize;
if word_count > out.capacity() {
crate::graph::scratch::reserve_graph_items(
out,
word_count - out.len(),
"quantized INT4 CPU oracle",
"pack_i4x8 output words",
)?;
}
out.clear();
out.resize(word_count, 0);
for (index, &value) in values.iter().enumerate() {
let clamped = value.clamp(-8, 7);
let nibble = (clamped as i8 as u8) & 0x0F;
let word = index / I4_LANES_PER_WORD as usize;
let shift = (index % I4_LANES_PER_WORD as usize) * 4;
out[word] |= u32::from(nibble) << shift;
}
Ok(())
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn unpack_i4x8_cpu(packed: &[u32], lane_count: u32) -> Vec<i32> {
let mut out = Vec::new();
try_unpack_i4x8_cpu_into(packed, lane_count, &mut out)
.unwrap_or_else(|error| panic!("{error}"));
out
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn unpack_i4x8_cpu_into(packed: &[u32], lane_count: u32, out: &mut Vec<i32>) {
try_unpack_i4x8_cpu_into(packed, lane_count, out).unwrap_or_else(|error| panic!("{error}"));
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn try_unpack_i4x8_cpu_into(
packed: &[u32],
lane_count: u32,
out: &mut Vec<i32>,
) -> Result<(), String> {
let lanes = lane_count as usize;
if lanes > out.capacity() {
crate::graph::scratch::reserve_graph_items(
out,
lanes - out.len(),
"quantized INT4 CPU oracle",
"unpack_i4x8 output lanes",
)?;
}
out.clear();
for lane in 0..lanes {
out.push(extract_i4_lane(packed, lane));
}
Ok(())
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_dot_i32_cpu(lhs_packed: &[u32], rhs_packed: &[u32], lane_count: u32) -> i32 {
let mut acc = 0i32;
for lane in 0..lane_count as usize {
let lhs = extract_i4_lane(lhs_packed, lane);
let rhs = extract_i4_lane(rhs_packed, lane);
acc = acc.wrapping_add(lhs.wrapping_mul(rhs));
}
acc
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_dot_f32_scaled_cpu(
lhs_packed: &[u32],
rhs_packed: &[u32],
lhs_scale: f32,
rhs_scale: f32,
lane_count: u32,
) -> f32 {
let mut acc = 0.0_f32;
for lane in 0..lane_count as usize {
let lhs = extract_i4_lane(lhs_packed, lane) as f32;
let rhs = extract_i4_lane(rhs_packed, lane) as f32;
acc += lhs * rhs;
}
acc * lhs_scale * rhs_scale
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_matvec_f32_scaled_cpu(
weights_packed: &[u32],
x: &[f32],
row_scales: &[f32],
rows: u32,
cols: u32,
) -> Vec<f32> {
let words_per_row = i4_packed_words(cols) as usize;
let mut out = vec![0.0_f32; rows as usize];
for row in 0..rows as usize {
let row_base = row * words_per_row;
let mut acc = 0.0_f32;
for col in 0..cols as usize {
acc += extract_i4_lane(&weights_packed[row_base..], col) as f32
* x.get(col).copied().unwrap_or(0.0);
}
out[row] = acc * row_scales.get(row).copied().unwrap_or(0.0);
}
out
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_batched_matvec_f32_scaled_cpu(
weights_packed: &[u32],
x_batches: &[f32],
row_scales: &[f32],
batch: u32,
rows: u32,
cols: u32,
) -> Vec<f32> {
let mut out = vec![0.0_f32; (batch * rows) as usize];
for batch_index in 0..batch as usize {
let x_start = batch_index * cols as usize;
let x_end = x_start + cols as usize;
let row_out = i4x8_matvec_f32_scaled_cpu(
weights_packed,
x_batches.get(x_start..x_end).unwrap_or(&[]),
row_scales,
rows,
cols,
);
let out_start = batch_index * rows as usize;
out[out_start..out_start + rows as usize].copy_from_slice(&row_out);
}
out
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_batched_matmul_f32_scaled_cpu(
weights_packed: &[u32],
activation_batches_packed: &[u32],
row_scales: &[f32],
batch_scales: &[f32],
batch: u32,
rows: u32,
cols: u32,
) -> Vec<f32> {
let words_per_row = i4_packed_words(cols) as usize;
let mut out = vec![0.0_f32; (batch * rows) as usize];
for batch_index in 0..batch as usize {
let activation_base = batch_index * words_per_row;
for row in 0..rows as usize {
let weight_base = row * words_per_row;
let mut acc = 0.0_f32;
for col in 0..cols as usize {
acc += extract_i4_lane(&weights_packed[weight_base..], col) as f32
* extract_i4_lane(&activation_batches_packed[activation_base..], col) as f32;
}
out[batch_index * rows as usize + row] = acc
* row_scales.get(row).copied().unwrap_or(0.0)
* batch_scales.get(batch_index).copied().unwrap_or(0.0);
}
}
out
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn i4x8_batched_matmul_top1_f32_scaled_cpu(
weights_packed: &[u32],
activation_batches_packed: &[u32],
row_scales: &[f32],
batch_scales: &[f32],
batch: u32,
rows: u32,
cols: u32,
) -> (Vec<f32>, Vec<u32>) {
let logits = i4x8_batched_matmul_f32_scaled_cpu(
weights_packed,
activation_batches_packed,
row_scales,
batch_scales,
batch,
rows,
cols,
);
let mut scores = vec![f32::MIN; batch as usize];
let mut indices = vec![0_u32; batch as usize];
for batch_index in 0..batch as usize {
let row_start = batch_index * rows as usize;
for row in 0..rows as usize {
let score = logits[row_start + row];
if score > scores[batch_index] {
scores[batch_index] = score;
indices[batch_index] = row as u32;
}
}
}
(scores, indices)
}
#[cfg(any(test, feature = "cpu-parity"))]
fn extract_i4_lane(packed: &[u32], lane: usize) -> i32 {
let word = packed
.get(lane / I4_LANES_PER_WORD as usize)
.copied()
.unwrap_or(0);
let shift = (lane % I4_LANES_PER_WORD as usize) * 4;
let nibble = ((word >> shift) & 0x0F) as i32;
if nibble & 0x8 == 0 {
nibble
} else {
nibble - 16
}
}