use std::ops::{Add, Mul};
#[derive(Clone, Debug)]
pub struct Int8Quantizer {
pub scale: f32,
pub zero_point: i8,
}
impl Int8Quantizer {
pub fn from_weights(weights: &[f32]) -> Self {
let (min, max) = weights.iter().fold(
(f32::INFINITY, f32::NEG_INFINITY),
|(min, max), &w| (min.min(w), max.max(w))
);
let abs_max = min.abs().max(max.abs());
let scale = abs_max / 127.0;
Self {
scale,
zero_point: 0, }
}
#[inline]
pub fn quantize(&self, value: f32) -> i8 {
let scaled = value / self.scale;
scaled.round().clamp(-127.0, 127.0) as i8
}
#[inline]
pub fn dequantize(&self, value: i8) -> f32 {
value as f32 * self.scale
}
pub fn quantize_array(&self, values: &[f32]) -> Vec<i8> {
values.iter().map(|&v| self.quantize(v)).collect()
}
pub fn dequantize_array(&self, values: &[i8]) -> Vec<f32> {
values.iter().map(|&v| self.dequantize(v)).collect()
}
}
pub struct QuantizedWeights {
pub weights: Vec<i8>,
pub quantizer: Int8Quantizer,
pub shape: (usize, usize),
}
impl QuantizedWeights {
pub fn from_float_matrix(weights: &[f32], rows: usize, cols: usize) -> Self {
let quantizer = Int8Quantizer::from_weights(weights);
let quantized = quantizer.quantize_array(weights);
Self {
weights: quantized,
quantizer,
shape: (rows, cols),
}
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> f32 {
let idx = row * self.shape.1 + col;
self.quantizer.dequantize(self.weights[idx])
}
pub fn matmul_quantized(&self, input: &[f32], output: &mut [f32]) {
let (rows, cols) = self.shape;
for i in 0..rows {
let mut sum = 0.0f32;
let row_offset = i * cols;
for j in 0..cols {
let w_int8 = self.weights[row_offset + j];
sum += (w_int8 as f32) * input[j];
}
output[i] = sum * self.quantizer.scale;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn matmul_quantized_avx2(&self, input: &[f32], output: &mut [f32]) {
use std::arch::x86_64::*;
let (rows, cols) = self.shape;
let scale = _mm256_set1_ps(self.quantizer.scale);
for i in 0..rows {
let mut sum = _mm256_setzero_ps();
let row_offset = i * cols;
let chunks = cols / 8;
for j in 0..chunks {
let idx = j * 8;
let w_ptr = self.weights.as_ptr().add(row_offset + idx);
let w_i8 = _mm_loadl_epi64(w_ptr as *const __m128i);
let w_i32 = _mm256_cvtepi8_epi32(w_i8);
let w_f32 = _mm256_cvtepi32_ps(w_i32);
let x = _mm256_loadu_ps(&input[idx]);
sum = _mm256_fmadd_ps(w_f32, x, sum);
}
let sum_array: [f32; 8] = std::mem::transmute(sum);
let mut result = sum_array.iter().sum::<f32>();
for j in (chunks * 8)..cols {
result += (self.weights[row_offset + j] as f32) * input[j];
}
output[i] = result * self.quantizer.scale;
}
}
pub fn memory_size(&self) -> usize {
self.weights.len() }
}
pub struct QuantizedMlp {
pub w1: QuantizedWeights,
pub b1: Vec<f32>, pub w2: QuantizedWeights,
pub b2: Vec<f32>,
hidden_dim: usize,
}
impl QuantizedMlp {
pub fn from_float_mlp(w1: &[f32], b1: &[f32], w2: &[f32], b2: &[f32],
input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
Self {
w1: QuantizedWeights::from_float_matrix(w1, hidden_dim, input_dim),
b1: b1.to_vec(),
w2: QuantizedWeights::from_float_matrix(w2, output_dim, hidden_dim),
b2: b2.to_vec(),
hidden_dim,
}
}
pub fn forward(&self, input: &[f32], output: &mut [f32]) {
let mut hidden = vec![0.0f32; self.hidden_dim];
self.w1.matmul_quantized(input, &mut hidden);
for (h, &b) in hidden.iter_mut().zip(&self.b1) {
*h = (*h + b).max(0.0);
}
self.w2.matmul_quantized(&hidden, output);
for (o, &b) in output.iter_mut().zip(&self.b2) {
*o += b;
}
}
#[cfg(target_arch = "x86_64")]
pub fn forward_avx2(&self, input: &[f32], output: &mut [f32]) {
unsafe {
let mut hidden = vec![0.0f32; self.hidden_dim];
self.w1.matmul_quantized_avx2(input, &mut hidden);
use std::arch::x86_64::*;
let zero = _mm256_setzero_ps();
for i in (0..self.hidden_dim).step_by(8) {
if i + 8 <= self.hidden_dim {
let h = _mm256_loadu_ps(&hidden[i]);
let b = _mm256_loadu_ps(&self.b1[i]);
let sum = _mm256_add_ps(h, b);
let relu = _mm256_max_ps(sum, zero);
_mm256_storeu_ps(&mut hidden[i], relu);
} else {
for j in i..self.hidden_dim {
hidden[j] = (hidden[j] + self.b1[j]).max(0.0);
}
}
}
self.w2.matmul_quantized_avx2(&hidden, output);
for (o, &b) in output.iter_mut().zip(&self.b2) {
*o += b;
}
}
}
pub fn model_size(&self) -> usize {
self.w1.memory_size() +
self.w2.memory_size() +
(self.b1.len() + self.b2.len()) * 4 }
pub fn compression_ratio(&self, original_params: usize) -> f32 {
let original_bytes = original_params * 4; let quantized_bytes = self.model_size();
original_bytes as f32 / quantized_bytes as f32
}
}
pub struct QuantizationAwareTraining {
pub fake_quantize: bool,
pub num_bits: u8,
}
impl QuantizationAwareTraining {
pub fn new() -> Self {
Self {
fake_quantize: true,
num_bits: 8,
}
}
pub fn fake_quantize_weights(&self, weights: &mut [f32]) {
if !self.fake_quantize {
return;
}
let quantizer = Int8Quantizer::from_weights(weights);
for w in weights.iter_mut() {
let quantized = quantizer.quantize(*w);
*w = quantizer.dequantize(quantized);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_round_trip() {
let weights = vec![-1.5, -0.5, 0.0, 0.5, 1.5];
let quantizer = Int8Quantizer::from_weights(&weights);
for &w in &weights {
let q = quantizer.quantize(w);
let dq = quantizer.dequantize(q);
assert!((w - dq).abs() < 0.02);
}
}
#[test]
fn test_compression_ratio() {
let input_dim = 32;
let hidden_dim = 64;
let output_dim = 3;
let total_params = (input_dim * hidden_dim) + hidden_dim +
(hidden_dim * output_dim) + output_dim;
let w1 = vec![0.1; input_dim * hidden_dim];
let b1 = vec![0.0; hidden_dim];
let w2 = vec![0.1; hidden_dim * output_dim];
let b2 = vec![0.0; output_dim];
let qmlp = QuantizedMlp::from_float_mlp(
&w1, &b1, &w2, &b2,
input_dim, hidden_dim, output_dim
);
let ratio = qmlp.compression_ratio(total_params);
assert!(ratio > 3.0 && ratio < 4.5);
println!("Compression ratio: {:.2}x", ratio);
println!("Original size: {} bytes", total_params * 4);
println!("Quantized size: {} bytes", qmlp.model_size());
}
}