use std::f32::consts::PI;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum PiQuantError {
#[error("Scale alpha must be positive, got {0}")]
NonPositiveScale(f32),
#[error("K value must be in {{2, 3, 4, 5}}, got {0}")]
InvalidK(u8),
#[error("Bits must be 2 or 3, got {0}")]
InvalidBits(u8),
#[error("Expected block size {expected}, got {actual}")]
BlockSizeMismatch { expected: usize, actual: usize },
#[error("Channel index {index} out of bounds (num_channels: {num_channels})")]
ChannelOutOfBounds { index: usize, num_channels: usize },
#[error("Alpha vector cannot be empty")]
EmptyAlpha,
}
pub type Result<T> = std::result::Result<T, PiQuantError>;
pub const PI3_BLOCK_WEIGHTS: usize = 8;
pub const PI3_BLOCK_BYTES: usize = 3;
pub const PI2_BLOCK_WEIGHTS: usize = 4;
pub const PI2_BLOCK_BYTES: usize = 1;
pub const VALID_K_VALUES: [u8; 4] = [2, 3, 4, 5];
#[derive(Debug, Clone)]
pub struct PiQuantizer {
bits: u8,
k: u8,
alpha: Vec<f32>,
half_range: i8,
base_step: f32,
}
impl PiQuantizer {
pub fn new(bits: u8, k: u8, alpha: Vec<f32>) -> Result<Self> {
if bits != 2 && bits != 3 {
return Err(PiQuantError::InvalidBits(bits));
}
if !VALID_K_VALUES.contains(&k) {
return Err(PiQuantError::InvalidK(k));
}
if alpha.is_empty() {
return Err(PiQuantError::EmptyAlpha);
}
for (i, &a) in alpha.iter().enumerate() {
if a <= 0.0 {
return Err(PiQuantError::NonPositiveScale(a));
}
if !a.is_finite() {
return Err(PiQuantError::NonPositiveScale(a));
}
}
let half_range = 1i8 << (bits - 1);
let base_step = PI / (k as f32);
Ok(Self {
bits,
k,
alpha,
half_range,
base_step,
})
}
pub fn with_uniform_scale(bits: u8, k: u8, scale: f32, num_channels: usize) -> Result<Self> {
if scale <= 0.0 || !scale.is_finite() {
return Err(PiQuantError::NonPositiveScale(scale));
}
if num_channels == 0 {
return Err(PiQuantError::EmptyAlpha);
}
Self::new(bits, k, vec![scale; num_channels])
}
#[inline]
pub fn bits(&self) -> u8 {
self.bits
}
#[inline]
pub fn k(&self) -> u8 {
self.k
}
#[inline]
pub fn alpha(&self) -> &[f32] {
&self.alpha
}
#[inline]
pub fn num_channels(&self) -> usize {
self.alpha.len()
}
#[inline]
pub fn step_size(&self, channel: usize) -> f32 {
self.alpha.get(channel).copied().unwrap_or(1.0) * self.base_step
}
#[inline]
pub fn bits_per_weight(&self) -> f32 {
match self.bits {
3 => 3.0625, 2 => 2.0625, _ => self.bits as f32,
}
}
#[inline(always)]
pub fn quantize_scalar(&self, w: f32, channel: usize) -> (i8, f32) {
let step = self.step_size(channel);
if step <= 0.0 {
return (0, 0.0);
}
let q = (w / step).round() as i32;
let half = self.half_range as i32;
let q_clamped = q.clamp(-half, half - 1) as i8;
let w_q = (q_clamped as f32) * step;
(q_clamped, w_q)
}
#[inline(always)]
pub fn quantize_to_int(&self, w: f32, channel: usize) -> i8 {
self.quantize_scalar(w, channel).0
}
#[inline(always)]
pub fn dequantize_int(&self, q: i8, channel: usize) -> f32 {
(q as f32) * self.step_size(channel)
}
pub fn quantize_block_3bit(&self, weights: &[f32], channel: usize) -> Result<Pi3BitBlock> {
if self.bits != 3 {
return Err(PiQuantError::InvalidBits(self.bits));
}
if weights.len() != PI3_BLOCK_WEIGHTS {
return Err(PiQuantError::BlockSizeMismatch {
expected: PI3_BLOCK_WEIGHTS,
actual: weights.len(),
});
}
let mut block = Pi3BitBlock::new();
let mut q_values = [0i8; 8];
for (i, &w) in weights.iter().enumerate() {
q_values[i] = self.quantize_to_int(w, channel);
}
block.pack(&q_values);
Ok(block)
}
#[inline]
pub fn dequantize_block_3bit(&self, block: &Pi3BitBlock, channel: usize) -> [f32; 8] {
let q_values = block.unpack();
let step = self.step_size(channel);
let mut weights = [0.0f32; 8];
for i in 0..8 {
weights[i] = (q_values[i] as f32) * step;
}
weights
}
pub fn quantize_block_2bit(&self, weights: &[f32], channel: usize) -> Result<Pi2BitBlock> {
if self.bits != 2 {
return Err(PiQuantError::InvalidBits(self.bits));
}
if weights.len() != PI2_BLOCK_WEIGHTS {
return Err(PiQuantError::BlockSizeMismatch {
expected: PI2_BLOCK_WEIGHTS,
actual: weights.len(),
});
}
let mut block = Pi2BitBlock::new();
let mut q_values = [0i8; 4];
for (i, &w) in weights.iter().enumerate() {
q_values[i] = self.quantize_to_int(w, channel);
}
block.pack(&q_values);
Ok(block)
}
#[inline]
pub fn dequantize_block_2bit(&self, block: &Pi2BitBlock, channel: usize) -> [f32; 4] {
let q_values = block.unpack();
let step = self.step_size(channel);
let mut weights = [0.0f32; 4];
for i in 0..4 {
weights[i] = (q_values[i] as f32) * step;
}
weights
}
pub fn update_alpha(&mut self, channel: usize, new_alpha: f32) -> Result<()> {
if channel >= self.alpha.len() {
return Err(PiQuantError::ChannelOutOfBounds {
index: channel,
num_channels: self.alpha.len(),
});
}
if new_alpha <= 0.0 || !new_alpha.is_finite() {
return Err(PiQuantError::NonPositiveScale(new_alpha));
}
self.alpha[channel] = new_alpha;
Ok(())
}
pub fn calibrate_from_weights(&mut self, weights_per_channel: &[&[f32]]) -> Result<()> {
if weights_per_channel.len() != self.alpha.len() {
return Err(PiQuantError::ChannelOutOfBounds {
index: weights_per_channel.len(),
num_channels: self.alpha.len(),
});
}
let half = self.half_range as f32;
let divisor = (half - 0.5) * self.base_step;
for (c, weights) in weights_per_channel.iter().enumerate() {
let max_abs = weights
.iter()
.map(|w| w.abs())
.fold(0.0f32, |a, b| a.max(b));
let new_alpha = (max_abs / divisor).max(1e-8);
self.alpha[c] = new_alpha;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct Pi3BitBlock {
pub data: [u8; 3],
}
impl Pi3BitBlock {
#[inline]
pub const fn new() -> Self {
Self { data: [0; 3] }
}
#[inline]
pub fn as_bytes(&self) -> &[u8; 3] {
&self.data
}
#[inline]
pub fn from_bytes(bytes: [u8; 3]) -> Self {
Self { data: bytes }
}
pub fn pack(&mut self, values: &[i8; 8]) {
let mut u = [0u8; 8];
for i in 0..8 {
let v = values[i].clamp(-4, 3);
u[i] = (v + 4) as u8;
}
self.data[0] = (u[0] & 0x07) | ((u[1] & 0x07) << 3) | ((u[2] & 0x07) << 6);
self.data[1] = ((u[2] >> 2) & 0x01) | ((u[3] & 0x07) << 1) | ((u[4] & 0x07) << 4) | ((u[5] & 0x07) << 7);
self.data[2] = ((u[5] >> 1) & 0x03) | ((u[6] & 0x07) << 2) | ((u[7] & 0x07) << 5); }
pub fn unpack(&self) -> [i8; 8] {
let d = self.data;
let u0 = d[0] & 0x07; let u1 = (d[0] >> 3) & 0x07; let u2 = ((d[0] >> 6) & 0x03) | ((d[1] & 0x01) << 2); let u3 = (d[1] >> 1) & 0x07; let u4 = (d[1] >> 4) & 0x07; let u5 = ((d[1] >> 7) & 0x01) | ((d[2] & 0x03) << 1); let u6 = (d[2] >> 2) & 0x07; let u7 = (d[2] >> 5) & 0x07;
[
(u0 as i8) - 4,
(u1 as i8) - 4,
(u2 as i8) - 4,
(u3 as i8) - 4,
(u4 as i8) - 4,
(u5 as i8) - 4,
(u6 as i8) - 4,
(u7 as i8) - 4,
]
}
#[inline]
pub const fn size_bytes() -> usize {
PI3_BLOCK_BYTES
}
#[inline]
pub const fn num_weights() -> usize {
PI3_BLOCK_WEIGHTS
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct Pi2BitBlock {
pub data: u8,
}
impl Pi2BitBlock {
#[inline]
pub const fn new() -> Self {
Self { data: 0 }
}
#[inline]
pub fn as_byte(&self) -> u8 {
self.data
}
#[inline]
pub fn from_byte(byte: u8) -> Self {
Self { data: byte }
}
pub fn pack(&mut self, values: &[i8; 4]) {
let u0 = ((values[0].clamp(-2, 1) + 2) as u8) & 0x03;
let u1 = ((values[1].clamp(-2, 1) + 2) as u8) & 0x03;
let u2 = ((values[2].clamp(-2, 1) + 2) as u8) & 0x03;
let u3 = ((values[3].clamp(-2, 1) + 2) as u8) & 0x03;
self.data = u0 | (u1 << 2) | (u2 << 4) | (u3 << 6);
}
pub fn unpack(&self) -> [i8; 4] {
let d = self.data;
[
((d & 0x03) as i8) - 2,
(((d >> 2) & 0x03) as i8) - 2,
(((d >> 4) & 0x03) as i8) - 2,
(((d >> 6) & 0x03) as i8) - 2,
]
}
#[inline]
pub const fn size_bytes() -> usize {
PI2_BLOCK_BYTES
}
#[inline]
pub const fn num_weights() -> usize {
PI2_BLOCK_WEIGHTS
}
}
pub fn quantize_tensor_3bit(
weights: &[f32],
quantizer: &PiQuantizer,
channel: usize,
) -> Result<Vec<Pi3BitBlock>> {
if quantizer.bits() != 3 {
return Err(PiQuantError::InvalidBits(quantizer.bits()));
}
let num_blocks = (weights.len() + PI3_BLOCK_WEIGHTS - 1) / PI3_BLOCK_WEIGHTS;
let mut blocks = Vec::with_capacity(num_blocks);
for chunk in weights.chunks(PI3_BLOCK_WEIGHTS) {
let mut padded = [0.0f32; 8];
padded[..chunk.len()].copy_from_slice(chunk);
let block = quantizer.quantize_block_3bit(&padded, channel)?;
blocks.push(block);
}
Ok(blocks)
}
pub fn dequantize_tensor_3bit(
blocks: &[Pi3BitBlock],
quantizer: &PiQuantizer,
channel: usize,
output: &mut [f32],
) {
let step = quantizer.step_size(channel);
for (i, block) in blocks.iter().enumerate() {
let q_values = block.unpack();
let base_idx = i * 8;
for j in 0..8 {
let out_idx = base_idx + j;
if out_idx < output.len() {
output[out_idx] = (q_values[j] as f32) * step;
}
}
}
}
pub fn quantize_tensor_2bit(
weights: &[f32],
quantizer: &PiQuantizer,
channel: usize,
) -> Result<Vec<Pi2BitBlock>> {
if quantizer.bits() != 2 {
return Err(PiQuantError::InvalidBits(quantizer.bits()));
}
let num_blocks = (weights.len() + PI2_BLOCK_WEIGHTS - 1) / PI2_BLOCK_WEIGHTS;
let mut blocks = Vec::with_capacity(num_blocks);
for chunk in weights.chunks(PI2_BLOCK_WEIGHTS) {
let mut padded = [0.0f32; 4];
padded[..chunk.len()].copy_from_slice(chunk);
let block = quantizer.quantize_block_2bit(&padded, channel)?;
blocks.push(block);
}
Ok(blocks)
}
pub fn dequantize_tensor_2bit(
blocks: &[Pi2BitBlock],
quantizer: &PiQuantizer,
channel: usize,
output: &mut [f32],
) {
let step = quantizer.step_size(channel);
for (i, block) in blocks.iter().enumerate() {
let q_values = block.unpack();
let base_idx = i * 4;
for j in 0..4 {
let out_idx = base_idx + j;
if out_idx < output.len() {
output[out_idx] = (q_values[j] as f32) * step;
}
}
}
}
pub fn quantize_3bit_fast(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
debug_assert!(
weights.len() % PI3_BLOCK_WEIGHTS == 0,
"Weight length must be multiple of 8"
);
let num_blocks = weights.len() / PI3_BLOCK_WEIGHTS;
let output_bytes = num_blocks * PI3_BLOCK_BYTES;
debug_assert!(output.len() >= output_bytes, "Output buffer too small");
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
unsafe {
quantize_3bit_inner(weights, inv_step, output, num_blocks);
}
output_bytes
}
#[inline(always)]
unsafe fn quantize_3bit_inner(
weights: &[f32],
inv_step: f32,
output: &mut [u8],
num_blocks: usize,
) {
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
for block in 0..num_blocks {
let w_offset = block * 8;
let o_offset = block * 3;
let mut combined: u32 = 0;
for i in 0..8 {
let w = *weights_ptr.add(w_offset + i);
let q = (w * inv_step).round() as i32;
let clamped = q.clamp(-4, 3);
let unsigned = (clamped + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
*output_ptr.add(o_offset) = (combined & 0xFF) as u8;
*output_ptr.add(o_offset + 1) = ((combined >> 8) & 0xFF) as u8;
*output_ptr.add(o_offset + 2) = ((combined >> 16) & 0xFF) as u8;
}
}
pub fn quantize_2bit_fast(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
debug_assert!(
weights.len() % PI2_BLOCK_WEIGHTS == 0,
"Weight length must be multiple of 4"
);
let num_blocks = weights.len() / PI2_BLOCK_WEIGHTS;
debug_assert!(output.len() >= num_blocks, "Output buffer too small");
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
unsafe {
quantize_2bit_inner(weights, inv_step, output, num_blocks);
}
num_blocks
}
#[inline(always)]
unsafe fn quantize_2bit_inner(
weights: &[f32],
inv_step: f32,
output: &mut [u8],
num_blocks: usize,
) {
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
for block in 0..num_blocks {
let w_offset = block * 4;
let w0 = *weights_ptr.add(w_offset);
let w1 = *weights_ptr.add(w_offset + 1);
let w2 = *weights_ptr.add(w_offset + 2);
let w3 = *weights_ptr.add(w_offset + 3);
let q0 = ((w0 * inv_step).round() as i32).clamp(-2, 1);
let q1 = ((w1 * inv_step).round() as i32).clamp(-2, 1);
let q2 = ((w2 * inv_step).round() as i32).clamp(-2, 1);
let q3 = ((w3 * inv_step).round() as i32).clamp(-2, 1);
let packed = ((q0 + 2) as u8 & 0x03)
| (((q1 + 2) as u8 & 0x03) << 2)
| (((q2 + 2) as u8 & 0x03) << 4)
| (((q3 + 2) as u8 & 0x03) << 6);
*output_ptr.add(block) = packed;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn quantize_3bit_neon(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
use core::arch::aarch64::*;
let num_blocks = weights.len() / 8;
let output_bytes = num_blocks * 3;
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
let inv_step_vec = vdupq_n_f32(inv_step);
let min_val = vdupq_n_s32(-4);
let max_val = vdupq_n_s32(3);
let offset = vdupq_n_s32(4);
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
let simd_iterations = num_blocks / 4;
let mut block = 0usize;
while block < simd_iterations * 4 {
for inner in 0..4 {
let b = block + inner;
let w_offset = b * 8;
let o_offset = b * 3;
let w_lo = vld1q_f32(weights_ptr.add(w_offset));
let w_hi = vld1q_f32(weights_ptr.add(w_offset + 4));
let scaled_lo = vmulq_f32(w_lo, inv_step_vec);
let scaled_hi = vmulq_f32(w_hi, inv_step_vec);
let rounded_lo = vrndnq_f32(scaled_lo);
let rounded_hi = vrndnq_f32(scaled_hi);
let q_lo = vcvtq_s32_f32(rounded_lo);
let q_hi = vcvtq_s32_f32(rounded_hi);
let clamped_lo = vminq_s32(vmaxq_s32(q_lo, min_val), max_val);
let clamped_hi = vminq_s32(vmaxq_s32(q_hi, min_val), max_val);
let unsigned_lo = vaddq_s32(clamped_lo, offset);
let unsigned_hi = vaddq_s32(clamped_hi, offset);
let mut vals = [0u32; 8];
vst1q_s32(vals.as_mut_ptr() as *mut i32, unsigned_lo);
vst1q_s32(vals.as_mut_ptr().add(4) as *mut i32, unsigned_hi);
let mut combined: u32 = 0;
for i in 0..8 {
combined |= (vals[i] & 0x7) << (i * 3);
}
*output_ptr.add(o_offset) = (combined & 0xFF) as u8;
*output_ptr.add(o_offset + 1) = ((combined >> 8) & 0xFF) as u8;
*output_ptr.add(o_offset + 2) = ((combined >> 16) & 0xFF) as u8;
}
block += 4;
}
while block < num_blocks {
let w_offset = block * 8;
let o_offset = block * 3;
let mut combined: u32 = 0;
for i in 0..8 {
let w = *weights_ptr.add(w_offset + i);
let q = (w * inv_step).round() as i32;
let clamped = q.clamp(-4, 3);
let unsigned = (clamped + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
*output_ptr.add(o_offset) = (combined & 0xFF) as u8;
*output_ptr.add(o_offset + 1) = ((combined >> 8) & 0xFF) as u8;
*output_ptr.add(o_offset + 2) = ((combined >> 16) & 0xFF) as u8;
block += 1;
}
output_bytes
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn quantize_2bit_neon(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
use core::arch::aarch64::*;
let num_blocks = weights.len() / 4;
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
let inv_step_vec = vdupq_n_f32(inv_step);
let min_val = vdupq_n_s32(-2);
let max_val = vdupq_n_s32(1);
let offset = vdupq_n_s32(2);
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
let simd_iterations = num_blocks / 4;
let mut block = 0usize;
while block < simd_iterations * 4 {
let w0 = vld1q_f32(weights_ptr.add(block * 4));
let w1 = vld1q_f32(weights_ptr.add((block + 1) * 4));
let w2 = vld1q_f32(weights_ptr.add((block + 2) * 4));
let w3 = vld1q_f32(weights_ptr.add((block + 3) * 4));
let scaled0 = vmulq_f32(w0, inv_step_vec);
let scaled1 = vmulq_f32(w1, inv_step_vec);
let scaled2 = vmulq_f32(w2, inv_step_vec);
let scaled3 = vmulq_f32(w3, inv_step_vec);
let rounded0 = vrndnq_f32(scaled0);
let rounded1 = vrndnq_f32(scaled1);
let rounded2 = vrndnq_f32(scaled2);
let rounded3 = vrndnq_f32(scaled3);
let q0 = vminq_s32(vmaxq_s32(vcvtq_s32_f32(rounded0), min_val), max_val);
let q1 = vminq_s32(vmaxq_s32(vcvtq_s32_f32(rounded1), min_val), max_val);
let q2 = vminq_s32(vmaxq_s32(vcvtq_s32_f32(rounded2), min_val), max_val);
let q3 = vminq_s32(vmaxq_s32(vcvtq_s32_f32(rounded3), min_val), max_val);
let u0 = vaddq_s32(q0, offset);
let u1 = vaddq_s32(q1, offset);
let u2 = vaddq_s32(q2, offset);
let u3 = vaddq_s32(q3, offset);
let mut vals0 = [0i32; 4];
let mut vals1 = [0i32; 4];
let mut vals2 = [0i32; 4];
let mut vals3 = [0i32; 4];
vst1q_s32(vals0.as_mut_ptr(), u0);
vst1q_s32(vals1.as_mut_ptr(), u1);
vst1q_s32(vals2.as_mut_ptr(), u2);
vst1q_s32(vals3.as_mut_ptr(), u3);
*output_ptr.add(block) = ((vals0[0] as u8) & 0x03)
| (((vals0[1] as u8) & 0x03) << 2)
| (((vals0[2] as u8) & 0x03) << 4)
| (((vals0[3] as u8) & 0x03) << 6);
*output_ptr.add(block + 1) = ((vals1[0] as u8) & 0x03)
| (((vals1[1] as u8) & 0x03) << 2)
| (((vals1[2] as u8) & 0x03) << 4)
| (((vals1[3] as u8) & 0x03) << 6);
*output_ptr.add(block + 2) = ((vals2[0] as u8) & 0x03)
| (((vals2[1] as u8) & 0x03) << 2)
| (((vals2[2] as u8) & 0x03) << 4)
| (((vals2[3] as u8) & 0x03) << 6);
*output_ptr.add(block + 3) = ((vals3[0] as u8) & 0x03)
| (((vals3[1] as u8) & 0x03) << 2)
| (((vals3[2] as u8) & 0x03) << 4)
| (((vals3[3] as u8) & 0x03) << 6);
block += 4;
}
while block < num_blocks {
let w_offset = block * 4;
let w0 = *weights_ptr.add(w_offset);
let w1 = *weights_ptr.add(w_offset + 1);
let w2 = *weights_ptr.add(w_offset + 2);
let w3 = *weights_ptr.add(w_offset + 3);
let q0 = ((w0 * inv_step).round() as i32).clamp(-2, 1);
let q1 = ((w1 * inv_step).round() as i32).clamp(-2, 1);
let q2 = ((w2 * inv_step).round() as i32).clamp(-2, 1);
let q3 = ((w3 * inv_step).round() as i32).clamp(-2, 1);
*output_ptr.add(block) = ((q0 + 2) as u8 & 0x03)
| (((q1 + 2) as u8 & 0x03) << 2)
| (((q2 + 2) as u8 & 0x03) << 4)
| (((q3 + 2) as u8 & 0x03) << 6);
block += 1;
}
num_blocks
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn quantize_3bit_avx2(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
use core::arch::x86_64::*;
let num_blocks = weights.len() / 8;
let output_bytes = num_blocks * 3;
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
let inv_step_vec = _mm256_set1_ps(inv_step);
let min_val = _mm256_set1_epi32(-4);
let max_val = _mm256_set1_epi32(3);
let offset = _mm256_set1_epi32(4);
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
for block in 0..num_blocks {
let w_offset = block * 8;
let o_offset = block * 3;
let w = _mm256_loadu_ps(weights_ptr.add(w_offset));
let scaled = _mm256_mul_ps(w, inv_step_vec);
let rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
let q = _mm256_cvtps_epi32(rounded);
let clamped = _mm256_min_epi32(_mm256_max_epi32(q, min_val), max_val);
let unsigned = _mm256_add_epi32(clamped, offset);
let mut vals = [0i32; 8];
_mm256_storeu_si256(vals.as_mut_ptr() as *mut __m256i, unsigned);
let mut combined: u32 = 0;
for i in 0..8 {
combined |= ((vals[i] as u32) & 0x7) << (i * 3);
}
*output_ptr.add(o_offset) = (combined & 0xFF) as u8;
*output_ptr.add(o_offset + 1) = ((combined >> 8) & 0xFF) as u8;
*output_ptr.add(o_offset + 2) = ((combined >> 16) & 0xFF) as u8;
}
output_bytes
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn quantize_2bit_avx2(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
use core::arch::x86_64::*;
let num_blocks = weights.len() / 4;
if num_blocks == 0 {
return 0;
}
let inv_step = if step.abs() > 1e-10 { 1.0 / step } else { 0.0 };
let inv_step_vec = _mm_set1_ps(inv_step);
let min_val = _mm_set1_epi32(-2);
let max_val = _mm_set1_epi32(1);
let offset = _mm_set1_epi32(2);
let weights_ptr = weights.as_ptr();
let output_ptr = output.as_mut_ptr();
for block in 0..num_blocks {
let w_offset = block * 4;
let w = _mm_loadu_ps(weights_ptr.add(w_offset));
let scaled = _mm_mul_ps(w, inv_step_vec);
let rounded = _mm_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
let q = _mm_cvtps_epi32(rounded);
let clamped = _mm_min_epi32(_mm_max_epi32(q, min_val), max_val);
let unsigned = _mm_add_epi32(clamped, offset);
let mut vals = [0i32; 4];
_mm_storeu_si128(vals.as_mut_ptr() as *mut __m128i, unsigned);
*output_ptr.add(block) = ((vals[0] as u8) & 0x03)
| (((vals[1] as u8) & 0x03) << 2)
| (((vals[2] as u8) & 0x03) << 4)
| (((vals[3] as u8) & 0x03) << 6);
}
num_blocks
}
pub fn quantize_3bit(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
#[cfg(target_arch = "aarch64")]
{
unsafe {
return quantize_3bit_neon(weights, step, output);
}
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return quantize_3bit_avx2(weights, step, output);
}
}
}
quantize_3bit_fast(weights, step, output)
}
pub fn quantize_2bit(weights: &[f32], step: f32, output: &mut [u8]) -> usize {
#[cfg(target_arch = "aarch64")]
{
unsafe {
return quantize_2bit_neon(weights, step, output);
}
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
return quantize_2bit_avx2(weights, step, output);
}
}
}
quantize_2bit_fast(weights, step, output)
}
pub fn quantize_kernel_name() -> &'static str {
#[cfg(target_arch = "aarch64")]
{
return "neon";
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return "avx2";
}
}
"scalar"
}
pub fn batch_quantize_3bit(tensors: &mut [(&[f32], &mut [u8])], step: f32) -> usize {
let mut total_bytes = 0;
for (weights, output) in tensors.iter_mut() {
total_bytes += quantize_3bit(weights, step, output);
}
total_bytes
}
pub fn compute_mse(original: &[f32], quantized: &[f32]) -> f64 {
if original.len() != quantized.len() || original.is_empty() {
return 0.0;
}
let sum: f64 = original
.iter()
.zip(quantized.iter())
.map(|(&o, &q)| {
let diff = (o - q) as f64;
diff * diff
})
.sum();
sum / (original.len() as f64)
}
pub fn compute_spectral_distortion_db(original: &[f32], quantized: &[f32]) -> f64 {
if original.len() != quantized.len() || original.is_empty() {
return f64::NEG_INFINITY;
}
let signal_power: f64 = original.iter().map(|&x| (x as f64).powi(2)).sum();
if signal_power == 0.0 {
return 0.0;
}
let mse = compute_mse(original, quantized);
10.0 * (mse / (signal_power / original.len() as f64)).log10()
}
pub fn compute_cosine_similarity(original: &[f32], quantized: &[f32]) -> f32 {
if original.len() != quantized.len() || original.is_empty() {
return 0.0;
}
let dot: f32 = original
.iter()
.zip(quantized.iter())
.map(|(&o, &q)| o * q)
.sum();
let norm_o: f32 = original.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_q: f32 = quantized.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm_o == 0.0 || norm_q == 0.0 {
return 0.0;
}
dot / (norm_o * norm_q)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pi_quantizer_new_valid() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0, 2.0, 0.5]).unwrap();
assert_eq!(quantizer.bits(), 3);
assert_eq!(quantizer.k(), 4);
assert_eq!(quantizer.num_channels(), 3);
}
#[test]
fn test_pi_quantizer_invalid_bits() {
let result = PiQuantizer::new(4, 4, vec![1.0]);
assert!(matches!(result, Err(PiQuantError::InvalidBits(4))));
let result = PiQuantizer::new(1, 4, vec![1.0]);
assert!(matches!(result, Err(PiQuantError::InvalidBits(1))));
}
#[test]
fn test_pi_quantizer_invalid_k() {
let result = PiQuantizer::new(3, 1, vec![1.0]);
assert!(matches!(result, Err(PiQuantError::InvalidK(1))));
let result = PiQuantizer::new(3, 6, vec![1.0]);
assert!(matches!(result, Err(PiQuantError::InvalidK(6))));
}
#[test]
fn test_pi_quantizer_inv2_positive_scale() {
let result = PiQuantizer::new(3, 4, vec![1.0, -0.5, 2.0]);
assert!(matches!(result, Err(PiQuantError::NonPositiveScale(_))));
let result = PiQuantizer::new(3, 4, vec![0.0]);
assert!(matches!(result, Err(PiQuantError::NonPositiveScale(0.0))));
let result = PiQuantizer::new(3, 4, vec![f32::NAN]);
assert!(matches!(result, Err(PiQuantError::NonPositiveScale(_))));
}
#[test]
fn test_pi_quantizer_empty_alpha() {
let result = PiQuantizer::new(3, 4, vec![]);
assert!(matches!(result, Err(PiQuantError::EmptyAlpha)));
}
#[test]
fn test_pi_quantizer_step_size() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0, 2.0]).unwrap();
let step_0 = quantizer.step_size(0);
let step_1 = quantizer.step_size(1);
assert!((step_0 - PI / 4.0).abs() < 1e-6);
assert!((step_1 - 2.0 * PI / 4.0).abs() < 1e-6);
}
#[test]
fn test_quantize_scalar_3bit() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
let step = PI / 4.0;
let (q, dq) = quantizer.quantize_scalar(0.0, 0);
assert_eq!(q, 0);
assert!(dq.abs() < 1e-6);
let (q, dq) = quantizer.quantize_scalar(step * 2.0, 0);
assert_eq!(q, 2);
assert!((dq - step * 2.0).abs() < 1e-6);
let (q, dq) = quantizer.quantize_scalar(-step * 3.0, 0);
assert_eq!(q, -3);
assert!((dq + step * 3.0).abs() < 1e-6);
let (q, _dq) = quantizer.quantize_scalar(step * 10.0, 0);
assert_eq!(q, 3);
let (q, _dq) = quantizer.quantize_scalar(-step * 10.0, 0);
assert_eq!(q, -4);
}
#[test]
fn test_quantize_scalar_2bit() {
let quantizer = PiQuantizer::new(2, 4, vec![1.0]).unwrap();
let step = PI / 4.0;
let (q, dq) = quantizer.quantize_scalar(0.0, 0);
assert_eq!(q, 0);
assert!(dq.abs() < 1e-6);
let (q, _dq) = quantizer.quantize_scalar(step * 1.0, 0);
assert_eq!(q, 1);
let (q, _dq) = quantizer.quantize_scalar(step * 10.0, 0);
assert_eq!(q, 1);
let (q, _dq) = quantizer.quantize_scalar(-step * 10.0, 0);
assert_eq!(q, -2);
}
#[test]
fn test_pi3bit_block_pack_unpack_roundtrip() {
let values: [i8; 8] = [-4, -3, -2, -1, 0, 1, 2, 3];
let mut block = Pi3BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi3bit_block_all_zeros() {
let values: [i8; 8] = [0, 0, 0, 0, 0, 0, 0, 0];
let mut block = Pi3BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi3bit_block_all_max() {
let values: [i8; 8] = [3, 3, 3, 3, 3, 3, 3, 3];
let mut block = Pi3BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi3bit_block_all_min() {
let values: [i8; 8] = [-4, -4, -4, -4, -4, -4, -4, -4];
let mut block = Pi3BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi3bit_block_clamping() {
let values: [i8; 8] = [-10, -5, -4, 0, 3, 4, 5, 10];
let mut block = Pi3BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
let expected: [i8; 8] = [-4, -4, -4, 0, 3, 3, 3, 3];
assert_eq!(expected, unpacked);
}
#[test]
fn test_pi3bit_block_size() {
assert_eq!(Pi3BitBlock::size_bytes(), 3);
assert_eq!(Pi3BitBlock::num_weights(), 8);
}
#[test]
fn test_pi2bit_block_pack_unpack_roundtrip() {
let values: [i8; 4] = [-2, -1, 0, 1];
let mut block = Pi2BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi2bit_block_all_zeros() {
let values: [i8; 4] = [0, 0, 0, 0];
let mut block = Pi2BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi2bit_block_extremes() {
let values: [i8; 4] = [-2, -2, 1, 1];
let mut block = Pi2BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_pi2bit_block_clamping() {
let values: [i8; 4] = [-10, -3, 2, 10];
let mut block = Pi2BitBlock::new();
block.pack(&values);
let unpacked = block.unpack();
let expected: [i8; 4] = [-2, -2, 1, 1];
assert_eq!(expected, unpacked);
}
#[test]
fn test_pi2bit_block_size() {
assert_eq!(Pi2BitBlock::size_bytes(), 1);
assert_eq!(Pi2BitBlock::num_weights(), 4);
}
#[test]
fn test_quantize_block_3bit() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
let weights = [0.0, 0.1, -0.1, 0.5, -0.5, 1.0, -1.0, 0.25];
let block = quantizer.quantize_block_3bit(&weights, 0).unwrap();
let dequantized = quantizer.dequantize_block_3bit(&block, 0);
let mse = compute_mse(&weights, &dequantized);
assert!(mse < 0.5, "MSE too high: {}", mse);
}
#[test]
fn test_quantize_block_2bit() {
let quantizer = PiQuantizer::new(2, 4, vec![1.0]).unwrap();
let weights = [0.0, 0.5, -0.5, 1.0];
let block = quantizer.quantize_block_2bit(&weights, 0).unwrap();
let dequantized = quantizer.dequantize_block_2bit(&block, 0);
assert_eq!(dequantized.len(), 4);
}
#[test]
fn test_quantize_tensor_3bit() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
let weights: Vec<f32> = (0..24).map(|i| (i as f32 - 12.0) * 0.1).collect();
let blocks = quantize_tensor_3bit(&weights, &quantizer, 0).unwrap();
assert_eq!(blocks.len(), 3);
let mut output = vec![0.0f32; weights.len()];
dequantize_tensor_3bit(&blocks, &quantizer, 0, &mut output);
let mse = compute_mse(&weights, &output);
assert!(mse < 0.5, "MSE too high: {}", mse);
}
#[test]
fn test_quantize_tensor_2bit() {
let quantizer = PiQuantizer::new(2, 4, vec![1.0]).unwrap();
let weights: Vec<f32> = (0..12).map(|i| (i as f32 - 6.0) * 0.2).collect();
let blocks = quantize_tensor_2bit(&weights, &quantizer, 0).unwrap();
assert_eq!(blocks.len(), 3);
let mut output = vec![0.0f32; weights.len()];
dequantize_tensor_2bit(&blocks, &quantizer, 0, &mut output);
assert_eq!(output.len(), 12);
}
#[test]
fn test_compute_mse_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let mse = compute_mse(&a, &b);
assert!((mse - 0.0).abs() < 1e-10);
}
#[test]
fn test_compute_mse_different() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
let mse = compute_mse(&a, &b);
assert!((mse - 1.0).abs() < 1e-10); }
#[test]
fn test_compute_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let sim = compute_cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_compute_cosine_similarity_opposite() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![-1.0, -2.0, -3.0];
let sim = compute_cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-6);
}
#[test]
fn test_compute_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = compute_cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_calibrate_from_weights() {
let mut quantizer = PiQuantizer::new(3, 4, vec![1.0, 1.0]).unwrap();
let channel_0_weights: Vec<f32> = vec![0.1, -0.2, 0.3, -0.4];
let channel_1_weights: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0];
quantizer
.calibrate_from_weights(&[&channel_0_weights, &channel_1_weights])
.unwrap();
assert!(quantizer.alpha()[1] > quantizer.alpha()[0]);
}
#[test]
fn test_update_alpha() {
let mut quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
quantizer.update_alpha(0, 2.0).unwrap();
assert!((quantizer.alpha()[0] - 2.0).abs() < 1e-6);
let result = quantizer.update_alpha(0, -1.0);
assert!(result.is_err());
let result = quantizer.update_alpha(99, 1.0);
assert!(result.is_err());
}
#[test]
fn test_quantize_block_wrong_bits() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
let weights = [0.0f32; 4];
let result = quantizer.quantize_block_2bit(&weights, 0);
assert!(matches!(result, Err(PiQuantError::InvalidBits(3))));
}
#[test]
fn test_quantize_block_wrong_size() {
let quantizer = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
let weights = [0.0f32; 4];
let result = quantizer.quantize_block_3bit(&weights, 0);
assert!(matches!(
result,
Err(PiQuantError::BlockSizeMismatch {
expected: 8,
actual: 4
})
));
}
#[test]
fn test_bits_per_weight() {
let q3 = PiQuantizer::new(3, 4, vec![1.0]).unwrap();
assert!((q3.bits_per_weight() - 3.0625).abs() < 1e-4);
let q2 = PiQuantizer::new(2, 4, vec![1.0]).unwrap();
assert!((q2.bits_per_weight() - 2.0625).abs() < 1e-4);
}
}