1use std::fmt;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum QuantScheme {
15 Symmetric,
17 Asymmetric,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum QuantBits {
24 Int8,
26 Int4,
28}
29
30#[derive(Debug, Clone)]
32pub struct QuantParams {
33 pub scale: f32,
34 pub zero_point: i32,
35 pub bits: QuantBits,
36 pub scheme: QuantScheme,
37 pub per_channel_scales: Option<Vec<f32>>,
39}
40
41impl QuantParams {
42 pub fn from_range(min_val: f32, max_val: f32, bits: QuantBits, scheme: QuantScheme) -> Self {
44 let (qmin, qmax) = match bits {
45 QuantBits::Int8 => (-128i32, 127i32),
46 QuantBits::Int4 => (-8i32, 7i32),
47 };
48
49 match scheme {
50 QuantScheme::Symmetric => {
51 let abs_max = min_val.abs().max(max_val.abs());
52 let scale = abs_max / qmax as f32;
53 Self {
54 scale: if scale == 0.0 { 1.0 } else { scale },
55 zero_point: 0,
56 bits,
57 scheme,
58 per_channel_scales: None,
59 }
60 }
61 QuantScheme::Asymmetric => {
62 let range = max_val - min_val;
63 let scale = range / (qmax - qmin) as f32;
64 let zero_point = (qmin as f32 - min_val / scale).round() as i32;
65 Self {
66 scale: if scale == 0.0 { 1.0 } else { scale },
67 zero_point: zero_point.clamp(qmin, qmax),
68 bits,
69 scheme,
70 per_channel_scales: None,
71 }
72 }
73 }
74 }
75
76 pub fn calibrate(data: &[f32], bits: QuantBits, scheme: QuantScheme) -> Self {
78 if data.is_empty() {
79 return Self { scale: 1.0, zero_point: 0, bits, scheme, per_channel_scales: None };
80 }
81 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
82 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
83 Self::from_range(min_val, max_val, bits, scheme)
84 }
85}
86
87pub fn quantize_int8(data: &[f32], params: &QuantParams) -> Vec<i8> {
89 data.iter().map(|&x| {
90 let q = (x / params.scale).round() as i32 + params.zero_point;
91 q.clamp(-128, 127) as i8
92 }).collect()
93}
94
95pub fn dequantize_int8(data: &[i8], params: &QuantParams) -> Vec<f32> {
97 data.iter().map(|&q| {
98 (q as i32 - params.zero_point) as f32 * params.scale
99 }).collect()
100}
101
102pub fn quantize_int4(data: &[f32], params: &QuantParams) -> Vec<u8> {
104 let mut packed = Vec::with_capacity((data.len() + 1) / 2);
105 for chunk in data.chunks(2) {
106 let lo = {
107 let q = (chunk[0] / params.scale).round() as i32 + params.zero_point;
108 (q.clamp(-8, 7) & 0x0F) as u8
109 };
110 let hi = if chunk.len() > 1 {
111 let q = (chunk[1] / params.scale).round() as i32 + params.zero_point;
112 ((q.clamp(-8, 7) & 0x0F) as u8) << 4
113 } else {
114 0
115 };
116 packed.push(lo | hi);
117 }
118 packed
119}
120
121pub fn dequantize_int4(data: &[u8], count: usize, params: &QuantParams) -> Vec<f32> {
123 let mut result = Vec::with_capacity(count);
124 for &byte in data {
125 if result.len() >= count { break; }
126 let lo = (byte & 0x0F) as i8;
128 let lo = if lo & 0x08 != 0 { lo | !0x0F_u8 as i8 } else { lo }; result.push((lo as i32 - params.zero_point) as f32 * params.scale);
130
131 if result.len() >= count { break; }
132 let hi = ((byte >> 4) & 0x0F) as i8;
134 let hi = if hi & 0x08 != 0 { hi | !0x0F_u8 as i8 } else { hi };
135 result.push((hi as i32 - params.zero_point) as f32 * params.scale);
136 }
137 result
138}
139
140pub fn quantized_gemm_int8(
143 a: &[i8], b: &[i8],
144 m: usize, k: usize, n: usize,
145 a_params: &QuantParams, b_params: &QuantParams,
146) -> Vec<f32> {
147 let mut c = vec![0i32; m * n];
148 for i in 0..m {
149 for p in 0..k {
150 let a_val = a[i * k + p] as i32 - a_params.zero_point;
151 for j in 0..n {
152 let b_val = b[p * n + j] as i32 - b_params.zero_point;
153 c[i * n + j] += a_val * b_val;
154 }
155 }
156 }
157 let output_scale = a_params.scale * b_params.scale;
159 c.iter().map(|&v| v as f32 * output_scale).collect()
160}
161
162pub fn quantization_error(original: &[f32], params: &QuantParams) -> QuantError {
164 let quantized = quantize_int8(original, params);
165 let dequantized = dequantize_int8(&quantized, params);
166
167 let mse: f64 = original.iter().zip(dequantized.iter())
168 .map(|(&o, &d)| ((o - d) as f64).powi(2))
169 .sum::<f64>() / original.len() as f64;
170
171 let max_error = original.iter().zip(dequantized.iter())
172 .map(|(&o, &d)| (o - d).abs())
173 .fold(0.0f32, f32::max);
174
175 let signal_power: f64 = original.iter().map(|&x| (x as f64).powi(2)).sum::<f64>() / original.len() as f64;
176 let snr = if mse > 0.0 { 10.0 * (signal_power / mse).log10() } else { f64::INFINITY };
177
178 QuantError {
179 mse: mse as f32,
180 max_error,
181 snr_db: snr as f32,
182 compression_ratio: match params.bits {
183 QuantBits::Int8 => 4.0, QuantBits::Int4 => 8.0, },
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct QuantError {
192 pub mse: f32,
193 pub max_error: f32,
194 pub snr_db: f32,
195 pub compression_ratio: f32,
196}
197
198impl fmt::Display for QuantError {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 write!(f, "QuantError: MSE={:.6}, MaxErr={:.4}, SNR={:.1}dB, {}x compression",
201 self.mse, self.max_error, self.snr_db, self.compression_ratio)
202 }
203}
204
205#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_symmetric_int8_roundtrip() {
213 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
214 let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
215 let quantized = quantize_int8(&data, ¶ms);
216 let dequantized = dequantize_int8(&quantized, ¶ms);
217
218 for i in 0..data.len() {
219 assert!((data[i] - dequantized[i]).abs() < 0.02,
220 "Mismatch at {}: original={}, dequantized={}", i, data[i], dequantized[i]);
221 }
222 }
223
224 #[test]
225 fn test_asymmetric_int8() {
226 let data = vec![0.0, 0.25, 0.5, 0.75, 1.0];
227 let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Asymmetric);
228 let quantized = quantize_int8(&data, ¶ms);
229 let dequantized = dequantize_int8(&quantized, ¶ms);
230
231 for i in 0..data.len() {
232 assert!((data[i] - dequantized[i]).abs() < 0.02,
233 "Asymmetric mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
234 }
235 }
236
237 #[test]
238 fn test_int4_quantization() {
239 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0, 1.5];
240 let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
241 let packed = quantize_int4(&data, ¶ms);
242 let dequantized = dequantize_int4(&packed, data.len(), ¶ms);
243
244 assert_eq!(dequantized.len(), data.len());
245 for i in 0..data.len() {
247 assert!((data[i] - dequantized[i]).abs() < 0.5,
248 "INT4 mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
249 }
250 }
251
252 #[test]
253 fn test_int4_packing() {
254 let data = vec![0.0, 0.0, 0.0, 0.0]; let params = QuantParams::from_range(-1.0, 1.0, QuantBits::Int4, QuantScheme::Symmetric);
256 let packed = quantize_int4(&data, ¶ms);
257 assert_eq!(packed.len(), 2);
258 }
259
260 #[test]
261 fn test_quantized_gemm() {
262 let a_f32 = vec![1.0f32, 2.0, 3.0, 4.0];
264 let b_f32 = vec![5.0f32, 6.0, 7.0, 8.0];
265
266 let a_params = QuantParams::calibrate(&a_f32, QuantBits::Int8, QuantScheme::Symmetric);
267 let b_params = QuantParams::calibrate(&b_f32, QuantBits::Int8, QuantScheme::Symmetric);
268
269 let a_q = quantize_int8(&a_f32, &a_params);
270 let b_q = quantize_int8(&b_f32, &b_params);
271
272 let c = quantized_gemm_int8(&a_q, &b_q, 2, 2, 2, &a_params, &b_params);
273 assert!((c[0] - 19.0).abs() < 1.0, "Got {}", c[0]);
275 assert!((c[1] - 22.0).abs() < 1.0, "Got {}", c[1]);
276 assert!((c[2] - 43.0).abs() < 1.5, "Got {}", c[2]);
277 assert!((c[3] - 50.0).abs() < 1.5, "Got {}", c[3]);
278 }
279
280 #[test]
281 fn test_quantization_error() {
282 let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
283 let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
284 let error = quantization_error(&data, ¶ms);
285
286 assert!(error.mse < 0.001, "MSE too high: {}", error.mse);
287 assert!(error.snr_db > 30.0, "SNR too low: {}dB", error.snr_db);
288 assert_eq!(error.compression_ratio, 4.0);
289 }
290
291 #[test]
292 fn test_quantization_error_int4() {
293 let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
294 let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
295 let error = quantization_error(&data, ¶ms);
296 assert_eq!(error.compression_ratio, 8.0);
297 }
299
300 #[test]
301 fn test_zero_range_calibration() {
302 let data = vec![0.0, 0.0, 0.0];
303 let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
304 assert_eq!(params.scale, 1.0); }
306
307 #[test]
308 fn test_empty_calibration() {
309 let params = QuantParams::calibrate(&[], QuantBits::Int8, QuantScheme::Symmetric);
310 assert_eq!(params.scale, 1.0);
311 }
312
313 #[test]
314 fn test_quant_error_display() {
315 let error = QuantError { mse: 0.001, max_error: 0.01, snr_db: 40.0, compression_ratio: 4.0 };
316 let s = format!("{}", error);
317 assert!(s.contains("MSE"));
318 assert!(s.contains("4x"));
319 }
320}