use super::error::QuantizationError;
#[derive(Debug, Clone, PartialEq)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i8,
pub dims: usize,
}
impl QuantizationParams {
pub fn new(scale: f32, dims: usize) -> Result<Self, QuantizationError> {
if !scale.is_finite() || scale <= 0.0 {
return Err(QuantizationError::InvalidScale { scale });
}
Ok(Self {
scale,
zero_point: 0, dims,
})
}
pub fn from_absmax(absmax: f32, dims: usize) -> Result<Self, QuantizationError> {
let scale = absmax / 127.0;
Self::new(scale, dims)
}
#[inline]
pub fn quantize_value(&self, value: f32) -> i8 {
let q = (value / self.scale).round() as i32;
q.clamp(-128, 127) as i8
}
#[inline]
pub fn dequantize_value(&self, value: i8) -> f32 {
(value as f32 - self.zero_point as f32) * self.scale
}
pub fn max_error_bound(&self) -> f32 {
self.scale / 2.0
}
}