1use crate::error::{NpuError, Result};
2use ndarray::ArrayD;
3
4#[derive(Debug, Clone)]
6pub struct QuantStats {
7 pub min_val: f32,
8 pub max_val: f32,
9 pub mean_val: f32,
10 pub std_val: f32,
11}
12
13impl QuantStats {
14 pub fn from_tensor(data: &ArrayD<f32>) -> Self {
16 let values: Vec<f32> = data.iter().cloned().collect();
17
18 let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
19 let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
20
21 let mean_val = values.iter().sum::<f32>() / values.len() as f32;
22 let variance = values.iter()
23 .map(|v| (v - mean_val).powi(2))
24 .sum::<f32>() / values.len() as f32;
25 let std_val = variance.sqrt();
26
27 Self {
28 min_val,
29 max_val,
30 mean_val,
31 std_val,
32 }
33 }
34
35 pub fn get_scale(&self, num_bits: u32) -> f32 {
37 let levels = (1u64 << num_bits) as f32;
38 self.max_val / (levels - 1.0)
39 }
40
41 pub fn get_zero_point(&self, num_bits: u32, signed: bool) -> i32 {
43 let levels = (1u64 << num_bits) as f32;
44 if signed {
45 (-(levels / 2.0)) as i32
46 } else {
47 0
48 }
49 }
50}
51
52pub struct QuantConverter {
54 scale: f32,
55 zero_point: i32,
56 num_bits: u32,
57}
58
59impl QuantConverter {
60 pub fn new(stats: &QuantStats, num_bits: u32, signed: bool) -> Self {
62 Self {
63 scale: stats.get_scale(num_bits),
64 zero_point: stats.get_zero_point(num_bits, signed),
65 num_bits,
66 }
67 }
68
69 pub fn quantize(&self, value: f32) -> i32 {
71 let quantized = (value / self.scale) as i32 + self.zero_point;
72 let max_val = (1i64 << self.num_bits) as i32 - 1;
73 let min_val = -(1i64 << (self.num_bits - 1)) as i32;
74 quantized.max(min_val).min(max_val)
75 }
76
77 pub fn dequantize(&self, value: i32) -> f32 {
79 ((value - self.zero_point) as f32) * self.scale
80 }
81
82 pub fn quantize_tensor(&self, tensor: &ArrayD<f32>) -> Result<Vec<i32>> {
84 Ok(tensor.iter().map(|&v| self.quantize(v)).collect())
85 }
86
87 pub fn dequantize_tensor(&self, quantized: &[i32]) -> Result<ArrayD<f32>> {
89 let values: Vec<f32> = quantized.iter().map(|&v| self.dequantize(v)).collect();
90 Ok(ArrayD::from_shape_vec(
91 ndarray::IxDyn(&[quantized.len()]),
92 values,
93 ).map_err(|_| NpuError::InvalidShape("Failed to reshape".to_string()))?)
94 }
95}
96
97pub struct PTQEngine {
99 num_bits: u32,
100 signed: bool,
101}
102
103impl PTQEngine {
104 pub fn new(num_bits: u32, signed: bool) -> Self {
106 Self { num_bits, signed }
107 }
108
109 pub fn calibrate(&self, sample_data: &[ArrayD<f32>]) -> Result<QuantConverter> {
111 if sample_data.is_empty() {
112 return Err(NpuError::InvalidConfiguration(
113 "No calibration data provided".to_string(),
114 ));
115 }
116
117 let mut all_values = Vec::new();
118 for tensor in sample_data {
119 all_values.extend(tensor.iter().cloned());
120 }
121
122 let combined = ArrayD::from_shape_vec(
123 ndarray::IxDyn(&[all_values.len()]),
124 all_values,
125 ).map_err(|_| NpuError::InvalidShape("Failed to calibrate".to_string()))?;
126
127 let stats = QuantStats::from_tensor(&combined);
128 Ok(QuantConverter::new(&stats, self.num_bits, self.signed))
129 }
130}