#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub mod pi_constants {
use std::f32::consts::PI;
pub const PI_FRAC: f32 = PI - 3.0;
pub const PI_SCALE: f32 = PI / 4.0;
pub const PHI_APPROX: f32 = 2.0 / (PI - 1.0);
pub const PI_DIGITS: [u8; 16] = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3];
#[inline]
pub fn anti_resonance(bits: u8) -> f32 {
PI_FRAC / (1u32 << bits) as f32
}
#[inline]
pub fn jitter(index: usize) -> f32 {
let digit = PI_DIGITS[index % 16];
(digit as f32) * 0.001 * PI_FRAC
}
}
#[derive(Debug, Clone, Copy)]
pub struct QuantParams {
pub scale: f32,
pub zero_point: i8,
pub anti_resonance: f32,
pub bits: u8,
}
impl QuantParams {
pub fn symmetric(min_val: f32, max_val: f32) -> Self {
let abs_max = min_val.abs().max(max_val.abs());
let bits = 7u8;
let qmax = 127.0f32;
let anti_resonance = pi_constants::anti_resonance(bits);
let scale = (abs_max + anti_resonance) / qmax;
Self {
scale: scale.max(1e-10), zero_point: 0,
anti_resonance,
bits,
}
}
pub fn asymmetric(min_val: f32, max_val: f32) -> Self {
let bits = 8u8;
let qmin = -128.0f32;
let qmax = 127.0f32;
let anti_resonance = pi_constants::anti_resonance(bits);
let range = (max_val - min_val).max(1e-10) + anti_resonance;
let scale = range / (qmax - qmin);
let zero_point_float = qmin - min_val / scale + pi_constants::jitter(0);
let zero_point = zero_point_float.round().clamp(-128.0, 127.0) as i8;
Self {
scale: scale.max(1e-10),
zero_point,
anti_resonance,
bits,
}
}
#[inline]
pub fn quantize(&self, value: f32) -> i8 {
let scaled = value / self.scale + self.zero_point as f32;
let rounded = (scaled + self.anti_resonance * 0.5).round();
rounded.clamp(-128.0, 127.0) as i8
}
#[inline]
pub fn dequantize(&self, quantized: i8) -> f32 {
(quantized as f32 - self.zero_point as f32) * self.scale
}
}
impl Default for QuantParams {
fn default() -> Self {
Self::symmetric(-1.0, 1.0)
}
}
#[derive(Debug, Clone)]
pub struct PerChannelQuantParams {
pub scales: Vec<f32>,
pub zero_points: Vec<i8>,
pub num_channels: usize,
}
impl PerChannelQuantParams {
pub fn symmetric_per_channel(weights: &[f32], out_channels: usize, in_channels: usize) -> Self {
let kernel_size = weights.len() / (out_channels * in_channels);
let mut scales = Vec::with_capacity(out_channels);
let zero_points = vec![0i8; out_channels];
for oc in 0..out_channels {
let start = oc * in_channels * kernel_size;
let end = start + in_channels * kernel_size;
let channel_weights = &weights[start..end];
let abs_max = channel_weights
.iter()
.map(|x| x.abs())
.fold(0.0f32, |a, b| a.max(b));
let anti_res = pi_constants::anti_resonance(7);
let scale = (abs_max + anti_res) / 127.0;
scales.push(scale.max(1e-10));
}
Self {
scales,
zero_points,
num_channels: out_channels,
}
}
#[inline]
pub fn channel_params(&self, channel: usize) -> QuantParams {
QuantParams {
scale: self.scales[channel],
zero_point: self.zero_points[channel],
anti_resonance: pi_constants::anti_resonance(7),
bits: 7,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Vec<i8>,
pub shape: Vec<usize>,
pub params: QuantizationType,
}
#[derive(Debug, Clone)]
pub enum QuantizationType {
PerTensor(QuantParams),
PerChannel(PerChannelQuantParams),
}
impl QuantizedTensor {
pub fn from_float_symmetric(data: &[f32], shape: &[usize]) -> Self {
let min_val = data.iter().fold(f32::MAX, |a, &b| a.min(b));
let max_val = data.iter().fold(f32::MIN, |a, &b| a.max(b));
let params = QuantParams::symmetric(min_val, max_val);
let quantized: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
Self {
data: quantized,
shape: shape.to_vec(),
params: QuantizationType::PerTensor(params),
}
}
pub fn from_weights_per_channel(
weights: &[f32],
out_channels: usize,
in_channels: usize,
kernel_h: usize,
kernel_w: usize,
) -> Self {
let per_channel = PerChannelQuantParams::symmetric_per_channel(weights, out_channels, in_channels);
let kernel_size = kernel_h * kernel_w;
let mut quantized = Vec::with_capacity(weights.len());
for oc in 0..out_channels {
let params = per_channel.channel_params(oc);
let start = oc * in_channels * kernel_size;
let end = start + in_channels * kernel_size;
for &w in &weights[start..end] {
quantized.push(params.quantize(w));
}
}
Self {
data: quantized,
shape: vec![out_channels, in_channels, kernel_h, kernel_w],
params: QuantizationType::PerChannel(per_channel),
}
}
pub fn dequantize(&self) -> Vec<f32> {
match &self.params {
QuantizationType::PerTensor(params) => {
self.data.iter().map(|&q| params.dequantize(q)).collect()
}
QuantizationType::PerChannel(per_channel) => {
let out_channels = self.shape[0];
let channel_size = self.data.len() / out_channels;
let mut output = Vec::with_capacity(self.data.len());
for oc in 0..out_channels {
let params = per_channel.channel_params(oc);
let start = oc * channel_size;
let end = start + channel_size;
for &q in &self.data[start..end] {
output.push(params.dequantize(q));
}
}
output
}
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
pub fn quantize_batch(input: &[f32], output: &mut [i8], params: &QuantParams) {
debug_assert_eq!(input.len(), output.len());
let inv_scale = 1.0 / params.scale;
let zp = params.zero_point as f32;
let anti_res = params.anti_resonance * 0.5;
for (i, &val) in input.iter().enumerate() {
let scaled = val * inv_scale + zp + anti_res;
output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
}
}
pub fn dequantize_batch(input: &[i8], output: &mut [f32], params: &QuantParams) {
debug_assert_eq!(input.len(), output.len());
let zp = params.zero_point as f32;
for (i, &q) in input.iter().enumerate() {
output[i] = (q as f32 - zp) * params.scale;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn quantize_batch_avx2(input: &[f32], output: &mut [i8], params: &QuantParams) {
let len = input.len();
let chunks = len / 8;
let inv_scale = _mm256_set1_ps(1.0 / params.scale);
let zp = _mm256_set1_ps(params.zero_point as f32);
let anti_res = _mm256_set1_ps(params.anti_resonance * 0.5);
let half = _mm256_set1_ps(0.5);
let min_val = _mm256_set1_ps(-128.0);
let max_val = _mm256_set1_ps(127.0);
for i in 0..chunks {
let offset = i * 8;
let v = _mm256_loadu_ps(input.as_ptr().add(offset));
let scaled = _mm256_add_ps(_mm256_mul_ps(v, inv_scale), zp);
let adjusted = _mm256_add_ps(scaled, anti_res);
let rounded = _mm256_round_ps(adjusted, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
let clamped = _mm256_min_ps(_mm256_max_ps(rounded, min_val), max_val);
let i32_vals = _mm256_cvtps_epi32(clamped);
let i32_array: [i32; 8] = std::mem::transmute(i32_vals);
for j in 0..8 {
output[offset + j] = i32_array[j] as i8;
}
}
let remainder_start = chunks * 8;
for i in remainder_start..len {
let scaled = input[i] / params.scale + params.zero_point as f32 + params.anti_resonance * 0.5;
output[i] = scaled.round().clamp(-128.0, 127.0) as i8;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn dequantize_batch_avx2(input: &[i8], output: &mut [f32], params: &QuantParams) {
let len = input.len();
let chunks = len / 8;
let scale = _mm256_set1_ps(params.scale);
let zp = _mm256_set1_ps(params.zero_point as f32);
for i in 0..chunks {
let offset = i * 8;
let mut i32_array = [0i32; 8];
for j in 0..8 {
i32_array[j] = input[offset + j] as i32;
}
let i32_vals: __m256i = std::mem::transmute(i32_array);
let f32_vals = _mm256_cvtepi32_ps(i32_vals);
let shifted = _mm256_sub_ps(f32_vals, zp);
let result = _mm256_mul_ps(shifted, scale);
_mm256_storeu_ps(output.as_mut_ptr().add(offset), result);
}
let remainder_start = chunks * 8;
for i in remainder_start..len {
output[i] = (input[i] as f32 - params.zero_point as f32) * params.scale;
}
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn quantize_batch_avx2(_input: &[f32], _output: &mut [i8], _params: &QuantParams) {}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn dequantize_batch_avx2(_input: &[i8], _output: &mut [f32], _params: &QuantParams) {}
#[inline(always)]
pub fn quantize_simd(input: &[f32], output: &mut [i8], params: &QuantParams) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
quantize_batch_avx2(input, output, params);
}
return;
}
}
quantize_batch(input, output, params);
}
#[inline(always)]
pub fn dequantize_simd(input: &[i8], output: &mut [f32], params: &QuantParams) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe {
dequantize_batch_avx2(input, output, params);
}
return;
}
}
dequantize_batch(input, output, params);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symmetric_quantization() {
let params = QuantParams::symmetric(-1.0, 1.0);
let q = params.quantize(0.5);
let dq = params.dequantize(q);
assert!((0.5 - dq).abs() < 0.02);
}
#[test]
fn test_asymmetric_quantization() {
let params = QuantParams::asymmetric(0.0, 1.0);
let q = params.quantize(0.5);
let dq = params.dequantize(q);
assert!((0.5 - dq).abs() < 0.02);
}
#[test]
fn test_pi_anti_resonance() {
use std::f32::consts::PI;
let anti_res = pi_constants::anti_resonance(8);
assert!(anti_res > 0.0);
assert!(anti_res < 0.001);
let expected = (PI - 3.0) / 256.0;
assert!((anti_res - expected).abs() < 1e-10);
}
#[test]
fn test_quantized_tensor_roundtrip() {
let data = vec![0.1, 0.2, 0.3, 0.4, -0.1, -0.2, -0.3, -0.4];
let shape = vec![2, 4];
let quantized = QuantizedTensor::from_float_symmetric(&data, &shape);
let dequantized = quantized.dequantize();
for (original, recovered) in data.iter().zip(dequantized.iter()) {
assert!((original - recovered).abs() < 0.02);
}
}
#[test]
fn test_per_channel_quantization() {
let weights: Vec<f32> = (0..36).map(|i| (i as f32 - 18.0) * 0.1).collect();
let quantized = QuantizedTensor::from_weights_per_channel(&weights, 2, 2, 3, 3);
let dequantized = quantized.dequantize();
let max_error: f32 = weights
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, |a, b| a.max(b));
assert!(max_error < 0.05);
}
#[test]
fn test_batch_quantize() {
let input = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let mut output = vec![0i8; 8];
let params = QuantParams::symmetric(-1.0, 1.0);
quantize_batch(&input, &mut output, ¶ms);
for &q in &output {
assert!(q >= -128 && q <= 127);
}
}
#[test]
fn test_batch_dequantize() {
let input = vec![10i8, 20, 30, 40, -10, -20, -30, -40];
let mut output = vec![0.0f32; 8];
let params = QuantParams::symmetric(-1.0, 1.0);
dequantize_batch(&input, &mut output, ¶ms);
assert!(output[0] > 0.0);
assert!(output[4] < 0.0);
}
#[test]
fn test_simd_dispatch() {
let input = vec![0.1f32; 16];
let mut output = vec![0i8; 16];
let params = QuantParams::symmetric(-1.0, 1.0);
quantize_simd(&input, &mut output, ¶ms);
let first = output[0];
for &q in &output {
assert_eq!(q, first);
}
}
}