use super::ternary_tensor::unpack_ternary;
pub fn dequantize_bitnet_t158(packed: &[u8], scales: &[f32], num_elements: usize) -> Vec<f32> {
let ternary = unpack_ternary(packed, num_elements);
let block_size = 256; let mut output = Vec::with_capacity(num_elements);
for (block_idx, chunk) in ternary.chunks(block_size).enumerate() {
let scale = scales.get(block_idx).copied().unwrap_or(1.0);
for &ternary_val in chunk {
let fp32_val = (ternary_val as f32) * scale;
output.push(fp32_val);
}
}
output
}
pub fn dequantize_bitnet_block(packed_block: &[u8], scale: f32, output: &mut [f32]) {
assert!(
output.len() >= 256,
"Output buffer must hold at least 256 elements"
);
assert_eq!(
packed_block.len(),
64,
"Packed block must be exactly 64 bytes"
);
let ternary = unpack_ternary(packed_block, 256);
for (i, &ternary_val) in ternary.iter().enumerate() {
output[i] = (ternary_val as f32) * scale;
}
}
pub fn compute_dequant_error(original: &[f32], dequantized: &[f32]) -> (f32, f32, f32) {
assert_eq!(
original.len(),
dequantized.len(),
"Arrays must have same length"
);
if original.is_empty() {
return (0.0, 0.0, 0.0);
}
let mut sum_abs_error = 0.0f32;
let mut sum_sq_error = 0.0f32;
let mut max_error = 0.0f32;
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
let error = (orig - dequant).abs();
sum_abs_error += error;
sum_sq_error += error * error;
max_error = max_error.max(error);
}
let n = original.len() as f32;
let mae = sum_abs_error / n;
let mse = sum_sq_error / n;
(mae, mse, max_error)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bitnet::{absmean_ternary, pack_ternary};
#[test]
fn test_dequantize_bitnet_t158_simple() {
let ternary = vec![-1i8, 0, 1, -1, 1, 0, 0, 1];
let packed = pack_ternary(&ternary);
let scales = vec![0.5f32];
let result = dequantize_bitnet_t158(&packed, &scales, 8);
assert_eq!(result.len(), 8);
assert_eq!(result[0], -0.5); assert_eq!(result[1], 0.0); assert_eq!(result[2], 0.5); assert_eq!(result[3], -0.5); }
#[test]
fn test_dequantize_bitnet_block() {
let ternary = vec![1i8; 256];
let packed = pack_ternary(&ternary);
let scale = 2.0;
let mut output = vec![0.0f32; 256];
dequantize_bitnet_block(&packed, scale, &mut output);
assert!(output.iter().all(|&v| (v - 2.0).abs() < 1e-6));
}
#[test]
fn test_dequantize_multiple_blocks() {
let ternary1 = vec![1i8; 256];
let ternary2 = vec![-1i8; 256];
let mut all_ternary = ternary1.clone();
all_ternary.extend_from_slice(&ternary2);
let packed = pack_ternary(&all_ternary);
let scales = vec![1.0, 2.0];
let result = dequantize_bitnet_t158(&packed, &scales, 512);
assert!(result[..256].iter().all(|&v| (v - 1.0).abs() < 1e-6));
assert!(result[256..512].iter().all(|&v| (v - (-2.0)).abs() < 1e-6));
}
#[test]
fn test_roundtrip_quantize_dequantize() {
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4, 0.2, -0.6];
let (ternary, scale) = absmean_ternary(&original);
let packed = pack_ternary(&ternary);
let dequantized = dequantize_bitnet_t158(&packed, &[scale], original.len());
assert_eq!(dequantized.len(), 8);
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
let error = (orig - dequant).abs();
assert!(error < scale * 2.0);
}
}
#[test]
fn test_compute_dequant_error() {
let original = vec![1.0, 2.0, 3.0, 4.0];
let dequantized = vec![1.1, 1.9, 3.2, 3.8];
let (mae, mse, max_error) = compute_dequant_error(&original, &dequantized);
assert!((mae - 0.15).abs() < 1e-6);
assert!((mse - 0.025).abs() < 1e-6);
assert!((max_error - 0.2).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "Output buffer must hold at least 256 elements")]
fn test_dequantize_block_small_buffer() {
let packed = vec![0u8; 64];
let mut output = vec![0.0f32; 128]; dequantize_bitnet_block(&packed, 1.0, &mut output);
}
#[test]
#[should_panic(expected = "Packed block must be exactly 64 bytes")]
fn test_dequantize_block_wrong_size() {
let packed = vec![0u8; 32]; let mut output = vec![0.0f32; 256];
dequantize_bitnet_block(&packed, 1.0, &mut output);
}
#[test]
fn test_dequantize_with_missing_scales() {
let ternary = vec![1i8; 512];
let packed = pack_ternary(&ternary);
let scales = vec![2.0];
let result = dequantize_bitnet_t158(&packed, &scales, 512);
assert!(result[..256].iter().all(|&v| (v - 2.0).abs() < 1e-6));
assert!(result[256..512].iter().all(|&v| (v - 1.0).abs() < 1e-6));
}
}