use ndarray::{Array1, Array2, Axis};
use tensorlogic_scirs_backend::quantization::{
QuantizationGranularity, QuantizationParams, QuantizationType, QuantizedTensor,
};
#[derive(Debug)]
pub enum QuantizationError {
ShapeMismatch(String),
InvalidParams(String),
}
impl std::fmt::Display for QuantizationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QuantizationError::ShapeMismatch(msg) => write!(f, "shape mismatch: {msg}"),
QuantizationError::InvalidParams(msg) => write!(f, "invalid params: {msg}"),
}
}
}
impl std::error::Error for QuantizationError {}
pub struct QuantizedLinear {
weight_q: Array2<i8>,
scale: Vec<f64>,
zero_point: Vec<i32>,
granularity: QuantizationGranularity,
bias: Option<Array1<f64>>,
}
impl QuantizedLinear {
pub fn from_fp(
weight: &Array2<f64>,
params: &QuantizationParams,
) -> Result<Self, QuantizationError> {
if params.qtype != QuantizationType::Int8 {
return Err(QuantizationError::InvalidParams(format!(
"only Int8 is supported, got {:?}",
params.qtype
)));
}
let (out_features, _in_features) = weight.dim();
if params.granularity == QuantizationGranularity::PerChannel
&& params.scale.len() != out_features
{
return Err(QuantizationError::ShapeMismatch(format!(
"PerChannel: scale.len()={} but out_features={}",
params.scale.len(),
out_features,
)));
}
let weight_dyn = weight.clone().into_dyn();
let qt = QuantizedTensor::quantize(&weight_dyn, params.clone());
let weight_i8 = qt
.data
.mapv(|x| x as i8)
.into_dimensionality::<ndarray::Ix2>()
.map_err(|e| {
QuantizationError::ShapeMismatch(format!("dimensionality cast failed: {e}"))
})?;
Ok(Self {
weight_q: weight_i8,
scale: params.scale.clone(),
zero_point: params.zero_point.clone(),
granularity: params.granularity,
bias: None,
})
}
pub fn with_bias(mut self, bias: Array1<f64>) -> Result<Self, QuantizationError> {
let out_features = self.weight_q.nrows();
if bias.len() != out_features {
return Err(QuantizationError::ShapeMismatch(format!(
"bias.len()={} but out_features={}",
bias.len(),
out_features
)));
}
self.bias = Some(bias);
Ok(self)
}
pub fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
let weight_fp = self.dequantize();
let out = x.dot(&weight_fp.t());
match &self.bias {
Some(b) => out + b,
None => out,
}
}
pub fn dequantize(&self) -> Array2<f64> {
let (out_features, in_features) = self.weight_q.dim();
let mut fp = Array2::<f64>::zeros((out_features, in_features));
match self.granularity {
QuantizationGranularity::PerTensor => {
let s = self.scale[0];
let zp = self.zero_point[0] as f64;
for (q_row, mut fp_row) in self
.weight_q
.axis_iter(Axis(0))
.zip(fp.axis_iter_mut(Axis(0)))
{
for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
*fp_val = (*q_val as f64 - zp) * s;
}
}
}
QuantizationGranularity::PerChannel => {
for (c, (q_row, mut fp_row)) in self
.weight_q
.axis_iter(Axis(0))
.zip(fp.axis_iter_mut(Axis(0)))
.enumerate()
{
let s = self.scale.get(c).copied().unwrap_or(self.scale[0]);
let zp = self.zero_point.get(c).copied().unwrap_or(0) as f64;
for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
*fp_val = (*q_val as f64 - zp) * s;
}
}
}
}
fp
}
pub fn out_features(&self) -> usize {
self.weight_q.nrows()
}
pub fn in_features(&self) -> usize {
self.weight_q.ncols()
}
pub fn granularity(&self) -> QuantizationGranularity {
self.granularity
}
pub fn scales(&self) -> &[f64] {
&self.scale
}
}