use half::f16;
use oxibonsai_core::quant_ternary::BlockTQ2_0_g128;
use super::error::PackError;
const GROUP_SIZE: usize = 128;
const CODES_PER_U32: usize = 16;
const U32_WORDS_PER_GROUP: usize = GROUP_SIZE / CODES_PER_U32;
#[inline]
pub fn bf16_to_f32(bits: u16) -> f32 {
f32::from_bits((bits as u32) << 16)
}
pub fn pack_quantized_module(
module: &str,
weight: &[u32],
scales: &[u16],
biases: &[u16],
out_features: usize,
in_features: usize,
) -> Result<Vec<BlockTQ2_0_g128>, PackError> {
if in_features == 0 || in_features % GROUP_SIZE != 0 {
return Err(PackError::InFeaturesNotAligned {
module: module.to_string(),
in_features,
});
}
let weight_cols = in_features / CODES_PER_U32; let group_cols = in_features / GROUP_SIZE;
let expected_weight = out_features * weight_cols;
if weight.len() != expected_weight {
return Err(PackError::BufferLengthMismatch {
module: module.to_string(),
which: "weight",
got: weight.len(),
expected: expected_weight,
});
}
let expected_groups = out_features * group_cols;
if scales.len() != expected_groups {
return Err(PackError::BufferLengthMismatch {
module: module.to_string(),
which: "scales",
got: scales.len(),
expected: expected_groups,
});
}
if biases.len() != expected_groups {
return Err(PackError::BufferLengthMismatch {
module: module.to_string(),
which: "biases",
got: biases.len(),
expected: expected_groups,
});
}
let mut blocks: Vec<BlockTQ2_0_g128> = Vec::with_capacity(expected_groups);
for row in 0..out_features {
let weight_row_base = row * weight_cols;
let group_row_base = row * group_cols;
for group in 0..group_cols {
let word_base = weight_row_base + group * U32_WORDS_PER_GROUP;
let words = &weight[word_base..word_base + U32_WORDS_PER_GROUP];
for &word in words {
for lane in 0..CODES_PER_U32 {
let code = ((word >> (lane * 2)) & 0x3) as u8;
if code > 2 {
return Err(PackError::CodeOutOfRange {
module: module.to_string(),
row,
group,
value: code,
});
}
}
}
let scale_bits = scales[group_row_base + group];
let bias_bits = biases[group_row_base + group];
let scale = bf16_to_f32(scale_bits);
let bias = bf16_to_f32(bias_bits);
if bias != -scale {
return Err(PackError::AsymmetricBias {
module: module.to_string(),
row,
group,
bias,
scale,
});
}
let mut qs = [0u8; 32];
for (w_idx, &word) in words.iter().enumerate() {
let le = word.to_le_bytes();
let dst = w_idx * 4;
qs[dst..dst + 4].copy_from_slice(&le);
}
blocks.push(BlockTQ2_0_g128 {
qs,
d: f16::from_f32(scale),
});
}
}
Ok(blocks)
}
#[cfg(test)]
pub fn f32_to_bf16(value: f32) -> u16 {
half::bf16::from_f32(value).to_bits()
}
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::quant_ternary::BlockTQ2_0_g128;
fn words_from_codes(codes: &[u8; 128]) -> [u32; 8] {
let mut words = [0u32; 8];
for (j, &q) in codes.iter().enumerate() {
let word = j / 16;
let lane = j % 16;
words[word] |= (q as u32) << (lane * 2);
}
words
}
fn single_group_uniform(q: u8, scale: f32) -> (Vec<u32>, Vec<u16>, Vec<u16>) {
let codes = [q; 128];
let words = words_from_codes(&codes);
let weight = words.to_vec();
let scales = vec![f32_to_bf16(scale)];
let biases = vec![f32_to_bf16(-scale)];
(weight, scales, biases)
}
#[test]
fn all_q0_gives_negative_scale() {
let scale = 0.125_f32; let (w, s, b) = single_group_uniform(0, scale);
let blocks =
pack_quantized_module("test.q0", &w, &s, &b, 1, 128).expect("pack should succeed");
assert_eq!(blocks.len(), 1);
let mut out = vec![0.0f32; 128];
BlockTQ2_0_g128::dequant(&blocks, &mut out).expect("dequant");
for &v in &out {
assert_eq!(v, -scale, "q=0 → -scale");
}
}
#[test]
fn all_q2_gives_positive_scale() {
let scale = 0.0625_f32;
let (w, s, b) = single_group_uniform(2, scale);
let blocks =
pack_quantized_module("test.q2", &w, &s, &b, 1, 128).expect("pack should succeed");
let mut out = vec![0.0f32; 128];
BlockTQ2_0_g128::dequant(&blocks, &mut out).expect("dequant");
for &v in &out {
assert_eq!(v, scale, "q=2 → +scale");
}
}
#[test]
fn all_q1_gives_zero() {
let scale = 0.5_f32;
let (w, s, b) = single_group_uniform(1, scale);
let blocks =
pack_quantized_module("test.q1", &w, &s, &b, 1, 128).expect("pack should succeed");
let mut out = vec![0.0f32; 128];
BlockTQ2_0_g128::dequant(&blocks, &mut out).expect("dequant");
for &v in &out {
assert_eq!(v, 0.0, "q=1 → 0");
}
}
#[test]
fn mixed_pattern_within_group() {
let scale = 0.25_f32;
let mut codes = [0u8; 128];
for (j, c) in codes.iter_mut().enumerate() {
*c = (j % 3) as u8;
}
let words = words_from_codes(&codes);
let weight = words.to_vec();
let scales = vec![f32_to_bf16(scale)];
let biases = vec![f32_to_bf16(-scale)];
let blocks = pack_quantized_module("test.mixed", &weight, &scales, &biases, 1, 128)
.expect("pack should succeed");
let mut out = vec![0.0f32; 128];
BlockTQ2_0_g128::dequant(&blocks, &mut out).expect("dequant");
for (j, &v) in out.iter().enumerate() {
let expected = match j % 3 {
0 => -scale,
1 => 0.0,
_ => scale,
};
assert_eq!(v, expected, "index {j}: q={} mismatch", j % 3);
}
}
#[test]
fn multi_row_multi_group_roundtrip_exact() {
let out = 4usize;
let in_features = 256usize;
let group_cols = in_features / 128; let weight_cols = in_features / 16;
let mut weight = vec![0u32; out * weight_cols];
let mut scales = vec![0u16; out * group_cols];
let mut biases = vec![0u16; out * group_cols];
let mut expected = vec![0.0f32; out * in_features];
for row in 0..out {
for g in 0..group_cols {
let scale = 1.0_f32 / ((1 << (row + g + 1)) as f32);
scales[row * group_cols + g] = f32_to_bf16(scale);
biases[row * group_cols + g] = f32_to_bf16(-scale);
let mut codes = [0u8; 128];
for (j, c) in codes.iter_mut().enumerate() {
let q = ((row + g + j) % 3) as u8;
*c = q;
let col = g * 128 + j;
expected[row * in_features + col] = scale * (q as f32 - 1.0);
}
let words = words_from_codes(&codes);
let word_base = row * weight_cols + g * 8;
weight[word_base..word_base + 8].copy_from_slice(&words);
}
}
let blocks =
pack_quantized_module("test.multi", &weight, &scales, &biases, out, in_features)
.expect("pack should succeed");
assert_eq!(blocks.len(), out * group_cols);
let mut deq = vec![0.0f32; out * in_features];
for row in 0..out {
for g in 0..group_cols {
let blk = &blocks[row * group_cols + g..row * group_cols + g + 1];
let mut tmp = vec![0.0f32; 128];
BlockTQ2_0_g128::dequant(blk, &mut tmp).expect("dequant");
let base = row * in_features + g * 128;
deq[base..base + 128].copy_from_slice(&tmp);
}
}
for (idx, (&a, &e)) in deq.iter().zip(expected.iter()).enumerate() {
assert_eq!(a, e, "element {idx}: dequant {a} != expected {e}");
}
}
#[test]
fn errors_on_code_value_3() {
let scale = 0.25_f32;
let mut codes = [1u8; 128];
codes[5] = 3;
let words = words_from_codes(&codes);
let weight = words.to_vec();
let scales = vec![f32_to_bf16(scale)];
let biases = vec![f32_to_bf16(-scale)];
let err = pack_quantized_module("test.bad_code", &weight, &scales, &biases, 1, 128)
.expect_err("q=3 must error");
match err {
PackError::CodeOutOfRange { value, .. } => assert_eq!(value, 3),
other => panic!("expected CodeOutOfRange, got {other:?}"),
}
}
#[test]
fn errors_on_asymmetric_bias() {
let scale = 0.25_f32;
let (w, s, _b) = single_group_uniform(0, scale);
let biases = vec![f32_to_bf16(0.5_f32)];
let err = pack_quantized_module("test.bad_bias", &w, &s, &biases, 1, 128)
.expect_err("asymmetric bias must error");
match err {
PackError::AsymmetricBias {
bias, scale: s_out, ..
} => {
assert_eq!(bias, 0.5);
assert_eq!(s_out, scale);
}
other => panic!("expected AsymmetricBias, got {other:?}"),
}
}
#[test]
fn scale_zero_bias_negzero_is_accepted() {
let codes = [1u8; 128];
let words = words_from_codes(&codes);
let weight = words.to_vec();
let scales = vec![0x0000u16]; let biases = vec![0x8000u16]; let blocks = pack_quantized_module("test.zero", &weight, &scales, &biases, 1, 128)
.expect("scale=0 / bias=-0 must be accepted");
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].d, f16::from_f32(0.0));
}
#[test]
fn errors_on_wrong_weight_columns() {
let weight = vec![0u32; 8];
let scales = vec![f32_to_bf16(0.1); 2];
let biases = vec![f32_to_bf16(-0.1); 2];
let err = pack_quantized_module("test.shape", &weight, &scales, &biases, 1, 256)
.expect_err("wrong weight length must error");
assert!(matches!(
err,
PackError::BufferLengthMismatch {
which: "weight",
..
}
));
}
#[test]
fn errors_on_unaligned_in_features() {
let weight = vec![0u32; 8];
let scales = vec![0u16; 1];
let biases = vec![0u16; 1];
let err = pack_quantized_module("test.align", &weight, &scales, &biases, 1, 100)
.expect_err("in not multiple of 128 must error");
assert!(matches!(err, PackError::InFeaturesNotAligned { .. }));
}
#[test]
fn bf16_to_f32_roundtrip() {
for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.0625, 2000.0] {
let bits = f32_to_bf16(v);
assert_eq!(bf16_to_f32(bits), v, "value {v}");
}
}
}