use crate::autograd::Variable;
use crate::nn::Module;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive, One, Signed, ToPrimitive, Zero};
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantizationType {
Int8,
Int4,
Float16,
Dynamic,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CalibrationMode {
MinMax,
Percentile(f32),
Entropy,
}
#[derive(Debug, Clone)]
pub struct QuantizationParams<T: Float> {
pub scale: T,
pub zero_point: i32,
pub qmin: i32,
pub qmax: i32,
pub qtype: QuantizationType,
}
impl<T: Float + FromPrimitive> QuantizationParams<T> {
pub fn int8_symmetric(scale: T) -> Self {
QuantizationParams {
scale,
zero_point: 0,
qmin: -128,
qmax: 127,
qtype: QuantizationType::Int8,
}
}
pub fn int8_asymmetric(scale: T, zero_point: i32) -> Self {
QuantizationParams {
scale,
zero_point,
qmin: -128,
qmax: 127,
qtype: QuantizationType::Int8,
}
}
pub fn int4_symmetric(scale: T) -> Self {
QuantizationParams {
scale,
zero_point: 0,
qmin: -8,
qmax: 7,
qtype: QuantizationType::Int4,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor<T: Float> {
pub data: Vec<i8>,
pub shape: Vec<usize>,
pub params: QuantizationParams<T>,
}
impl<T> QuantizedTensor<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Default
+ Zero
+ One
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ Signed,
{
pub fn new(data: Vec<i8>, shape: Vec<usize>, params: QuantizationParams<T>) -> Self {
QuantizedTensor {
data,
shape,
params,
}
}
pub fn dequantize(&self) -> Tensor<T> {
let mut float_data = Vec::with_capacity(self.data.len());
for &qval in &self.data {
let qval_adjusted = i32::from(qval) - self.params.zero_point;
let float_val = T::from_i32(qval_adjusted).unwrap() * self.params.scale;
float_data.push(float_val);
}
Tensor::from_vec(float_data, self.shape.clone())
}
pub fn compression_ratio(&self) -> f32 {
match self.params.qtype {
QuantizationType::Int8 => 4.0, QuantizationType::Int4 => 8.0, QuantizationType::Float16 => 2.0, QuantizationType::Dynamic => 3.0, }
}
pub fn memory_bytes(&self) -> usize {
match self.params.qtype {
QuantizationType::Int8 => self.data.len(),
QuantizationType::Int4 => self.data.len().div_ceil(2), QuantizationType::Float16 => self.data.len() * 2,
QuantizationType::Dynamic => self.data.len(),
}
}
}
#[derive(Debug)]
pub struct Quantizer<T: Float> {
calibration_mode: CalibrationMode,
param_cache: HashMap<String, QuantizationParams<T>>,
symmetric: bool,
}
impl<T> Quantizer<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Default
+ Zero
+ One
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ Signed,
{
pub fn new(calibration_mode: CalibrationMode, symmetric: bool) -> Self {
Quantizer {
calibration_mode,
param_cache: HashMap::new(),
symmetric,
}
}
pub fn quantize_tensor(
&mut self,
tensor: &Tensor<T>,
qtype: QuantizationType,
layer_name: Option<&str>,
) -> QuantizedTensor<T> {
let params = if let Some(name) = layer_name {
if let Some(cached_params) = self.param_cache.get(name) {
cached_params.clone()
} else {
let params = self.compute_quantization_params(tensor, qtype);
self.param_cache.insert(name.to_string(), params.clone());
params
}
} else {
self.compute_quantization_params(tensor, qtype)
};
let quantized_data = self.quantize_data(tensor, ¶ms);
QuantizedTensor::new(quantized_data, tensor.shape().to_vec(), params)
}
fn compute_quantization_params(
&self,
tensor: &Tensor<T>,
qtype: QuantizationType,
) -> QuantizationParams<T> {
let tensor_array = tensor.as_array();
let tensor_slice = tensor_array
.as_slice()
.unwrap_or_else(|| panic!("Cannot get slice from tensor"));
let (min_val, max_val) = match self.calibration_mode {
CalibrationMode::MinMax => {
let min = tensor_slice
.iter()
.fold(T::infinity(), |a, &b| if a < b { a } else { b });
let max = tensor_slice
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
(min, max)
}
CalibrationMode::Percentile(p) => {
let mut sorted = tensor_slice.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = sorted.len();
let low_idx = ((1.0 - p) * 0.5 * len as f32) as usize;
let high_idx = len - 1 - low_idx;
(sorted[low_idx], sorted[high_idx])
}
CalibrationMode::Entropy => {
let min = tensor_slice
.iter()
.fold(T::infinity(), |a, &b| if a < b { a } else { b });
let max = tensor_slice
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
(min, max)
}
};
match qtype {
QuantizationType::Int8 => {
if self.symmetric {
let abs_max = min_val.abs().max(max_val.abs());
let scale = abs_max / T::from_i32(127).unwrap();
QuantizationParams::int8_symmetric(scale)
} else {
let range = max_val - min_val;
let scale = range / T::from_i32(255).unwrap();
let zero_point = -(min_val / scale).round().to_i32().unwrap_or(0) - 128;
QuantizationParams::int8_asymmetric(scale, zero_point)
}
}
QuantizationType::Int4 => {
let abs_max = min_val.abs().max(max_val.abs());
let scale = abs_max / T::from_i32(7).unwrap();
QuantizationParams::int4_symmetric(scale)
}
QuantizationType::Float16 => {
QuantizationParams::int8_symmetric(T::one())
}
QuantizationType::Dynamic => {
let abs_max = min_val.abs().max(max_val.abs());
let scale = abs_max / T::from_i32(127).unwrap();
QuantizationParams::int8_symmetric(scale)
}
}
}
fn quantize_data(&self, tensor: &Tensor<T>, params: &QuantizationParams<T>) -> Vec<i8> {
let tensor_array = tensor.as_array();
let tensor_slice = tensor_array.as_slice().unwrap();
tensor_slice
.iter()
.map(|&val| {
let scaled = val / params.scale;
let quantized = scaled.round().to_i32().unwrap_or(0) + params.zero_point;
let clamped = quantized.max(params.qmin).min(params.qmax);
clamped as i8
})
.collect()
}
pub fn quantize_module<M: Module<T>>(
&mut self,
module: &M,
qtype: QuantizationType,
layer_name: &str,
) -> Vec<QuantizedTensor<T>> {
let parameters = module.parameters();
let mut quantized_params = Vec::new();
for (i, param) in parameters.iter().enumerate() {
let param_name = format!("{}_{}", layer_name, i);
let param_tensor = param.data();
let param_data = param_tensor.read().unwrap();
let quantized = self.quantize_tensor(&*param_data, qtype, Some(¶m_name));
quantized_params.push(quantized);
}
quantized_params
}
pub fn clear_cache(&mut self) {
self.param_cache.clear();
}
pub fn get_statistics(&self) -> HashMap<String, QuantizationParams<T>> {
self.param_cache.clone()
}
}
#[derive(Debug)]
pub struct QuantizationAwareModule<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
M: Module<T> + 'static,
> {
module: M,
qparams: HashMap<String, QuantizationParams<T>>,
qat_enabled: bool,
fake_quantize: bool,
}
impl<T, M> QuantizationAwareModule<T, M>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Default
+ Zero
+ One
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ Signed
+ 'static,
M: Module<T> + 'static,
{
pub fn new(module: M) -> Self {
QuantizationAwareModule {
module,
qparams: HashMap::new(),
qat_enabled: false,
fake_quantize: true,
}
}
pub fn enable_qat(&mut self) {
self.qat_enabled = true;
}
pub fn disable_qat(&mut self) {
self.qat_enabled = false;
}
pub fn set_fake_quantize(&mut self, enabled: bool) {
self.fake_quantize = enabled;
}
fn apply_fake_quantization(&self, input: &Variable<T>, layer_name: &str) -> Variable<T> {
if !self.qat_enabled || !self.fake_quantize {
return input.clone();
}
if let Some(params) = self.qparams.get(layer_name) {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let quantized_data: Vec<T> = input_data
.as_array()
.iter()
.map(|&val| {
let scaled = val / params.scale;
let quantized = scaled.round().to_i32().unwrap_or(0) + params.zero_point;
let clamped = quantized.max(params.qmin).min(params.qmax);
let qval_adjusted = clamped - params.zero_point;
T::from_i32(qval_adjusted).unwrap() * params.scale
})
.collect();
let fake_quantized_tensor =
Tensor::from_vec(quantized_data, input_data.shape().to_vec());
Variable::new(fake_quantized_tensor, input.requires_grad())
} else {
input.clone()
}
}
}
impl<T, M> Module<T> for QuantizationAwareModule<T, M>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Default
+ Zero
+ One
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ Signed
+ 'static,
M: Module<T> + 'static,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let fake_quantized_input = self.apply_fake_quantization(input, "input");
self.module.forward(&fake_quantized_input)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.module.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_params_creation() {
let params = QuantizationParams::<f32>::int8_symmetric(0.1);
assert_eq!(params.qmin, -128);
assert_eq!(params.qmax, 127);
assert_eq!(params.zero_point, 0);
assert_eq!(params.qtype, QuantizationType::Int8);
}
#[test]
fn test_quantizer_creation() {
let quantizer = Quantizer::<f32>::new(CalibrationMode::MinMax, true);
assert_eq!(quantizer.calibration_mode, CalibrationMode::MinMax);
assert!(quantizer.symmetric);
}
#[test]
fn test_tensor_quantization() {
let mut quantizer = Quantizer::<f32>::new(CalibrationMode::MinMax, true);
let data = vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0];
let tensor = Tensor::from_vec(data, vec![2, 3]);
let quantized = quantizer.quantize_tensor(&tensor, QuantizationType::Int8, None);
assert_eq!(quantized.shape, vec![2, 3]);
assert_eq!(quantized.data.len(), 6);
assert_eq!(quantized.params.qtype, QuantizationType::Int8);
}
#[test]
fn test_quantized_tensor_dequantization() {
let params = QuantizationParams::<f32>::int8_symmetric(0.1);
let data = vec![10, 20, 30, -10, -20, -30];
let shape = vec![2, 3];
let quantized = QuantizedTensor::new(data, shape, params);
let dequantized = quantized.dequantize();
assert_eq!(dequantized.shape(), &[2, 3]);
let dequant_array = dequantized.as_array();
let expected = [1.0, 2.0, 3.0, -1.0, -2.0, -3.0];
for (i, &val) in dequant_array.iter().enumerate() {
assert!((val - expected[i]).abs() < 0.01);
}
}
#[test]
fn test_compression_ratio() {
let params = QuantizationParams::<f32>::int8_symmetric(0.1);
let quantized = QuantizedTensor::new(vec![1, 2, 3], vec![3], params);
assert_eq!(quantized.compression_ratio(), 4.0);
}
}