use std::f32::consts::PI;
pub const DEFAULT_K: u32 = 4;
pub const PI_F32: f32 = PI;
pub const PI3_VALUES_PER_GROUP: usize = 8;
pub const PI3_BYTES_PER_GROUP: usize = 3;
const SIGNED_3BIT_LUT: [i8; 8] = [-4, -3, -2, -1, 0, 1, 2, 3];
pub fn pi_dequantize_scalar(packed: &[u8], scale: f32, output: &mut [f32]) {
assert!(
packed.len() % PI3_BYTES_PER_GROUP == 0,
"Packed length {} must be a multiple of {}",
packed.len(),
PI3_BYTES_PER_GROUP
);
let num_groups = packed.len() / PI3_BYTES_PER_GROUP;
let expected_output_len = num_groups * PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
expected_output_len,
"Output length {} doesn't match expected {} (from {} packed bytes)",
output.len(),
expected_output_len,
packed.len()
);
for group in 0..num_groups {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
let b0 = packed[byte_offset] as u32;
let b1 = packed[byte_offset + 1] as u32;
let b2 = packed[byte_offset + 2] as u32;
let combined = b0 | (b1 << 8) | (b2 << 16);
for i in 0..8 {
let shift = i * 3;
let raw = ((combined >> shift) & 0x7) as i32;
let signed = raw - 4;
output[out_offset + i] = (signed as f32) * scale;
}
}
}
#[inline]
pub fn extract_pi3_value(packed: &[u8], index: usize) -> i8 {
let group = index / PI3_VALUES_PER_GROUP;
let offset_in_group = index % PI3_VALUES_PER_GROUP;
let byte_offset = group * PI3_BYTES_PER_GROUP;
let b0 = packed[byte_offset] as u32;
let b1 = packed[byte_offset + 1] as u32;
let b2 = packed[byte_offset + 2] as u32;
let combined = b0 | (b1 << 8) | (b2 << 16);
let shift = offset_in_group * 3;
let raw = ((combined >> shift) & 0x7) as i32;
(raw - 4) as i8
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn pi_dequantize_neon(packed: &[u8], scale: f32, output: &mut [f32]) {
use core::arch::aarch64::*;
let num_groups = packed.len() / PI3_BYTES_PER_GROUP;
let total_values = num_groups * PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
total_values,
"Output length mismatch: {} vs expected {}",
output.len(),
total_values
);
if num_groups == 0 {
return;
}
let scale_vec = vdupq_n_f32(scale);
let bias_scaled = vdupq_n_f32(-4.0 * scale);
let bias_f32 = vdupq_n_f32(-4.0);
let shifts_lo: int32x4_t = vld1q_s32([0i32, -3, -6, -9].as_ptr());
let shifts_hi: int32x4_t = vld1q_s32([-12i32, -15, -18, -21].as_ptr());
let mask_3bit = vdupq_n_u32(0x7);
const PREFETCH_DISTANCE: usize = 256;
let packed_ptr = packed.as_ptr();
let output_ptr = output.as_mut_ptr();
let mut group = 0usize;
while group + 8 <= num_groups {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
let p = packed_ptr.add(byte_offset);
let o = output_ptr.add(out_offset);
if byte_offset + PREFETCH_DISTANCE < packed.len() {
let prefetch_addr = packed_ptr.add(byte_offset + PREFETCH_DISTANCE);
core::arch::asm!(
"prfm pldl1keep, [{addr}]",
"prfm pldl1keep, [{addr}, #64]",
addr = in(reg) prefetch_addr,
options(nostack, preserves_flags)
);
}
let c0 = (*p as u32) | ((*p.add(1) as u32) << 8) | ((*p.add(2) as u32) << 16);
let c1 = (*p.add(3) as u32) | ((*p.add(4) as u32) << 8) | ((*p.add(5) as u32) << 16);
let c2 = (*p.add(6) as u32) | ((*p.add(7) as u32) << 8) | ((*p.add(8) as u32) << 16);
let c3 = (*p.add(9) as u32) | ((*p.add(10) as u32) << 8) | ((*p.add(11) as u32) << 16);
let c4 = (*p.add(12) as u32) | ((*p.add(13) as u32) << 8) | ((*p.add(14) as u32) << 16);
let c5 = (*p.add(15) as u32) | ((*p.add(16) as u32) << 8) | ((*p.add(17) as u32) << 16);
let c6 = (*p.add(18) as u32) | ((*p.add(19) as u32) << 8) | ((*p.add(20) as u32) << 16);
let c7 = (*p.add(21) as u32) | ((*p.add(22) as u32) << 8) | ((*p.add(23) as u32) << 16);
let v0 = vdupq_n_u32(c0);
let lo0 = vandq_u32(vshlq_u32(v0, shifts_lo), mask_3bit);
let hi0 = vandq_u32(vshlq_u32(v0, shifts_hi), mask_3bit);
vst1q_f32(o, vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo0), scale_vec));
vst1q_f32(
o.add(4),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi0), scale_vec),
);
let v1 = vdupq_n_u32(c1);
let lo1 = vandq_u32(vshlq_u32(v1, shifts_lo), mask_3bit);
let hi1 = vandq_u32(vshlq_u32(v1, shifts_hi), mask_3bit);
vst1q_f32(
o.add(8),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo1), scale_vec),
);
vst1q_f32(
o.add(12),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi1), scale_vec),
);
let v2 = vdupq_n_u32(c2);
let lo2 = vandq_u32(vshlq_u32(v2, shifts_lo), mask_3bit);
let hi2 = vandq_u32(vshlq_u32(v2, shifts_hi), mask_3bit);
vst1q_f32(
o.add(16),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo2), scale_vec),
);
vst1q_f32(
o.add(20),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi2), scale_vec),
);
let v3 = vdupq_n_u32(c3);
let lo3 = vandq_u32(vshlq_u32(v3, shifts_lo), mask_3bit);
let hi3 = vandq_u32(vshlq_u32(v3, shifts_hi), mask_3bit);
vst1q_f32(
o.add(24),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo3), scale_vec),
);
vst1q_f32(
o.add(28),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi3), scale_vec),
);
let v4 = vdupq_n_u32(c4);
let lo4 = vandq_u32(vshlq_u32(v4, shifts_lo), mask_3bit);
let hi4 = vandq_u32(vshlq_u32(v4, shifts_hi), mask_3bit);
vst1q_f32(
o.add(32),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo4), scale_vec),
);
vst1q_f32(
o.add(36),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi4), scale_vec),
);
let v5 = vdupq_n_u32(c5);
let lo5 = vandq_u32(vshlq_u32(v5, shifts_lo), mask_3bit);
let hi5 = vandq_u32(vshlq_u32(v5, shifts_hi), mask_3bit);
vst1q_f32(
o.add(40),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo5), scale_vec),
);
vst1q_f32(
o.add(44),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi5), scale_vec),
);
let v6 = vdupq_n_u32(c6);
let lo6 = vandq_u32(vshlq_u32(v6, shifts_lo), mask_3bit);
let hi6 = vandq_u32(vshlq_u32(v6, shifts_hi), mask_3bit);
vst1q_f32(
o.add(48),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo6), scale_vec),
);
vst1q_f32(
o.add(52),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi6), scale_vec),
);
let v7 = vdupq_n_u32(c7);
let lo7 = vandq_u32(vshlq_u32(v7, shifts_lo), mask_3bit);
let hi7 = vandq_u32(vshlq_u32(v7, shifts_hi), mask_3bit);
vst1q_f32(
o.add(56),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(lo7), scale_vec),
);
vst1q_f32(
o.add(60),
vfmaq_f32(bias_scaled, vcvtq_f32_u32(hi7), scale_vec),
);
group += 8;
}
while group < num_groups {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
let combined = neon_load_combined_24bit(packed_ptr.add(byte_offset));
let (lo, hi) = neon_extract_and_convert(combined, bias_f32, scale_vec);
vst1q_f32(output_ptr.add(out_offset), lo);
vst1q_f32(output_ptr.add(out_offset + 4), hi);
group += 1;
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn neon_load_combined_24bit(ptr: *const u8) -> u32 {
let b0 = *ptr as u32;
let b1 = *ptr.add(1) as u32;
let b2 = *ptr.add(2) as u32;
b0 | (b1 << 8) | (b2 << 16)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn neon_extract_and_convert(
combined: u32,
bias_f32: core::arch::aarch64::float32x4_t,
scale_vec: core::arch::aarch64::float32x4_t,
) -> (
core::arch::aarch64::float32x4_t,
core::arch::aarch64::float32x4_t,
) {
use core::arch::aarch64::*;
let c_vec = vdupq_n_u32(combined);
let mask_3bit = vdupq_n_u32(0x7);
let shifts_lo = vld1q_s32([0i32, -3, -6, -9].as_ptr());
let shifts_hi = vld1q_s32([-12i32, -15, -18, -21].as_ptr());
let lo_u32 = vandq_u32(vshlq_u32(c_vec, shifts_lo), mask_3bit);
let hi_u32 = vandq_u32(vshlq_u32(c_vec, shifts_hi), mask_3bit);
let lo_f32 = vcvtq_f32_u32(lo_u32);
let hi_f32 = vcvtq_f32_u32(hi_u32);
let biased_lo = vaddq_f32(lo_f32, bias_f32);
let biased_hi = vaddq_f32(hi_f32, bias_f32);
let result_lo = vmulq_f32(biased_lo, scale_vec);
let result_hi = vmulq_f32(biased_hi, scale_vec);
(result_lo, result_hi)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn pi_dequantize_avx512(packed: &[u8], scale: f32, output: &mut [f32]) {
use core::arch::x86_64::*;
let num_groups = packed.len() / PI3_BYTES_PER_GROUP;
let total_values = num_groups * PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
total_values,
"Output length mismatch: {} vs expected {}",
output.len(),
total_values
);
if num_groups == 0 {
return;
}
let scale_vec = _mm512_set1_ps(scale);
let bias_vec = _mm512_set1_epi32(-4);
let simd_groups = num_groups / 8;
let mut group = 0usize;
while group < simd_groups * 8 {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
for batch in 0..4 {
let g0 = batch * 2;
let g1 = batch * 2 + 1;
let gb0 = byte_offset + g0 * 3;
let b0_0 = *packed.get_unchecked(gb0) as u32;
let b0_1 = *packed.get_unchecked(gb0 + 1) as u32;
let b0_2 = *packed.get_unchecked(gb0 + 2) as u32;
let combined0 = b0_0 | (b0_1 << 8) | (b0_2 << 16);
let v0_0 = (combined0 & 0x7) as i32;
let v0_1 = ((combined0 >> 3) & 0x7) as i32;
let v0_2 = ((combined0 >> 6) & 0x7) as i32;
let v0_3 = ((combined0 >> 9) & 0x7) as i32;
let v0_4 = ((combined0 >> 12) & 0x7) as i32;
let v0_5 = ((combined0 >> 15) & 0x7) as i32;
let v0_6 = ((combined0 >> 18) & 0x7) as i32;
let v0_7 = ((combined0 >> 21) & 0x7) as i32;
let gb1 = byte_offset + g1 * 3;
let b1_0 = *packed.get_unchecked(gb1) as u32;
let b1_1 = *packed.get_unchecked(gb1 + 1) as u32;
let b1_2 = *packed.get_unchecked(gb1 + 2) as u32;
let combined1 = b1_0 | (b1_1 << 8) | (b1_2 << 16);
let v1_0 = (combined1 & 0x7) as i32;
let v1_1 = ((combined1 >> 3) & 0x7) as i32;
let v1_2 = ((combined1 >> 6) & 0x7) as i32;
let v1_3 = ((combined1 >> 9) & 0x7) as i32;
let v1_4 = ((combined1 >> 12) & 0x7) as i32;
let v1_5 = ((combined1 >> 15) & 0x7) as i32;
let v1_6 = ((combined1 >> 18) & 0x7) as i32;
let v1_7 = ((combined1 >> 21) & 0x7) as i32;
let raw_vec = _mm512_setr_epi32(
v0_0, v0_1, v0_2, v0_3, v0_4, v0_5, v0_6, v0_7, v1_0, v1_1, v1_2, v1_3, v1_4, v1_5,
v1_6, v1_7,
);
let signed_vec = _mm512_add_epi32(raw_vec, bias_vec);
let float_vec = _mm512_cvtepi32_ps(signed_vec);
let result_vec = _mm512_mul_ps(float_vec, scale_vec);
let go = out_offset + batch * 16;
_mm512_storeu_ps(output.as_mut_ptr().add(go), result_vec);
}
group += 8;
}
while group < num_groups {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
let b0 = *packed.get_unchecked(byte_offset) as u32;
let b1 = *packed.get_unchecked(byte_offset + 1) as u32;
let b2 = *packed.get_unchecked(byte_offset + 2) as u32;
let combined = b0 | (b1 << 8) | (b2 << 16);
for i in 0..8 {
let shift = i * 3;
let raw = ((combined >> shift) & 0x7) as i32;
let signed = raw - 4;
*output.get_unchecked_mut(out_offset + i) = (signed as f32) * scale;
}
group += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn pi_quantize_avx512(weights: &[f32], scale: f32, output: &mut [u8]) {
use core::arch::x86_64::*;
assert!(
weights.len() % PI3_VALUES_PER_GROUP == 0,
"Weights length must be multiple of 8"
);
let num_groups = weights.len() / PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
num_groups * PI3_BYTES_PER_GROUP,
"Output buffer size mismatch"
);
if num_groups == 0 {
return;
}
let inv_scale = if scale.abs() > 1e-10 {
1.0 / scale
} else {
0.0
};
let inv_scale_vec = _mm512_set1_ps(inv_scale);
let bias_vec = _mm512_set1_epi32(4);
let min_vec = _mm512_set1_epi32(0);
let max_vec = _mm512_set1_epi32(7);
const ROUNDING: i32 = 0x08;
let simd_groups = num_groups / 2;
let mut group = 0usize;
while group < simd_groups * 2 {
let val_offset = group * PI3_VALUES_PER_GROUP;
let byte_offset = group * PI3_BYTES_PER_GROUP;
let weights_vec = _mm512_loadu_ps(weights.as_ptr().add(val_offset));
let scaled_vec = _mm512_mul_ps(weights_vec, inv_scale_vec);
let rounded_vec = _mm512_roundscale_ps(scaled_vec, ROUNDING);
let quantized_vec = _mm512_cvtps_epi32(rounded_vec);
let biased_vec = _mm512_add_epi32(quantized_vec, bias_vec);
let clamped_vec = _mm512_max_epi32(_mm512_min_epi32(biased_vec, max_vec), min_vec);
let mut values = [0i32; 16];
_mm512_storeu_si512(values.as_mut_ptr() as *mut __m512i, clamped_vec);
let combined0: u32 = (values[0] as u32 & 0x7)
| ((values[1] as u32 & 0x7) << 3)
| ((values[2] as u32 & 0x7) << 6)
| ((values[3] as u32 & 0x7) << 9)
| ((values[4] as u32 & 0x7) << 12)
| ((values[5] as u32 & 0x7) << 15)
| ((values[6] as u32 & 0x7) << 18)
| ((values[7] as u32 & 0x7) << 21);
*output.get_unchecked_mut(byte_offset) = (combined0 & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 1) = ((combined0 >> 8) & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 2) = ((combined0 >> 16) & 0xFF) as u8;
let combined1: u32 = (values[8] as u32 & 0x7)
| ((values[9] as u32 & 0x7) << 3)
| ((values[10] as u32 & 0x7) << 6)
| ((values[11] as u32 & 0x7) << 9)
| ((values[12] as u32 & 0x7) << 12)
| ((values[13] as u32 & 0x7) << 15)
| ((values[14] as u32 & 0x7) << 18)
| ((values[15] as u32 & 0x7) << 21);
*output.get_unchecked_mut(byte_offset + 3) = (combined1 & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 4) = ((combined1 >> 8) & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 5) = ((combined1 >> 16) & 0xFF) as u8;
group += 2;
}
while group < num_groups {
let val_offset = group * PI3_VALUES_PER_GROUP;
let byte_offset = group * PI3_BYTES_PER_GROUP;
let mut combined: u32 = 0;
for i in 0..8 {
let v = *weights.get_unchecked(val_offset + i);
let quantized = (v * inv_scale).round() as i32;
let clamped = quantized.clamp(-4, 3);
let unsigned = (clamped + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
*output.get_unchecked_mut(byte_offset) = (combined & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 1) = ((combined >> 8) & 0xFF) as u8;
*output.get_unchecked_mut(byte_offset + 2) = ((combined >> 16) & 0xFF) as u8;
group += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn pi_dequantize_avx2(packed: &[u8], scale: f32, output: &mut [f32]) {
use core::arch::x86_64::*;
let num_groups = packed.len() / PI3_BYTES_PER_GROUP;
let total_values = num_groups * PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
total_values,
"Output length mismatch: {} vs expected {}",
output.len(),
total_values
);
if num_groups == 0 {
return;
}
let scale_vec = _mm256_set1_ps(scale);
let bias_vec = _mm256_set1_epi32(-4);
let simd_groups = num_groups / 4;
let mut group = 0usize;
while group < simd_groups * 4 {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
for g in 0..4 {
let gb = byte_offset + g * 3;
let go = out_offset + g * 8;
let b0 = *packed.get_unchecked(gb) as u32;
let b1 = *packed.get_unchecked(gb + 1) as u32;
let b2 = *packed.get_unchecked(gb + 2) as u32;
let combined = b0 | (b1 << 8) | (b2 << 16);
let v0 = (combined & 0x7) as i32;
let v1 = ((combined >> 3) & 0x7) as i32;
let v2 = ((combined >> 6) & 0x7) as i32;
let v3 = ((combined >> 9) & 0x7) as i32;
let v4 = ((combined >> 12) & 0x7) as i32;
let v5 = ((combined >> 15) & 0x7) as i32;
let v6 = ((combined >> 18) & 0x7) as i32;
let v7 = ((combined >> 21) & 0x7) as i32;
let raw_vec = _mm256_setr_epi32(v0, v1, v2, v3, v4, v5, v6, v7);
let signed_vec = _mm256_add_epi32(raw_vec, bias_vec);
let float_vec = _mm256_cvtepi32_ps(signed_vec);
let result_vec = _mm256_mul_ps(float_vec, scale_vec);
_mm256_storeu_ps(output.as_mut_ptr().add(go), result_vec);
}
group += 4;
}
while group < num_groups {
let byte_offset = group * PI3_BYTES_PER_GROUP;
let out_offset = group * PI3_VALUES_PER_GROUP;
let b0 = *packed.get_unchecked(byte_offset) as u32;
let b1 = *packed.get_unchecked(byte_offset + 1) as u32;
let b2 = *packed.get_unchecked(byte_offset + 2) as u32;
let combined = b0 | (b1 << 8) | (b2 << 16);
for i in 0..8 {
let shift = i * 3;
let raw = ((combined >> shift) & 0x7) as i32;
let signed = raw - 4;
*output.get_unchecked_mut(out_offset + i) = (signed as f32) * scale;
}
group += 1;
}
}
pub fn pi_dequantize(packed: &[u8], scale: f32, output: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
{
unsafe {
pi_dequantize_neon(packed, scale, output);
}
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe {
pi_dequantize_avx512(packed, scale, output);
}
return;
}
if is_x86_feature_detected!("avx2") {
unsafe {
pi_dequantize_avx2(packed, scale, output);
}
return;
}
pi_dequantize_scalar(packed, scale, output);
return;
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
pi_dequantize_scalar(packed, scale, output);
}
}
pub fn pi_dequantize_kernel_name() -> &'static str {
#[cfg(target_arch = "aarch64")]
{
return "neon";
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return "avx512";
}
if is_x86_feature_detected!("avx2") {
return "avx2";
}
}
"scalar"
}
pub fn pi_quantize(weights: &[f32], scale: f32, output: &mut [u8]) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe {
pi_quantize_avx512(weights, scale, output);
}
return;
}
}
pi_quantize_scalar(weights, scale, output);
}
pub fn pi_quantize_kernel_name() -> &'static str {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return "avx512";
}
}
"scalar"
}
pub fn pi_quantize_scalar(values: &[f32], scale: f32, output: &mut [u8]) {
assert!(
values.len() % PI3_VALUES_PER_GROUP == 0,
"Values length must be multiple of 8"
);
let num_groups = values.len() / PI3_VALUES_PER_GROUP;
assert_eq!(
output.len(),
num_groups * PI3_BYTES_PER_GROUP,
"Output buffer size mismatch"
);
let inv_scale = if scale.abs() > 1e-10 {
1.0 / scale
} else {
0.0
};
for group in 0..num_groups {
let val_offset = group * PI3_VALUES_PER_GROUP;
let byte_offset = group * PI3_BYTES_PER_GROUP;
let mut combined: u32 = 0;
for i in 0..8 {
let v = values[val_offset + i];
let quantized = (v * inv_scale).round() as i32;
let clamped = quantized.clamp(-4, 3);
let unsigned = (clamped + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
output[byte_offset] = (combined & 0xFF) as u8;
output[byte_offset + 1] = ((combined >> 8) & 0xFF) as u8;
output[byte_offset + 2] = ((combined >> 16) & 0xFF) as u8;
}
}
#[inline]
pub fn pi_scale(k: u32) -> f32 {
PI_F32 / (k as f32)
}
#[inline]
pub fn pi_scale_adaptive(alpha: f32, k: u32) -> f32 {
alpha * PI_F32 / (k as f32)
}
#[inline]
pub fn pi_scale_from_max(max_abs: f32) -> f32 {
if max_abs < 1e-10 {
PI_F32 / 4.0
} else {
max_abs / 3.0
}
}
#[inline]
pub fn pi_quantize_value(value: f32, scale: f32) -> i8 {
if scale < 1e-10 {
return 0;
}
let quantized = (value / scale).round() as i32;
quantized.clamp(-4, 3) as i8
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-6;
fn ulp_distance(a: f32, b: f32) -> u32 {
if a == b {
return 0;
}
if a.is_nan() || b.is_nan() {
return u32::MAX;
}
let a_bits = a.to_bits() as i32;
let b_bits = b.to_bits() as i32;
let a_signed = if a_bits < 0 {
i32::MIN - a_bits
} else {
a_bits
};
let b_signed = if b_bits < 0 {
i32::MIN - b_bits
} else {
b_bits
};
(a_signed - b_signed).unsigned_abs()
}
#[test]
fn test_pi_dequantize_scalar_zeros() {
let packed = vec![0u8; 3];
let scale = 1.0;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for &v in &output {
assert!((v - (-4.0)).abs() < EPSILON, "Expected -4.0, got {}", v);
}
}
#[test]
fn test_pi_dequantize_scalar_all_4s() {
let combined: u32 = (0..8).map(|i| 4u32 << (i * 3)).sum();
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
let scale = 1.0;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for &v in &output {
assert!((v - 0.0).abs() < EPSILON, "Expected 0.0, got {}", v);
}
}
#[test]
fn test_pi_dequantize_scalar_range() {
let values: Vec<i32> = (-4..=3).collect();
let mut combined: u32 = 0;
for (i, &v) in values.iter().enumerate() {
let unsigned = (v + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
let scale = 0.5;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for (i, &v) in values.iter().enumerate() {
let expected = (v as f32) * scale;
assert!(
(output[i] - expected).abs() < EPSILON,
"Index {}: expected {}, got {}",
i,
expected,
output[i]
);
}
}
#[test]
fn test_pi_dequantize_scalar_pi_scale() {
let scale = pi_scale(4);
let values = [-4i32, -3, -2, -1, 0, 1, 2, 3];
let mut combined: u32 = 0;
for (i, &v) in values.iter().enumerate() {
let unsigned = (v + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for (i, &v) in values.iter().enumerate() {
let expected = (v as f32) * scale;
assert!(
(output[i] - expected).abs() < EPSILON,
"Index {}: expected {}, got {}",
i,
expected,
output[i]
);
}
}
#[test]
fn test_pi_dequantize_scalar_multiple_groups() {
let scale = 1.0;
let mut packed = vec![0u8; 12];
for group in 0..4 {
let base_val = group as i32;
let clamped = base_val.clamp(-4, 3);
let unsigned = (clamped + 4) as u32;
let combined: u32 = (0..8).map(|_| unsigned).fold(0u32, |acc, v| {
acc
});
let combined: u32 = (0..8).map(|i| unsigned << (i * 3)).sum();
packed[group * 3] = (combined & 0xFF) as u8;
packed[group * 3 + 1] = ((combined >> 8) & 0xFF) as u8;
packed[group * 3 + 2] = ((combined >> 16) & 0xFF) as u8;
}
let mut output = vec![0.0f32; 32];
pi_dequantize_scalar(&packed, scale, &mut output);
for group in 0..4 {
let expected_val = (group as i32).clamp(-4, 3) as f32;
for i in 0..8 {
let idx = group * 8 + i;
assert!(
(output[idx] - expected_val).abs() < EPSILON,
"Group {}, index {}: expected {}, got {}",
group,
i,
expected_val,
output[idx]
);
}
}
}
#[test]
fn test_extract_pi3_value() {
let values = [-4i32, -3, -2, -1, 0, 1, 2, 3];
let mut combined: u32 = 0;
for (i, &v) in values.iter().enumerate() {
let unsigned = (v + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
for (i, &expected) in values.iter().enumerate() {
let actual = extract_pi3_value(&packed, i);
assert_eq!(
actual, expected as i8,
"Index {}: expected {}, got {}",
i, expected, actual
);
}
}
#[test]
fn test_quantize_dequantize_roundtrip() {
let scale = pi_scale(4);
let original: Vec<f32> = (-4..=3).map(|v| (v as f32) * scale).collect();
let mut packed = vec![0u8; 3];
let mut reconstructed = vec![0.0f32; 8];
pi_quantize_scalar(&original, scale, &mut packed);
pi_dequantize_scalar(&packed, scale, &mut reconstructed);
for (i, (&orig, &recon)) in original.iter().zip(reconstructed.iter()).enumerate() {
let ulp = ulp_distance(orig, recon);
assert!(
ulp <= 1,
"Index {}: roundtrip error > 1 ULP: orig={}, recon={}, ulp={}",
i,
orig,
recon,
ulp
);
}
}
#[test]
fn test_quantize_clipping() {
let scale = 1.0;
let values = vec![-10.0, -5.0, -4.0, 0.0, 3.0, 5.0, 10.0, 100.0];
let mut packed = vec![0u8; 3];
pi_quantize_scalar(&values, scale, &mut packed);
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
assert!((output[0] - (-4.0)).abs() < EPSILON); assert!((output[1] - (-4.0)).abs() < EPSILON); assert!((output[2] - (-4.0)).abs() < EPSILON); assert!((output[3] - 0.0).abs() < EPSILON); assert!((output[4] - 3.0).abs() < EPSILON); assert!((output[5] - 3.0).abs() < EPSILON); assert!((output[6] - 3.0).abs() < EPSILON); assert!((output[7] - 3.0).abs() < EPSILON); }
#[cfg(target_arch = "aarch64")]
#[test]
fn test_neon_equivalence_to_scalar() {
for num_groups in [1, 4, 16, 100] {
let packed: Vec<u8> = (0..num_groups * 3)
.map(|i| (i * 17) as u8) .collect();
let scale = pi_scale(4);
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut neon_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
unsafe {
pi_dequantize_neon(&packed, scale, &mut neon_output);
}
for i in 0..scalar_output.len() {
let ulp = ulp_distance(scalar_output[i], neon_output[i]);
assert!(
ulp <= 1,
"NEON divergence at index {} (groups={}): scalar={}, neon={}, ulp={}",
i,
num_groups,
scalar_output[i],
neon_output[i],
ulp
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx2_equivalence_to_scalar() {
if !is_x86_feature_detected!("avx2") {
println!("Skipping AVX2 test: feature not available");
return;
}
for num_groups in [1, 4, 16, 100] {
let packed: Vec<u8> = (0..num_groups * 3)
.map(|i| (i * 17) as u8) .collect();
let scale = pi_scale(4);
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut avx2_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
unsafe {
pi_dequantize_avx2(&packed, scale, &mut avx2_output);
}
for i in 0..scalar_output.len() {
let ulp = ulp_distance(scalar_output[i], avx2_output[i]);
assert!(
ulp <= 1,
"AVX2 divergence at index {} (groups={}): scalar={}, avx2={}, ulp={}",
i,
num_groups,
scalar_output[i],
avx2_output[i],
ulp
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx512_dequantize_equivalence_to_scalar() {
if !is_x86_feature_detected!("avx512f") {
println!("Skipping AVX-512 dequantize test: feature not available");
return;
}
for num_groups in [1, 4, 8, 16, 32, 100, 123] {
let packed: Vec<u8> = (0..num_groups * 3)
.map(|i| (i * 17) as u8) .collect();
let scale = pi_scale(4);
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut avx512_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
unsafe {
pi_dequantize_avx512(&packed, scale, &mut avx512_output);
}
for i in 0..scalar_output.len() {
let ulp = ulp_distance(scalar_output[i], avx512_output[i]);
assert!(
ulp <= 1,
"AVX-512 dequantize divergence at index {} (groups={}): scalar={}, avx512={}, ulp={}",
i,
num_groups,
scalar_output[i],
avx512_output[i],
ulp
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx512_quantize_equivalence_to_scalar() {
if !is_x86_feature_detected!("avx512f") {
println!("Skipping AVX-512 quantize test: feature not available");
return;
}
for num_groups in [1, 2, 4, 8, 16, 32, 100, 123] {
let num_values = num_groups * 8;
let scale = pi_scale(4);
let weights: Vec<f32> = (0..num_values)
.map(|i| {
let t = (i as f32) / (num_values as f32);
-4.0 * scale + t * 7.0 * scale
})
.collect();
let num_bytes = num_groups * 3;
let mut scalar_output = vec![0u8; num_bytes];
let mut avx512_output = vec![0u8; num_bytes];
pi_quantize_scalar(&weights, scale, &mut scalar_output);
unsafe {
pi_quantize_avx512(&weights, scale, &mut avx512_output);
}
for i in 0..num_bytes {
assert_eq!(
scalar_output[i], avx512_output[i],
"AVX-512 quantize divergence at byte {} (groups={}): scalar={:#04x}, avx512={:#04x}",
i, num_groups, scalar_output[i], avx512_output[i]
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx512_quantize_dequantize_roundtrip() {
if !is_x86_feature_detected!("avx512f") {
println!("Skipping AVX-512 roundtrip test: feature not available");
return;
}
let scale = pi_scale(4);
for num_groups in [1, 2, 8, 16, 100] {
let num_values = num_groups * 8;
let original: Vec<f32> = (0..num_values)
.map(|i| {
let level = ((i % 8) as i32) - 4;
(level as f32) * scale
})
.collect();
let num_bytes = num_groups * 3;
let mut packed = vec![0u8; num_bytes];
let mut reconstructed = vec![0.0f32; num_values];
unsafe {
pi_quantize_avx512(&original, scale, &mut packed);
pi_dequantize_avx512(&packed, scale, &mut reconstructed);
}
for (i, (&orig, &recon)) in original.iter().zip(reconstructed.iter()).enumerate() {
let ulp = ulp_distance(orig, recon);
assert!(
ulp <= 1,
"AVX-512 roundtrip error > 1 ULP at index {}: orig={}, recon={}, ulp={}",
i,
orig,
recon,
ulp
);
}
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx512_all_value_combinations() {
if !is_x86_feature_detected!("avx512f") {
println!("Skipping AVX-512 all values test: feature not available");
return;
}
let scale = pi_scale(4);
let values: [i32; 8] = [-4, -3, -2, -1, 0, 1, 2, 3];
let mut combined: u32 = 0;
for (i, &v) in values.iter().enumerate() {
let unsigned = (v + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
let mut scalar_output = vec![0.0f32; 8];
let mut avx512_output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
unsafe {
pi_dequantize_avx512(&packed, scale, &mut avx512_output);
}
for i in 0..8 {
let ulp = ulp_distance(scalar_output[i], avx512_output[i]);
assert!(
ulp <= 1,
"AVX-512 all values divergence at index {}: scalar={}, avx512={}, ulp={}",
i,
scalar_output[i],
avx512_output[i],
ulp
);
let expected = (values[i] as f32) * scale;
let ulp_expected = ulp_distance(expected, avx512_output[i]);
assert!(
ulp_expected <= 1,
"AVX-512 expected value divergence at index {}: expected={}, avx512={}, ulp={}",
i,
expected,
avx512_output[i],
ulp_expected
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_avx512_edge_cases() {
if !is_x86_feature_detected!("avx512f") {
println!("Skipping AVX-512 edge cases test: feature not available");
return;
}
let test_scales = [
1.0f32, 0.001, 1000.0, -1.0, PI / 4.0, PI / 2.0, f32::MIN_POSITIVE, ];
for &scale in &test_scales {
let num_groups = 8;
let packed: Vec<u8> = (0..num_groups * 3).map(|i| (i * 31) as u8).collect();
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut avx512_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
unsafe {
pi_dequantize_avx512(&packed, scale, &mut avx512_output);
}
for i in 0..scalar_output.len() {
let ulp = ulp_distance(scalar_output[i], avx512_output[i]);
assert!(
ulp <= 1,
"AVX-512 edge case (scale={}) divergence at index {}: scalar={}, avx512={}, ulp={}",
scale,
i,
scalar_output[i],
avx512_output[i],
ulp
);
}
}
}
#[test]
fn test_dispatch_equivalence() {
for num_groups in [1, 4, 8, 16, 32, 100, 123] {
let packed: Vec<u8> = (0..num_groups * 3).map(|i| (i * 23) as u8).collect();
let scale = pi_scale(4);
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut dispatch_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
pi_dequantize(&packed, scale, &mut dispatch_output);
for i in 0..scalar_output.len() {
let ulp = ulp_distance(scalar_output[i], dispatch_output[i]);
assert!(
ulp <= 1,
"Dispatch ({}) divergence at index {}: scalar={}, dispatch={}, ulp={}",
pi_dequantize_kernel_name(),
i,
scalar_output[i],
dispatch_output[i],
ulp
);
}
}
}
#[test]
fn test_quantize_dispatch_equivalence() {
for num_groups in [1, 2, 4, 8, 16, 100] {
let num_values = num_groups * 8;
let scale = pi_scale(4);
let weights: Vec<f32> = (0..num_values)
.map(|i| {
let t = (i as f32) / (num_values as f32);
-4.0 * scale + t * 7.0 * scale
})
.collect();
let num_bytes = num_groups * 3;
let mut scalar_output = vec![0u8; num_bytes];
let mut dispatch_output = vec![0u8; num_bytes];
pi_quantize_scalar(&weights, scale, &mut scalar_output);
pi_quantize(&weights, scale, &mut dispatch_output);
for i in 0..num_bytes {
assert_eq!(
scalar_output[i], dispatch_output[i],
"Quantize dispatch ({}) divergence at byte {}: scalar={:#04x}, dispatch={:#04x}",
pi_quantize_kernel_name(),
i,
scalar_output[i],
dispatch_output[i]
);
}
}
}
#[test]
fn test_empty_input() {
let packed: Vec<u8> = vec![];
let scale = 1.0;
let mut output: Vec<f32> = vec![];
pi_dequantize_scalar(&packed, scale, &mut output);
assert!(output.is_empty());
}
#[test]
fn test_zero_scale() {
let packed = vec![0xFFu8; 3]; let scale = 0.0;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for &v in &output {
assert!((v - 0.0).abs() < EPSILON, "Zero scale should give zeros");
}
}
#[test]
fn test_negative_scale() {
let values = [1i32; 8];
let mut combined: u32 = 0;
for (i, &v) in values.iter().enumerate() {
let unsigned = (v + 4) as u32;
combined |= (unsigned & 0x7) << (i * 3);
}
let packed = vec![
(combined & 0xFF) as u8,
((combined >> 8) & 0xFF) as u8,
((combined >> 16) & 0xFF) as u8,
];
let scale = -1.0;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
for &v in &output {
assert!((v - (-1.0)).abs() < EPSILON, "Expected -1.0, got {}", v);
}
}
#[test]
#[should_panic(expected = "Packed length")]
fn test_invalid_packed_length() {
let packed = vec![0u8; 4]; let scale = 1.0;
let mut output = vec![0.0f32; 8];
pi_dequantize_scalar(&packed, scale, &mut output);
}
#[test]
#[should_panic(expected = "Output length")]
fn test_output_length_mismatch() {
let packed = vec![0u8; 3];
let scale = 1.0;
let mut output = vec![0.0f32; 4];
pi_dequantize_scalar(&packed, scale, &mut output);
}
#[test]
fn test_pi_scale() {
assert!((pi_scale(4) - PI / 4.0).abs() < EPSILON);
assert!((pi_scale(2) - PI / 2.0).abs() < EPSILON);
assert!((pi_scale(8) - PI / 8.0).abs() < EPSILON);
}
#[test]
fn test_pi_scale_adaptive() {
assert!((pi_scale_adaptive(2.0, 4) - 2.0 * PI / 4.0).abs() < EPSILON);
assert!((pi_scale_adaptive(0.5, 4) - 0.5 * PI / 4.0).abs() < EPSILON);
}
#[test]
fn test_kernel_name() {
let name = pi_dequantize_kernel_name();
assert!(
name == "neon" || name == "avx512" || name == "avx2" || name == "scalar",
"Unknown kernel name: {}",
name
);
let quant_name = pi_quantize_kernel_name();
assert!(
quant_name == "avx512" || quant_name == "scalar",
"Unknown quantize kernel name: {}",
quant_name
);
}
#[test]
fn test_large_data_simd_path() {
let num_groups = 1000;
let packed: Vec<u8> = (0..num_groups * 3).map(|i| (i % 256) as u8).collect();
let scale = pi_scale(4);
let mut output = vec![0.0f32; num_groups * 8];
pi_dequantize(&packed, scale, &mut output);
for (i, &v) in output.iter().enumerate() {
assert!(v.is_finite(), "Non-finite value at index {}: {}", i, v);
let min_val = -4.0 * scale;
let max_val = 3.0 * scale;
assert!(
v >= min_val && v <= max_val,
"Value {} at index {} out of range [{}, {}]",
v,
i,
min_val,
max_val
);
}
}
#[test]
fn test_simd_remainder_handling() {
for num_groups in [1, 2, 3, 5, 6, 7, 9, 13, 17] {
let packed: Vec<u8> = (0..num_groups * 3).map(|i| (i * 37) as u8).collect();
let scale = 1.0;
let mut scalar_output = vec![0.0f32; num_groups * 8];
let mut dispatch_output = vec![0.0f32; num_groups * 8];
pi_dequantize_scalar(&packed, scale, &mut scalar_output);
pi_dequantize(&packed, scale, &mut dispatch_output);
for i in 0..scalar_output.len() {
assert!(
(scalar_output[i] - dispatch_output[i]).abs() < EPSILON,
"Remainder mismatch at index {} (groups={}): {} vs {}",
i,
num_groups,
scalar_output[i],
dispatch_output[i]
);
}
}
}
}