use crate::error::{CnnError, CnnResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i32,
pub qmin: i8,
pub qmax: i8,
}
impl QuantizationParams {
pub fn from_minmax(min_val: f32, max_val: f32, mode: QuantizationMode) -> CnnResult<Self> {
if min_val > max_val {
return Err(CnnError::InvalidParameter(format!(
"min_val ({}) must be <= max_val ({})",
min_val, max_val
)));
}
match mode {
QuantizationMode::Symmetric => {
let max_abs = min_val.abs().max(max_val.abs());
let scale = if max_abs > 0.0 {
max_abs / 127.0
} else {
1.0 };
Ok(Self {
scale,
zero_point: 0,
qmin: -127,
qmax: 127,
})
}
QuantizationMode::Asymmetric => {
let scale = if max_val > min_val {
(max_val - min_val) / 254.0 } else {
1.0
};
let zero_point = if scale > 0.0 {
((-min_val / scale).round() - 127.0).clamp(-127.0, 127.0) as i32
} else {
0
};
Ok(Self {
scale,
zero_point,
qmin: -127,
qmax: 127,
})
}
}
}
pub fn from_percentile(
percentile_min: f32,
percentile_max: f32,
mode: QuantizationMode,
) -> CnnResult<Self> {
Self::from_minmax(percentile_min, percentile_max, mode)
}
pub fn validate(&self) -> CnnResult<()> {
if self.scale <= 0.0 {
return Err(CnnError::QuantizationError(format!(
"scale must be positive, got {}",
self.scale
)));
}
if self.qmin > self.qmax {
return Err(CnnError::QuantizationError(format!(
"qmin ({}) must be <= qmax ({})",
self.qmin, self.qmax
)));
}
if self.zero_point < self.qmin as i32 || self.zero_point > self.qmax as i32 {
return Err(CnnError::QuantizationError(format!(
"zero_point ({}) must be in range [{}, {}]",
self.zero_point, self.qmin, self.qmax
)));
}
Ok(())
}
#[inline]
pub fn quantize_value(&self, value: f32) -> i8 {
let q = (value / self.scale).round() + self.zero_point as f32;
q.clamp(self.qmin as f32, self.qmax as f32) as i8
}
#[inline]
pub fn dequantize_value(&self, value: i8) -> f32 {
(value as f32 - self.zero_point as f32) * self.scale
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationScheme {
PerTensor,
PerChannel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationMode {
Symmetric,
Asymmetric,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symmetric_minmax() {
let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
assert_eq!(params.zero_point, 0);
assert!(params.scale > 0.0);
assert_eq!(params.qmin, -127);
assert_eq!(params.qmax, 127);
params.validate().unwrap();
}
#[test]
fn test_asymmetric_minmax() {
let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
.unwrap();
assert!(params.scale > 0.0);
assert!(params.zero_point >= -128);
assert!(params.zero_point <= 127);
params.validate().unwrap();
}
#[test]
fn test_quantize_dequantize_symmetric() {
let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
let value = 5.0f32;
let quantized = params.quantize_value(value);
let dequantized = params.dequantize_value(quantized);
assert!((dequantized - value).abs() < 0.1);
}
#[test]
fn test_quantize_dequantize_asymmetric() {
let params = QuantizationParams::from_minmax(0.0, 10.0, QuantizationMode::Asymmetric)
.unwrap();
let value = 5.0f32;
let quantized = params.quantize_value(value);
let dequantized = params.dequantize_value(quantized);
assert!((dequantized - value).abs() < 0.1);
}
#[test]
fn test_zero_value_quantization() {
let params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
let quantized = params.quantize_value(0.0);
assert_eq!(quantized, 0);
let dequantized = params.dequantize_value(0);
assert_eq!(dequantized, 0.0);
}
#[test]
fn test_clipping() {
let params = QuantizationParams::from_minmax(-1.0, 1.0, QuantizationMode::Symmetric)
.unwrap();
let large = params.quantize_value(1000.0);
assert_eq!(large, 127);
let small = params.quantize_value(-1000.0);
assert_eq!(small, -127);
}
#[test]
fn test_invalid_range() {
let result = QuantizationParams::from_minmax(10.0, -10.0, QuantizationMode::Symmetric);
assert!(result.is_err());
}
#[test]
fn test_percentile_constructor() {
let params = QuantizationParams::from_percentile(-9.5, 9.5, QuantizationMode::Symmetric)
.unwrap();
assert_eq!(params.zero_point, 0);
params.validate().unwrap();
}
#[test]
fn test_validation_negative_scale() {
let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
params.scale = -1.0;
assert!(params.validate().is_err());
}
#[test]
fn test_validation_zero_scale() {
let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
params.scale = 0.0;
assert!(params.validate().is_err());
}
#[test]
fn test_validation_invalid_qmin_qmax() {
let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
params.qmin = 127;
params.qmax = -127;
assert!(params.validate().is_err());
}
#[test]
fn test_validation_zero_point_out_of_range() {
let mut params = QuantizationParams::from_minmax(-10.0, 10.0, QuantizationMode::Symmetric)
.unwrap();
params.zero_point = 200;
assert!(params.validate().is_err());
}
}