use crate::error::{NpuError, Result};
use ndarray::ArrayD;
#[derive(Debug, Clone)]
pub struct QuantStats {
pub min_val: f32,
pub max_val: f32,
pub mean_val: f32,
pub std_val: f32,
}
impl QuantStats {
pub fn from_tensor(data: &ArrayD<f32>) -> Self {
let values: Vec<f32> = data.iter().cloned().collect();
let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mean_val = values.iter().sum::<f32>() / values.len() as f32;
let variance = values.iter()
.map(|v| (v - mean_val).powi(2))
.sum::<f32>() / values.len() as f32;
let std_val = variance.sqrt();
Self {
min_val,
max_val,
mean_val,
std_val,
}
}
pub fn get_scale(&self, num_bits: u32) -> f32 {
let levels = (1u64 << num_bits) as f32;
self.max_val / (levels - 1.0)
}
pub fn get_zero_point(&self, num_bits: u32, signed: bool) -> i32 {
let levels = (1u64 << num_bits) as f32;
if signed {
(-(levels / 2.0)) as i32
} else {
0
}
}
}
pub struct QuantConverter {
scale: f32,
zero_point: i32,
num_bits: u32,
}
impl QuantConverter {
pub fn new(stats: &QuantStats, num_bits: u32, signed: bool) -> Self {
Self {
scale: stats.get_scale(num_bits),
zero_point: stats.get_zero_point(num_bits, signed),
num_bits,
}
}
pub fn quantize(&self, value: f32) -> i32 {
let quantized = (value / self.scale) as i32 + self.zero_point;
let max_val = (1i64 << self.num_bits) as i32 - 1;
let min_val = -(1i64 << (self.num_bits - 1)) as i32;
quantized.max(min_val).min(max_val)
}
pub fn dequantize(&self, value: i32) -> f32 {
((value - self.zero_point) as f32) * self.scale
}
pub fn quantize_tensor(&self, tensor: &ArrayD<f32>) -> Result<Vec<i32>> {
Ok(tensor.iter().map(|&v| self.quantize(v)).collect())
}
pub fn dequantize_tensor(&self, quantized: &[i32]) -> Result<ArrayD<f32>> {
let values: Vec<f32> = quantized.iter().map(|&v| self.dequantize(v)).collect();
Ok(ArrayD::from_shape_vec(
ndarray::IxDyn(&[quantized.len()]),
values,
).map_err(|_| NpuError::InvalidShape("Failed to reshape".to_string()))?)
}
}
pub struct PTQEngine {
num_bits: u32,
signed: bool,
}
impl PTQEngine {
pub fn new(num_bits: u32, signed: bool) -> Self {
Self { num_bits, signed }
}
pub fn calibrate(&self, sample_data: &[ArrayD<f32>]) -> Result<QuantConverter> {
if sample_data.is_empty() {
return Err(NpuError::InvalidConfiguration(
"No calibration data provided".to_string(),
));
}
let mut all_values = Vec::new();
for tensor in sample_data {
all_values.extend(tensor.iter().cloned());
}
let combined = ArrayD::from_shape_vec(
ndarray::IxDyn(&[all_values.len()]),
all_values,
).map_err(|_| NpuError::InvalidShape("Failed to calibrate".to_string()))?;
let stats = QuantStats::from_tensor(&combined);
Ok(QuantConverter::new(&stats, self.num_bits, self.signed))
}
}