use crate::{CnnError, CnnResult, Tensor};
use super::Linear;
#[derive(Debug, Clone)]
pub struct QuantizedLinear {
weights_q: Vec<i8>,
weight_scales: Vec<f32>,
bias_q: Vec<i32>,
bias_f32: Vec<f32>,
in_features: usize,
out_features: usize,
}
impl QuantizedLinear {
pub fn from_fp32(linear: &Linear, input_scale: f32) -> Self {
let in_features = linear.in_features();
let out_features = linear.out_features();
let weights = linear.weight();
let mut weight_scales = vec![0.0f32; out_features];
for of in 0..out_features {
let mut max_abs = 0.0f32;
for if_ in 0..in_features {
let idx = of * in_features + if_;
max_abs = max_abs.max(weights[idx].abs());
}
weight_scales[of] = if max_abs > 0.0 {
max_abs / 127.0
} else {
1.0
};
}
let mut weights_q = vec![0i8; weights.len()];
for of in 0..out_features {
let scale = weight_scales[of];
for if_ in 0..in_features {
let idx = of * in_features + if_;
let w_q = (weights[idx] / scale).round().clamp(-127.0, 127.0) as i8;
weights_q[idx] = w_q;
}
}
let bias_f32 = linear.bias()
.map(|b| b.to_vec())
.unwrap_or_else(|| vec![0.0; out_features]);
let mut bias_q = vec![0i32; out_features];
for of in 0..out_features {
let combined_scale = input_scale * weight_scales[of];
bias_q[of] = if combined_scale > 0.0 {
(bias_f32[of] / combined_scale).round() as i32
} else {
0
};
}
Self {
weights_q,
weight_scales,
bias_q,
bias_f32,
in_features,
out_features,
}
}
pub fn forward_int8(
&self,
input: &[u8],
batch_size: usize,
input_scale: f32,
input_zero_point: u8,
) -> CnnResult<Tensor> {
if input.len() != batch_size * self.in_features {
return Err(CnnError::invalid_shape(
format!("input size {}", batch_size * self.in_features),
format!("size {}", input.len())
));
}
let mut output_i32 = vec![0i32; batch_size * self.out_features];
let mut weight_sums = vec![0i32; self.out_features];
for of in 0..self.out_features {
let mut sum = 0i32;
for if_ in 0..self.in_features {
sum += self.weights_q[of * self.in_features + if_] as i32;
}
weight_sums[of] = sum;
}
for b in 0..batch_size {
for of in 0..self.out_features {
let mut acc = self.bias_q[of] - (input_zero_point as i32) * weight_sums[of];
for if_ in 0..self.in_features {
let input_val = input[b * self.in_features + if_] as i32;
let weight_val = self.weights_q[of * self.in_features + if_] as i32;
acc += input_val * weight_val;
}
output_i32[b * self.out_features + of] = acc;
}
}
let output_f32 = self.dequantize_output(&output_i32, input_scale);
Tensor::from_data(output_f32, &[batch_size, self.out_features])
}
fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
let mut output = vec![0.0f32; acc.len()];
for (i, &val) in acc.iter().enumerate() {
let of = i % self.out_features;
let scale = input_scale * self.weight_scales[of];
output[i] = val as f32 * scale;
}
output
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantized_linear_creation() {
let linear = Linear::new(128, 64, true).unwrap();
let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
assert_eq!(qlinear.in_features(), 128);
assert_eq!(qlinear.out_features(), 64);
}
#[test]
fn test_quantized_linear_forward() {
let linear = Linear::new(32, 16, true).unwrap();
let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
let batch_size = 4;
let input = vec![128u8; batch_size * 32];
let output = qlinear.forward_int8(&input, batch_size, 0.01, 128).unwrap();
assert_eq!(output.shape(), &[batch_size, 16]);
}
#[test]
fn test_quantized_linear_zero_point_correction() {
let linear = Linear::new(8, 4, true).unwrap();
let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
let input = vec![200u8; 1 * 8];
let output = qlinear.forward_int8(&input, 1, 0.01, 128).unwrap();
assert_eq!(output.shape(), &[1, 4]);
}
}