use crate::Tensor;
use super::config::FakeQuantConfig;
#[derive(Clone, Debug)]
pub struct FakeQuantize {
pub config: FakeQuantConfig,
pub scale: f32,
pub zero_point: i32,
pub initialized: bool,
}
impl FakeQuantize {
pub fn new(config: FakeQuantConfig) -> Self {
Self { config, scale: 1.0, zero_point: 0, initialized: false }
}
pub fn q4() -> Self {
Self::new(FakeQuantConfig::q4_symmetric())
}
pub fn q8() -> Self {
Self::new(FakeQuantConfig::q8_symmetric())
}
pub fn calibrate(&mut self, data: &[f32]) {
if data.is_empty() {
return;
}
let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if self.config.symmetric {
let max_abs = min_val.abs().max(max_val.abs());
self.scale = max_abs / self.config.qmax as f32;
self.zero_point = 0;
} else {
self.scale = (max_val - min_val) / (self.config.qmax - self.config.qmin) as f32;
self.zero_point = (self.config.qmin as f32 - min_val / self.scale).round() as i32;
self.zero_point = self.zero_point.clamp(self.config.qmin, self.config.qmax);
}
if self.scale < 1e-10 {
self.scale = 1e-10;
}
self.initialized = true;
}
pub fn forward(&self, input: &Tensor) -> Tensor {
let data: Vec<f32> = input.data().iter().map(|&x| self.fake_quantize_value(x)).collect();
Tensor::new(ndarray::arr1(&data), input.requires_grad())
}
pub fn forward_with_calibration(&mut self, input: &Tensor) -> Tensor {
if !self.initialized {
self.calibrate(input.data().as_slice().unwrap_or(&[]));
}
self.forward(input)
}
pub fn backward(&self, grad_output: &Tensor) -> Tensor {
grad_output.clone()
}
pub fn backward_clamped(&self, grad_output: &Tensor, input: &Tensor) -> Tensor {
let qmin_float = self.config.qmin as f32 * self.scale;
let qmax_float = self.config.qmax as f32 * self.scale;
let data: Vec<f32> = grad_output
.data()
.iter()
.zip(input.data().iter())
.map(|(&grad, &x)| {
if x < qmin_float || x > qmax_float {
0.0
} else {
grad
}
})
.collect();
Tensor::new(ndarray::arr1(&data), grad_output.requires_grad())
}
fn fake_quantize_value(&self, x: f32) -> f32 {
let q = if self.config.symmetric {
(x / self.scale).round().clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
} else {
((x / self.scale) + self.zero_point as f32)
.round()
.clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
};
if self.config.symmetric {
q as f32 * self.scale
} else {
(q - self.zero_point) as f32 * self.scale
}
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn zero_point(&self) -> i32 {
self.zero_point
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn num_levels(&self) -> usize {
(self.config.qmax - self.config.qmin + 1) as usize
}
}