use std::collections::HashMap;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantScheme {
PerTensor,
PerChannel(usize),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantDtype {
Int8,
Int4,
Uint8,
}
impl QuantDtype {
#[inline]
fn qmin(self) -> i32 {
match self {
QuantDtype::Int8 => -128,
QuantDtype::Int4 => -8,
QuantDtype::Uint8 => 0,
}
}
#[inline]
fn qmax(self) -> i32 {
match self {
QuantDtype::Int8 => 127,
QuantDtype::Int4 => 7,
QuantDtype::Uint8 => 255,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
data: Vec<i8>,
scale: Vec<f32>,
zero_point: Vec<i32>,
shape: Vec<usize>,
scheme: QuantScheme,
dtype: QuantDtype,
}
impl QuantizedTensor {
#[inline]
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn data(&self) -> &[i8] {
&self.data
}
#[inline]
pub fn scale(&self) -> &[f32] {
&self.scale
}
#[inline]
pub fn zero_point(&self) -> &[i32] {
&self.zero_point
}
#[inline]
pub fn scheme(&self) -> QuantScheme {
self.scheme
}
#[inline]
pub fn qdtype(&self) -> QuantDtype {
self.dtype
}
}
fn compute_scale_zp(min_val: f32, max_val: f32, dtype: QuantDtype) -> (f32, i32) {
let qmin = dtype.qmin();
let qmax = dtype.qmax();
let min_val = min_val.min(0.0);
let max_val = max_val.max(0.0);
let range = (max_val - min_val).max(f32::EPSILON);
let scale = range / (qmax - qmin) as f32;
let zp = (qmin as f32 - min_val / scale).round() as i32;
(scale, zp)
}
#[inline]
fn quantize_val(x: f32, scale: f32, zp: i32, qmin: i32, qmax: i32, is_unsigned: bool) -> i8 {
let q = (x / scale + zp as f32).round() as i32;
let clamped = q.clamp(qmin, qmax);
if is_unsigned {
(clamped as u8) as i8
} else {
clamped as i8
}
}
#[inline]
fn stored_to_i32(val: i8, is_unsigned: bool) -> i32 {
if is_unsigned {
(val as u8) as i32
} else {
val as i32
}
}
#[inline]
fn channel_index(flat_index: usize, shape: &[usize], axis: usize) -> usize {
let stride: usize = shape[axis + 1..].iter().product();
(flat_index / stride) % shape[axis]
}
pub fn quantize<T: Float>(
tensor: &Tensor<T>,
scheme: QuantScheme,
dtype: QuantDtype,
) -> FerrotorchResult<QuantizedTensor> {
let data = tensor.data()?;
let shape = tensor.shape().to_vec();
let numel = tensor.numel();
let qmin = dtype.qmin();
let qmax = dtype.qmax();
let is_unsigned = dtype == QuantDtype::Uint8;
match scheme {
QuantScheme::PerTensor => {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for &v in data {
let f = v.to_f32().unwrap();
if f < min_val {
min_val = f;
}
if f > max_val {
max_val = f;
}
}
let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
let qdata: Vec<i8> = data
.iter()
.map(|&v| {
quantize_val(v.to_f32().unwrap(), scale, zp, qmin, qmax, is_unsigned)
})
.collect();
Ok(QuantizedTensor {
data: qdata,
scale: vec![scale],
zero_point: vec![zp],
shape,
scheme,
dtype,
})
}
QuantScheme::PerChannel(axis) => {
if axis >= shape.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PerChannel axis {axis} out of range for {}-d tensor",
shape.len()
),
});
}
let num_channels = shape[axis];
let mut mins = vec![f32::INFINITY; num_channels];
let mut maxs = vec![f32::NEG_INFINITY; num_channels];
for (i, &v) in data.iter().enumerate() {
let ch = channel_index(i, &shape, axis);
let f = v.to_f32().unwrap();
if f < mins[ch] {
mins[ch] = f;
}
if f > maxs[ch] {
maxs[ch] = f;
}
}
let params: Vec<(f32, i32)> = mins
.iter()
.zip(maxs.iter())
.map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
.collect();
let scales: Vec<f32> = params.iter().map(|&(s, _)| s).collect();
let zps: Vec<i32> = params.iter().map(|&(_, z)| z).collect();
let mut qdata = Vec::with_capacity(numel);
for (i, &v) in data.iter().enumerate() {
let ch = channel_index(i, &shape, axis);
qdata.push(quantize_val(
v.to_f32().unwrap(),
scales[ch],
zps[ch],
qmin,
qmax,
is_unsigned,
));
}
Ok(QuantizedTensor {
data: qdata,
scale: scales,
zero_point: zps,
shape,
scheme,
dtype,
})
}
}
}
pub fn dequantize<T: Float>(qtensor: &QuantizedTensor) -> FerrotorchResult<Tensor<T>> {
let numel = qtensor.numel();
let mut result = Vec::with_capacity(numel);
let is_unsigned = qtensor.dtype == QuantDtype::Uint8;
match qtensor.scheme {
QuantScheme::PerTensor => {
let scale = qtensor.scale[0];
let zp = qtensor.zero_point[0];
for &q in &qtensor.data {
let val = (stored_to_i32(q, is_unsigned) - zp) as f32 * scale;
result.push(T::from(val).unwrap());
}
}
QuantScheme::PerChannel(axis) => {
for (i, &q) in qtensor.data.iter().enumerate() {
let ch = channel_index(i, &qtensor.shape, axis);
let val = (stored_to_i32(q, is_unsigned) - qtensor.zero_point[ch]) as f32
* qtensor.scale[ch];
result.push(T::from(val).unwrap());
}
}
}
Tensor::from_storage(TensorStorage::cpu(result), qtensor.shape.clone(), false)
}
pub fn quantized_matmul(
a: &QuantizedTensor,
b: &QuantizedTensor,
) -> FerrotorchResult<QuantizedTensor> {
if a.shape.len() != 2 || b.shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"quantized_matmul requires 2-D tensors, got shapes {:?} and {:?}",
a.shape, b.shape
),
});
}
let m = a.shape[0];
let k = a.shape[1];
let k2 = b.shape[0];
let n = b.shape[1];
if k != k2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"quantized_matmul inner dimensions mismatch: [{m}, {k}] x [{k2}, {n}]"
),
});
}
if a.scale.len() != 1 || b.scale.len() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: "quantized_matmul currently requires PerTensor-quantized inputs".into(),
});
}
let a_scale = a.scale[0];
let a_zp = a.zero_point[0];
let b_scale = b.scale[0];
let b_zp = b.zero_point[0];
let a_unsigned = a.dtype == QuantDtype::Uint8;
let b_unsigned = b.dtype == QuantDtype::Uint8;
let mut acc = vec![0i32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0i32;
for p in 0..k {
let qa = stored_to_i32(a.data[i * k + p], a_unsigned) - a_zp;
let qb = stored_to_i32(b.data[p * n + j], b_unsigned) - b_zp;
sum += qa * qb;
}
acc[i * n + j] = sum;
}
}
let combined_scale = a_scale * b_scale;
let mut out_min = f32::INFINITY;
let mut out_max = f32::NEG_INFINITY;
for &a_val in &acc {
let real = a_val as f32 * combined_scale;
if real < out_min {
out_min = real;
}
if real > out_max {
out_max = real;
}
}
let out_dtype = QuantDtype::Int8;
let (out_scale, out_zp) = compute_scale_zp(out_min, out_max, out_dtype);
let qmin = out_dtype.qmin();
let qmax = out_dtype.qmax();
let qdata: Vec<i8> = acc
.iter()
.map(|&a_val| {
let real = a_val as f32 * combined_scale;
quantize_val(real, out_scale, out_zp, qmin, qmax, false)
})
.collect();
Ok(QuantizedTensor {
data: qdata,
scale: vec![out_scale],
zero_point: vec![out_zp],
shape: vec![m, n],
scheme: QuantScheme::PerTensor,
dtype: out_dtype,
})
}
pub fn quantize_named_tensors<T: Float>(
named_tensors: impl IntoIterator<Item = (String, Tensor<T>)>,
scheme: QuantScheme,
dtype: QuantDtype,
) -> FerrotorchResult<HashMap<String, QuantizedTensor>> {
let mut result = HashMap::new();
for (name, tensor) in named_tensors {
let qtensor = quantize(&tensor, scheme, dtype)?;
result.insert(name, qtensor);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
crate::from_slice(data, shape).unwrap()
}
#[test]
fn test_per_tensor_int8_roundtrip() {
let data: Vec<f32> = (-10..=10).map(|x| x as f32 * 0.5).collect();
let t = make_tensor(&data, &[data.len()]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
assert_eq!(rt.shape(), t.shape());
let orig = t.data().unwrap();
let recovered = rt.data().unwrap();
for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
let err = (o - r).abs();
assert!(
err < 0.05,
"element {i}: original={o}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_per_tensor_uint8_roundtrip() {
let data: Vec<f32> = (0..=20).map(|x| x as f32 * 0.1).collect();
let t = make_tensor(&data, &[data.len()]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Uint8).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
let orig = t.data().unwrap();
let recovered = rt.data().unwrap();
for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
let err = (o - r).abs();
assert!(
err < 0.02,
"element {i}: original={o}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_per_tensor_int4_roundtrip() {
let data: Vec<f32> = (-8..=7).map(|x| x as f32).collect();
let t = make_tensor(&data, &[data.len()]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int4).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
let orig = t.data().unwrap();
let recovered = rt.data().unwrap();
for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
let err = (o - r).abs();
assert!(
err < 1.01,
"element {i}: original={o}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_per_channel_int8_roundtrip() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
0.0, 1.0, 2.0, 3.0,
-10.0, -5.0, 5.0, 10.0,
100.0, 130.0, 170.0, 200.0,
];
let t = make_tensor(&data, &[3, 4]);
let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int8).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
assert_eq!(qt.scale.len(), 3);
assert_eq!(qt.zero_point.len(), 3);
let orig = t.data().unwrap();
let recovered = rt.data().unwrap();
for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
let err = (o - r).abs();
assert!(
err < 0.5,
"element {i}: original={o}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_per_channel_axis_out_of_bounds() {
let t = make_tensor(&[1.0, 2.0, 3.0], &[3]);
let result = quantize(&t, QuantScheme::PerChannel(5), QuantDtype::Int8);
assert!(result.is_err());
}
#[test]
fn test_quantized_matmul_identity() {
let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = make_tensor(&a_data, &[2, 2]);
let eye = make_tensor(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qi = quantize(&eye, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qc = quantized_matmul(&qa, &qi).unwrap();
let c: Tensor<f32> = dequantize(&qc).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let c_data = c.data().unwrap();
for (i, (&expected, &got)) in a_data.iter().zip(c_data.iter()).enumerate() {
let err = (expected - got).abs();
assert!(
err < 0.5,
"element {i}: expected={expected}, got={got}, error={err}"
);
}
}
#[test]
fn test_quantized_matmul_correctness() {
let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = make_tensor(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qc = quantized_matmul(&qa, &qb).unwrap();
let c: Tensor<f32> = dequantize(&qc).unwrap();
let expected = [58.0f32, 64.0, 139.0, 154.0];
let c_data = c.data().unwrap();
assert_eq!(c.shape(), &[2, 2]);
for (i, (&e, &g)) in expected.iter().zip(c_data.iter()).enumerate() {
let err = (e - g).abs();
assert!(
err < 3.0,
"element {i}: expected={e}, got={g}, error={err}"
);
}
}
#[test]
fn test_quantized_matmul_shape_mismatch() {
let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let b = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let result = quantized_matmul(&qa, &qb);
assert!(result.is_err());
}
#[test]
fn test_quantized_matmul_non_2d() {
let a = make_tensor(&[1.0, 2.0, 3.0], &[3]);
let b = make_tensor(&[4.0, 5.0, 6.0], &[3]);
let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let result = quantized_matmul(&qa, &qb);
assert!(result.is_err());
}
#[test]
fn test_quantize_named_tensors() {
let w1 = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let w2 = make_tensor(&[-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], &[3, 2]);
let named = vec![
("layer.weight".to_string(), w1),
("layer2.weight".to_string(), w2),
];
let qmap =
quantize_named_tensors(named, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
assert_eq!(qmap.len(), 2);
assert!(qmap.contains_key("layer.weight"));
assert!(qmap.contains_key("layer2.weight"));
assert_eq!(qmap["layer.weight"].shape(), &[2, 2]);
assert_eq!(qmap["layer2.weight"].shape(), &[3, 2]);
}
#[test]
fn test_quantize_constant_tensor() {
let t = make_tensor(&[5.0, 5.0, 5.0, 5.0], &[4]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
let recovered = rt.data().unwrap();
for &r in recovered {
assert!(
(r - 5.0).abs() < 0.1,
"constant tensor dequantized to {r}, expected 5.0"
);
}
}
#[test]
fn test_quantize_single_element() {
let t = make_tensor(&[42.0], &[1]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let rt: Tensor<f32> = dequantize(&qt).unwrap();
assert!((rt.data().unwrap()[0] - 42.0).abs() < 0.5);
}
#[test]
fn test_per_channel_int4() {
let data = vec![0.0, 1.0, 2.0, -4.0, 0.0, 4.0];
let t = make_tensor(&data, &[2, 3]);
let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int4).unwrap();
assert_eq!(qt.scale.len(), 2);
assert_eq!(qt.zero_point.len(), 2);
let rt: Tensor<f32> = dequantize(&qt).unwrap();
let orig = t.data().unwrap();
let recovered = rt.data().unwrap();
for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
let err = (o - r).abs();
assert!(
err < 1.0,
"element {i}: original={o}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_dequantize_f64() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let t = crate::from_slice(&data, &[4]).unwrap();
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
let rt: Tensor<f64> = dequantize(&qt).unwrap();
assert_eq!(rt.shape(), &[4]);
let recovered = rt.data().unwrap();
for (i, &r) in recovered.iter().enumerate() {
let expected = data[i] as f64;
let err = (expected - r).abs();
assert!(
err < 0.05,
"element {i}: expected={expected}, recovered={r}, error={err}"
);
}
}
#[test]
fn test_quantized_tensor_accessors() {
let t = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
assert_eq!(qt.numel(), 6);
assert_eq!(qt.shape(), &[2, 3]);
assert_eq!(qt.data().len(), 6);
assert_eq!(qt.scale().len(), 1);
assert_eq!(qt.zero_point().len(), 1);
assert_eq!(qt.scheme(), QuantScheme::PerTensor);
assert_eq!(qt.qdtype(), QuantDtype::Int8);
}
}