Skip to main content

quantize_rs/quantization/
mod.rs

1//! Core quantization logic for INT8 and INT4.
2//!
3//! Provides tensor-level quantization (per-tensor and per-channel),
4//! INT4 bit-packing, and the high-level [`Quantizer`] that combines
5//! a [`QuantConfig`] with optional calibration statistics.
6
7use crate::errors::{QuantizeError, Result};
8
9/// Configuration for a quantization pass.
10#[derive(Debug, Clone)]
11pub struct QuantConfig {
12    /// Bit width: `4` for INT4 or `8` for INT8.
13    pub bits: u8,
14    /// When `true`, compute separate scale/zero-point per output channel (axis 0).
15    pub per_channel: bool,
16    /// Optional calibration method used for range optimization.
17    pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
18    /// Layer names to skip entirely (exact match against the initializer name).
19    pub excluded_layers: Vec<String>,
20    /// Per-layer bit-width overrides.  Key = initializer name, value = 4 or 8.
21    pub layer_bits: std::collections::HashMap<String, u8>,
22    /// Minimum number of elements a tensor must have to be quantized.
23    /// Tensors with fewer elements are left in FP32.  Defaults to 0 (no minimum).
24    pub min_elements: usize,
25}
26
27impl Default for QuantConfig {
28    fn default() -> Self {
29        Self {
30            bits: 8,
31            per_channel: false,
32            calibration_method: None,
33            excluded_layers: Vec::new(),
34            layer_bits: std::collections::HashMap::new(),
35            min_elements: 0,
36        }
37    }
38}
39
40impl QuantConfig {
41    /// Create a default INT8 per-tensor configuration.
42    pub fn int8() -> Self {
43        Self::default()
44    }
45
46    /// Enable or disable per-channel quantization.
47    pub fn with_per_channel(mut self, enabled: bool) -> Self {
48        self.per_channel = enabled;
49        self
50    }
51
52    /// Set the calibration method for range optimization.
53    pub fn with_calibration(
54        mut self,
55        method: crate::calibration::methods::CalibrationMethod,
56    ) -> Self {
57        self.calibration_method = Some(method);
58        self
59    }
60
61    /// Return `true` if the layer should be quantized.
62    ///
63    /// A layer is skipped when:
64    /// - its name appears in [`excluded_layers`], or
65    /// - `num_elements` is below [`min_elements`] (and `min_elements > 0`).
66    pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
67        if self.excluded_layers.iter().any(|e| e == name) {
68            return false;
69        }
70        if self.min_elements > 0 && num_elements < self.min_elements {
71            return false;
72        }
73        true
74    }
75
76    /// Return the effective bit width for a layer.
77    ///
78    /// If the layer name has an entry in [`layer_bits`], that value is used;
79    /// otherwise the global [`bits`] is returned.
80    pub fn bits_for_layer(&self, name: &str) -> u8 {
81        self.layer_bits.get(name).copied().unwrap_or(self.bits)
82    }
83}
84
85// ---------------------------------------------------------------------------
86// QuantRange trait and marker types
87// ---------------------------------------------------------------------------
88
89/// Marker trait that supplies the clamp constants for a quantization bit-width.
90pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
91    /// Minimum quantized value (inclusive).
92    const QMIN: f32;
93    /// Maximum quantized value (inclusive).
94    const QMAX: f32;
95    /// Bit width (4 or 8).
96    const BITS: u8;
97}
98
99/// Marker for INT8 quantization (`-128 … 127`).
100#[derive(Debug, Clone)]
101pub struct Int8Range;
102impl QuantRange for Int8Range {
103    const QMIN: f32 = -128.0;
104    const QMAX: f32 = 127.0;
105    const BITS: u8 = 8;
106}
107
108/// Marker for INT4 quantization (`-8 … 7`).
109#[derive(Debug, Clone)]
110pub struct Int4Range;
111impl QuantRange for Int4Range {
112    const QMIN: f32 = -8.0;
113    const QMAX: f32 = 7.0;
114    const BITS: u8 = 4;
115}
116
117// ---------------------------------------------------------------------------
118// QuantParamsGeneric<R>
119// ---------------------------------------------------------------------------
120
121/// Affine quantization parameters (scale and zero-point), generic over bit-width.
122///
123/// - INT8: `q = clamp(round(x / scale) + zero_point, -128, 127)`
124/// - INT4: `q = clamp(round(x / scale) + zero_point, -8, 7)`
125/// - Dequantization: `x = (q - zero_point) * scale`
126#[derive(Debug, Clone)]
127pub struct QuantParamsGeneric<R: QuantRange> {
128    scale: f32,
129    zero_point: i8,
130    _marker: std::marker::PhantomData<R>,
131}
132
133/// INT8 affine quantization parameters — `clamp(-128, 127)`.
134pub type QuantParams = QuantParamsGeneric<Int8Range>;
135/// INT4 affine quantization parameters — `clamp(-8, 7)`.
136pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
137
138impl<R: QuantRange> QuantParamsGeneric<R> {
139    /// Quantization scale factor.
140    pub fn scale(&self) -> f32 {
141        self.scale
142    }
143    /// Quantization zero point.
144    pub fn zero_point(&self) -> i8 {
145        self.zero_point
146    }
147
148    /// Compute quantization parameters from a floating-point range.
149    pub fn from_range(min: f32, max: f32) -> Self {
150        let min = min.min(0.0);
151        let max = max.max(0.0);
152
153        // Handle constant-value tensors: when min ≈ max the data is (near-)constant.
154        // Use unit scale centred on zero so that the constant dequantizes accurately.
155        let (min, max) = if (max - min).abs() < 1e-8 {
156            let abs = min.abs().max(max.abs()).max(1e-8);
157            (-abs, abs)
158        } else {
159            (min, max)
160        };
161
162        let scale = (max - min) / (R::QMAX - R::QMIN);
163        let scale = scale.max(1e-8);
164
165        let initial_zero_point = R::QMIN - min / scale;
166        // Guard against NaN — if min/scale produced NaN (degenerate input),
167        // fall back to 0 to avoid undefined behaviour on the `as i8` cast.
168        let zero_point = if initial_zero_point.is_finite() {
169            initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8
170        } else {
171            0i8
172        };
173
174        QuantParamsGeneric {
175            scale,
176            zero_point,
177            _marker: std::marker::PhantomData,
178        }
179    }
180
181    /// Quantize a single float to the target integer type.
182    pub fn quantize(&self, value: f32) -> i8 {
183        if !value.is_finite() {
184            return self.zero_point;
185        }
186        let quantized = (value / self.scale).round() + (self.zero_point as f32);
187        quantized.clamp(R::QMIN, R::QMAX) as i8
188    }
189
190    /// Dequantize a single integer value back to float.
191    pub fn dequantize(&self, value: i8) -> f32 {
192        ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
193    }
194}
195
196// ---------------------------------------------------------------------------
197// QuantizedTensorGeneric<R>
198// ---------------------------------------------------------------------------
199
200/// Generic quantized tensor, parameterized by bit-width marker.
201///
202/// For INT4 tensors, call [`QuantizedTensorGeneric::pack`] to compress two
203/// values per byte for 2× storage savings.
204#[derive(Debug, Clone)]
205pub struct QuantizedTensorGeneric<R: QuantRange> {
206    pub(crate) data: Vec<i8>,
207    /// Bit-packed storage — always `None` for INT8, set by `.pack()` for INT4.
208    pub(crate) packed_data: Option<Vec<u8>>,
209    pub(crate) shape: Vec<usize>,
210    pub(crate) params: QuantParamsGeneric<R>,
211    pub(crate) per_channel: bool,
212    pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
213}
214
215/// An INT8 quantized tensor with optional per-channel parameters.
216pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
217
218/// An INT4 quantized tensor with optional per-channel parameters and bit packing.
219///
220/// Values are stored in the range `[-8, 7]`. Call [`pack`](QuantizedTensorInt4::pack) to
221/// compress two values into one byte for 2× storage savings.
222pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
223
224// ---------------------------------------------------------------------------
225// Shared impl for all bit-widths
226// ---------------------------------------------------------------------------
227
228impl<R: QuantRange> QuantizedTensorGeneric<R> {
229    /// Tensor shape.
230    pub fn shape(&self) -> &[usize] {
231        &self.shape
232    }
233    /// Per-tensor quantization parameters (channel-0 if per-channel).
234    pub fn params(&self) -> &QuantParamsGeneric<R> {
235        &self.params
236    }
237    /// Whether per-channel quantization was used.
238    pub fn is_per_channel(&self) -> bool {
239        self.per_channel
240    }
241
242    /// Quantize FP32 data, computing the range from the data.
243    ///
244    /// # Errors
245    ///
246    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
247    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
248        if data.is_empty() {
249            return Err(QuantizeError::InvalidTensor {
250                reason: "Cannot quantize empty tensor".into(),
251            });
252        }
253
254        let expected_len: usize = shape.iter().product();
255        if expected_len != data.len() {
256            return Err(QuantizeError::InvalidTensor {
257                reason: format!(
258                    "Shape {:?} expects {} elements but got {}",
259                    shape,
260                    expected_len,
261                    data.len()
262                ),
263            });
264        }
265
266        let min = data
267            .iter()
268            .copied()
269            .filter(|v| v.is_finite())
270            .fold(f32::INFINITY, f32::min);
271        let max = data
272            .iter()
273            .copied()
274            .filter(|v| v.is_finite())
275            .fold(f32::NEG_INFINITY, f32::max);
276
277        if !min.is_finite() || !max.is_finite() {
278            return Err(QuantizeError::InvalidTensor {
279                reason: "Tensor contains only non-finite values (NaN/Inf)".into(),
280            });
281        }
282
283        let params = QuantParamsGeneric::<R>::from_range(min, max);
284
285        let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
286
287        Ok(QuantizedTensorGeneric {
288            data: quantized_data,
289            packed_data: None,
290            shape,
291            params,
292            per_channel: false,
293            channel_params: None,
294        })
295    }
296
297    /// Quantize FP32 data using an explicit range (for calibration).
298    ///
299    /// # Errors
300    ///
301    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
302    pub fn from_f32_with_range(
303        data: &[f32],
304        shape: Vec<usize>,
305        min: f32,
306        max: f32,
307    ) -> Result<Self> {
308        if data.is_empty() {
309            return Err(QuantizeError::InvalidTensor {
310                reason: "Cannot quantize empty tensor".into(),
311            });
312        }
313
314        let expected_len: usize = shape.iter().product();
315        if expected_len != data.len() {
316            return Err(QuantizeError::InvalidTensor {
317                reason: format!(
318                    "Shape {:?} expects {} elements but got {}",
319                    shape,
320                    expected_len,
321                    data.len()
322                ),
323            });
324        }
325
326        let params = QuantParamsGeneric::<R>::from_range(min, max);
327
328        let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
329
330        Ok(QuantizedTensorGeneric {
331            data: quantized_data,
332            packed_data: None,
333            shape,
334            params,
335            per_channel: false,
336            channel_params: None,
337        })
338    }
339
340    /// Quantize FP32 data with per-channel ranges (axis 0 only).
341    ///
342    /// # Errors
343    ///
344    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty, shape
345    /// mismatches, or the tensor is scalar.
346    pub fn from_f32_per_channel(data: &[f32], shape: Vec<usize>) -> Result<Self> {
347        if data.is_empty() {
348            return Err(QuantizeError::InvalidTensor {
349                reason: "Cannot quantize empty tensor".into(),
350            });
351        }
352
353        if shape.is_empty() {
354            return Err(QuantizeError::InvalidTensor {
355                reason: "Cannot do per-channel quantization on scalar".into(),
356            });
357        }
358
359        let expected_len: usize = shape.iter().product();
360        if expected_len != data.len() {
361            return Err(QuantizeError::InvalidTensor {
362                reason: format!(
363                    "Shape {:?} expects {} elements but got {}",
364                    shape,
365                    expected_len,
366                    data.len()
367                ),
368            });
369        }
370
371        let num_channels = shape[0];
372
373        let mut channel_params = Vec::new();
374        let mut quantized_data = Vec::with_capacity(data.len());
375
376        for channel_idx in 0..num_channels {
377            let channel_data = extract_channel(data, &shape, channel_idx)?;
378
379            let min = channel_data
380                .iter()
381                .copied()
382                .filter(|v| v.is_finite())
383                .fold(f32::INFINITY, f32::min);
384            let max = channel_data
385                .iter()
386                .copied()
387                .filter(|v| v.is_finite())
388                .fold(f32::NEG_INFINITY, f32::max);
389
390            if !min.is_finite() || !max.is_finite() {
391                return Err(QuantizeError::InvalidTensor {
392                    reason: format!(
393                        "Channel {} contains only non-finite values (NaN/Inf)",
394                        channel_idx
395                    ),
396                });
397            }
398
399            let params = QuantParamsGeneric::<R>::from_range(min, max);
400            channel_params.push(params.clone());
401
402            for &value in &channel_data {
403                quantized_data.push(params.quantize(value));
404            }
405        }
406
407        // Use first channel params as "representative" for backward compatibility
408        let params = channel_params[0].clone();
409
410        Ok(QuantizedTensorGeneric {
411            data: quantized_data,
412            packed_data: None,
413            shape,
414            params,
415            per_channel: true,
416            channel_params: Some(channel_params),
417        })
418    }
419
420    /// Dequantize all values back to FP32.
421    pub fn to_f32(&self) -> Vec<f32> {
422        // Borrow data directly when unpacked; allocate only for the packed INT4 path.
423        let data_owned;
424        let data: &[i8] = if let Some(ref packed) = self.packed_data {
425            data_owned = unpack_int4(packed, self.data.len());
426            &data_owned
427        } else {
428            &self.data
429        };
430
431        if self.per_channel {
432            if let Some(ref channel_params) = self.channel_params {
433                if channel_params.is_empty() {
434                    return data.iter().map(|&v| self.params.dequantize(v)).collect();
435                }
436                let elements_per_channel = data.len() / channel_params.len();
437                data.iter()
438                    .enumerate()
439                    .map(|(i, &v)| {
440                        let channel_idx = (i / elements_per_channel).min(channel_params.len() - 1);
441                        channel_params[channel_idx].dequantize(v)
442                    })
443                    .collect()
444            } else {
445                data.iter().map(|&v| self.params.dequantize(v)).collect()
446            }
447        } else {
448            data.iter().map(|&v| self.params.dequantize(v)).collect()
449        }
450    }
451
452    /// Size of the quantized data in bytes (packed if available, unpacked otherwise).
453    pub fn size_bytes(&self) -> usize {
454        if let Some(ref packed) = self.packed_data {
455            packed.len()
456        } else {
457            self.data.len() * std::mem::size_of::<i8>()
458        }
459    }
460
461    /// Mean squared error between the original data and the dequantized values.
462    pub fn quantization_error(&self, original: &[f32]) -> f32 {
463        if original.is_empty() {
464            return 0.0;
465        }
466
467        let dequantized = self.to_f32();
468
469        let sum: f32 = original
470            .iter()
471            .zip(dequantized.iter())
472            .map(|(a, b)| (a - b).powi(2))
473            .sum();
474
475        sum / original.len() as f32
476    }
477}
478
479// ---------------------------------------------------------------------------
480// INT4-specific methods
481// ---------------------------------------------------------------------------
482
483impl QuantizedTensorGeneric<Int4Range> {
484    /// Pack two INT4 values per byte for 2× compression.
485    pub fn pack(&mut self) {
486        self.packed_data = Some(pack_int4(&self.data));
487    }
488
489    /// Return unpacked i8 data, decompressing from packed storage if needed.
490    pub fn ensure_unpacked(&self) -> Vec<i8> {
491        if let Some(ref packed) = self.packed_data {
492            unpack_int4(packed, self.data.len())
493        } else {
494            self.data.clone()
495        }
496    }
497
498    /// Whether the data is currently bit-packed.
499    pub fn is_packed(&self) -> bool {
500        self.packed_data.is_some()
501    }
502
503    /// Size that the packed representation would occupy (or already occupies).
504    pub fn packed_size_bytes(&self) -> usize {
505        if let Some(ref packed) = self.packed_data {
506            packed.len()
507        } else {
508            self.data.len().div_ceil(2)
509        }
510    }
511
512    /// Size of the unpacked representation in bytes.
513    pub fn unpacked_size_bytes(&self) -> usize {
514        self.data.len() * std::mem::size_of::<i8>()
515    }
516}
517
518// ---------------------------------------------------------------------------
519// INT4 bit-packing helpers
520// ---------------------------------------------------------------------------
521
522fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
523    debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
524    debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
525
526    // Convert to 4-bit representation
527    let nibble1 = (val1 & 0x0F) as u8;
528    let nibble2 = (val2 & 0x0F) as u8;
529
530    // Pack: high 4 bits = val1, low 4 bits = val2
531    (nibble1 << 4) | nibble2
532}
533
534fn unpack_int4_pair(byte: u8) -> (i8, i8) {
535    let nibble1 = (byte >> 4) & 0x0F;
536    let nibble2 = byte & 0x0F;
537
538    // Convert from 4-bit to signed i8
539    let val1 = if nibble1 >= 8 {
540        (nibble1 as i8) | !0x0F
541    } else {
542        nibble1 as i8
543    };
544
545    let val2 = if nibble2 >= 8 {
546        (nibble2 as i8) | !0x0F
547    } else {
548        nibble2 as i8
549    };
550
551    (val1, val2)
552}
553
554/// Pack a slice of INT4 values (two per byte, high nibble first).
555pub fn pack_int4(values: &[i8]) -> Vec<u8> {
556    let mut packed = Vec::with_capacity(values.len().div_ceil(2));
557
558    for chunk in values.chunks(2) {
559        let val1 = chunk[0];
560        let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
561
562        packed.push(pack_int4_pair(val1, val2));
563    }
564
565    packed
566}
567
568/// Unpack INT4 values from packed bytes, returning exactly `num_values` i8s.
569pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
570    let mut values = Vec::with_capacity(num_values);
571
572    for &byte in packed {
573        let (val1, val2) = unpack_int4_pair(byte);
574        values.push(val1);
575        if values.len() < num_values {
576            values.push(val2);
577        }
578    }
579
580    // Truncate to exact size (removes padding)
581    values.truncate(num_values);
582    values
583}
584
585/// Extract contiguous data for a single channel along axis 0.
586///
587/// Only correct for axis 0 (the leading dimension), which is the standard
588/// layout for weight tensors (e.g. [out_channels, in_channels, H, W]).
589fn extract_channel(data: &[f32], shape: &[usize], channel_idx: usize) -> Result<Vec<f32>> {
590    if shape.is_empty() {
591        return Err(QuantizeError::InvalidTensor {
592            reason: "Cannot extract channel from empty shape".into(),
593        });
594    }
595    let num_channels = shape[0];
596    if num_channels == 0 {
597        return Err(QuantizeError::InvalidTensor {
598            reason: "Number of channels is 0".into(),
599        });
600    }
601    if channel_idx >= num_channels {
602        return Err(QuantizeError::InvalidTensor {
603            reason: format!(
604                "Channel index {} out of bounds for {} channels",
605                channel_idx, num_channels
606            ),
607        });
608    }
609    if !data.len().is_multiple_of(num_channels) {
610        return Err(QuantizeError::InvalidTensor {
611            reason: format!(
612                "Data length {} not evenly divisible by {} channels",
613                data.len(),
614                num_channels
615            ),
616        });
617    }
618    let elements_per_channel = data.len() / num_channels;
619    let start = channel_idx * elements_per_channel;
620    let end = start + elements_per_channel;
621    Ok(data[start..end].to_vec())
622}
623
624// ---------------------------------------------------------------------------
625// QuantizedTensorType
626// ---------------------------------------------------------------------------
627
628/// Type-erased wrapper over [`QuantizedTensor`] (INT8) and [`QuantizedTensorInt4`] (INT4).
629#[derive(Debug, Clone)]
630pub enum QuantizedTensorType {
631    Int8(QuantizedTensor),
632    Int4(QuantizedTensorInt4),
633}
634
635impl QuantizedTensorType {
636    /// Dequantize all values back to FP32.
637    pub fn to_f32(&self) -> Vec<f32> {
638        match self {
639            QuantizedTensorType::Int8(t) => t.to_f32(),
640            QuantizedTensorType::Int4(t) => t.to_f32(),
641        }
642    }
643
644    /// Size of the quantized data in bytes.
645    pub fn size_bytes(&self) -> usize {
646        match self {
647            QuantizedTensorType::Int8(t) => t.size_bytes(),
648            QuantizedTensorType::Int4(t) => t.size_bytes(),
649        }
650    }
651
652    #[must_use]
653    pub fn quantization_error(&self, original: &[f32]) -> f32 {
654        match self {
655            QuantizedTensorType::Int8(t) => t.quantization_error(original),
656            QuantizedTensorType::Int4(t) => t.quantization_error(original),
657        }
658    }
659
660    #[must_use]
661    pub fn data(&self) -> Vec<i8> {
662        match self {
663            QuantizedTensorType::Int8(t) => t.data.clone(),
664            QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
665        }
666    }
667
668    /// Per-tensor scale and zero-point.
669    pub fn get_scale_zero_point(&self) -> (f32, i8) {
670        match self {
671            QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
672            QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
673        }
674    }
675
676    /// Return all per-channel scales and zero-points.
677    ///
678    /// For per-tensor quantization, returns single-element vectors.
679    /// For per-channel, returns one entry per channel.
680    pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
681        match self {
682            QuantizedTensorType::Int8(t) => {
683                if let Some(ref cp) = t.channel_params {
684                    (
685                        cp.iter().map(|p| p.scale).collect(),
686                        cp.iter().map(|p| p.zero_point).collect(),
687                    )
688                } else {
689                    (vec![t.params.scale], vec![t.params.zero_point])
690                }
691            }
692            QuantizedTensorType::Int4(t) => {
693                if let Some(ref cp) = t.channel_params {
694                    (
695                        cp.iter().map(|p| p.scale).collect(),
696                        cp.iter().map(|p| p.zero_point).collect(),
697                    )
698                } else {
699                    (vec![t.params.scale], vec![t.params.zero_point])
700                }
701            }
702        }
703    }
704
705    /// Whether per-channel quantization was used.
706    pub fn is_per_channel(&self) -> bool {
707        match self {
708            QuantizedTensorType::Int8(t) => t.per_channel,
709            QuantizedTensorType::Int4(t) => t.per_channel,
710        }
711    }
712
713    #[must_use]
714    pub fn bits(&self) -> u8 {
715        match self {
716            QuantizedTensorType::Int8(_) => 8,
717            QuantizedTensorType::Int4(_) => 4,
718        }
719    }
720
721    /// `true` if this is an INT8 tensor.
722    pub fn is_int8(&self) -> bool {
723        matches!(self, QuantizedTensorType::Int8(_))
724    }
725
726    /// `true` if this is an INT4 tensor.
727    pub fn is_int4(&self) -> bool {
728        matches!(self, QuantizedTensorType::Int4(_))
729    }
730
731    /// Borrow quantized data without cloning.
732    ///
733    /// Returns `None` for packed INT4 tensors (must use `data()` which unpacks).
734    pub fn data_ref(&self) -> Option<&[i8]> {
735        match self {
736            QuantizedTensorType::Int8(t) => Some(&t.data),
737            QuantizedTensorType::Int4(t) => {
738                if t.packed_data.is_some() {
739                    None // packed: caller must use data() to unpack
740                } else {
741                    Some(&t.data)
742                }
743            }
744        }
745    }
746}
747
748// ---------------------------------------------------------------------------
749// Quantizer
750// ---------------------------------------------------------------------------
751
752/// High-level quantizer that combines configuration with optional calibration.
753pub struct Quantizer {
754    config: QuantConfig,
755    calibration_stats:
756        Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
757}
758
759impl std::fmt::Debug for Quantizer {
760    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
761        let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
762        f.debug_struct("Quantizer")
763            .field("config", &self.config)
764            .field("calibration_stats_count", &stats_count)
765            .finish()
766    }
767}
768
769impl Quantizer {
770    /// Create a quantizer with the given configuration (no calibration).
771    pub fn new(config: QuantConfig) -> Self {
772        Self {
773            config,
774            calibration_stats: None,
775        }
776    }
777
778    /// Create a quantizer with configuration and pre-collected activation statistics.
779    pub fn with_calibration(
780        config: QuantConfig,
781        stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
782    ) -> Self {
783        Self {
784            config,
785            calibration_stats: Some(stats),
786        }
787    }
788
789    /// Quantize a tensor with optional calibration.
790    pub fn quantize_tensor_with_name(
791        &self,
792        name: &str,
793        data: &[f32],
794        shape: Vec<usize>,
795    ) -> Result<QuantizedTensorType> {
796        let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
797            if let Some(stats) = stats_map.get(name) {
798                if let Some(method) = self.config.calibration_method {
799                    use crate::calibration::stats::calculate_optimal_range;
800
801                    let sample_data = sample_from_activation_stats(stats, 1000);
802                    calculate_optimal_range(&sample_data, method)
803                } else {
804                    (stats.min(), stats.max())
805                }
806            } else {
807                finite_min_max(data, name)?
808            }
809        } else {
810            finite_min_max(data, name)?
811        };
812
813        self.quantize_with_range(data, shape, min, max)
814    }
815
816    /// Quantize a tensor using the configured bit width and per-channel setting.
817    ///
818    /// # Errors
819    ///
820    /// Returns [`QuantizeError::InvalidTensor`] or [`QuantizeError::UnsupportedConfig`].
821    pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
822        self.build_tensor_with_optional_range(data, shape, None)
823    }
824
825    /// Quantize with specific range (for calibration).
826    ///
827    /// When `per_channel` is enabled, the provided `min`/`max` are ignored
828    /// because per-channel quantization computes separate ranges from the
829    /// weight data for each channel.  The calibration range (derived from
830    /// activation statistics) applies to per-tensor mode only.
831    fn quantize_with_range(
832        &self,
833        data: &[f32],
834        shape: Vec<usize>,
835        min: f32,
836        max: f32,
837    ) -> Result<QuantizedTensorType> {
838        self.build_tensor_with_optional_range(data, shape, Some((min, max)))
839    }
840
841    /// Shared core: build a [`QuantizedTensorType`] for any bit-width and range mode.
842    fn build_tensor_with_optional_range(
843        &self,
844        data: &[f32],
845        shape: Vec<usize>,
846        range: Option<(f32, f32)>,
847    ) -> Result<QuantizedTensorType> {
848        let pc = self.config.per_channel && shape.len() >= 2;
849        match self.config.bits {
850            8 => {
851                let t = match (pc, range) {
852                    (true, _) => QuantizedTensor::from_f32_per_channel(data, shape)?,
853                    (false, Some((min, max))) => {
854                        QuantizedTensor::from_f32_with_range(data, shape, min, max)?
855                    }
856                    (false, None) => QuantizedTensor::from_f32(data, shape)?,
857                };
858                Ok(QuantizedTensorType::Int8(t))
859            }
860            4 => {
861                let mut t = match (pc, range) {
862                    (true, _) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
863                    (false, Some((min, max))) => {
864                        QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?
865                    }
866                    (false, None) => QuantizedTensorInt4::from_f32(data, shape)?,
867                };
868                t.pack();
869                Ok(QuantizedTensorType::Int4(t))
870            }
871            b => Err(QuantizeError::UnsupportedConfig {
872                reason: format!("bits must be 4 or 8, got {b}"),
873            }),
874        }
875    }
876}
877
878// ---------------------------------------------------------------------------
879// Calibration helper
880// ---------------------------------------------------------------------------
881
882/// Compute the finite min/max of `data`, returning an error if all values are NaN/Inf.
883fn finite_min_max(data: &[f32], name: &str) -> Result<(f32, f32)> {
884    let min = data
885        .iter()
886        .copied()
887        .filter(|v| v.is_finite())
888        .fold(f32::INFINITY, f32::min);
889    let max = data
890        .iter()
891        .copied()
892        .filter(|v| v.is_finite())
893        .fold(f32::NEG_INFINITY, f32::max);
894    if !min.is_finite() || !max.is_finite() {
895        return Err(QuantizeError::InvalidTensor {
896            reason: format!(
897                "Tensor '{}' contains only non-finite values (NaN/Inf)",
898                name
899            ),
900        });
901    }
902    Ok((min, max))
903}
904
905/// Sample synthetic data from the observed activation histogram distribution.
906fn sample_from_activation_stats(
907    stats: &crate::calibration::stats::ActivationStats,
908    n: usize,
909) -> Vec<f32> {
910    use rand::Rng;
911
912    let histogram = stats.histogram_data();
913    if histogram.is_empty() {
914        // Fallback to uniform
915        let mut rng = rand::thread_rng();
916        let range = stats.max() - stats.min();
917        if !range.is_finite() || range.abs() < 1e-8 {
918            return vec![stats.mean(); n];
919        }
920        return (0..n)
921            .map(|_| rng.gen::<f32>() * range + stats.min())
922            .collect();
923    }
924
925    let total_count: usize = histogram.iter().map(|&(_, c)| c).sum();
926    if total_count == 0 {
927        let mut rng = rand::thread_rng();
928        let range = stats.max() - stats.min();
929        if !range.is_finite() || range.abs() < 1e-8 {
930            return vec![stats.mean(); n];
931        }
932        return (0..n)
933            .map(|_| rng.gen::<f32>() * range + stats.min())
934            .collect();
935    }
936
937    let mut samples = Vec::with_capacity(n);
938    for &(value, count) in &histogram {
939        let num_samples = ((count as f64 / total_count as f64) * n as f64).round() as usize;
940        for _ in 0..num_samples {
941            samples.push(value);
942        }
943    }
944
945    // Trim or pad to exactly n
946    samples.truncate(n);
947    while samples.len() < n {
948        samples.push(stats.mean());
949    }
950
951    samples
952}
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957
958    // -----------------------------------------------------------------------
959    // QuantConfig per-layer selection
960    // -----------------------------------------------------------------------
961
962    #[test]
963    fn test_should_quantize_no_restrictions() {
964        let config = QuantConfig::default();
965        assert!(config.should_quantize("any.layer", 1));
966        assert!(config.should_quantize("any.layer", 1_000_000));
967    }
968
969    #[test]
970    fn test_should_quantize_excluded_layer() {
971        let config = QuantConfig {
972            excluded_layers: vec!["head.weight".to_string()],
973            ..Default::default()
974        };
975        assert!(!config.should_quantize("head.weight", 1024));
976        assert!(config.should_quantize("body.weight", 1024));
977    }
978
979    #[test]
980    fn test_should_quantize_min_elements() {
981        let config = QuantConfig {
982            min_elements: 512,
983            ..Default::default()
984        };
985        assert!(!config.should_quantize("small.bias", 4));
986        assert!(!config.should_quantize("small.bias", 511));
987        assert!(config.should_quantize("large.weight", 512));
988        assert!(config.should_quantize("large.weight", 1024));
989    }
990
991    #[test]
992    fn test_should_quantize_excluded_takes_priority_over_min_elements() {
993        let config = QuantConfig {
994            excluded_layers: vec!["head.weight".to_string()],
995            min_elements: 1,
996            ..Default::default()
997        };
998        // excluded → skipped regardless of size
999        assert!(!config.should_quantize("head.weight", 1_000_000));
1000    }
1001
1002    #[test]
1003    fn test_bits_for_layer_default() {
1004        let config = QuantConfig {
1005            bits: 8,
1006            ..Default::default()
1007        };
1008        assert_eq!(config.bits_for_layer("any.weight"), 8);
1009    }
1010
1011    #[test]
1012    fn test_bits_for_layer_override() {
1013        let mut layer_bits = std::collections::HashMap::new();
1014        layer_bits.insert("head.weight".to_string(), 4u8);
1015        let config = QuantConfig {
1016            bits: 8,
1017            layer_bits,
1018            ..Default::default()
1019        };
1020        assert_eq!(config.bits_for_layer("head.weight"), 4);
1021        assert_eq!(config.bits_for_layer("body.weight"), 8);
1022    }
1023
1024    // -----------------------------------------------------------------------
1025    // Existing tests below
1026    // -----------------------------------------------------------------------
1027
1028    #[test]
1029    fn test_quant_params() {
1030        let params = QuantParams::from_range(-1.0, 1.0);
1031
1032        assert_eq!(params.quantize(0.0), params.zero_point);
1033
1034        let original = 0.5;
1035        let quantized = params.quantize(original);
1036        let dequantized = params.dequantize(quantized);
1037
1038        assert!((original - dequantized).abs() < 0.01);
1039    }
1040
1041    #[test]
1042    fn test_quantize_tensor() {
1043        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1044        let shape = vec![5];
1045
1046        let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
1047
1048        assert_eq!(quantized.data.len(), 5);
1049        assert_eq!(quantized.size_bytes(), 5);
1050    }
1051
1052    #[test]
1053    fn test_per_channel_quantization() {
1054        let mut data = vec![];
1055        for _ in 0..100 {
1056            data.push(0.5); // Channel 0
1057        }
1058        for _ in 0..100 {
1059            data.push(5.0); // Channel 1
1060        }
1061
1062        let shape = vec![2, 100];
1063
1064        let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1065
1066        assert!(quantized.per_channel);
1067        assert!(quantized.channel_params.is_some());
1068        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1069
1070        let dequantized = quantized.to_f32();
1071        let error: f32 = data
1072            .iter()
1073            .zip(dequantized.iter())
1074            .map(|(a, b)| (a - b).powi(2))
1075            .sum::<f32>()
1076            / data.len() as f32;
1077
1078        println!("Per-channel MSE: {}", error);
1079        assert!(error < 0.1);
1080    }
1081
1082    #[test]
1083    fn test_per_channel_vs_per_tensor() {
1084        let mut data = vec![];
1085
1086        for _ in 0..1000 {
1087            data.push(0.01);
1088        }
1089
1090        for _ in 0..1000 {
1091            data.push(10.0);
1092        }
1093
1094        let shape = vec![2, 1000];
1095
1096        // Per-tensor quantization
1097        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1098        let per_tensor_error = per_tensor.quantization_error(&data);
1099
1100        // Per-channel quantization
1101        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1102        let per_channel_error = per_channel.quantization_error(&data);
1103
1104        println!("Per-tensor error:  {:.8}", per_tensor_error);
1105        println!("Per-channel error: {:.8}", per_channel_error);
1106
1107        // Per-channel
1108        assert!(per_channel_error < per_tensor_error);
1109        assert!(per_channel_error < per_tensor_error * 0.5);
1110    }
1111
1112    #[test]
1113    fn test_per_channel_benefit() {
1114        let mut data = vec![];
1115
1116        for i in 0..1000 {
1117            data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1118        }
1119
1120        for i in 0..1000 {
1121            data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1122        }
1123
1124        let shape = vec![2, 1000];
1125
1126        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1127        let per_tensor_error = per_tensor.quantization_error(&data);
1128
1129        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1130        let per_channel_error = per_channel.quantization_error(&data);
1131
1132        println!("Per-tensor MSE:  {:.8}", per_tensor_error);
1133        println!("Per-channel MSE: {:.8}", per_channel_error);
1134
1135        assert!(
1136            per_channel_error < per_tensor_error,
1137            "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1138            per_channel_error,
1139            per_tensor_error
1140        );
1141    }
1142
1143    #[test]
1144    fn test_int4_quant_params() {
1145        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1146
1147        assert!(params.quantize(-10.0) >= -8);
1148        assert!(params.quantize(-10.0) <= 7);
1149        assert!(params.quantize(10.0) >= -8);
1150        assert!(params.quantize(10.0) <= 7);
1151
1152        let zero_quant = params.quantize(0.0);
1153        assert!(zero_quant >= -8 && zero_quant <= 7);
1154
1155        for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1156            let quantized = params.quantize(original);
1157            let dequantized = params.dequantize(quantized);
1158
1159            println!(
1160                "Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1161                original,
1162                quantized,
1163                dequantized,
1164                (original - dequantized).abs()
1165            );
1166
1167            assert!((original - dequantized).abs() < params.scale * 2.0);
1168        }
1169    }
1170
1171    #[test]
1172    fn test_int4_extreme_values() {
1173        // Test with extreme value ranges
1174        let params = QuantParamsInt4::from_range(-100.0, 100.0);
1175
1176        let q_neg = params.quantize(-100.0);
1177        let q_pos = params.quantize(100.0);
1178
1179        assert_eq!(q_neg, -8);
1180        assert_eq!(q_pos, 7);
1181    }
1182
1183    #[test]
1184    fn test_int4_vs_int8_error() {
1185        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1186
1187        let params_int8 = QuantParams::from_range(-1.0, 1.0);
1188        let error_int8: f32 = data
1189            .iter()
1190            .map(|&v| {
1191                let q = params_int8.quantize(v);
1192                let dq = params_int8.dequantize(q);
1193                (v - dq).powi(2)
1194            })
1195            .sum::<f32>()
1196            / data.len() as f32;
1197
1198        let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1199        let error_int4: f32 = data
1200            .iter()
1201            .map(|&v| {
1202                let q = params_int4.quantize(v);
1203                let dq = params_int4.dequantize(q);
1204                (v - dq).powi(2)
1205            })
1206            .sum::<f32>()
1207            / data.len() as f32;
1208
1209        println!("INT8 MSE: {:.8}", error_int8);
1210        println!("INT4 MSE: {:.8}", error_int4);
1211
1212        assert!(error_int4 > error_int8);
1213
1214        assert!(
1215            error_int4 < error_int8 * 500.0,
1216            "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1217            error_int4,
1218            error_int8
1219        );
1220
1221        assert!(error_int4.is_finite());
1222        assert!(error_int4 < 0.01);
1223    }
1224
1225    #[test]
1226    fn test_int4_range() {
1227        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1228
1229        assert!(params.quantize(-10.0) == -8);
1230        assert!(params.quantize(10.0) == 7);
1231
1232        // Test quantization within range
1233        for i in -8..=7 {
1234            let value = i as f32 * params.scale;
1235            let quantized = params.quantize(value);
1236            assert!(quantized >= -8 && quantized <= 7);
1237        }
1238    }
1239
1240    #[test]
1241    fn test_int4_optimal_precision() {
1242        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1243
1244        let mut unique_values = std::collections::HashSet::new();
1245
1246        // Sample across the range
1247        for i in 0..1000 {
1248            let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1249            unique_values.insert(params.quantize(value));
1250        }
1251
1252        println!("Unique quantized values: {}", unique_values.len());
1253        assert!(unique_values.len() >= 14);
1254    }
1255
1256    #[test]
1257    fn test_int4_tensor_quantization() {
1258        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1259        let shape = vec![5];
1260
1261        let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1262
1263        assert_eq!(quantized.data.len(), 5);
1264        assert_eq!(quantized.size_bytes(), 5);
1265        assert_eq!(quantized.packed_size_bytes(), 3);
1266
1267        for &val in &quantized.data {
1268            assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1269        }
1270    }
1271
1272    #[test]
1273    fn test_int4_round_trip() {
1274        let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1275        let shape = vec![5];
1276
1277        let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1278        let dequantized = quantized.to_f32();
1279
1280        println!("Original:    {:?}", original);
1281        println!("Quantized:   {:?}", quantized.data);
1282        println!("Dequantized: {:?}", dequantized);
1283
1284        for (orig, deq) in original.iter().zip(dequantized.iter()) {
1285            let error = (orig - deq).abs();
1286            println!("  {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1287            assert!(error < 0.15, "Error too large: {}", error);
1288        }
1289    }
1290
1291    #[test]
1292    fn test_int4_per_channel() {
1293        let mut data = vec![];
1294
1295        // Channel 0: small range [-0.1, 0.1]
1296        for i in 0..100 {
1297            data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1298        }
1299
1300        // Channel 1: large range [-10.0, 10.0]
1301        for i in 0..100 {
1302            data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1303        }
1304
1305        let shape = vec![2, 100];
1306
1307        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1308
1309        assert!(quantized.per_channel);
1310        assert!(quantized.channel_params.is_some());
1311        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1312
1313        let error = quantized.quantization_error(&data);
1314        println!("INT4 per-channel MSE: {:.8}", error);
1315
1316        assert!(error < 1.0, "Error too high: {}", error);
1317    }
1318
1319    #[test]
1320    fn test_int4_vs_int8_compression() {
1321        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1322        let shape = vec![1000];
1323
1324        let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1325        let int8_size = int8_quantized.size_bytes();
1326        let int8_error = int8_quantized.quantization_error(&data);
1327
1328        let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1329        let int4_size = int4_quantized.size_bytes();
1330        let int4_packed_size = int4_quantized.packed_size_bytes();
1331        let int4_error = int4_quantized.quantization_error(&data);
1332
1333        println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1334        println!(
1335            "INT4 (unpacked): {} bytes, MSE: {:.8}",
1336            int4_size, int4_error
1337        );
1338        println!(
1339            "INT4 (packed): {} bytes, MSE: {:.8}",
1340            int4_packed_size, int4_error
1341        );
1342
1343        assert_eq!(int4_size, int8_size);
1344
1345        assert!(int4_packed_size <= int8_size / 2 + 1);
1346
1347        assert!(int4_error > int8_error);
1348
1349        assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1350    }
1351
1352    #[test]
1353    fn test_int4_large_tensor() {
1354        let size = 64 * 3 * 3 * 3; // 64 filters, 3x3x3 kernels
1355        let data: Vec<f32> = (0..size)
1356            .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1357            .collect();
1358
1359        let shape = vec![64, 3, 3, 3];
1360
1361        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1362
1363        assert_eq!(quantized.data.len(), size);
1364        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1365
1366        let error = quantized.quantization_error(&data);
1367        println!("Large tensor INT4 error: {:.8}", error);
1368
1369        assert!(error < 0.01, "Error too high for large tensor: {}", error);
1370    }
1371
1372    #[test]
1373    fn test_int4_extreme_ranges() {
1374        let test_cases = vec![
1375            (vec![-0.001, 0.0, 0.001], "tiny range"),
1376            (vec![-100.0, 0.0, 100.0], "large range"),
1377            (vec![0.0, 0.0, 0.0], "all zeros"),
1378            (vec![1.0, 1.0, 1.0], "all same"),
1379        ];
1380
1381        for (data, desc) in test_cases {
1382            println!("\nTesting: {}", desc);
1383            let shape = vec![data.len()];
1384
1385            let result = QuantizedTensorInt4::from_f32(&data, shape);
1386            assert!(result.is_ok(), "Failed on {}", desc);
1387
1388            let quantized = result.unwrap();
1389            let dequantized = quantized.to_f32();
1390
1391            println!("  Original:    {:?}", data);
1392            println!("  Dequantized: {:?}", dequantized);
1393
1394            for &val in &quantized.data {
1395                assert!(
1396                    val >= -8 && val <= 7,
1397                    "Value {} out of range for {}",
1398                    val,
1399                    desc
1400                );
1401            }
1402        }
1403    }
1404
1405    #[test]
1406    fn test_int4_pack_unpack_pair() {
1407        let test_cases = vec![
1408            (-8, 7),
1409            (-8, -8),
1410            (7, 7),
1411            (0, 0),
1412            (-1, 0),
1413            (0, -1),
1414            (-5, 3),
1415            (6, -4),
1416        ];
1417
1418        for (val1, val2) in test_cases {
1419            println!("\nTesting: ({}, {})", val1, val2);
1420
1421            let packed = pack_int4_pair(val1, val2);
1422            let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1423
1424            println!("  Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1425            println!("  Unpacked: ({}, {})", unpacked1, unpacked2);
1426
1427            assert_eq!(val1, unpacked1, "First value mismatch");
1428            assert_eq!(val2, unpacked2, "Second value mismatch");
1429        }
1430    }
1431
1432    #[test]
1433    fn test_int4_pack_unpack_vector() {
1434        let values = vec![-8, -7, -1, 0, 1, 7];
1435        let packed = pack_int4(&values);
1436        let unpacked = unpack_int4(&packed, values.len());
1437
1438        println!("\nEven length:");
1439        println!("  Original: {:?}", values);
1440        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1441        println!("  Unpacked: {:?}", unpacked);
1442
1443        assert_eq!(values, unpacked);
1444        assert_eq!(packed.len(), (values.len() + 1) / 2);
1445    }
1446
1447    #[test]
1448    fn test_int4_pack_unpack_odd_length() {
1449        let values = vec![-8, -5, 0, 5, 7];
1450        let packed = pack_int4(&values);
1451        let unpacked = unpack_int4(&packed, values.len());
1452
1453        println!("\nOdd length:");
1454        println!("  Original: {:?}", values);
1455        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1456        println!("  Unpacked: {:?}", unpacked);
1457
1458        assert_eq!(values, unpacked);
1459        assert_eq!(packed.len(), (values.len() + 1) / 2);
1460    }
1461
1462    #[test]
1463    fn test_int4_pack_all_values() {
1464        let values: Vec<i8> = (-8..=7).collect();
1465        let packed = pack_int4(&values);
1466        let unpacked = unpack_int4(&packed, values.len());
1467
1468        println!("\nAll INT4 values:");
1469        println!("  Original: {:?}", values);
1470        println!("  Packed:   {} bytes", packed.len());
1471        println!("  Unpacked: {:?}", unpacked);
1472
1473        assert_eq!(values, unpacked);
1474        assert_eq!(packed.len(), 8);
1475    }
1476
1477    #[test]
1478    fn test_int4_pack_large_vector() {
1479        let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1480        let packed = pack_int4(&values);
1481        let unpacked = unpack_int4(&packed, values.len());
1482
1483        assert_eq!(values, unpacked);
1484        assert_eq!(packed.len(), 500);
1485
1486        println!("\nLarge vector:");
1487        println!("  Original: {} values", values.len());
1488        println!(
1489            "  Packed:   {} bytes ({}x compression)",
1490            packed.len(),
1491            values.len() / packed.len()
1492        );
1493        println!("  Unpacked: {} values", unpacked.len());
1494    }
1495
1496    #[test]
1497    fn test_int4_compression_ratio() {
1498        let size = 10000;
1499        let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1500
1501        let unpacked_size = values.len() * std::mem::size_of::<i8>();
1502
1503        let packed = pack_int4(&values);
1504        let packed_size = packed.len();
1505
1506        let compression_ratio = unpacked_size as f32 / packed_size as f32;
1507
1508        println!("\nCompression test:");
1509        println!("  Values:      {}", size);
1510        println!("  Unpacked:    {} bytes", unpacked_size);
1511        println!("  Packed:      {} bytes", packed_size);
1512        println!("  Compression: {:.2}x", compression_ratio);
1513
1514        assert!(
1515            (compression_ratio - 2.0).abs() < 0.01,
1516            "Expected ~2x compression, got {:.2}x",
1517            compression_ratio
1518        );
1519    }
1520
1521    #[test]
1522    fn test_int4_tensor_packing() {
1523        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1524        let shape = vec![1000];
1525
1526        let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1527
1528        println!("Before packing:");
1529        println!("  Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1530        println!("  Is packed: {}", quantized.is_packed());
1531
1532        assert!(!quantized.is_packed());
1533        assert_eq!(quantized.size_bytes(), 1000);
1534
1535        quantized.pack();
1536
1537        println!("\nAfter packing:");
1538        println!("  Packed size: {} bytes", quantized.size_bytes());
1539        println!("  Is packed: {}", quantized.is_packed());
1540        println!(
1541            "  Compression: {}x",
1542            quantized.unpacked_size_bytes() / quantized.size_bytes()
1543        );
1544
1545        assert!(quantized.is_packed());
1546        assert_eq!(quantized.size_bytes(), 500);
1547
1548        let dequantized = quantized.to_f32();
1549        assert_eq!(dequantized.len(), 1000);
1550
1551        let error = quantized.quantization_error(&data);
1552        println!("  MSE after packing: {:.8}", error);
1553        assert!(error < 0.01);
1554    }
1555
1556    #[test]
1557    fn test_int4_packed_vs_unpacked_error() {
1558        let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1559        let shape = vec![100];
1560
1561        let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1562        let error_unpacked = unpacked.quantization_error(&data);
1563
1564        let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1565        packed.pack();
1566        let error_packed = packed.quantization_error(&data);
1567
1568        println!("Unpacked error: {:.8}", error_unpacked);
1569        println!("Packed error:   {:.8}", error_packed);
1570
1571        assert!((error_unpacked - error_packed).abs() < 1e-6);
1572    }
1573
1574    #[test]
1575    fn test_int4_per_channel_packing() {
1576        let mut data = vec![];
1577        for i in 0..500 {
1578            data.push((i as f32 / 500.0) * 0.2 - 0.1); // Channel 0
1579        }
1580        for i in 0..500 {
1581            data.push((i as f32 / 500.0) * 20.0 - 10.0); // Channel 1
1582        }
1583
1584        let shape = vec![2, 500];
1585
1586        let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1587
1588        let error_before = quantized.quantization_error(&data);
1589        println!("Error before packing: {:.8}", error_before);
1590
1591        quantized.pack();
1592
1593        let error_after = quantized.quantization_error(&data);
1594        println!("Error after packing:  {:.8}", error_after);
1595        println!(
1596            "Size: {} bytes (packed from {} bytes)",
1597            quantized.size_bytes(),
1598            quantized.unpacked_size_bytes()
1599        );
1600
1601        assert!((error_before - error_after).abs() < 1e-6);
1602
1603        assert_eq!(quantized.size_bytes(), 500);
1604    }
1605
1606    #[test]
1607    fn test_int4_compression_comparison() {
1608        let size = 10000;
1609        let data: Vec<f32> = (0..size)
1610            .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1611            .collect();
1612        let shape = vec![size];
1613
1614        let fp32_size = size * std::mem::size_of::<f32>();
1615
1616        let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1617        let int8_size = int8.size_bytes();
1618
1619        let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1620        let int4_unpacked_size = int4_unpacked.size_bytes();
1621
1622        let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1623        int4_packed.pack();
1624        let int4_packed_size = int4_packed.size_bytes();
1625
1626        println!("\nCompression Comparison:");
1627        println!("  FP32:          {} bytes", fp32_size);
1628        println!(
1629            "  INT8:          {} bytes ({:.1}x)",
1630            int8_size,
1631            fp32_size as f32 / int8_size as f32
1632        );
1633        println!(
1634            "  INT4 unpacked: {} bytes ({:.1}x)",
1635            int4_unpacked_size,
1636            fp32_size as f32 / int4_unpacked_size as f32
1637        );
1638        println!(
1639            "  INT4 packed:   {} bytes ({:.1}x)",
1640            int4_packed_size,
1641            fp32_size as f32 / int4_packed_size as f32
1642        );
1643
1644        assert_eq!(fp32_size / int8_size, 4); // 4x compression
1645        assert_eq!(fp32_size / int4_packed_size, 8); // 8x compression!
1646    }
1647
1648    #[test]
1649    #[ignore] // Run manually with: cargo test test_int4_real_model -- --ignored --nocapture
1650    fn test_int4_real_model() {
1651        use crate::onnx_utils::OnnxModel;
1652
1653        println!("\n{}", "=".repeat(60));
1654        println!("INT4 Real Model Test");
1655        println!("\n{}", "=".repeat(60));
1656
1657        let model_paths = vec![
1658            "test_models/mnist.onnx",
1659            "mnist.onnx",
1660            "test_models/resnet18-v1-7.onnx",
1661            "resnet18-v1-7.onnx",
1662        ];
1663
1664        let mut model = None;
1665        for path in &model_paths {
1666            if std::path::Path::new(path).exists() {
1667                println!("Loading model: {}", path);
1668                match OnnxModel::load(path) {
1669                    Ok(m) => {
1670                        model = Some(m);
1671                        break;
1672                    }
1673                    Err(e) => println!("  Failed: {}", e),
1674                }
1675            }
1676        }
1677
1678        let model = match model {
1679            Some(m) => m,
1680            None => {
1681                println!("No test models found. Skipping test.");
1682                println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1683                return;
1684            }
1685        };
1686
1687        let info = model.info();
1688        println!("✓ Model loaded: {}", info.name);
1689        println!("  Nodes: {}", info.num_nodes);
1690        println!();
1691
1692        println!("Extracting weights...");
1693        let weights = model.extract_weights();
1694        println!("✓ Found {} weight tensors", weights.len());
1695
1696        if weights.is_empty() {
1697            println!("No weights to quantize!");
1698            return;
1699        }
1700
1701        println!();
1702        println!("\n{}", "=".repeat(60));
1703        println!("Testing Per-Tensor Quantization");
1704        println!("\n{}", "=".repeat(60));
1705
1706        let test_weights: Vec<_> = weights
1707            .iter()
1708            .filter(|w| w.data.len() > 1000)
1709            .take(5)
1710            .collect();
1711
1712        println!("Testing {} large layers:\n", test_weights.len());
1713
1714        for (idx, weight) in test_weights.iter().enumerate() {
1715            let name = if weight.name.len() > 40 {
1716                format!("{}...", &weight.name[..37])
1717            } else {
1718                weight.name.clone()
1719            };
1720
1721            println!("[{}] {}", idx + 1, name);
1722            println!(
1723                "    Shape: {:?}, Elements: {}",
1724                weight.shape,
1725                weight.data.len()
1726            );
1727
1728            let fp32_size = weight.data.len() * 4;
1729
1730            let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1731            let (int8_size, int8_error) = if let Ok(q) = int8_result {
1732                (q.size_bytes(), q.quantization_error(&weight.data))
1733            } else {
1734                println!("    INT8 failed!");
1735                continue;
1736            };
1737
1738            let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1739            let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1740                (q.size_bytes(), q.quantization_error(&weight.data))
1741            } else {
1742                println!("    INT4 failed!");
1743                continue;
1744            };
1745
1746            let mut int4_packed =
1747                QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1748            int4_packed.pack();
1749            let int4_packed_size = int4_packed.size_bytes();
1750            let int4_packed_error = int4_packed.quantization_error(&weight.data);
1751
1752            println!("    FP32:          {:7} bytes", fp32_size);
1753            println!(
1754                "    INT8:          {:7} bytes ({:.1}x) MSE: {:.8}",
1755                int8_size,
1756                fp32_size as f32 / int8_size as f32,
1757                int8_error
1758            );
1759            println!(
1760                "    INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1761                int4_unpacked_size,
1762                fp32_size as f32 / int4_unpacked_size as f32,
1763                int4_error
1764            );
1765            println!(
1766                "    INT4 packed:   {:7} bytes ({:.1}x) MSE: {:.8}",
1767                int4_packed_size,
1768                fp32_size as f32 / int4_packed_size as f32,
1769                int4_packed_error
1770            );
1771
1772            assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1773
1774            let int8_ratio = fp32_size as f32 / int8_size as f32;
1775            let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1776
1777            assert!(
1778                (int8_ratio - 4.0).abs() < 0.1,
1779                "INT8 compression should be ~4x"
1780            );
1781            assert!(
1782                (int4_ratio - 8.0).abs() < 0.1,
1783                "INT4 compression should be ~8x"
1784            );
1785
1786            println!();
1787        }
1788
1789        println!("\n{}", "=".repeat(60));
1790        println!("Testing Per-Channel Quantization");
1791        println!("\n{}", "=".repeat(60));
1792
1793        // Test per-channel on Conv layers (multi-dimensional)
1794        let conv_weights: Vec<_> = weights
1795            .iter()
1796            .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1797            .take(3)
1798            .collect();
1799
1800        if conv_weights.is_empty() {
1801            println!("No multi-channel layers found for per-channel test.");
1802        } else {
1803            println!("Testing {} conv layers:\n", conv_weights.len());
1804
1805            for (idx, weight) in conv_weights.iter().enumerate() {
1806                let name = if weight.name.len() > 40 {
1807                    format!("{}...", &weight.name[..37])
1808                } else {
1809                    weight.name.clone()
1810                };
1811
1812                println!("[{}] {}", idx + 1, name);
1813                println!(
1814                    "    Shape: {:?}, Channels: {}",
1815                    weight.shape, weight.shape[0]
1816                );
1817
1818                let per_tensor =
1819                    QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1820                let per_tensor_error = per_tensor.quantization_error(&weight.data);
1821
1822                let per_channel_result =
1823                    QuantizedTensorInt4::from_f32_per_channel(&weight.data, weight.shape.clone());
1824
1825                if let Ok(per_channel) = per_channel_result {
1826                    let per_channel_error = per_channel.quantization_error(&weight.data);
1827
1828                    let improvement =
1829                        ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1830
1831                    println!("    Per-tensor:  MSE: {:.8}", per_tensor_error);
1832                    println!(
1833                        "    Per-channel: MSE: {:.8} ({:.1}% better)",
1834                        per_channel_error, improvement
1835                    );
1836
1837                    assert!(
1838                        per_channel_error <= per_tensor_error * 1.1,
1839                        "Per-channel should not be significantly worse"
1840                    );
1841                } else {
1842                    println!("    Per-channel failed!");
1843                }
1844
1845                println!();
1846            }
1847        }
1848
1849        println!("\n{}", "=".repeat(60));
1850        println!("Summary");
1851        println!("\n{}", "=".repeat(60));
1852
1853        println!("✓ INT4 quantization works on real model weights");
1854        println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1855        println!("✓ Bit packing is lossless");
1856        println!("✓ Per-channel quantization works");
1857        println!("\nINT4 implementation is ready for CLI integration!");
1858    }
1859
1860    // -----------------------------------------------------------------------
1861    // All-NaN / all-Inf edge cases
1862    // -----------------------------------------------------------------------
1863
1864    #[test]
1865    fn test_all_nan_returns_error() {
1866        let data = vec![f32::NAN, f32::NAN, f32::NAN];
1867        let result = QuantizedTensor::from_f32(&data, vec![3]);
1868        assert!(result.is_err());
1869        let err = result.unwrap_err().to_string();
1870        assert!(
1871            err.contains("non-finite"),
1872            "error should mention non-finite: {}",
1873            err
1874        );
1875    }
1876
1877    #[test]
1878    fn test_all_inf_returns_error() {
1879        let data = vec![f32::INFINITY, f32::NEG_INFINITY];
1880        let result = QuantizedTensor::from_f32(&data, vec![2]);
1881        assert!(result.is_err());
1882    }
1883
1884    #[test]
1885    fn test_all_nan_int4_returns_error() {
1886        let data = vec![f32::NAN; 4];
1887        let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
1888        assert!(result.is_err());
1889    }
1890
1891    #[test]
1892    fn test_all_nan_per_channel_returns_error() {
1893        let data = vec![f32::NAN; 6];
1894        let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
1895        assert!(result.is_err());
1896        let err = result.unwrap_err().to_string();
1897        assert!(
1898            err.contains("Channel 0"),
1899            "error should mention channel: {}",
1900            err
1901        );
1902    }
1903
1904    #[test]
1905    fn test_mixed_nan_finite_succeeds() {
1906        // Some NaN, some finite — should succeed using finite range
1907        let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
1908        let result = QuantizedTensor::from_f32(&data, vec![4]);
1909        assert!(result.is_ok());
1910    }
1911}