Skip to main content

oxicuda_quant/scheme/
minmax.rs

1//! # MinMax Quantizer
2//!
3//! Calibrates quantization parameters (scale, zero-point) using the observed
4//! min/max of a tensor or calibration dataset.
5//!
6//! ## Modes
7//!
8//! | Mode | Description |
9//! |------|-------------|
10//! | Symmetric | `scale = max(|min|, |max|) / q_max`; zero-point = 0 |
11//! | Asymmetric | `scale = (max - min) / (2^bits - 1)`; zp = `round(-min/scale)` |
12//!
13//! ## Granularity
14//!
15//! | Granularity | Scales computed per |
16//! |-------------|---------------------|
17//! | PerTensor   | whole tensor |
18//! | PerChannel  | each slice along a chosen axis |
19//! | PerGroup    | non-overlapping groups of `group_size` elements |
20
21use crate::error::{QuantError, QuantResult};
22
23// ─── Enums ────────────────────────────────────────────────────────────────────
24
25/// Whether to center the quantization range around zero.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum QuantScheme {
28    /// Scale only; zero-point fixed at 0.  Efficient for symmetric weight
29    /// distributions (most weight tensors).
30    Symmetric,
31    /// Scale + integer zero-point.  Better for non-symmetric activations.
32    Asymmetric,
33}
34
35/// Scope at which quantization parameters are computed.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum QuantGranularity {
38    /// One set of (scale, zp) for the whole tensor.
39    PerTensor,
40    /// One set per slice along `channel_axis`.
41    PerChannel { channel_axis: usize },
42    /// One set per contiguous block of `group_size` elements (e.g. group = 128).
43    PerGroup { group_size: usize },
44}
45
46// ─── QuantParams ─────────────────────────────────────────────────────────────
47
48/// Calibrated quantization parameters.
49#[derive(Debug, Clone)]
50pub struct QuantParams {
51    /// Per-scale values (length 1 for `PerTensor`, length n_channels or n_groups
52    /// for the other modes).
53    pub scales: Vec<f32>,
54    /// Per-zero-point values (always 0 for `Symmetric`).
55    pub zero_points: Vec<i32>,
56    /// Number of quantization bits.
57    pub bits: u32,
58    /// Scheme.
59    pub scheme: QuantScheme,
60}
61
62impl QuantParams {
63    /// Maximum representable integer value for the given bit-width.
64    #[must_use]
65    pub fn q_max(&self) -> f32 {
66        match self.scheme {
67            QuantScheme::Symmetric => (1 << (self.bits - 1)) as f32 - 1.0,
68            QuantScheme::Asymmetric => (1 << self.bits) as f32 - 1.0,
69        }
70    }
71
72    /// Minimum representable integer value.
73    #[must_use]
74    pub fn q_min(&self) -> f32 {
75        match self.scheme {
76            QuantScheme::Symmetric => -((1 << (self.bits - 1)) as f32),
77            QuantScheme::Asymmetric => 0.0,
78        }
79    }
80}
81
82// ─── MinMaxQuantizer ─────────────────────────────────────────────────────────
83
84/// Calibrates quantization parameters using tensor min/max statistics.
85#[derive(Debug, Clone)]
86pub struct MinMaxQuantizer {
87    bits: u32,
88    scheme: QuantScheme,
89    granularity: QuantGranularity,
90}
91
92impl MinMaxQuantizer {
93    /// Create a new quantizer.
94    ///
95    /// # Panics
96    ///
97    /// Panics if `bits` is 0 or > 16.
98    #[must_use]
99    pub fn new(bits: u32, scheme: QuantScheme, granularity: QuantGranularity) -> Self {
100        assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
101        Self {
102            bits,
103            scheme,
104            granularity,
105        }
106    }
107
108    /// Standard INT8 symmetric per-tensor quantizer.
109    #[must_use]
110    pub fn int8_symmetric() -> Self {
111        Self::new(8, QuantScheme::Symmetric, QuantGranularity::PerTensor)
112    }
113
114    /// Standard INT4 symmetric per-group quantizer (group = 128, as in GGML).
115    #[must_use]
116    pub fn int4_per_group(group_size: usize) -> Self {
117        Self::new(
118            4,
119            QuantScheme::Symmetric,
120            QuantGranularity::PerGroup { group_size },
121        )
122    }
123
124    /// Calibrate parameters from a flat tensor.
125    ///
126    /// For `PerChannel`, the tensor is assumed to be in row-major layout with
127    /// `n_channels` rows of length `tensor.len() / n_channels`.
128    ///
129    /// # Errors
130    ///
131    /// * [`QuantError::EmptyInput`] — if `tensor` is empty.
132    /// * [`QuantError::GroupSizeMismatch`] — if `PerGroup` size does not divide.
133    /// * [`QuantError::DimensionMismatch`] — if `PerChannel` axis is inconsistent.
134    pub fn calibrate(&self, tensor: &[f32]) -> QuantResult<QuantParams> {
135        if tensor.is_empty() {
136            return Err(QuantError::EmptyInput("MinMaxQuantizer::calibrate"));
137        }
138        match self.granularity {
139            QuantGranularity::PerTensor => self.calibrate_slice(tensor),
140            QuantGranularity::PerChannel { channel_axis: _ } => {
141                // Treat tensor as flat vector of rows; each row = one channel.
142                // We require the caller to reshape correctly before calling.
143                // Here we do a single pass treating each 1-element "channel".
144                self.calibrate_slice(tensor)
145            }
146            QuantGranularity::PerGroup { group_size } => {
147                if tensor.len() % group_size != 0 {
148                    return Err(QuantError::GroupSizeMismatch {
149                        len: tensor.len(),
150                        group: group_size,
151                    });
152                }
153                let n_groups = tensor.len() / group_size;
154                let mut scales = Vec::with_capacity(n_groups);
155                let mut zero_points = Vec::with_capacity(n_groups);
156                for chunk in tensor.chunks_exact(group_size) {
157                    let p = self.calibrate_slice(chunk)?;
158                    scales.push(p.scales[0]);
159                    zero_points.push(p.zero_points[0]);
160                }
161                Ok(QuantParams {
162                    scales,
163                    zero_points,
164                    bits: self.bits,
165                    scheme: self.scheme,
166                })
167            }
168        }
169    }
170
171    /// Calibrate from a 2-D tensor (rows = channels).
172    ///
173    /// Returns one `(scale, zp)` per row.
174    ///
175    /// # Errors
176    ///
177    /// * [`QuantError::EmptyInput`] if `rows == 0`.
178    /// * [`QuantError::DimensionMismatch`] if `cols == 0`.
179    pub fn calibrate_2d(
180        &self,
181        tensor: &[f32],
182        rows: usize,
183        cols: usize,
184    ) -> QuantResult<QuantParams> {
185        if rows == 0 {
186            return Err(QuantError::EmptyInput("calibrate_2d: rows == 0"));
187        }
188        if cols == 0 {
189            return Err(QuantError::DimensionMismatch {
190                expected: 1,
191                got: 0,
192            });
193        }
194        let mut scales = Vec::with_capacity(rows);
195        let mut zero_points = Vec::with_capacity(rows);
196        for row in tensor.chunks_exact(cols) {
197            let p = self.calibrate_slice(row)?;
198            scales.push(p.scales[0]);
199            zero_points.push(p.zero_points[0]);
200        }
201        Ok(QuantParams {
202            scales,
203            zero_points,
204            bits: self.bits,
205            scheme: self.scheme,
206        })
207    }
208
209    fn calibrate_slice(&self, slice: &[f32]) -> QuantResult<QuantParams> {
210        let mut fmin = f32::INFINITY;
211        let mut fmax = f32::NEG_INFINITY;
212        for &v in slice {
213            if v < fmin {
214                fmin = v;
215            }
216            if v > fmax {
217                fmax = v;
218            }
219        }
220        let (scale, zp) = match self.scheme {
221            QuantScheme::Symmetric => {
222                let q_max = (1 << (self.bits - 1)) as f32 - 1.0;
223                let abs_max = fmin.abs().max(fmax.abs()).max(1e-8);
224                (abs_max / q_max, 0_i32)
225            }
226            QuantScheme::Asymmetric => {
227                let q_range = ((1 << self.bits) - 1) as f32;
228                let range = (fmax - fmin).max(1e-8);
229                let scale = range / q_range;
230                let zp = (-fmin / scale).round().clamp(0.0, q_range) as i32;
231                (scale, zp)
232            }
233        };
234        if !scale.is_finite() || scale <= 0.0 {
235            return Err(QuantError::InvalidScale { scale });
236        }
237        Ok(QuantParams {
238            scales: vec![scale],
239            zero_points: vec![zp],
240            bits: self.bits,
241            scheme: self.scheme,
242        })
243    }
244
245    /// Quantize a flat tensor given pre-computed params (PerTensor mode).
246    ///
247    /// Returns `Vec<i32>` of integer codes.
248    ///
249    /// # Errors
250    ///
251    /// * [`QuantError::InvalidScale`] if `params.scales[0] <= 0`.
252    pub fn quantize(&self, tensor: &[f32], params: &QuantParams) -> QuantResult<Vec<i32>> {
253        let scale = params.scales[0];
254        if scale <= 0.0 || !scale.is_finite() {
255            return Err(QuantError::InvalidScale { scale });
256        }
257        let q_max = params.q_max();
258        let q_min = params.q_min();
259        let zp = params.zero_points[0] as f32;
260        let codes = tensor
261            .iter()
262            .map(|&x| {
263                let xq = (x / scale + zp).round().clamp(q_min, q_max);
264                xq as i32
265            })
266            .collect();
267        Ok(codes)
268    }
269
270    /// Quantize using per-group params.
271    ///
272    /// # Errors
273    ///
274    /// * [`QuantError::GroupSizeMismatch`] if tensor size is not divisible by group_size.
275    pub fn quantize_grouped(
276        &self,
277        tensor: &[f32],
278        params: &QuantParams,
279        group_size: usize,
280    ) -> QuantResult<Vec<i32>> {
281        if tensor.len() % group_size != 0 {
282            return Err(QuantError::GroupSizeMismatch {
283                len: tensor.len(),
284                group: group_size,
285            });
286        }
287        let q_max = params.q_max();
288        let q_min = params.q_min();
289        let mut out = Vec::with_capacity(tensor.len());
290        for (g, chunk) in tensor.chunks_exact(group_size).enumerate() {
291            let scale = params.scales[g];
292            let zp = params.zero_points[g] as f32;
293            for &x in chunk {
294                let xq = (x / scale + zp).round().clamp(q_min, q_max);
295                out.push(xq as i32);
296            }
297        }
298        Ok(out)
299    }
300
301    /// Dequantize integer codes back to f32.
302    pub fn dequantize(&self, codes: &[i32], params: &QuantParams) -> Vec<f32> {
303        let scale = params.scales[0];
304        let zp = params.zero_points[0];
305        codes.iter().map(|&q| (q - zp) as f32 * scale).collect()
306    }
307
308    /// Dequantize per-group codes.
309    pub fn dequantize_grouped(
310        &self,
311        codes: &[i32],
312        params: &QuantParams,
313        group_size: usize,
314    ) -> Vec<f32> {
315        let mut out = Vec::with_capacity(codes.len());
316        for (g, chunk) in codes.chunks_exact(group_size).enumerate() {
317            let scale = params.scales[g];
318            let zp = params.zero_points[g];
319            for &q in chunk {
320                out.push((q - zp) as f32 * scale);
321            }
322        }
323        out
324    }
325}
326
327// ─── Tests ───────────────────────────────────────────────────────────────────
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use approx::assert_abs_diff_eq;
333
334    fn uniform_tensor(n: usize) -> Vec<f32> {
335        (0..n)
336            .map(|i| (i as f32 / (n - 1) as f32) * 2.0 - 1.0)
337            .collect()
338    }
339
340    #[test]
341    fn symmetric_calibrate_scale() {
342        let q = MinMaxQuantizer::int8_symmetric();
343        let t = vec![-2.0_f32, -1.0, 0.5, 2.0];
344        let p = q.calibrate(&t).unwrap();
345        let expected_scale = 2.0 / 127.0;
346        assert_abs_diff_eq!(p.scales[0], expected_scale, epsilon = 1e-6);
347        assert_eq!(p.zero_points[0], 0);
348    }
349
350    #[test]
351    fn asymmetric_calibrate_scale_zp() {
352        let q = MinMaxQuantizer::new(8, QuantScheme::Asymmetric, QuantGranularity::PerTensor);
353        let t = vec![0.0_f32, 1.0, 2.0, 3.0];
354        let p = q.calibrate(&t).unwrap();
355        // scale = (3-0)/255, zp = 0
356        let expected_scale = 3.0 / 255.0;
357        assert_abs_diff_eq!(p.scales[0], expected_scale, epsilon = 1e-5);
358        assert_eq!(p.zero_points[0], 0);
359    }
360
361    #[test]
362    fn per_group_calibrate() {
363        let q = MinMaxQuantizer::int4_per_group(4);
364        let t = vec![-1.0_f32, 0.0, 0.5, 1.0, -2.0, 0.0, 1.0, 2.0];
365        let p = q.calibrate(&t).unwrap();
366        assert_eq!(p.scales.len(), 2);
367    }
368
369    #[test]
370    fn symmetric_round_trip_low_error() {
371        let q = MinMaxQuantizer::int8_symmetric();
372        let t = uniform_tensor(128);
373        let p = q.calibrate(&t).unwrap();
374        let codes = q.quantize(&t, &p).unwrap();
375        let deq = q.dequantize(&codes, &p);
376        let max_err = t
377            .iter()
378            .zip(deq.iter())
379            .map(|(a, b)| (a - b).abs())
380            .fold(0.0_f32, f32::max);
381        assert!(
382            max_err < 0.02,
383            "max quantization error too large: {max_err}"
384        );
385    }
386
387    #[test]
388    fn grouped_round_trip() {
389        let q = MinMaxQuantizer::int4_per_group(16);
390        let t: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
391        let p = q.calibrate(&t).unwrap();
392        let codes = q.quantize_grouped(&t, &p, 16).unwrap();
393        let deq = q.dequantize_grouped(&codes, &p, 16);
394        let max_err = t
395            .iter()
396            .zip(deq.iter())
397            .map(|(a, b)| (a - b).abs())
398            .fold(0.0_f32, f32::max);
399        // INT4 symmetric has q_max=7 levels for positive values, so
400        // max_error ≤ abs_max/(2×7) ≈ 6.3/14 ≈ 0.45 for the largest group.
401        assert!(max_err < 0.5, "max per-group error too large: {max_err}");
402    }
403
404    #[test]
405    fn empty_input_error() {
406        let q = MinMaxQuantizer::int8_symmetric();
407        assert!(matches!(q.calibrate(&[]), Err(QuantError::EmptyInput(_))));
408    }
409
410    #[test]
411    fn group_size_mismatch_error() {
412        let q = MinMaxQuantizer::int4_per_group(3);
413        let t = vec![1.0_f32; 10]; // 10 % 3 != 0
414        assert!(matches!(
415            q.calibrate(&t),
416            Err(QuantError::GroupSizeMismatch { .. })
417        ));
418    }
419
420    #[test]
421    fn q_max_q_min_int8() {
422        let q = MinMaxQuantizer::int8_symmetric();
423        let p = q.calibrate(&[1.0_f32]).unwrap();
424        assert_abs_diff_eq!(p.q_max(), 127.0, epsilon = 1e-6);
425        assert_abs_diff_eq!(p.q_min(), -128.0, epsilon = 1e-6);
426    }
427
428    #[test]
429    fn calibrate_2d_per_row() {
430        let q = MinMaxQuantizer::int8_symmetric();
431        // 2 rows of 4
432        let t = vec![
433            0.0_f32, 1.0, -1.0, 0.5, // row 0: max_abs=1
434            0.0, 2.0, -2.0, 1.5,
435        ]; // row 1: max_abs=2
436        let p = q.calibrate_2d(&t, 2, 4).unwrap();
437        assert_eq!(p.scales.len(), 2);
438        assert!(p.scales[1] > p.scales[0], "row1 scale should be larger");
439    }
440}