use super::error::DequantError;
pub const EXPECTED_BLOCK_SIZE: usize = 128;
pub const EXPECTED_BITS: u32 = 2;
#[inline]
pub fn unpack_2bit_le(byte: u8) -> [u8; 4] {
[
byte & 0b11,
(byte >> 2) & 0b11,
(byte >> 4) & 0b11,
(byte >> 6) & 0b11,
]
}
pub fn dequantize_matmul_nbits(
packed: &[u8],
scales: &[f32],
zero_points: Option<&[u8]>,
n: usize,
k: usize,
bits: u32,
block_size: usize,
) -> Result<Vec<f32>, DequantError> {
if bits != EXPECTED_BITS {
return Err(DequantError::Unsupported(format!(
"unsupported bits={bits} (only 2 is implemented)"
)));
}
if block_size != EXPECTED_BLOCK_SIZE {
return Err(DequantError::Unsupported(format!(
"unsupported block_size={block_size} (only 128 is implemented)"
)));
}
if n == 0 || k == 0 {
return Ok(Vec::new());
}
let n_blocks = k.div_ceil(block_size);
let bytes_per_row = n_blocks * (block_size / 4); let expected_packed = n * bytes_per_row;
if packed.len() != expected_packed {
return Err(DequantError::LengthMismatch {
what: "packed B",
expected: expected_packed,
got: packed.len(),
});
}
let expected_scales = n * n_blocks;
if scales.len() != expected_scales {
return Err(DequantError::LengthMismatch {
what: "scales",
expected: expected_scales,
got: scales.len(),
});
}
if let Some(zp) = zero_points {
let expected_zp_bytes = expected_scales.div_ceil(4);
if zp.len() != expected_zp_bytes {
return Err(DequantError::LengthMismatch {
what: "zero_points",
expected: expected_zp_bytes,
got: zp.len(),
});
}
}
let k_padded = n_blocks * block_size;
let mut out = vec![0.0_f32; n * k];
for row in 0..n {
let row_packed_base = row * bytes_per_row;
let row_scales_base = row * n_blocks;
for block_idx in 0..n_blocks {
let scale = scales[row_scales_base + block_idx];
let zp_value: u8 = if let Some(zp) = zero_points {
let global_zp_idx = row_scales_base + block_idx;
let byte = zp[global_zp_idx / 4];
(byte >> (2 * (global_zp_idx % 4))) & 0b11
} else {
2
};
let zp_f32 = zp_value as f32;
let block_packed_base = row_packed_base + block_idx * (block_size / 4);
let block_k_base = block_idx * block_size;
for byte_idx in 0..(block_size / 4) {
let byte = packed[block_packed_base + byte_idx];
let codes = unpack_2bit_le(byte);
let k_base = block_k_base + byte_idx * 4;
for (lane, code) in codes.iter().enumerate() {
let k_pos = k_base + lane;
if k_pos >= k_padded || k_pos >= k {
continue;
}
let value = ((*code as f32) - zp_f32) * scale;
out[row * k + k_pos] = value;
}
}
}
}
Ok(out)
}
pub fn repack_4bit_zp_to_2bit(zp_4bit: &[u8], total_codes: usize) -> Result<Vec<u8>, DequantError> {
let required_input_bytes = total_codes.div_ceil(2);
if zp_4bit.len() < required_input_bytes {
return Err(DequantError::LengthMismatch {
what: "4-bit ZP input",
expected: required_input_bytes,
got: zp_4bit.len(),
});
}
let out_len = total_codes.div_ceil(4);
let mut out = vec![0u8; out_len];
for code_idx in 0..total_codes {
let byte_idx = code_idx / 2;
let nibble = if code_idx % 2 == 0 {
zp_4bit[byte_idx] & 0x0F
} else {
zp_4bit[byte_idx] >> 4
};
if nibble > 3 {
return Err(DequantError::NibbleOutOfRange {
index: code_idx,
value: nibble,
});
}
let code = nibble & 0b11;
let out_byte_idx = code_idx / 4;
let out_lane = code_idx % 4;
out[out_byte_idx] |= code << (2 * out_lane);
}
Ok(out)
}
#[cfg(test)]
#[allow(clippy::identity_op)]
mod tests {
use super::*;
#[test]
fn unpack_2bit_le_works() {
assert_eq!(unpack_2bit_le(0b11_10_01_00), [0, 1, 2, 3]);
assert_eq!(unpack_2bit_le(0xFF), [3, 3, 3, 3]);
assert_eq!(unpack_2bit_le(0), [0, 0, 0, 0]);
}
#[test]
fn dequant_single_block_zp1_matches_ternary() {
let block_size = 128;
let n = 1;
let k = 128;
let n_blocks = 1;
let bytes_per_row = n_blocks * (block_size / 4);
let mut packed = vec![0u8; n * bytes_per_row];
for code_idx in 0..(block_size) {
let code = (code_idx % 3) as u8;
let byte_idx = code_idx / 4;
let lane = code_idx % 4;
packed[byte_idx] |= code << (2 * lane);
}
let scales = vec![1.0_f32; n * n_blocks];
let zero_points = vec![0b01u8];
let out =
dequantize_matmul_nbits(&packed, &scales, Some(&zero_points), n, k, 2, block_size)
.expect("dequantize ok");
assert_eq!(out.len(), n * k);
for (i, v) in out.iter().enumerate() {
let expected = match i % 3 {
0 => -1.0_f32,
1 => 0.0,
_ => 1.0,
};
assert!(
(*v - expected).abs() < 1e-6,
"mismatch at {i}: got {v}, want {expected}"
);
}
}
#[test]
fn dequant_default_zp_is_2() {
let block_size = 128;
let n = 1;
let k = 128;
let n_blocks = 1;
let bytes_per_row = n_blocks * (block_size / 4);
let mut packed = vec![0u8; n * bytes_per_row];
for byte in packed.iter_mut() {
*byte = 0xAA;
}
let scales = vec![0.5_f32; n * n_blocks];
let out = dequantize_matmul_nbits(&packed, &scales, None, n, k, 2, block_size)
.expect("dequantize ok");
assert!(out.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn dequant_rejects_unsupported_bits() {
let err = dequantize_matmul_nbits(&[], &[], None, 0, 0, 4, 128).unwrap_err();
assert!(matches!(err, DequantError::Unsupported(_)));
}
#[test]
fn dequant_rejects_unsupported_block_size() {
let err = dequantize_matmul_nbits(&[], &[], None, 0, 0, 2, 64).unwrap_err();
assert!(matches!(err, DequantError::Unsupported(_)));
}
#[test]
fn dequant_rejects_packed_length_mismatch() {
let packed = vec![0u8; 63];
let scales = vec![1.0_f32; 2];
let err = dequantize_matmul_nbits(&packed, &scales, None, 2, 128, 2, 128).unwrap_err();
match err {
DequantError::LengthMismatch {
what,
expected,
got,
} => {
assert_eq!(what, "packed B");
assert_eq!(expected, 64);
assert_eq!(got, 63);
}
_ => panic!("expected LengthMismatch, got {:?}", err),
}
}
#[test]
fn dequant_k_padding_truncates_to_real_k() {
let block_size = 128;
let n = 1;
let k = 120;
let n_blocks = 1;
let bytes_per_row = n_blocks * (block_size / 4);
let packed = vec![0xFFu8; n * bytes_per_row];
let scales = vec![1.0_f32; n * n_blocks];
let out = dequantize_matmul_nbits(&packed, &scales, None, n, k, 2, block_size).expect("ok");
assert_eq!(out.len(), n * k);
assert!(out.iter().all(|&v| (v - 1.0).abs() < 1e-6));
}
#[test]
fn repack_happy_path_eight_codes() {
let input: Vec<u8> = vec![
(1u8 << 4) | 0, (3u8 << 4) | 2, (1u8 << 4) | 0, (3u8 << 4) | 2, ];
let out = repack_4bit_zp_to_2bit(&input, 8).expect("happy path ok");
assert_eq!(out, vec![0xE4, 0xE4]);
}
#[test]
fn repack_trailing_partial_byte() {
let input: Vec<u8> = vec![
(2u8 << 4) | 1, (0u8 << 4) | 3, (2u8 << 4) | 1, ];
let out = repack_4bit_zp_to_2bit(&input, 6).expect("partial tail ok");
assert_eq!(out.len(), 2);
assert_eq!(out[0], 0x39);
assert_eq!(out[1], 0x09);
assert_eq!(out[1] & 0xF0, 0);
}
#[test]
fn repack_rejects_nibble_above_three() {
let input: Vec<u8> = vec![0x04, 0x00];
let err = repack_4bit_zp_to_2bit(&input, 4).expect_err("should reject");
match err {
DequantError::NibbleOutOfRange { index, value } => {
assert_eq!(index, 0);
assert_eq!(value, 0x4);
}
other => panic!("expected NibbleOutOfRange, got {other:?}"),
}
let input_hi: Vec<u8> = vec![0xF0];
let err_hi = repack_4bit_zp_to_2bit(&input_hi, 2).expect_err("should reject");
match err_hi {
DequantError::NibbleOutOfRange { index, value } => {
assert_eq!(index, 1);
assert_eq!(value, 0xF);
}
other => panic!("expected NibbleOutOfRange, got {other:?}"),
}
}
#[test]
fn repack_empty_input() {
let out = repack_4bit_zp_to_2bit(&[], 0).expect("empty ok");
assert!(out.is_empty());
}
#[test]
fn repack_rejects_short_input() {
let input: Vec<u8> = vec![0x00, 0x00];
let err = repack_4bit_zp_to_2bit(&input, 5).expect_err("short input");
match err {
DequantError::LengthMismatch {
what,
expected,
got,
} => {
assert_eq!(what, "4-bit ZP input");
assert_eq!(expected, 3);
assert_eq!(got, 2);
}
other => panic!("expected LengthMismatch, got {other:?}"),
}
}
}