Skip to main content

quantize_rs/quantization/
mod.rs

1//! Core quantization logic for INT8 and INT4.
2//!
3//! Provides tensor-level quantization (per-tensor and per-channel),
4//! INT4 bit-packing, and the high-level [`Quantizer`] that combines
5//! a [`QuantConfig`] with optional calibration statistics.
6
7use crate::errors::{QuantizeError, Result};
8
9/// Configuration for a quantization pass.
10#[derive(Debug, Clone)]
11pub struct QuantConfig {
12    /// Bit width: `4` for INT4 or `8` for INT8.
13    pub bits: u8,
14    /// When `true`, compute separate scale/zero-point per output channel (axis 0).
15    pub per_channel: bool,
16    /// When `true`, force `zero_point == 0` (symmetric quantization) — required
17    /// by most ONNX Runtime / TensorRT INT8 matmul kernels for per-channel
18    /// weight quantization.  Defaults to `false` (asymmetric).
19    pub symmetric: bool,
20    /// Optional calibration method used for range optimization.
21    pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
22    /// Layer names to skip entirely (exact match against the initializer name).
23    pub excluded_layers: Vec<String>,
24    /// Per-layer bit-width overrides.  Key = initializer name, value = 4 or 8.
25    pub layer_bits: std::collections::HashMap<String, u8>,
26    /// Minimum number of elements a tensor must have to be quantized.
27    /// Tensors with fewer elements are left in FP32.  Defaults to 0 (no minimum).
28    pub min_elements: usize,
29}
30
31impl Default for QuantConfig {
32    fn default() -> Self {
33        Self {
34            bits: 8,
35            per_channel: false,
36            symmetric: false,
37            calibration_method: None,
38            excluded_layers: Vec::new(),
39            layer_bits: std::collections::HashMap::new(),
40            min_elements: 0,
41        }
42    }
43}
44
45impl QuantConfig {
46    /// Create a default INT8 per-tensor configuration.
47    pub fn int8() -> Self {
48        Self::default()
49    }
50
51    /// Enable or disable per-channel quantization.
52    pub fn with_per_channel(mut self, enabled: bool) -> Self {
53        self.per_channel = enabled;
54        self
55    }
56
57    /// Enable or disable symmetric quantization (`zero_point == 0`).
58    pub fn with_symmetric(mut self, enabled: bool) -> Self {
59        self.symmetric = enabled;
60        self
61    }
62
63    /// Set the calibration method for range optimization.
64    pub fn with_calibration(
65        mut self,
66        method: crate::calibration::methods::CalibrationMethod,
67    ) -> Self {
68        self.calibration_method = Some(method);
69        self
70    }
71
72    /// Return `true` if the layer should be quantized.
73    ///
74    /// A layer is skipped when:
75    /// - its name appears in [`excluded_layers`], or
76    /// - `num_elements` is below [`min_elements`] (and `min_elements > 0`).
77    pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
78        if self.excluded_layers.iter().any(|e| e == name) {
79            return false;
80        }
81        if self.min_elements > 0 && num_elements < self.min_elements {
82            return false;
83        }
84        true
85    }
86
87    /// Return the effective bit width for a layer.
88    ///
89    /// If the layer name has an entry in [`layer_bits`], that value is used;
90    /// otherwise the global [`bits`] is returned.
91    pub fn bits_for_layer(&self, name: &str) -> u8 {
92        self.layer_bits.get(name).copied().unwrap_or(self.bits)
93    }
94}
95
96// ---------------------------------------------------------------------------
97// QuantRange trait and marker types
98// ---------------------------------------------------------------------------
99
100/// Marker trait that supplies the clamp constants for a quantization bit-width.
101pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
102    /// Minimum quantized value (inclusive).
103    const QMIN: f32;
104    /// Maximum quantized value (inclusive).
105    const QMAX: f32;
106    /// Bit width (4 or 8).
107    const BITS: u8;
108}
109
110/// Marker for INT8 quantization (`-128 … 127`).
111#[derive(Debug, Clone)]
112pub struct Int8Range;
113impl QuantRange for Int8Range {
114    const QMIN: f32 = -128.0;
115    const QMAX: f32 = 127.0;
116    const BITS: u8 = 8;
117}
118
119/// Marker for INT4 quantization (`-8 … 7`).
120#[derive(Debug, Clone)]
121pub struct Int4Range;
122impl QuantRange for Int4Range {
123    const QMIN: f32 = -8.0;
124    const QMAX: f32 = 7.0;
125    const BITS: u8 = 4;
126}
127
128// ---------------------------------------------------------------------------
129// QuantParamsGeneric<R>
130// ---------------------------------------------------------------------------
131
132/// Affine quantization parameters (scale and zero-point), generic over bit-width.
133///
134/// - INT8: `q = clamp(round(x / scale) + zero_point, -128, 127)`
135/// - INT4: `q = clamp(round(x / scale) + zero_point, -8, 7)`
136/// - Dequantization: `x = (q - zero_point) * scale`
137#[derive(Debug, Clone)]
138pub struct QuantParamsGeneric<R: QuantRange> {
139    scale: f32,
140    zero_point: i8,
141    _marker: std::marker::PhantomData<R>,
142}
143
144/// INT8 affine quantization parameters — `clamp(-128, 127)`.
145pub type QuantParams = QuantParamsGeneric<Int8Range>;
146/// INT4 affine quantization parameters — `clamp(-8, 7)`.
147pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
148
149impl<R: QuantRange> QuantParamsGeneric<R> {
150    /// Quantization scale factor.
151    pub fn scale(&self) -> f32 {
152        self.scale
153    }
154    /// Quantization zero point.
155    pub fn zero_point(&self) -> i8 {
156        self.zero_point
157    }
158
159    /// Compute asymmetric quantization parameters from a floating-point range.
160    ///
161    /// The resulting zero-point is in `[QMIN, QMAX]` and the full integer
162    /// range `[QMIN, QMAX]` is used to represent the input.  For ONNX Runtime /
163    /// TensorRT per-channel weight quantization, prefer
164    /// [`from_range_symmetric`](Self::from_range_symmetric) instead — those
165    /// kernels assume `zero_point == 0`.
166    pub fn from_range(min: f32, max: f32) -> Self {
167        let min = min.min(0.0);
168        let max = max.max(0.0);
169
170        // Handle constant-value tensors: when min ≈ max the data is (near-)constant.
171        // Use unit scale centred on zero so that the constant dequantizes accurately.
172        let (min, max) = if (max - min).abs() < 1e-8 {
173            let abs = min.abs().max(max.abs()).max(1e-8);
174            (-abs, abs)
175        } else {
176            (min, max)
177        };
178
179        let scale = (max - min) / (R::QMAX - R::QMIN);
180        let scale = scale.max(1e-8);
181
182        let initial_zero_point = R::QMIN - min / scale;
183        // Guard against NaN — if min/scale produced NaN (degenerate input),
184        // fall back to 0 to avoid undefined behaviour on the `as i8` cast.
185        let zero_point = if initial_zero_point.is_finite() {
186            initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8
187        } else {
188            0i8
189        };
190
191        QuantParamsGeneric {
192            scale,
193            zero_point,
194            _marker: std::marker::PhantomData,
195        }
196    }
197
198    /// Compute symmetric quantization parameters: `zero_point = 0`, `scale`
199    /// chosen so that the positive half of the quantized range covers
200    /// `max(|min|, |max|)`.
201    ///
202    /// This is the convention ONNX Runtime, TensorRT, and most accelerated
203    /// INT8 matmul kernels expect for per-channel weight quantization.
204    /// For INT8, the effective representable range is `[-127*scale, 127*scale]`;
205    /// any input value mapped to `-128` after rounding is clamped back into
206    /// range.  Zero always dequantizes to exactly 0.0.
207    pub fn from_range_symmetric(min: f32, max: f32) -> Self {
208        let abs_max = min.abs().max(max.abs()).max(1e-8);
209        // R::QMAX is the positive maximum (127 for INT8, 7 for INT4).  Dividing
210        // by QMAX — not by (QMAX - QMIN) — is what makes the result symmetric
211        // and keeps the zero-point at 0.
212        let scale = (abs_max / R::QMAX).max(1e-8);
213        QuantParamsGeneric {
214            scale,
215            zero_point: 0,
216            _marker: std::marker::PhantomData,
217        }
218    }
219
220    /// Quantize a single float to the target integer type.
221    pub fn quantize(&self, value: f32) -> i8 {
222        if !value.is_finite() {
223            return self.zero_point;
224        }
225        let quantized = (value / self.scale).round() + (self.zero_point as f32);
226        quantized.clamp(R::QMIN, R::QMAX) as i8
227    }
228
229    /// Dequantize a single integer value back to float.
230    pub fn dequantize(&self, value: i8) -> f32 {
231        ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
232    }
233}
234
235// ---------------------------------------------------------------------------
236// QuantizedTensorGeneric<R>
237// ---------------------------------------------------------------------------
238
239/// Generic quantized tensor, parameterized by bit-width marker.
240///
241/// For INT4 tensors, call [`QuantizedTensorGeneric::pack`] to compress two
242/// values per byte for 2× storage savings.
243#[derive(Debug, Clone)]
244pub struct QuantizedTensorGeneric<R: QuantRange> {
245    pub(crate) data: Vec<i8>,
246    /// Bit-packed storage — always `None` for INT8, set by `.pack()` for INT4.
247    pub(crate) packed_data: Option<Vec<u8>>,
248    pub(crate) shape: Vec<usize>,
249    pub(crate) params: QuantParamsGeneric<R>,
250    pub(crate) per_channel: bool,
251    pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
252}
253
254/// An INT8 quantized tensor with optional per-channel parameters.
255pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
256
257/// An INT4 quantized tensor with optional per-channel parameters and bit packing.
258///
259/// Values are stored in the range `[-8, 7]`. Call [`pack`](QuantizedTensorInt4::pack) to
260/// compress two values into one byte for 2× storage savings.
261pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
262
263// ---------------------------------------------------------------------------
264// Shared impl for all bit-widths
265// ---------------------------------------------------------------------------
266
267impl<R: QuantRange> QuantizedTensorGeneric<R> {
268    /// Tensor shape.
269    pub fn shape(&self) -> &[usize] {
270        &self.shape
271    }
272    /// Per-tensor quantization parameters (channel-0 if per-channel).
273    pub fn params(&self) -> &QuantParamsGeneric<R> {
274        &self.params
275    }
276    /// Whether per-channel quantization was used.
277    pub fn is_per_channel(&self) -> bool {
278        self.per_channel
279    }
280
281    /// Quantize FP32 data, computing the range from the data (asymmetric).
282    ///
283    /// # Errors
284    ///
285    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
286    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
287        Self::from_f32_with_mode(data, shape, false)
288    }
289
290    /// Quantize FP32 data using symmetric quantization (`zero_point == 0`).
291    ///
292    /// Required by most ONNX Runtime / TensorRT INT8 kernels for per-channel
293    /// weight quantization.  See [`QuantParamsGeneric::from_range_symmetric`].
294    pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Result<Self> {
295        Self::from_f32_with_mode(data, shape, true)
296    }
297
298    fn from_f32_with_mode(data: &[f32], shape: Vec<usize>, symmetric: bool) -> Result<Self> {
299        if data.is_empty() {
300            return Err(QuantizeError::InvalidTensor {
301                reason: "Cannot quantize empty tensor".into(),
302            });
303        }
304
305        let expected_len: usize = shape.iter().product();
306        if expected_len != data.len() {
307            return Err(QuantizeError::InvalidTensor {
308                reason: format!(
309                    "Shape {:?} expects {} elements but got {}",
310                    shape,
311                    expected_len,
312                    data.len()
313                ),
314            });
315        }
316
317        let min = data
318            .iter()
319            .copied()
320            .filter(|v| v.is_finite())
321            .fold(f32::INFINITY, f32::min);
322        let max = data
323            .iter()
324            .copied()
325            .filter(|v| v.is_finite())
326            .fold(f32::NEG_INFINITY, f32::max);
327
328        if !min.is_finite() || !max.is_finite() {
329            return Err(QuantizeError::InvalidTensor {
330                reason: "Tensor contains only non-finite values (NaN/Inf)".into(),
331            });
332        }
333
334        let params = if symmetric {
335            QuantParamsGeneric::<R>::from_range_symmetric(min, max)
336        } else {
337            QuantParamsGeneric::<R>::from_range(min, max)
338        };
339
340        let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
341
342        Ok(QuantizedTensorGeneric {
343            data: quantized_data,
344            packed_data: None,
345            shape,
346            params,
347            per_channel: false,
348            channel_params: None,
349        })
350    }
351
352    /// Quantize FP32 data using an explicit range (for calibration; asymmetric).
353    ///
354    /// # Errors
355    ///
356    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
357    pub fn from_f32_with_range(
358        data: &[f32],
359        shape: Vec<usize>,
360        min: f32,
361        max: f32,
362    ) -> Result<Self> {
363        Self::from_f32_with_range_and_mode(data, shape, min, max, false)
364    }
365
366    /// Same as [`from_f32_with_range`](Self::from_f32_with_range) but produces
367    /// symmetric parameters (`zero_point == 0`).
368    pub fn from_f32_with_range_symmetric(
369        data: &[f32],
370        shape: Vec<usize>,
371        min: f32,
372        max: f32,
373    ) -> Result<Self> {
374        Self::from_f32_with_range_and_mode(data, shape, min, max, true)
375    }
376
377    fn from_f32_with_range_and_mode(
378        data: &[f32],
379        shape: Vec<usize>,
380        min: f32,
381        max: f32,
382        symmetric: bool,
383    ) -> Result<Self> {
384        if data.is_empty() {
385            return Err(QuantizeError::InvalidTensor {
386                reason: "Cannot quantize empty tensor".into(),
387            });
388        }
389
390        let expected_len: usize = shape.iter().product();
391        if expected_len != data.len() {
392            return Err(QuantizeError::InvalidTensor {
393                reason: format!(
394                    "Shape {:?} expects {} elements but got {}",
395                    shape,
396                    expected_len,
397                    data.len()
398                ),
399            });
400        }
401
402        let params = if symmetric {
403            QuantParamsGeneric::<R>::from_range_symmetric(min, max)
404        } else {
405            QuantParamsGeneric::<R>::from_range(min, max)
406        };
407
408        let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
409
410        Ok(QuantizedTensorGeneric {
411            data: quantized_data,
412            packed_data: None,
413            shape,
414            params,
415            per_channel: false,
416            channel_params: None,
417        })
418    }
419
420    /// Quantize FP32 data with per-channel ranges (axis 0 only, asymmetric).
421    ///
422    /// # Errors
423    ///
424    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty, shape
425    /// mismatches, or the tensor is scalar.
426    pub fn from_f32_per_channel(data: &[f32], shape: Vec<usize>) -> Result<Self> {
427        Self::from_f32_per_channel_with_mode(data, shape, false)
428    }
429
430    /// Same as [`from_f32_per_channel`](Self::from_f32_per_channel) but emits
431    /// symmetric parameters (`zero_point == 0` for every channel).  Required
432    /// by most INT8 per-channel matmul kernels.
433    pub fn from_f32_per_channel_symmetric(data: &[f32], shape: Vec<usize>) -> Result<Self> {
434        Self::from_f32_per_channel_with_mode(data, shape, true)
435    }
436
437    fn from_f32_per_channel_with_mode(
438        data: &[f32],
439        shape: Vec<usize>,
440        symmetric: bool,
441    ) -> Result<Self> {
442        if data.is_empty() {
443            return Err(QuantizeError::InvalidTensor {
444                reason: "Cannot quantize empty tensor".into(),
445            });
446        }
447
448        if shape.is_empty() {
449            return Err(QuantizeError::InvalidTensor {
450                reason: "Cannot do per-channel quantization on scalar".into(),
451            });
452        }
453
454        let expected_len: usize = shape.iter().product();
455        if expected_len != data.len() {
456            return Err(QuantizeError::InvalidTensor {
457                reason: format!(
458                    "Shape {:?} expects {} elements but got {}",
459                    shape,
460                    expected_len,
461                    data.len()
462                ),
463            });
464        }
465
466        let num_channels = shape[0];
467        if num_channels == 0 {
468            return Err(QuantizeError::InvalidTensor {
469                reason: "Number of channels is 0".into(),
470            });
471        }
472        if !data.len().is_multiple_of(num_channels) {
473            return Err(QuantizeError::InvalidTensor {
474                reason: format!(
475                    "Data length {} not evenly divisible by {} channels",
476                    data.len(),
477                    num_channels
478                ),
479            });
480        }
481        let elements_per_channel = data.len() / num_channels;
482
483        let mut channel_params = Vec::with_capacity(num_channels);
484        let mut quantized_data = Vec::with_capacity(data.len());
485
486        // Walk the data channel-by-channel with a borrowed slice — no Vec alloc
487        // per channel.  For typical Conv weights this avoids hundreds of small
488        // allocations that used to dominate the per-channel hot path.
489        for (channel_idx, channel_slice) in data.chunks_exact(elements_per_channel).enumerate() {
490            let mut min = f32::INFINITY;
491            let mut max = f32::NEG_INFINITY;
492            for &v in channel_slice {
493                if v.is_finite() {
494                    if v < min {
495                        min = v;
496                    }
497                    if v > max {
498                        max = v;
499                    }
500                }
501            }
502
503            if !min.is_finite() || !max.is_finite() {
504                return Err(QuantizeError::InvalidTensor {
505                    reason: format!(
506                        "Channel {} contains only non-finite values (NaN/Inf)",
507                        channel_idx
508                    ),
509                });
510            }
511
512            let params = if symmetric {
513                QuantParamsGeneric::<R>::from_range_symmetric(min, max)
514            } else {
515                QuantParamsGeneric::<R>::from_range(min, max)
516            };
517
518            quantized_data.extend(channel_slice.iter().map(|&v| params.quantize(v)));
519            channel_params.push(params);
520        }
521
522        // Use first channel params as "representative" for backward compatibility
523        let params = channel_params[0].clone();
524
525        Ok(QuantizedTensorGeneric {
526            data: quantized_data,
527            packed_data: None,
528            shape,
529            params,
530            per_channel: true,
531            channel_params: Some(channel_params),
532        })
533    }
534
535    /// Dequantize all values back to FP32.
536    pub fn to_f32(&self) -> Vec<f32> {
537        // Borrow data directly when unpacked; allocate only for the packed INT4 path.
538        let data_owned;
539        let data: &[i8] = if let Some(ref packed) = self.packed_data {
540            data_owned = unpack_int4(packed, self.data.len());
541            &data_owned
542        } else {
543            &self.data
544        };
545
546        if self.per_channel {
547            if let Some(ref channel_params) = self.channel_params {
548                if channel_params.is_empty() {
549                    return data.iter().map(|&v| self.params.dequantize(v)).collect();
550                }
551                // Chunk the data by elements_per_channel and zip with channel params.
552                // This replaces a per-element division (`i / elements_per_channel`)
553                // with outer-loop iteration over contiguous channel slices.
554                let elements_per_channel = data.len() / channel_params.len();
555                let mut out = Vec::with_capacity(data.len());
556                if elements_per_channel == 0 {
557                    // Degenerate: fewer values than channels.  Fall back to the
558                    // representative params so we don't panic on the zip below.
559                    return data.iter().map(|&v| self.params.dequantize(v)).collect();
560                }
561                for (chunk, params) in data.chunks(elements_per_channel).zip(channel_params.iter())
562                {
563                    out.extend(chunk.iter().map(|&v| params.dequantize(v)));
564                }
565                out
566            } else {
567                data.iter().map(|&v| self.params.dequantize(v)).collect()
568            }
569        } else {
570            data.iter().map(|&v| self.params.dequantize(v)).collect()
571        }
572    }
573
574    /// Size of the quantized data in bytes (packed if available, unpacked otherwise).
575    pub fn size_bytes(&self) -> usize {
576        if let Some(ref packed) = self.packed_data {
577            packed.len()
578        } else {
579            self.data.len() * std::mem::size_of::<i8>()
580        }
581    }
582
583    /// Mean squared error between the original data and the dequantized values.
584    pub fn quantization_error(&self, original: &[f32]) -> f32 {
585        if original.is_empty() {
586            return 0.0;
587        }
588
589        let dequantized = self.to_f32();
590
591        let sum: f32 = original
592            .iter()
593            .zip(dequantized.iter())
594            .map(|(a, b)| (a - b).powi(2))
595            .sum();
596
597        sum / original.len() as f32
598    }
599}
600
601// ---------------------------------------------------------------------------
602// INT4-specific methods
603// ---------------------------------------------------------------------------
604
605impl QuantizedTensorGeneric<Int4Range> {
606    /// Pack two INT4 values per byte for 2× compression.
607    pub fn pack(&mut self) {
608        self.packed_data = Some(pack_int4(&self.data));
609    }
610
611    /// Return unpacked i8 data, decompressing from packed storage if needed.
612    pub fn ensure_unpacked(&self) -> Vec<i8> {
613        if let Some(ref packed) = self.packed_data {
614            unpack_int4(packed, self.data.len())
615        } else {
616            self.data.clone()
617        }
618    }
619
620    /// Whether the data is currently bit-packed.
621    pub fn is_packed(&self) -> bool {
622        self.packed_data.is_some()
623    }
624
625    /// Size that the packed representation would occupy (or already occupies).
626    pub fn packed_size_bytes(&self) -> usize {
627        if let Some(ref packed) = self.packed_data {
628            packed.len()
629        } else {
630            self.data.len().div_ceil(2)
631        }
632    }
633
634    /// Size of the unpacked representation in bytes.
635    pub fn unpacked_size_bytes(&self) -> usize {
636        self.data.len() * std::mem::size_of::<i8>()
637    }
638}
639
640// ---------------------------------------------------------------------------
641// INT4 bit-packing helpers
642// ---------------------------------------------------------------------------
643
644fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
645    debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
646    debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
647
648    // Convert to 4-bit representation
649    let nibble1 = (val1 & 0x0F) as u8;
650    let nibble2 = (val2 & 0x0F) as u8;
651
652    // Pack: high 4 bits = val1, low 4 bits = val2
653    (nibble1 << 4) | nibble2
654}
655
656fn unpack_int4_pair(byte: u8) -> (i8, i8) {
657    let nibble1 = (byte >> 4) & 0x0F;
658    let nibble2 = byte & 0x0F;
659
660    // Convert from 4-bit to signed i8
661    let val1 = if nibble1 >= 8 {
662        (nibble1 as i8) | !0x0F
663    } else {
664        nibble1 as i8
665    };
666
667    let val2 = if nibble2 >= 8 {
668        (nibble2 as i8) | !0x0F
669    } else {
670        nibble2 as i8
671    };
672
673    (val1, val2)
674}
675
676/// Pack a slice of INT4 values (two per byte, high nibble first).
677pub fn pack_int4(values: &[i8]) -> Vec<u8> {
678    let mut packed = Vec::with_capacity(values.len().div_ceil(2));
679
680    for chunk in values.chunks(2) {
681        let val1 = chunk[0];
682        let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
683
684        packed.push(pack_int4_pair(val1, val2));
685    }
686
687    packed
688}
689
690/// Unpack INT4 values from packed bytes, returning exactly `num_values` i8s.
691pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
692    let mut values = Vec::with_capacity(num_values);
693
694    for &byte in packed {
695        let (val1, val2) = unpack_int4_pair(byte);
696        values.push(val1);
697        if values.len() < num_values {
698            values.push(val2);
699        }
700    }
701
702    // Truncate to exact size (removes padding)
703    values.truncate(num_values);
704    values
705}
706
707// ---------------------------------------------------------------------------
708// QuantizedTensorType
709// ---------------------------------------------------------------------------
710
711/// Type-erased wrapper over [`QuantizedTensor`] (INT8) and [`QuantizedTensorInt4`] (INT4).
712#[derive(Debug, Clone)]
713pub enum QuantizedTensorType {
714    Int8(QuantizedTensor),
715    Int4(QuantizedTensorInt4),
716}
717
718impl QuantizedTensorType {
719    /// Dequantize all values back to FP32.
720    pub fn to_f32(&self) -> Vec<f32> {
721        match self {
722            QuantizedTensorType::Int8(t) => t.to_f32(),
723            QuantizedTensorType::Int4(t) => t.to_f32(),
724        }
725    }
726
727    /// Size of the quantized data in bytes.
728    pub fn size_bytes(&self) -> usize {
729        match self {
730            QuantizedTensorType::Int8(t) => t.size_bytes(),
731            QuantizedTensorType::Int4(t) => t.size_bytes(),
732        }
733    }
734
735    #[must_use]
736    pub fn quantization_error(&self, original: &[f32]) -> f32 {
737        match self {
738            QuantizedTensorType::Int8(t) => t.quantization_error(original),
739            QuantizedTensorType::Int4(t) => t.quantization_error(original),
740        }
741    }
742
743    #[must_use]
744    pub fn data(&self) -> Vec<i8> {
745        match self {
746            QuantizedTensorType::Int8(t) => t.data.clone(),
747            QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
748        }
749    }
750
751    /// Per-tensor scale and zero-point.
752    pub fn get_scale_zero_point(&self) -> (f32, i8) {
753        match self {
754            QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
755            QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
756        }
757    }
758
759    /// Return all per-channel scales and zero-points.
760    ///
761    /// For per-tensor quantization, returns single-element vectors.
762    /// For per-channel, returns one entry per channel.
763    pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
764        match self {
765            QuantizedTensorType::Int8(t) => {
766                if let Some(ref cp) = t.channel_params {
767                    (
768                        cp.iter().map(|p| p.scale).collect(),
769                        cp.iter().map(|p| p.zero_point).collect(),
770                    )
771                } else {
772                    (vec![t.params.scale], vec![t.params.zero_point])
773                }
774            }
775            QuantizedTensorType::Int4(t) => {
776                if let Some(ref cp) = t.channel_params {
777                    (
778                        cp.iter().map(|p| p.scale).collect(),
779                        cp.iter().map(|p| p.zero_point).collect(),
780                    )
781                } else {
782                    (vec![t.params.scale], vec![t.params.zero_point])
783                }
784            }
785        }
786    }
787
788    /// Whether per-channel quantization was used.
789    pub fn is_per_channel(&self) -> bool {
790        match self {
791            QuantizedTensorType::Int8(t) => t.per_channel,
792            QuantizedTensorType::Int4(t) => t.per_channel,
793        }
794    }
795
796    #[must_use]
797    pub fn bits(&self) -> u8 {
798        match self {
799            QuantizedTensorType::Int8(_) => 8,
800            QuantizedTensorType::Int4(_) => 4,
801        }
802    }
803
804    /// `true` if this is an INT8 tensor.
805    pub fn is_int8(&self) -> bool {
806        matches!(self, QuantizedTensorType::Int8(_))
807    }
808
809    /// `true` if this is an INT4 tensor.
810    pub fn is_int4(&self) -> bool {
811        matches!(self, QuantizedTensorType::Int4(_))
812    }
813
814    /// Borrow quantized data without cloning.
815    ///
816    /// Returns `None` for packed INT4 tensors (must use `data()` which unpacks).
817    pub fn data_ref(&self) -> Option<&[i8]> {
818        match self {
819            QuantizedTensorType::Int8(t) => Some(&t.data),
820            QuantizedTensorType::Int4(t) => {
821                if t.packed_data.is_some() {
822                    None // packed: caller must use data() to unpack
823                } else {
824                    Some(&t.data)
825                }
826            }
827        }
828    }
829}
830
831// ---------------------------------------------------------------------------
832// Quantizer
833// ---------------------------------------------------------------------------
834
835/// High-level quantizer that combines configuration with optional calibration.
836pub struct Quantizer {
837    config: QuantConfig,
838    calibration_stats:
839        Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
840}
841
842impl std::fmt::Debug for Quantizer {
843    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844        let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
845        f.debug_struct("Quantizer")
846            .field("config", &self.config)
847            .field("calibration_stats_count", &stats_count)
848            .finish()
849    }
850}
851
852impl Quantizer {
853    /// Create a quantizer with the given configuration (no calibration).
854    pub fn new(config: QuantConfig) -> Self {
855        Self {
856            config,
857            calibration_stats: None,
858        }
859    }
860
861    /// Create a quantizer with configuration and pre-collected activation statistics.
862    pub fn with_calibration(
863        config: QuantConfig,
864        stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
865    ) -> Self {
866        Self {
867            config,
868            calibration_stats: Some(stats),
869        }
870    }
871
872    /// Quantize a tensor with optional calibration.
873    pub fn quantize_tensor_with_name(
874        &self,
875        name: &str,
876        data: &[f32],
877        shape: Vec<usize>,
878    ) -> Result<QuantizedTensorType> {
879        let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
880            if let Some(stats) = stats_map.get(name) {
881                if let Some(method) = self.config.calibration_method {
882                    // Compute the range directly from the histogram — deterministic,
883                    // no sample regeneration, no RNG.
884                    use crate::calibration::stats::calculate_optimal_range_from_stats;
885                    calculate_optimal_range_from_stats(stats, method)
886                } else {
887                    (stats.min(), stats.max())
888                }
889            } else {
890                finite_min_max(data, name)?
891            }
892        } else {
893            finite_min_max(data, name)?
894        };
895
896        self.quantize_with_range(data, shape, min, max)
897    }
898
899    /// Quantize a tensor using the configured bit width and per-channel setting.
900    ///
901    /// # Errors
902    ///
903    /// Returns [`QuantizeError::InvalidTensor`] or [`QuantizeError::UnsupportedConfig`].
904    pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
905        self.build_tensor_with_optional_range(data, shape, None)
906    }
907
908    /// Quantize with specific range (for calibration).
909    ///
910    /// When `per_channel` is enabled, the provided `min`/`max` are ignored
911    /// because per-channel quantization computes separate ranges from the
912    /// weight data for each channel.  The calibration range (derived from
913    /// activation statistics) applies to per-tensor mode only.
914    fn quantize_with_range(
915        &self,
916        data: &[f32],
917        shape: Vec<usize>,
918        min: f32,
919        max: f32,
920    ) -> Result<QuantizedTensorType> {
921        self.build_tensor_with_optional_range(data, shape, Some((min, max)))
922    }
923
924    /// Shared core: build a [`QuantizedTensorType`] for any bit-width and range mode.
925    fn build_tensor_with_optional_range(
926        &self,
927        data: &[f32],
928        shape: Vec<usize>,
929        range: Option<(f32, f32)>,
930    ) -> Result<QuantizedTensorType> {
931        let pc = self.config.per_channel && shape.len() >= 2;
932        let sym = self.config.symmetric;
933        match self.config.bits {
934            8 => {
935                let t = match (pc, range, sym) {
936                    (true, _, true) => {
937                        QuantizedTensor::from_f32_per_channel_symmetric(data, shape)?
938                    }
939                    (true, _, false) => QuantizedTensor::from_f32_per_channel(data, shape)?,
940                    (false, Some((min, max)), true) => {
941                        QuantizedTensor::from_f32_with_range_symmetric(data, shape, min, max)?
942                    }
943                    (false, Some((min, max)), false) => {
944                        QuantizedTensor::from_f32_with_range(data, shape, min, max)?
945                    }
946                    (false, None, true) => QuantizedTensor::from_f32_symmetric(data, shape)?,
947                    (false, None, false) => QuantizedTensor::from_f32(data, shape)?,
948                };
949                Ok(QuantizedTensorType::Int8(t))
950            }
951            4 => {
952                let mut t = match (pc, range, sym) {
953                    (true, _, true) => {
954                        QuantizedTensorInt4::from_f32_per_channel_symmetric(data, shape)?
955                    }
956                    (true, _, false) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
957                    (false, Some((min, max)), true) => {
958                        QuantizedTensorInt4::from_f32_with_range_symmetric(data, shape, min, max)?
959                    }
960                    (false, Some((min, max)), false) => {
961                        QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?
962                    }
963                    (false, None, true) => QuantizedTensorInt4::from_f32_symmetric(data, shape)?,
964                    (false, None, false) => QuantizedTensorInt4::from_f32(data, shape)?,
965                };
966                t.pack();
967                Ok(QuantizedTensorType::Int4(t))
968            }
969            b => Err(QuantizeError::UnsupportedConfig {
970                reason: format!("bits must be 4 or 8, got {b}"),
971            }),
972        }
973    }
974
975    /// Quantize every weight in `model` that passes
976    /// [`QuantConfig::should_quantize`].  Honours per-layer bit-width overrides.
977    ///
978    /// When this quantizer was built with calibration, activation-based
979    /// range optimization is used for the default bit-width; layers whose
980    /// bit-width is overridden fall back to weight-only quantization
981    /// (the calibration stats are keyed by the default configuration).
982    ///
983    /// Skipped weights do not appear in the returned vector.
984    pub fn quantize_model(
985        &self,
986        model: &crate::onnx_utils::OnnxModel,
987    ) -> Result<Vec<QuantizedWeightOutput>> {
988        use rayon::prelude::*;
989
990        let weights = model.extract_weights();
991        let to_quantize: Vec<_> = weights
992            .iter()
993            .filter(|w| self.config.should_quantize(&w.name, w.num_elements()))
994            .collect();
995
996        to_quantize
997            .par_iter()
998            .map(|w| self.quantize_weight_to_output(w))
999            .collect()
1000    }
1001
1002    fn quantize_weight_to_output(
1003        &self,
1004        weight: &crate::onnx_utils::WeightTensor,
1005    ) -> Result<QuantizedWeightOutput> {
1006        let layer_bits = self.config.bits_for_layer(&weight.name);
1007
1008        // For the default bit-width, use the shared (possibly calibrated)
1009        // quantizer.  For per-layer bit-width overrides, build a layer-local
1010        // quantizer: calibration stats are keyed by the default configuration
1011        // and re-applying them at a different bit-width is ill-defined.
1012        let quantized = if layer_bits == self.config.bits {
1013            self.quantize_tensor_with_name(&weight.name, &weight.data, weight.shape.clone())?
1014        } else {
1015            let layer_config = QuantConfig {
1016                bits: layer_bits,
1017                per_channel: self.config.per_channel,
1018                symmetric: self.config.symmetric,
1019                ..Default::default()
1020            };
1021            Quantizer::new(layer_config).quantize_tensor(&weight.data, weight.shape.clone())?
1022        };
1023
1024        let mse = quantized.quantization_error(&weight.data);
1025        let (scales, zero_points) = quantized.get_all_scales_zero_points();
1026        let is_per_channel = quantized.is_per_channel();
1027        let bits_used = quantized.bits();
1028        let quantized_size_bytes = quantized.size_bytes();
1029
1030        Ok(QuantizedWeightOutput {
1031            qdq: crate::onnx_utils::graph_builder::QdqWeightInput {
1032                original_name: weight.name.clone(),
1033                quantized_values: quantized.data(),
1034                scales,
1035                zero_points,
1036                bits: bits_used,
1037                axis: if is_per_channel { Some(0) } else { None },
1038            },
1039            quantized_size_bytes,
1040            mse,
1041        })
1042    }
1043}
1044
1045/// Per-weight result of [`Quantizer::quantize_model`].
1046///
1047/// Bundles the QDQ block ready for
1048/// [`OnnxModel::save_quantized_with_options`](crate::onnx_utils::OnnxModel::save_quantized_with_options)
1049/// with the two telemetry values callers typically want to print: on-disk
1050/// size and round-trip MSE.
1051#[derive(Debug, Clone)]
1052pub struct QuantizedWeightOutput {
1053    /// QDQ block for `save_quantized_with_options`.
1054    pub qdq: crate::onnx_utils::graph_builder::QdqWeightInput,
1055    /// Size of the quantized payload in bytes (INT8 = 1 byte/elem;
1056    /// packed INT4 = ceil(elem/2)).
1057    pub quantized_size_bytes: usize,
1058    /// MSE between the original FP32 values and the dequantized output.
1059    pub mse: f32,
1060}
1061
1062// ---------------------------------------------------------------------------
1063// Calibration helper
1064// ---------------------------------------------------------------------------
1065
1066/// Compute the finite min/max of `data`, returning an error if all values are NaN/Inf.
1067fn finite_min_max(data: &[f32], name: &str) -> Result<(f32, f32)> {
1068    let min = data
1069        .iter()
1070        .copied()
1071        .filter(|v| v.is_finite())
1072        .fold(f32::INFINITY, f32::min);
1073    let max = data
1074        .iter()
1075        .copied()
1076        .filter(|v| v.is_finite())
1077        .fold(f32::NEG_INFINITY, f32::max);
1078    if !min.is_finite() || !max.is_finite() {
1079        return Err(QuantizeError::InvalidTensor {
1080            reason: format!(
1081                "Tensor '{}' contains only non-finite values (NaN/Inf)",
1082                name
1083            ),
1084        });
1085    }
1086    Ok((min, max))
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use super::*;
1092
1093    // -----------------------------------------------------------------------
1094    // QuantConfig per-layer selection
1095    // -----------------------------------------------------------------------
1096
1097    #[test]
1098    fn test_should_quantize_no_restrictions() {
1099        let config = QuantConfig::default();
1100        assert!(config.should_quantize("any.layer", 1));
1101        assert!(config.should_quantize("any.layer", 1_000_000));
1102    }
1103
1104    #[test]
1105    fn test_should_quantize_excluded_layer() {
1106        let config = QuantConfig {
1107            excluded_layers: vec!["head.weight".to_string()],
1108            ..Default::default()
1109        };
1110        assert!(!config.should_quantize("head.weight", 1024));
1111        assert!(config.should_quantize("body.weight", 1024));
1112    }
1113
1114    #[test]
1115    fn test_should_quantize_min_elements() {
1116        let config = QuantConfig {
1117            min_elements: 512,
1118            ..Default::default()
1119        };
1120        assert!(!config.should_quantize("small.bias", 4));
1121        assert!(!config.should_quantize("small.bias", 511));
1122        assert!(config.should_quantize("large.weight", 512));
1123        assert!(config.should_quantize("large.weight", 1024));
1124    }
1125
1126    #[test]
1127    fn test_should_quantize_excluded_takes_priority_over_min_elements() {
1128        let config = QuantConfig {
1129            excluded_layers: vec!["head.weight".to_string()],
1130            min_elements: 1,
1131            ..Default::default()
1132        };
1133        // excluded → skipped regardless of size
1134        assert!(!config.should_quantize("head.weight", 1_000_000));
1135    }
1136
1137    #[test]
1138    fn test_bits_for_layer_default() {
1139        let config = QuantConfig {
1140            bits: 8,
1141            ..Default::default()
1142        };
1143        assert_eq!(config.bits_for_layer("any.weight"), 8);
1144    }
1145
1146    #[test]
1147    fn test_bits_for_layer_override() {
1148        let mut layer_bits = std::collections::HashMap::new();
1149        layer_bits.insert("head.weight".to_string(), 4u8);
1150        let config = QuantConfig {
1151            bits: 8,
1152            layer_bits,
1153            ..Default::default()
1154        };
1155        assert_eq!(config.bits_for_layer("head.weight"), 4);
1156        assert_eq!(config.bits_for_layer("body.weight"), 8);
1157    }
1158
1159    // -----------------------------------------------------------------------
1160    // Existing tests below
1161    // -----------------------------------------------------------------------
1162
1163    #[test]
1164    fn test_quant_params() {
1165        let params = QuantParams::from_range(-1.0, 1.0);
1166
1167        assert_eq!(params.quantize(0.0), params.zero_point);
1168
1169        let original = 0.5;
1170        let quantized = params.quantize(original);
1171        let dequantized = params.dequantize(quantized);
1172
1173        assert!((original - dequantized).abs() < 0.01);
1174    }
1175
1176    #[test]
1177    fn test_quantize_tensor() {
1178        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1179        let shape = vec![5];
1180
1181        let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
1182
1183        assert_eq!(quantized.data.len(), 5);
1184        assert_eq!(quantized.size_bytes(), 5);
1185    }
1186
1187    #[test]
1188    fn test_per_channel_quantization() {
1189        let mut data = vec![];
1190        for _ in 0..100 {
1191            data.push(0.5); // Channel 0
1192        }
1193        for _ in 0..100 {
1194            data.push(5.0); // Channel 1
1195        }
1196
1197        let shape = vec![2, 100];
1198
1199        let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1200
1201        assert!(quantized.per_channel);
1202        assert!(quantized.channel_params.is_some());
1203        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1204
1205        let dequantized = quantized.to_f32();
1206        let error: f32 = data
1207            .iter()
1208            .zip(dequantized.iter())
1209            .map(|(a, b)| (a - b).powi(2))
1210            .sum::<f32>()
1211            / data.len() as f32;
1212
1213        println!("Per-channel MSE: {}", error);
1214        assert!(error < 0.1);
1215    }
1216
1217    #[test]
1218    fn test_per_channel_vs_per_tensor() {
1219        let mut data = vec![];
1220
1221        for _ in 0..1000 {
1222            data.push(0.01);
1223        }
1224
1225        for _ in 0..1000 {
1226            data.push(10.0);
1227        }
1228
1229        let shape = vec![2, 1000];
1230
1231        // Per-tensor quantization
1232        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1233        let per_tensor_error = per_tensor.quantization_error(&data);
1234
1235        // Per-channel quantization
1236        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1237        let per_channel_error = per_channel.quantization_error(&data);
1238
1239        println!("Per-tensor error:  {:.8}", per_tensor_error);
1240        println!("Per-channel error: {:.8}", per_channel_error);
1241
1242        // Per-channel
1243        assert!(per_channel_error < per_tensor_error);
1244        assert!(per_channel_error < per_tensor_error * 0.5);
1245    }
1246
1247    #[test]
1248    fn test_per_channel_benefit() {
1249        let mut data = vec![];
1250
1251        for i in 0..1000 {
1252            data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1253        }
1254
1255        for i in 0..1000 {
1256            data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1257        }
1258
1259        let shape = vec![2, 1000];
1260
1261        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1262        let per_tensor_error = per_tensor.quantization_error(&data);
1263
1264        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1265        let per_channel_error = per_channel.quantization_error(&data);
1266
1267        println!("Per-tensor MSE:  {:.8}", per_tensor_error);
1268        println!("Per-channel MSE: {:.8}", per_channel_error);
1269
1270        assert!(
1271            per_channel_error < per_tensor_error,
1272            "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1273            per_channel_error,
1274            per_tensor_error
1275        );
1276    }
1277
1278    #[test]
1279    fn test_int4_quant_params() {
1280        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1281
1282        assert!(params.quantize(-10.0) >= -8);
1283        assert!(params.quantize(-10.0) <= 7);
1284        assert!(params.quantize(10.0) >= -8);
1285        assert!(params.quantize(10.0) <= 7);
1286
1287        let zero_quant = params.quantize(0.0);
1288        assert!(zero_quant >= -8 && zero_quant <= 7);
1289
1290        for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1291            let quantized = params.quantize(original);
1292            let dequantized = params.dequantize(quantized);
1293
1294            println!(
1295                "Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1296                original,
1297                quantized,
1298                dequantized,
1299                (original - dequantized).abs()
1300            );
1301
1302            assert!((original - dequantized).abs() < params.scale * 2.0);
1303        }
1304    }
1305
1306    #[test]
1307    fn test_int4_extreme_values() {
1308        // Test with extreme value ranges
1309        let params = QuantParamsInt4::from_range(-100.0, 100.0);
1310
1311        let q_neg = params.quantize(-100.0);
1312        let q_pos = params.quantize(100.0);
1313
1314        assert_eq!(q_neg, -8);
1315        assert_eq!(q_pos, 7);
1316    }
1317
1318    #[test]
1319    fn test_int4_vs_int8_error() {
1320        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1321
1322        let params_int8 = QuantParams::from_range(-1.0, 1.0);
1323        let error_int8: f32 = data
1324            .iter()
1325            .map(|&v| {
1326                let q = params_int8.quantize(v);
1327                let dq = params_int8.dequantize(q);
1328                (v - dq).powi(2)
1329            })
1330            .sum::<f32>()
1331            / data.len() as f32;
1332
1333        let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1334        let error_int4: f32 = data
1335            .iter()
1336            .map(|&v| {
1337                let q = params_int4.quantize(v);
1338                let dq = params_int4.dequantize(q);
1339                (v - dq).powi(2)
1340            })
1341            .sum::<f32>()
1342            / data.len() as f32;
1343
1344        println!("INT8 MSE: {:.8}", error_int8);
1345        println!("INT4 MSE: {:.8}", error_int4);
1346
1347        assert!(error_int4 > error_int8);
1348
1349        assert!(
1350            error_int4 < error_int8 * 500.0,
1351            "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1352            error_int4,
1353            error_int8
1354        );
1355
1356        assert!(error_int4.is_finite());
1357        assert!(error_int4 < 0.01);
1358    }
1359
1360    #[test]
1361    fn test_int4_range() {
1362        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1363
1364        assert!(params.quantize(-10.0) == -8);
1365        assert!(params.quantize(10.0) == 7);
1366
1367        // Test quantization within range
1368        for i in -8..=7 {
1369            let value = i as f32 * params.scale;
1370            let quantized = params.quantize(value);
1371            assert!(quantized >= -8 && quantized <= 7);
1372        }
1373    }
1374
1375    #[test]
1376    fn test_int4_optimal_precision() {
1377        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1378
1379        let mut unique_values = std::collections::HashSet::new();
1380
1381        // Sample across the range
1382        for i in 0..1000 {
1383            let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1384            unique_values.insert(params.quantize(value));
1385        }
1386
1387        println!("Unique quantized values: {}", unique_values.len());
1388        assert!(unique_values.len() >= 14);
1389    }
1390
1391    #[test]
1392    fn test_int4_tensor_quantization() {
1393        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1394        let shape = vec![5];
1395
1396        let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1397
1398        assert_eq!(quantized.data.len(), 5);
1399        assert_eq!(quantized.size_bytes(), 5);
1400        assert_eq!(quantized.packed_size_bytes(), 3);
1401
1402        for &val in &quantized.data {
1403            assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1404        }
1405    }
1406
1407    #[test]
1408    fn test_int4_round_trip() {
1409        let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1410        let shape = vec![5];
1411
1412        let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1413        let dequantized = quantized.to_f32();
1414
1415        println!("Original:    {:?}", original);
1416        println!("Quantized:   {:?}", quantized.data);
1417        println!("Dequantized: {:?}", dequantized);
1418
1419        for (orig, deq) in original.iter().zip(dequantized.iter()) {
1420            let error = (orig - deq).abs();
1421            println!("  {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1422            assert!(error < 0.15, "Error too large: {}", error);
1423        }
1424    }
1425
1426    #[test]
1427    fn test_int4_per_channel() {
1428        let mut data = vec![];
1429
1430        // Channel 0: small range [-0.1, 0.1]
1431        for i in 0..100 {
1432            data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1433        }
1434
1435        // Channel 1: large range [-10.0, 10.0]
1436        for i in 0..100 {
1437            data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1438        }
1439
1440        let shape = vec![2, 100];
1441
1442        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1443
1444        assert!(quantized.per_channel);
1445        assert!(quantized.channel_params.is_some());
1446        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1447
1448        let error = quantized.quantization_error(&data);
1449        println!("INT4 per-channel MSE: {:.8}", error);
1450
1451        assert!(error < 1.0, "Error too high: {}", error);
1452    }
1453
1454    #[test]
1455    fn test_int4_vs_int8_compression() {
1456        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1457        let shape = vec![1000];
1458
1459        let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1460        let int8_size = int8_quantized.size_bytes();
1461        let int8_error = int8_quantized.quantization_error(&data);
1462
1463        let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1464        let int4_size = int4_quantized.size_bytes();
1465        let int4_packed_size = int4_quantized.packed_size_bytes();
1466        let int4_error = int4_quantized.quantization_error(&data);
1467
1468        println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1469        println!(
1470            "INT4 (unpacked): {} bytes, MSE: {:.8}",
1471            int4_size, int4_error
1472        );
1473        println!(
1474            "INT4 (packed): {} bytes, MSE: {:.8}",
1475            int4_packed_size, int4_error
1476        );
1477
1478        assert_eq!(int4_size, int8_size);
1479
1480        assert!(int4_packed_size <= int8_size / 2 + 1);
1481
1482        assert!(int4_error > int8_error);
1483
1484        assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1485    }
1486
1487    #[test]
1488    fn test_int4_large_tensor() {
1489        let size = 64 * 3 * 3 * 3; // 64 filters, 3x3x3 kernels
1490        let data: Vec<f32> = (0..size)
1491            .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1492            .collect();
1493
1494        let shape = vec![64, 3, 3, 3];
1495
1496        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1497
1498        assert_eq!(quantized.data.len(), size);
1499        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1500
1501        let error = quantized.quantization_error(&data);
1502        println!("Large tensor INT4 error: {:.8}", error);
1503
1504        assert!(error < 0.01, "Error too high for large tensor: {}", error);
1505    }
1506
1507    #[test]
1508    fn test_int4_extreme_ranges() {
1509        let test_cases = vec![
1510            (vec![-0.001, 0.0, 0.001], "tiny range"),
1511            (vec![-100.0, 0.0, 100.0], "large range"),
1512            (vec![0.0, 0.0, 0.0], "all zeros"),
1513            (vec![1.0, 1.0, 1.0], "all same"),
1514        ];
1515
1516        for (data, desc) in test_cases {
1517            println!("\nTesting: {}", desc);
1518            let shape = vec![data.len()];
1519
1520            let result = QuantizedTensorInt4::from_f32(&data, shape);
1521            assert!(result.is_ok(), "Failed on {}", desc);
1522
1523            let quantized = result.unwrap();
1524            let dequantized = quantized.to_f32();
1525
1526            println!("  Original:    {:?}", data);
1527            println!("  Dequantized: {:?}", dequantized);
1528
1529            for &val in &quantized.data {
1530                assert!(
1531                    val >= -8 && val <= 7,
1532                    "Value {} out of range for {}",
1533                    val,
1534                    desc
1535                );
1536            }
1537        }
1538    }
1539
1540    #[test]
1541    fn test_int4_pack_unpack_pair() {
1542        let test_cases = vec![
1543            (-8, 7),
1544            (-8, -8),
1545            (7, 7),
1546            (0, 0),
1547            (-1, 0),
1548            (0, -1),
1549            (-5, 3),
1550            (6, -4),
1551        ];
1552
1553        for (val1, val2) in test_cases {
1554            println!("\nTesting: ({}, {})", val1, val2);
1555
1556            let packed = pack_int4_pair(val1, val2);
1557            let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1558
1559            println!("  Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1560            println!("  Unpacked: ({}, {})", unpacked1, unpacked2);
1561
1562            assert_eq!(val1, unpacked1, "First value mismatch");
1563            assert_eq!(val2, unpacked2, "Second value mismatch");
1564        }
1565    }
1566
1567    #[test]
1568    fn test_int4_pack_unpack_vector() {
1569        let values = vec![-8, -7, -1, 0, 1, 7];
1570        let packed = pack_int4(&values);
1571        let unpacked = unpack_int4(&packed, values.len());
1572
1573        println!("\nEven length:");
1574        println!("  Original: {:?}", values);
1575        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1576        println!("  Unpacked: {:?}", unpacked);
1577
1578        assert_eq!(values, unpacked);
1579        assert_eq!(packed.len(), (values.len() + 1) / 2);
1580    }
1581
1582    #[test]
1583    fn test_int4_pack_unpack_odd_length() {
1584        let values = vec![-8, -5, 0, 5, 7];
1585        let packed = pack_int4(&values);
1586        let unpacked = unpack_int4(&packed, values.len());
1587
1588        println!("\nOdd length:");
1589        println!("  Original: {:?}", values);
1590        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1591        println!("  Unpacked: {:?}", unpacked);
1592
1593        assert_eq!(values, unpacked);
1594        assert_eq!(packed.len(), (values.len() + 1) / 2);
1595    }
1596
1597    #[test]
1598    fn test_int4_pack_all_values() {
1599        let values: Vec<i8> = (-8..=7).collect();
1600        let packed = pack_int4(&values);
1601        let unpacked = unpack_int4(&packed, values.len());
1602
1603        println!("\nAll INT4 values:");
1604        println!("  Original: {:?}", values);
1605        println!("  Packed:   {} bytes", packed.len());
1606        println!("  Unpacked: {:?}", unpacked);
1607
1608        assert_eq!(values, unpacked);
1609        assert_eq!(packed.len(), 8);
1610    }
1611
1612    #[test]
1613    fn test_int4_pack_large_vector() {
1614        let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1615        let packed = pack_int4(&values);
1616        let unpacked = unpack_int4(&packed, values.len());
1617
1618        assert_eq!(values, unpacked);
1619        assert_eq!(packed.len(), 500);
1620
1621        println!("\nLarge vector:");
1622        println!("  Original: {} values", values.len());
1623        println!(
1624            "  Packed:   {} bytes ({}x compression)",
1625            packed.len(),
1626            values.len() / packed.len()
1627        );
1628        println!("  Unpacked: {} values", unpacked.len());
1629    }
1630
1631    #[test]
1632    fn test_int4_compression_ratio() {
1633        let size = 10000;
1634        let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1635
1636        let unpacked_size = values.len() * std::mem::size_of::<i8>();
1637
1638        let packed = pack_int4(&values);
1639        let packed_size = packed.len();
1640
1641        let compression_ratio = unpacked_size as f32 / packed_size as f32;
1642
1643        println!("\nCompression test:");
1644        println!("  Values:      {}", size);
1645        println!("  Unpacked:    {} bytes", unpacked_size);
1646        println!("  Packed:      {} bytes", packed_size);
1647        println!("  Compression: {:.2}x", compression_ratio);
1648
1649        assert!(
1650            (compression_ratio - 2.0).abs() < 0.01,
1651            "Expected ~2x compression, got {:.2}x",
1652            compression_ratio
1653        );
1654    }
1655
1656    #[test]
1657    fn test_int4_tensor_packing() {
1658        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1659        let shape = vec![1000];
1660
1661        let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1662
1663        println!("Before packing:");
1664        println!("  Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1665        println!("  Is packed: {}", quantized.is_packed());
1666
1667        assert!(!quantized.is_packed());
1668        assert_eq!(quantized.size_bytes(), 1000);
1669
1670        quantized.pack();
1671
1672        println!("\nAfter packing:");
1673        println!("  Packed size: {} bytes", quantized.size_bytes());
1674        println!("  Is packed: {}", quantized.is_packed());
1675        println!(
1676            "  Compression: {}x",
1677            quantized.unpacked_size_bytes() / quantized.size_bytes()
1678        );
1679
1680        assert!(quantized.is_packed());
1681        assert_eq!(quantized.size_bytes(), 500);
1682
1683        let dequantized = quantized.to_f32();
1684        assert_eq!(dequantized.len(), 1000);
1685
1686        let error = quantized.quantization_error(&data);
1687        println!("  MSE after packing: {:.8}", error);
1688        assert!(error < 0.01);
1689    }
1690
1691    #[test]
1692    fn test_int4_packed_vs_unpacked_error() {
1693        let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1694        let shape = vec![100];
1695
1696        let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1697        let error_unpacked = unpacked.quantization_error(&data);
1698
1699        let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1700        packed.pack();
1701        let error_packed = packed.quantization_error(&data);
1702
1703        println!("Unpacked error: {:.8}", error_unpacked);
1704        println!("Packed error:   {:.8}", error_packed);
1705
1706        assert!((error_unpacked - error_packed).abs() < 1e-6);
1707    }
1708
1709    #[test]
1710    fn test_int4_per_channel_packing() {
1711        let mut data = vec![];
1712        for i in 0..500 {
1713            data.push((i as f32 / 500.0) * 0.2 - 0.1); // Channel 0
1714        }
1715        for i in 0..500 {
1716            data.push((i as f32 / 500.0) * 20.0 - 10.0); // Channel 1
1717        }
1718
1719        let shape = vec![2, 500];
1720
1721        let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1722
1723        let error_before = quantized.quantization_error(&data);
1724        println!("Error before packing: {:.8}", error_before);
1725
1726        quantized.pack();
1727
1728        let error_after = quantized.quantization_error(&data);
1729        println!("Error after packing:  {:.8}", error_after);
1730        println!(
1731            "Size: {} bytes (packed from {} bytes)",
1732            quantized.size_bytes(),
1733            quantized.unpacked_size_bytes()
1734        );
1735
1736        assert!((error_before - error_after).abs() < 1e-6);
1737
1738        assert_eq!(quantized.size_bytes(), 500);
1739    }
1740
1741    #[test]
1742    fn test_int4_compression_comparison() {
1743        let size = 10000;
1744        let data: Vec<f32> = (0..size)
1745            .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1746            .collect();
1747        let shape = vec![size];
1748
1749        let fp32_size = size * std::mem::size_of::<f32>();
1750
1751        let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1752        let int8_size = int8.size_bytes();
1753
1754        let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1755        let int4_unpacked_size = int4_unpacked.size_bytes();
1756
1757        let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1758        int4_packed.pack();
1759        let int4_packed_size = int4_packed.size_bytes();
1760
1761        println!("\nCompression Comparison:");
1762        println!("  FP32:          {} bytes", fp32_size);
1763        println!(
1764            "  INT8:          {} bytes ({:.1}x)",
1765            int8_size,
1766            fp32_size as f32 / int8_size as f32
1767        );
1768        println!(
1769            "  INT4 unpacked: {} bytes ({:.1}x)",
1770            int4_unpacked_size,
1771            fp32_size as f32 / int4_unpacked_size as f32
1772        );
1773        println!(
1774            "  INT4 packed:   {} bytes ({:.1}x)",
1775            int4_packed_size,
1776            fp32_size as f32 / int4_packed_size as f32
1777        );
1778
1779        assert_eq!(fp32_size / int8_size, 4); // 4x compression
1780        assert_eq!(fp32_size / int4_packed_size, 8); // 8x compression!
1781    }
1782
1783    #[test]
1784    #[ignore] // Run manually with: cargo test test_int4_real_model -- --ignored --nocapture
1785    fn test_int4_real_model() {
1786        use crate::onnx_utils::OnnxModel;
1787
1788        println!("\n{}", "=".repeat(60));
1789        println!("INT4 Real Model Test");
1790        println!("\n{}", "=".repeat(60));
1791
1792        let model_paths = vec![
1793            "test_models/mnist.onnx",
1794            "mnist.onnx",
1795            "test_models/resnet18-v1-7.onnx",
1796            "resnet18-v1-7.onnx",
1797        ];
1798
1799        let mut model = None;
1800        for path in &model_paths {
1801            if std::path::Path::new(path).exists() {
1802                println!("Loading model: {}", path);
1803                match OnnxModel::load(path) {
1804                    Ok(m) => {
1805                        model = Some(m);
1806                        break;
1807                    }
1808                    Err(e) => println!("  Failed: {}", e),
1809                }
1810            }
1811        }
1812
1813        let model = match model {
1814            Some(m) => m,
1815            None => {
1816                println!("No test models found. Skipping test.");
1817                println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1818                return;
1819            }
1820        };
1821
1822        let info = model.info();
1823        println!("✓ Model loaded: {}", info.name);
1824        println!("  Nodes: {}", info.num_nodes);
1825        println!();
1826
1827        println!("Extracting weights...");
1828        let weights = model.extract_weights();
1829        println!("✓ Found {} weight tensors", weights.len());
1830
1831        if weights.is_empty() {
1832            println!("No weights to quantize!");
1833            return;
1834        }
1835
1836        println!();
1837        println!("\n{}", "=".repeat(60));
1838        println!("Testing Per-Tensor Quantization");
1839        println!("\n{}", "=".repeat(60));
1840
1841        let test_weights: Vec<_> = weights
1842            .iter()
1843            .filter(|w| w.data.len() > 1000)
1844            .take(5)
1845            .collect();
1846
1847        println!("Testing {} large layers:\n", test_weights.len());
1848
1849        for (idx, weight) in test_weights.iter().enumerate() {
1850            let name = if weight.name.len() > 40 {
1851                format!("{}...", &weight.name[..37])
1852            } else {
1853                weight.name.clone()
1854            };
1855
1856            println!("[{}] {}", idx + 1, name);
1857            println!(
1858                "    Shape: {:?}, Elements: {}",
1859                weight.shape,
1860                weight.data.len()
1861            );
1862
1863            let fp32_size = weight.data.len() * 4;
1864
1865            let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1866            let (int8_size, int8_error) = if let Ok(q) = int8_result {
1867                (q.size_bytes(), q.quantization_error(&weight.data))
1868            } else {
1869                println!("    INT8 failed!");
1870                continue;
1871            };
1872
1873            let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1874            let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1875                (q.size_bytes(), q.quantization_error(&weight.data))
1876            } else {
1877                println!("    INT4 failed!");
1878                continue;
1879            };
1880
1881            let mut int4_packed =
1882                QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1883            int4_packed.pack();
1884            let int4_packed_size = int4_packed.size_bytes();
1885            let int4_packed_error = int4_packed.quantization_error(&weight.data);
1886
1887            println!("    FP32:          {:7} bytes", fp32_size);
1888            println!(
1889                "    INT8:          {:7} bytes ({:.1}x) MSE: {:.8}",
1890                int8_size,
1891                fp32_size as f32 / int8_size as f32,
1892                int8_error
1893            );
1894            println!(
1895                "    INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1896                int4_unpacked_size,
1897                fp32_size as f32 / int4_unpacked_size as f32,
1898                int4_error
1899            );
1900            println!(
1901                "    INT4 packed:   {:7} bytes ({:.1}x) MSE: {:.8}",
1902                int4_packed_size,
1903                fp32_size as f32 / int4_packed_size as f32,
1904                int4_packed_error
1905            );
1906
1907            assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1908
1909            let int8_ratio = fp32_size as f32 / int8_size as f32;
1910            let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1911
1912            assert!(
1913                (int8_ratio - 4.0).abs() < 0.1,
1914                "INT8 compression should be ~4x"
1915            );
1916            assert!(
1917                (int4_ratio - 8.0).abs() < 0.1,
1918                "INT4 compression should be ~8x"
1919            );
1920
1921            println!();
1922        }
1923
1924        println!("\n{}", "=".repeat(60));
1925        println!("Testing Per-Channel Quantization");
1926        println!("\n{}", "=".repeat(60));
1927
1928        // Test per-channel on Conv layers (multi-dimensional)
1929        let conv_weights: Vec<_> = weights
1930            .iter()
1931            .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1932            .take(3)
1933            .collect();
1934
1935        if conv_weights.is_empty() {
1936            println!("No multi-channel layers found for per-channel test.");
1937        } else {
1938            println!("Testing {} conv layers:\n", conv_weights.len());
1939
1940            for (idx, weight) in conv_weights.iter().enumerate() {
1941                let name = if weight.name.len() > 40 {
1942                    format!("{}...", &weight.name[..37])
1943                } else {
1944                    weight.name.clone()
1945                };
1946
1947                println!("[{}] {}", idx + 1, name);
1948                println!(
1949                    "    Shape: {:?}, Channels: {}",
1950                    weight.shape, weight.shape[0]
1951                );
1952
1953                let per_tensor =
1954                    QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1955                let per_tensor_error = per_tensor.quantization_error(&weight.data);
1956
1957                let per_channel_result =
1958                    QuantizedTensorInt4::from_f32_per_channel(&weight.data, weight.shape.clone());
1959
1960                if let Ok(per_channel) = per_channel_result {
1961                    let per_channel_error = per_channel.quantization_error(&weight.data);
1962
1963                    let improvement =
1964                        ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1965
1966                    println!("    Per-tensor:  MSE: {:.8}", per_tensor_error);
1967                    println!(
1968                        "    Per-channel: MSE: {:.8} ({:.1}% better)",
1969                        per_channel_error, improvement
1970                    );
1971
1972                    assert!(
1973                        per_channel_error <= per_tensor_error * 1.1,
1974                        "Per-channel should not be significantly worse"
1975                    );
1976                } else {
1977                    println!("    Per-channel failed!");
1978                }
1979
1980                println!();
1981            }
1982        }
1983
1984        println!("\n{}", "=".repeat(60));
1985        println!("Summary");
1986        println!("\n{}", "=".repeat(60));
1987
1988        println!("✓ INT4 quantization works on real model weights");
1989        println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1990        println!("✓ Bit packing is lossless");
1991        println!("✓ Per-channel quantization works");
1992        println!("\nINT4 implementation is ready for CLI integration!");
1993    }
1994
1995    // -----------------------------------------------------------------------
1996    // All-NaN / all-Inf edge cases
1997    // -----------------------------------------------------------------------
1998
1999    #[test]
2000    fn test_all_nan_returns_error() {
2001        let data = vec![f32::NAN, f32::NAN, f32::NAN];
2002        let result = QuantizedTensor::from_f32(&data, vec![3]);
2003        assert!(result.is_err());
2004        let err = result.unwrap_err().to_string();
2005        assert!(
2006            err.contains("non-finite"),
2007            "error should mention non-finite: {}",
2008            err
2009        );
2010    }
2011
2012    #[test]
2013    fn test_all_inf_returns_error() {
2014        let data = vec![f32::INFINITY, f32::NEG_INFINITY];
2015        let result = QuantizedTensor::from_f32(&data, vec![2]);
2016        assert!(result.is_err());
2017    }
2018
2019    #[test]
2020    fn test_all_nan_int4_returns_error() {
2021        let data = vec![f32::NAN; 4];
2022        let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
2023        assert!(result.is_err());
2024    }
2025
2026    #[test]
2027    fn test_all_nan_per_channel_returns_error() {
2028        let data = vec![f32::NAN; 6];
2029        let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
2030        assert!(result.is_err());
2031        let err = result.unwrap_err().to_string();
2032        assert!(
2033            err.contains("Channel 0"),
2034            "error should mention channel: {}",
2035            err
2036        );
2037    }
2038
2039    #[test]
2040    fn test_mixed_nan_finite_succeeds() {
2041        // Some NaN, some finite — should succeed using finite range
2042        let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
2043        let result = QuantizedTensor::from_f32(&data, vec![4]);
2044        assert!(result.is_ok());
2045    }
2046
2047    // -----------------------------------------------------------------------
2048    // Symmetric quantization
2049    // -----------------------------------------------------------------------
2050
2051    #[test]
2052    fn test_int8_symmetric_params_zero_point_is_zero() {
2053        let params = QuantParams::from_range_symmetric(-0.5, 2.0);
2054        assert_eq!(params.zero_point(), 0, "symmetric must have zp=0");
2055        // scale = abs_max / QMAX = 2.0 / 127
2056        let expected_scale = 2.0_f32 / 127.0;
2057        assert!(
2058            (params.scale() - expected_scale).abs() < 1e-6,
2059            "scale {} vs expected {}",
2060            params.scale(),
2061            expected_scale
2062        );
2063    }
2064
2065    #[test]
2066    fn test_int4_symmetric_params_zero_point_is_zero() {
2067        let params = QuantParamsInt4::from_range_symmetric(-3.0, 1.0);
2068        assert_eq!(params.zero_point(), 0);
2069        // scale = 3.0 / 7
2070        let expected_scale = 3.0_f32 / 7.0;
2071        assert!((params.scale() - expected_scale).abs() < 1e-6);
2072    }
2073
2074    #[test]
2075    fn test_symmetric_zero_dequantizes_to_zero() {
2076        // The defining property of symmetric quantization: 0.0 → 0 → 0.0 exactly.
2077        let params = QuantParams::from_range_symmetric(-10.0, 10.0);
2078        let q = params.quantize(0.0);
2079        assert_eq!(q, 0);
2080        let dq = params.dequantize(q);
2081        assert_eq!(dq, 0.0);
2082    }
2083
2084    #[test]
2085    fn test_symmetric_asymmetric_produce_different_scales() {
2086        // For a skewed range, asymmetric gives a tighter scale than symmetric.
2087        let asym = QuantParams::from_range(0.0, 10.0);
2088        let sym = QuantParams::from_range_symmetric(0.0, 10.0);
2089        assert_ne!(asym.zero_point(), sym.zero_point());
2090        // Asymmetric packs [0, 10] into [-128, 127] → scale = 10/255 ≈ 0.039
2091        // Symmetric uses abs_max/127 → scale = 10/127 ≈ 0.079
2092        assert!(
2093            sym.scale() > asym.scale(),
2094            "symmetric scale {} should exceed asymmetric {}",
2095            sym.scale(),
2096            asym.scale()
2097        );
2098    }
2099
2100    #[test]
2101    fn test_symmetric_constant_tensor_handled() {
2102        // All-zero tensor would give abs_max = 0 → scale must be clamped away from 0.
2103        let params = QuantParams::from_range_symmetric(0.0, 0.0);
2104        assert!(params.scale() > 0.0);
2105        assert_eq!(params.zero_point(), 0);
2106    }
2107
2108    #[test]
2109    fn test_from_f32_symmetric_tensor_has_zero_zp() {
2110        let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.1).collect();
2111        let tensor = QuantizedTensor::from_f32_symmetric(&data, vec![100]).unwrap();
2112        assert_eq!(tensor.params().zero_point(), 0);
2113    }
2114
2115    #[test]
2116    fn test_from_f32_per_channel_symmetric_every_channel_zp_zero() {
2117        // 4 channels with different ranges.
2118        let mut data = Vec::new();
2119        for ch in 0..4 {
2120            let scale = (ch + 1) as f32;
2121            for i in 0..16 {
2122                data.push((i as f32 - 8.0) * 0.1 * scale);
2123            }
2124        }
2125        let tensor = QuantizedTensor::from_f32_per_channel_symmetric(&data, vec![4, 16]).unwrap();
2126
2127        let channel_params = tensor
2128            .channel_params
2129            .as_ref()
2130            .expect("per-channel expected");
2131        assert_eq!(channel_params.len(), 4);
2132        for (i, p) in channel_params.iter().enumerate() {
2133            assert_eq!(p.zero_point(), 0, "channel {} zp should be 0", i);
2134            assert!(p.scale() > 0.0, "channel {} scale must be positive", i);
2135        }
2136    }
2137
2138    #[test]
2139    fn test_symmetric_round_trip_error_bounded() {
2140        let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
2141        let tensor = QuantizedTensor::from_f32_symmetric(&data, vec![500]).unwrap();
2142        let mse = tensor.quantization_error(&data);
2143        // Symmetric INT8 on [-1, 1] should still be very accurate.
2144        assert!(mse < 1e-3, "symmetric MSE unexpectedly high: {}", mse);
2145    }
2146
2147    #[test]
2148    fn test_int4_symmetric_round_trip_error_bounded() {
2149        let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
2150        let tensor = QuantizedTensorInt4::from_f32_symmetric(&data, vec![500]).unwrap();
2151        let mse = tensor.quantization_error(&data);
2152        // INT4 symmetric has ~15 levels over [-1, 1] → expected MSE < ~0.005
2153        assert!(mse < 0.01, "INT4 symmetric MSE too high: {}", mse);
2154    }
2155
2156    #[test]
2157    fn test_quantizer_symmetric_config_routes_correctly() {
2158        let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
2159        let config = QuantConfig {
2160            bits: 8,
2161            per_channel: true,
2162            symmetric: true,
2163            ..Default::default()
2164        };
2165        let q = Quantizer::new(config)
2166            .quantize_tensor(&data, vec![4, 16])
2167            .unwrap();
2168        let (_, zero_points) = q.get_all_scales_zero_points();
2169        assert!(
2170            zero_points.iter().all(|&z| z == 0),
2171            "all zero_points must be 0 under symmetric config, got {:?}",
2172            zero_points
2173        );
2174    }
2175}