1use crate::error::{QuantError, QuantResult};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum QuantScheme {
28 Symmetric,
31 Asymmetric,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum QuantGranularity {
38 PerTensor,
40 PerChannel { channel_axis: usize },
42 PerGroup { group_size: usize },
44}
45
46#[derive(Debug, Clone)]
50pub struct QuantParams {
51 pub scales: Vec<f32>,
54 pub zero_points: Vec<i32>,
56 pub bits: u32,
58 pub scheme: QuantScheme,
60}
61
62impl QuantParams {
63 #[must_use]
65 pub fn q_max(&self) -> f32 {
66 match self.scheme {
67 QuantScheme::Symmetric => (1 << (self.bits - 1)) as f32 - 1.0,
68 QuantScheme::Asymmetric => (1 << self.bits) as f32 - 1.0,
69 }
70 }
71
72 #[must_use]
74 pub fn q_min(&self) -> f32 {
75 match self.scheme {
76 QuantScheme::Symmetric => -((1 << (self.bits - 1)) as f32),
77 QuantScheme::Asymmetric => 0.0,
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
86pub struct MinMaxQuantizer {
87 bits: u32,
88 scheme: QuantScheme,
89 granularity: QuantGranularity,
90}
91
92impl MinMaxQuantizer {
93 #[must_use]
99 pub fn new(bits: u32, scheme: QuantScheme, granularity: QuantGranularity) -> Self {
100 assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
101 Self {
102 bits,
103 scheme,
104 granularity,
105 }
106 }
107
108 #[must_use]
110 pub fn int8_symmetric() -> Self {
111 Self::new(8, QuantScheme::Symmetric, QuantGranularity::PerTensor)
112 }
113
114 #[must_use]
116 pub fn int4_per_group(group_size: usize) -> Self {
117 Self::new(
118 4,
119 QuantScheme::Symmetric,
120 QuantGranularity::PerGroup { group_size },
121 )
122 }
123
124 pub fn calibrate(&self, tensor: &[f32]) -> QuantResult<QuantParams> {
135 if tensor.is_empty() {
136 return Err(QuantError::EmptyInput("MinMaxQuantizer::calibrate"));
137 }
138 match self.granularity {
139 QuantGranularity::PerTensor => self.calibrate_slice(tensor),
140 QuantGranularity::PerChannel { channel_axis: _ } => {
141 self.calibrate_slice(tensor)
145 }
146 QuantGranularity::PerGroup { group_size } => {
147 if tensor.len() % group_size != 0 {
148 return Err(QuantError::GroupSizeMismatch {
149 len: tensor.len(),
150 group: group_size,
151 });
152 }
153 let n_groups = tensor.len() / group_size;
154 let mut scales = Vec::with_capacity(n_groups);
155 let mut zero_points = Vec::with_capacity(n_groups);
156 for chunk in tensor.chunks_exact(group_size) {
157 let p = self.calibrate_slice(chunk)?;
158 scales.push(p.scales[0]);
159 zero_points.push(p.zero_points[0]);
160 }
161 Ok(QuantParams {
162 scales,
163 zero_points,
164 bits: self.bits,
165 scheme: self.scheme,
166 })
167 }
168 }
169 }
170
171 pub fn calibrate_2d(
180 &self,
181 tensor: &[f32],
182 rows: usize,
183 cols: usize,
184 ) -> QuantResult<QuantParams> {
185 if rows == 0 {
186 return Err(QuantError::EmptyInput("calibrate_2d: rows == 0"));
187 }
188 if cols == 0 {
189 return Err(QuantError::DimensionMismatch {
190 expected: 1,
191 got: 0,
192 });
193 }
194 let mut scales = Vec::with_capacity(rows);
195 let mut zero_points = Vec::with_capacity(rows);
196 for row in tensor.chunks_exact(cols) {
197 let p = self.calibrate_slice(row)?;
198 scales.push(p.scales[0]);
199 zero_points.push(p.zero_points[0]);
200 }
201 Ok(QuantParams {
202 scales,
203 zero_points,
204 bits: self.bits,
205 scheme: self.scheme,
206 })
207 }
208
209 fn calibrate_slice(&self, slice: &[f32]) -> QuantResult<QuantParams> {
210 let mut fmin = f32::INFINITY;
211 let mut fmax = f32::NEG_INFINITY;
212 for &v in slice {
213 if v < fmin {
214 fmin = v;
215 }
216 if v > fmax {
217 fmax = v;
218 }
219 }
220 let (scale, zp) = match self.scheme {
221 QuantScheme::Symmetric => {
222 let q_max = (1 << (self.bits - 1)) as f32 - 1.0;
223 let abs_max = fmin.abs().max(fmax.abs()).max(1e-8);
224 (abs_max / q_max, 0_i32)
225 }
226 QuantScheme::Asymmetric => {
227 let q_range = ((1 << self.bits) - 1) as f32;
228 let range = (fmax - fmin).max(1e-8);
229 let scale = range / q_range;
230 let zp = (-fmin / scale).round().clamp(0.0, q_range) as i32;
231 (scale, zp)
232 }
233 };
234 if !scale.is_finite() || scale <= 0.0 {
235 return Err(QuantError::InvalidScale { scale });
236 }
237 Ok(QuantParams {
238 scales: vec![scale],
239 zero_points: vec![zp],
240 bits: self.bits,
241 scheme: self.scheme,
242 })
243 }
244
245 pub fn quantize(&self, tensor: &[f32], params: &QuantParams) -> QuantResult<Vec<i32>> {
253 let scale = params.scales[0];
254 if scale <= 0.0 || !scale.is_finite() {
255 return Err(QuantError::InvalidScale { scale });
256 }
257 let q_max = params.q_max();
258 let q_min = params.q_min();
259 let zp = params.zero_points[0] as f32;
260 let codes = tensor
261 .iter()
262 .map(|&x| {
263 let xq = (x / scale + zp).round().clamp(q_min, q_max);
264 xq as i32
265 })
266 .collect();
267 Ok(codes)
268 }
269
270 pub fn quantize_grouped(
276 &self,
277 tensor: &[f32],
278 params: &QuantParams,
279 group_size: usize,
280 ) -> QuantResult<Vec<i32>> {
281 if tensor.len() % group_size != 0 {
282 return Err(QuantError::GroupSizeMismatch {
283 len: tensor.len(),
284 group: group_size,
285 });
286 }
287 let q_max = params.q_max();
288 let q_min = params.q_min();
289 let mut out = Vec::with_capacity(tensor.len());
290 for (g, chunk) in tensor.chunks_exact(group_size).enumerate() {
291 let scale = params.scales[g];
292 let zp = params.zero_points[g] as f32;
293 for &x in chunk {
294 let xq = (x / scale + zp).round().clamp(q_min, q_max);
295 out.push(xq as i32);
296 }
297 }
298 Ok(out)
299 }
300
301 pub fn dequantize(&self, codes: &[i32], params: &QuantParams) -> Vec<f32> {
303 let scale = params.scales[0];
304 let zp = params.zero_points[0];
305 codes.iter().map(|&q| (q - zp) as f32 * scale).collect()
306 }
307
308 pub fn dequantize_grouped(
310 &self,
311 codes: &[i32],
312 params: &QuantParams,
313 group_size: usize,
314 ) -> Vec<f32> {
315 let mut out = Vec::with_capacity(codes.len());
316 for (g, chunk) in codes.chunks_exact(group_size).enumerate() {
317 let scale = params.scales[g];
318 let zp = params.zero_points[g];
319 for &q in chunk {
320 out.push((q - zp) as f32 * scale);
321 }
322 }
323 out
324 }
325}
326
327#[cfg(test)]
330mod tests {
331 use super::*;
332 use approx::assert_abs_diff_eq;
333
334 fn uniform_tensor(n: usize) -> Vec<f32> {
335 (0..n)
336 .map(|i| (i as f32 / (n - 1) as f32) * 2.0 - 1.0)
337 .collect()
338 }
339
340 #[test]
341 fn symmetric_calibrate_scale() {
342 let q = MinMaxQuantizer::int8_symmetric();
343 let t = vec![-2.0_f32, -1.0, 0.5, 2.0];
344 let p = q.calibrate(&t).unwrap();
345 let expected_scale = 2.0 / 127.0;
346 assert_abs_diff_eq!(p.scales[0], expected_scale, epsilon = 1e-6);
347 assert_eq!(p.zero_points[0], 0);
348 }
349
350 #[test]
351 fn asymmetric_calibrate_scale_zp() {
352 let q = MinMaxQuantizer::new(8, QuantScheme::Asymmetric, QuantGranularity::PerTensor);
353 let t = vec![0.0_f32, 1.0, 2.0, 3.0];
354 let p = q.calibrate(&t).unwrap();
355 let expected_scale = 3.0 / 255.0;
357 assert_abs_diff_eq!(p.scales[0], expected_scale, epsilon = 1e-5);
358 assert_eq!(p.zero_points[0], 0);
359 }
360
361 #[test]
362 fn per_group_calibrate() {
363 let q = MinMaxQuantizer::int4_per_group(4);
364 let t = vec![-1.0_f32, 0.0, 0.5, 1.0, -2.0, 0.0, 1.0, 2.0];
365 let p = q.calibrate(&t).unwrap();
366 assert_eq!(p.scales.len(), 2);
367 }
368
369 #[test]
370 fn symmetric_round_trip_low_error() {
371 let q = MinMaxQuantizer::int8_symmetric();
372 let t = uniform_tensor(128);
373 let p = q.calibrate(&t).unwrap();
374 let codes = q.quantize(&t, &p).unwrap();
375 let deq = q.dequantize(&codes, &p);
376 let max_err = t
377 .iter()
378 .zip(deq.iter())
379 .map(|(a, b)| (a - b).abs())
380 .fold(0.0_f32, f32::max);
381 assert!(
382 max_err < 0.02,
383 "max quantization error too large: {max_err}"
384 );
385 }
386
387 #[test]
388 fn grouped_round_trip() {
389 let q = MinMaxQuantizer::int4_per_group(16);
390 let t: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
391 let p = q.calibrate(&t).unwrap();
392 let codes = q.quantize_grouped(&t, &p, 16).unwrap();
393 let deq = q.dequantize_grouped(&codes, &p, 16);
394 let max_err = t
395 .iter()
396 .zip(deq.iter())
397 .map(|(a, b)| (a - b).abs())
398 .fold(0.0_f32, f32::max);
399 assert!(max_err < 0.5, "max per-group error too large: {max_err}");
402 }
403
404 #[test]
405 fn empty_input_error() {
406 let q = MinMaxQuantizer::int8_symmetric();
407 assert!(matches!(q.calibrate(&[]), Err(QuantError::EmptyInput(_))));
408 }
409
410 #[test]
411 fn group_size_mismatch_error() {
412 let q = MinMaxQuantizer::int4_per_group(3);
413 let t = vec![1.0_f32; 10]; assert!(matches!(
415 q.calibrate(&t),
416 Err(QuantError::GroupSizeMismatch { .. })
417 ));
418 }
419
420 #[test]
421 fn q_max_q_min_int8() {
422 let q = MinMaxQuantizer::int8_symmetric();
423 let p = q.calibrate(&[1.0_f32]).unwrap();
424 assert_abs_diff_eq!(p.q_max(), 127.0, epsilon = 1e-6);
425 assert_abs_diff_eq!(p.q_min(), -128.0, epsilon = 1e-6);
426 }
427
428 #[test]
429 fn calibrate_2d_per_row() {
430 let q = MinMaxQuantizer::int8_symmetric();
431 let t = vec![
433 0.0_f32, 1.0, -1.0, 0.5, 0.0, 2.0, -2.0, 1.5,
435 ]; let p = q.calibrate_2d(&t, 2, 4).unwrap();
437 assert_eq!(p.scales.len(), 2);
438 assert!(p.scales[1] > p.scales[0], "row1 scale should be larger");
439 }
440}