use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array1, Array2, Axis};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantizationScheme {
Symmetric,
Asymmetric,
PerChannel,
}
#[derive(Debug, Clone)]
pub struct QuantizedLayer {
pub quantized: Array2<i8>,
pub scale: Vec<f32>,
pub zero_point: Vec<i32>,
pub bits: u8,
pub scheme: QuantizationScheme,
pub shape: (usize, usize),
}
pub fn calibrate_range(activations: &[Array2<f32>]) -> Result<(f32, f32)> {
if activations.is_empty() {
return Err(NeuralError::InvalidArchitecture(
"calibrate_range: activation slice must not be empty".into(),
));
}
let mut global_min = f32::INFINITY;
let mut global_max = f32::NEG_INFINITY;
for act in activations {
for &v in act.iter() {
if v < global_min {
global_min = v;
}
if v > global_max {
global_max = v;
}
}
}
if !global_min.is_finite() || !global_max.is_finite() {
return Err(NeuralError::InvalidArchitecture(
"calibrate_range: activations contain non-finite values".into(),
));
}
Ok((global_min, global_max))
}
pub fn quantize_weights(
weights: &Array2<f32>,
bits: u8,
scheme: QuantizationScheme,
) -> Result<QuantizedLayer> {
if bits == 0 || bits > 8 {
return Err(NeuralError::InvalidArchitecture(format!(
"quantize_weights: bits must be in [1, 8], got {bits}"
)));
}
let (nrows, ncols) = (weights.nrows(), weights.ncols());
match scheme {
QuantizationScheme::Symmetric => {
let (scale, zp) = symmetric_params(weights, bits)?;
let q = quantize_tensor_symmetric(weights, scale, bits)?;
Ok(QuantizedLayer {
quantized: q,
scale: vec![scale],
zero_point: vec![zp],
bits,
scheme,
shape: (nrows, ncols),
})
}
QuantizationScheme::Asymmetric => {
let (w_min, w_max) = tensor_min_max(weights)?;
let (scale, zp) = asymmetric_params(w_min, w_max, bits);
let q = quantize_tensor_asymmetric(weights, scale, zp, bits)?;
Ok(QuantizedLayer {
quantized: q,
scale: vec![scale],
zero_point: vec![zp],
bits,
scheme,
shape: (nrows, ncols),
})
}
QuantizationScheme::PerChannel => {
let mut scales = Vec::with_capacity(nrows);
let mut zps = Vec::with_capacity(nrows);
let mut q_data = Vec::with_capacity(nrows * ncols);
for row in weights.rows() {
let row_arr = row.to_owned();
let (s, zp) = {
let (rmin, rmax) = slice_min_max(row_arr.as_slice().ok_or_else(|| {
NeuralError::InvalidArchitecture(
"per-channel row is not contiguous".into(),
)
})?)?;
asymmetric_params(rmin, rmax, bits)
};
scales.push(s);
zps.push(zp);
let (qmin, qmax) = int_range(bits, false);
for &v in row_arr.iter() {
let q = (v / s).round() as i32 + zp;
q_data.push(q.clamp(qmin, qmax) as i8);
}
}
let q = Array2::from_shape_vec((nrows, ncols), q_data).map_err(|e| {
NeuralError::InvalidArchitecture(format!("per-channel quantization reshape: {e}"))
})?;
Ok(QuantizedLayer {
quantized: q,
scale: scales,
zero_point: zps,
bits,
scheme,
shape: (nrows, ncols),
})
}
}
}
pub fn dequantize(layer: &QuantizedLayer) -> Result<Array2<f32>> {
let (nrows, ncols) = layer.shape;
if layer.quantized.shape() != [nrows, ncols] {
return Err(NeuralError::InvalidArchitecture(format!(
"dequantize: stored shape ({nrows},{ncols}) does not match quantized shape {:?}",
layer.quantized.shape()
)));
}
let mut out = Array2::zeros((nrows, ncols));
match layer.scheme {
QuantizationScheme::Symmetric | QuantizationScheme::Asymmetric => {
let scale = layer.scale[0];
let zp = layer.zero_point[0];
for ((r, c), &q) in layer.quantized.indexed_iter() {
out[(r, c)] = scale * (q as i32 - zp) as f32;
}
}
QuantizationScheme::PerChannel => {
for (r, row) in layer.quantized.rows().into_iter().enumerate() {
let scale = layer.scale[r];
let zp = layer.zero_point[r];
for (c, &q) in row.iter().enumerate() {
out[(r, c)] = scale * (q as i32 - zp) as f32;
}
}
}
}
Ok(out)
}
fn int_range(bits: u8, signed: bool) -> (i32, i32) {
let n = 1i32 << bits;
if signed {
(-(n / 2), n / 2 - 1)
} else {
(0, n - 1)
}
}
fn tensor_min_max(weights: &Array2<f32>) -> Result<(f32, f32)> {
let mut mn = f32::INFINITY;
let mut mx = f32::NEG_INFINITY;
for &v in weights.iter() {
if v < mn {
mn = v;
}
if v > mx {
mx = v;
}
}
if !mn.is_finite() || !mx.is_finite() {
return Err(NeuralError::InvalidArchitecture(
"weight matrix contains non-finite values".into(),
));
}
Ok((mn, mx))
}
fn slice_min_max(s: &[f32]) -> Result<(f32, f32)> {
let mut mn = f32::INFINITY;
let mut mx = f32::NEG_INFINITY;
for &v in s {
if v < mn {
mn = v;
}
if v > mx {
mx = v;
}
}
if !mn.is_finite() || !mx.is_finite() {
return Err(NeuralError::InvalidArchitecture(
"slice contains non-finite values".into(),
));
}
Ok((mn, mx))
}
fn symmetric_params(weights: &Array2<f32>, bits: u8) -> Result<(f32, i32)> {
let max_abs = weights
.iter()
.map(|v| v.abs())
.fold(0.0_f32, f32::max);
if max_abs == 0.0 {
return Ok((1.0, 0));
}
let qmax = ((1i32 << (bits - 1)) - 1) as f32;
let scale = max_abs / qmax;
Ok((scale, 0))
}
fn asymmetric_params(w_min: f32, w_max: f32, bits: u8) -> (f32, i32) {
let qmin = 0_i32;
let qmax = (1i32 << bits) - 1;
let range = w_max - w_min;
let scale = if range == 0.0 { 1.0 } else { range / (qmax - qmin) as f32 };
let zp = (qmin as f32 - w_min / scale).round() as i32;
let zp = zp.clamp(qmin, qmax);
(scale, zp)
}
fn quantize_tensor_symmetric(
weights: &Array2<f32>,
scale: f32,
bits: u8,
) -> Result<Array2<i8>> {
let qmax = (1i32 << (bits - 1)) - 1;
let qmin = -qmax - 1;
let (nrows, ncols) = (weights.nrows(), weights.ncols());
let flat: Vec<i8> = weights
.iter()
.map(|&v| {
let q = (v / scale).round() as i32;
q.clamp(qmin, qmax) as i8
})
.collect();
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| NeuralError::InvalidArchitecture(format!("symmetric quant reshape: {e}")))
}
fn quantize_tensor_asymmetric(
weights: &Array2<f32>,
scale: f32,
zp: i32,
bits: u8,
) -> Result<Array2<i8>> {
let qmin = 0_i32;
let qmax = (1i32 << bits) - 1;
let (nrows, ncols) = (weights.nrows(), weights.ncols());
let flat: Vec<i8> = weights
.iter()
.map(|&v| {
let q = (v / scale).round() as i32 + zp;
let q_shifted = q - (1i32 << (bits - 1));
q.clamp(qmin, qmax) as i8
})
.collect();
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| NeuralError::InvalidArchitecture(format!("asymmetric quant reshape: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn uniform_weights(nrows: usize, ncols: usize, lo: f32, hi: f32) -> Array2<f32> {
let step = (hi - lo) / ((nrows * ncols) as f32 - 1.0);
let flat: Vec<f32> = (0..nrows * ncols).map(|i| lo + i as f32 * step).collect();
Array2::from_shape_vec((nrows, ncols), flat).expect("shape error")
}
#[test]
fn test_calibrate_range_basic() {
let acts = vec![
Array2::from_shape_vec((2, 2), vec![-1.0, 0.5, 0.0, 2.0]).expect("shape"),
Array2::from_shape_vec((2, 2), vec![0.3, -3.0, 1.2, 0.0]).expect("shape"),
];
let (mn, mx) = calibrate_range(&acts).expect("calibrate_range failed");
assert!((mn - (-3.0)).abs() < 1e-6);
assert!((mx - 2.0).abs() < 1e-6);
}
#[test]
fn test_calibrate_range_empty() {
let result = calibrate_range(&[]);
assert!(result.is_err());
}
#[test]
fn test_quantize_dequantize_symmetric() {
let w = uniform_weights(4, 4, -1.0, 1.0);
let ql = quantize_weights(&w, 8, QuantizationScheme::Symmetric).expect("quant failed");
let deq = dequantize(&ql).expect("dequant failed");
for (o, d) in w.iter().zip(deq.iter()) {
assert!((o - d).abs() < 0.02, "error too large: orig={o}, dequant={d}");
}
}
#[test]
fn test_quantize_dequantize_asymmetric() {
let w = uniform_weights(4, 4, 0.0, 1.0);
let ql = quantize_weights(&w, 8, QuantizationScheme::Asymmetric).expect("quant failed");
let deq = dequantize(&ql).expect("dequant failed");
for (o, d) in w.iter().zip(deq.iter()) {
assert!((o - d).abs() < 0.02, "error too large: orig={o}, dequant={d}");
}
}
#[test]
fn test_quantize_dequantize_per_channel() {
let w = uniform_weights(8, 4, -2.0, 2.0);
let ql = quantize_weights(&w, 8, QuantizationScheme::PerChannel).expect("quant failed");
assert_eq!(ql.scale.len(), 8);
assert_eq!(ql.zero_point.len(), 8);
let deq = dequantize(&ql).expect("dequant failed");
for (o, d) in w.iter().zip(deq.iter()) {
assert!((o - d).abs() < 0.05, "error too large: orig={o}, dequant={d}");
}
}
#[test]
fn test_quantize_invalid_bits() {
let w = uniform_weights(2, 2, -1.0, 1.0);
assert!(quantize_weights(&w, 0, QuantizationScheme::Symmetric).is_err());
assert!(quantize_weights(&w, 9, QuantizationScheme::Symmetric).is_err());
}
#[test]
fn test_quantize_all_zeros() {
let w = Array2::zeros((3, 3));
let ql = quantize_weights(&w, 8, QuantizationScheme::Symmetric).expect("quant failed");
let deq = dequantize(&ql).expect("dequant failed");
assert!(deq.iter().all(|&v| v == 0.0));
}
}