Skip to main content

quantize_rs/quantization/
mod.rs

1//! Core quantization logic for INT8 and INT4.
2//!
3//! Provides tensor-level quantization (per-tensor and per-channel),
4//! INT4 bit-packing, and the high-level [`Quantizer`] that combines
5//! a [`QuantConfig`] with optional calibration statistics.
6
7use crate::errors::{QuantizeError, Result};
8
9/// Configuration for a quantization pass.
10#[derive(Debug, Clone)]
11pub struct QuantConfig {
12    /// Bit width: `4` for INT4 or `8` for INT8.
13    pub bits: u8,
14    /// When `true`, compute separate scale/zero-point per output channel (axis 0).
15    pub per_channel: bool,
16    /// Optional calibration method used for range optimization.
17    pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
18    /// Layer names to skip entirely (exact match against the initializer name).
19    pub excluded_layers: Vec<String>,
20    /// Per-layer bit-width overrides.  Key = initializer name, value = 4 or 8.
21    pub layer_bits: std::collections::HashMap<String, u8>,
22    /// Minimum number of elements a tensor must have to be quantized.
23    /// Tensors with fewer elements are left in FP32.  Defaults to 0 (no minimum).
24    pub min_elements: usize,
25}
26
27impl Default for QuantConfig {
28    fn default() -> Self {
29        Self {
30            bits: 8,
31            per_channel: false,
32            calibration_method: None,
33            excluded_layers: Vec::new(),
34            layer_bits: std::collections::HashMap::new(),
35            min_elements: 0,
36        }
37    }
38}
39
40impl QuantConfig {
41    /// Create a default INT8 per-tensor configuration.
42    pub fn int8() -> Self {
43        Self::default()
44    }
45
46    /// Enable or disable per-channel quantization.
47    pub fn with_per_channel(mut self, enabled: bool) -> Self {
48        self.per_channel = enabled;
49        self
50    }
51
52    /// Set the calibration method for range optimization.
53    pub fn with_calibration(mut self, method: crate::calibration::methods::CalibrationMethod) -> Self {
54        self.calibration_method = Some(method);
55        self
56    }
57
58    /// Return `true` if the layer should be quantized.
59    ///
60    /// A layer is skipped when:
61    /// - its name appears in [`excluded_layers`], or
62    /// - `num_elements` is below [`min_elements`] (and `min_elements > 0`).
63    pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
64        if self.excluded_layers.iter().any(|e| e == name) {
65            return false;
66        }
67        if self.min_elements > 0 && num_elements < self.min_elements {
68            return false;
69        }
70        true
71    }
72
73    /// Return the effective bit width for a layer.
74    ///
75    /// If the layer name has an entry in [`layer_bits`], that value is used;
76    /// otherwise the global [`bits`] is returned.
77    pub fn bits_for_layer(&self, name: &str) -> u8 {
78        self.layer_bits.get(name).copied().unwrap_or(self.bits)
79    }
80}
81
82// ---------------------------------------------------------------------------
83// QuantRange trait and marker types
84// ---------------------------------------------------------------------------
85
86/// Marker trait that supplies the clamp constants for a quantization bit-width.
87pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
88    /// Minimum quantized value (inclusive).
89    const QMIN: f32;
90    /// Maximum quantized value (inclusive).
91    const QMAX: f32;
92    /// Bit width (4 or 8).
93    const BITS: u8;
94}
95
96/// Marker for INT8 quantization (`-128 … 127`).
97#[derive(Debug, Clone)]
98pub struct Int8Range;
99impl QuantRange for Int8Range {
100    const QMIN: f32 = -128.0;
101    const QMAX: f32 = 127.0;
102    const BITS: u8 = 8;
103}
104
105/// Marker for INT4 quantization (`-8 … 7`).
106#[derive(Debug, Clone)]
107pub struct Int4Range;
108impl QuantRange for Int4Range {
109    const QMIN: f32 = -8.0;
110    const QMAX: f32 = 7.0;
111    const BITS: u8 = 4;
112}
113
114// ---------------------------------------------------------------------------
115// QuantParamsGeneric<R>
116// ---------------------------------------------------------------------------
117
118/// Affine quantization parameters (scale and zero-point), generic over bit-width.
119///
120/// - INT8: `q = clamp(round(x / scale) + zero_point, -128, 127)`
121/// - INT4: `q = clamp(round(x / scale) + zero_point, -8, 7)`
122/// - Dequantization: `x = (q - zero_point) * scale`
123#[derive(Debug, Clone)]
124pub struct QuantParamsGeneric<R: QuantRange> {
125    scale: f32,
126    zero_point: i8,
127    _marker: std::marker::PhantomData<R>,
128}
129
130/// INT8 affine quantization parameters — `clamp(-128, 127)`.
131pub type QuantParams = QuantParamsGeneric<Int8Range>;
132/// INT4 affine quantization parameters — `clamp(-8, 7)`.
133pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
134
135impl<R: QuantRange> QuantParamsGeneric<R> {
136    /// Quantization scale factor.
137    pub fn scale(&self) -> f32 { self.scale }
138    /// Quantization zero point.
139    pub fn zero_point(&self) -> i8 { self.zero_point }
140
141    /// Compute quantization parameters from a floating-point range.
142    pub fn from_range(min: f32, max: f32) -> Self {
143        let min = min.min(0.0);
144        let max = max.max(0.0);
145
146        // Handle edge case: if min == max, set a small range
147        let (min, max) = if (max - min).abs() < 1e-8 {
148            (min - 0.01, max + 0.01)
149        } else {
150            (min, max)
151        };
152
153        let scale = (max - min) / (R::QMAX - R::QMIN);
154        let scale = scale.max(1e-8);
155
156        let initial_zero_point = R::QMIN - min / scale;
157        let zero_point = initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8;
158
159        QuantParamsGeneric {
160            scale,
161            zero_point,
162            _marker: std::marker::PhantomData,
163        }
164    }
165
166    /// Quantize a single float to the target integer type.
167    pub fn quantize(&self, value: f32) -> i8 {
168        if !value.is_finite() {
169            return self.zero_point;
170        }
171        let quantized = (value / self.scale).round() + (self.zero_point as f32);
172        quantized.clamp(R::QMIN, R::QMAX) as i8
173    }
174
175    /// Dequantize a single integer value back to float.
176    pub fn dequantize(&self, value: i8) -> f32 {
177        ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
178    }
179}
180
181// ---------------------------------------------------------------------------
182// QuantizedTensorGeneric<R>
183// ---------------------------------------------------------------------------
184
185/// Generic quantized tensor, parameterized by bit-width marker.
186///
187/// For INT4 tensors, call [`QuantizedTensorGeneric::pack`] to compress two
188/// values per byte for 2× storage savings.
189#[derive(Debug, Clone)]
190pub struct QuantizedTensorGeneric<R: QuantRange> {
191    pub(crate) data: Vec<i8>,
192    /// Bit-packed storage — always `None` for INT8, set by `.pack()` for INT4.
193    pub(crate) packed_data: Option<Vec<u8>>,
194    pub(crate) shape: Vec<usize>,
195    pub(crate) params: QuantParamsGeneric<R>,
196    pub(crate) per_channel: bool,
197    pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
198}
199
200/// An INT8 quantized tensor with optional per-channel parameters.
201pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
202
203/// An INT4 quantized tensor with optional per-channel parameters and bit packing.
204///
205/// Values are stored in the range `[-8, 7]`. Call [`pack`](QuantizedTensorInt4::pack) to
206/// compress two values into one byte for 2× storage savings.
207pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
208
209// ---------------------------------------------------------------------------
210// Shared impl for all bit-widths
211// ---------------------------------------------------------------------------
212
213impl<R: QuantRange> QuantizedTensorGeneric<R> {
214    /// Tensor shape.
215    pub fn shape(&self) -> &[usize] { &self.shape }
216    /// Per-tensor quantization parameters (channel-0 if per-channel).
217    pub fn params(&self) -> &QuantParamsGeneric<R> { &self.params }
218    /// Whether per-channel quantization was used.
219    pub fn is_per_channel(&self) -> bool { self.per_channel }
220
221    /// Quantize FP32 data, computing the range from the data.
222    ///
223    /// # Errors
224    ///
225    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
226    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
227        if data.is_empty() {
228            return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
229        }
230
231        let expected_len: usize = shape.iter().product();
232        if expected_len != data.len() {
233            return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
234        }
235
236        let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
237        let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
238
239        if !min.is_finite() || !max.is_finite() {
240            return Err(QuantizeError::InvalidTensor { reason: "Tensor contains only non-finite values (NaN/Inf)".into() });
241        }
242
243        let params = QuantParamsGeneric::<R>::from_range(min, max);
244
245        let quantized_data: Vec<i8> = data.iter()
246            .map(|&v| params.quantize(v))
247            .collect();
248
249        Ok(QuantizedTensorGeneric {
250            data: quantized_data,
251            packed_data: None,
252            shape,
253            params,
254            per_channel: false,
255            channel_params: None,
256        })
257    }
258
259    /// Quantize FP32 data using an explicit range (for calibration).
260    ///
261    /// # Errors
262    ///
263    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty or shape mismatches.
264    pub fn from_f32_with_range(data: &[f32], shape: Vec<usize>, min: f32, max: f32) -> Result<Self> {
265        if data.is_empty() {
266            return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
267        }
268
269        let expected_len: usize = shape.iter().product();
270        if expected_len != data.len() {
271            return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
272        }
273
274        let params = QuantParamsGeneric::<R>::from_range(min, max);
275
276        let quantized_data: Vec<i8> = data.iter()
277            .map(|&v| params.quantize(v))
278            .collect();
279
280        Ok(QuantizedTensorGeneric {
281            data: quantized_data,
282            packed_data: None,
283            shape,
284            params,
285            per_channel: false,
286            channel_params: None,
287        })
288    }
289
290    /// Quantize FP32 data with per-channel ranges (axis 0 only).
291    ///
292    /// # Errors
293    ///
294    /// Returns [`QuantizeError::InvalidTensor`] if `data` is empty, shape
295    /// mismatches, or the tensor is scalar.
296    pub fn from_f32_per_channel(
297        data: &[f32],
298        shape: Vec<usize>,
299    ) -> Result<Self> {
300        if data.is_empty() {
301            return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
302        }
303
304        if shape.is_empty() {
305            return Err(QuantizeError::InvalidTensor { reason: "Cannot do per-channel quantization on scalar".into() });
306        }
307
308        let expected_len: usize = shape.iter().product();
309        if expected_len != data.len() {
310            return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
311        }
312
313        let num_channels = shape[0];
314
315        let mut channel_params = Vec::new();
316        let mut quantized_data = Vec::with_capacity(data.len());
317
318        for channel_idx in 0..num_channels {
319            let channel_data = extract_channel(data, &shape, channel_idx)?;
320
321            let min = channel_data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
322            let max = channel_data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
323
324            if !min.is_finite() || !max.is_finite() {
325                return Err(QuantizeError::InvalidTensor {
326                    reason: format!("Channel {} contains only non-finite values (NaN/Inf)", channel_idx),
327                });
328            }
329
330            let params = QuantParamsGeneric::<R>::from_range(min, max);
331            channel_params.push(params.clone());
332
333            for &value in &channel_data {
334                quantized_data.push(params.quantize(value));
335            }
336        }
337
338        // Use first channel params as "representative" for backward compatibility
339        let params = channel_params[0].clone();
340
341        Ok(QuantizedTensorGeneric {
342            data: quantized_data,
343            packed_data: None,
344            shape,
345            params,
346            per_channel: true,
347            channel_params: Some(channel_params),
348        })
349    }
350
351    /// Dequantize all values back to FP32.
352    pub fn to_f32(&self) -> Vec<f32> {
353        // Borrow data directly when unpacked; allocate only for the packed INT4 path.
354        let data_owned;
355        let data: &[i8] = if let Some(ref packed) = self.packed_data {
356            data_owned = unpack_int4(packed, self.data.len());
357            &data_owned
358        } else {
359            &self.data
360        };
361
362        if self.per_channel {
363            if let Some(ref channel_params) = self.channel_params {
364                if channel_params.is_empty() {
365                    return data.iter().map(|&v| self.params.dequantize(v)).collect();
366                }
367                let elements_per_channel = data.len() / channel_params.len();
368                data.iter()
369                    .enumerate()
370                    .map(|(i, &v)| {
371                        let channel_idx = (i / elements_per_channel).min(channel_params.len() - 1);
372                        channel_params[channel_idx].dequantize(v)
373                    })
374                    .collect()
375            } else {
376                data.iter().map(|&v| self.params.dequantize(v)).collect()
377            }
378        } else {
379            data.iter().map(|&v| self.params.dequantize(v)).collect()
380        }
381    }
382
383    /// Size of the quantized data in bytes (packed if available, unpacked otherwise).
384    pub fn size_bytes(&self) -> usize {
385        if let Some(ref packed) = self.packed_data {
386            packed.len()
387        } else {
388            self.data.len() * std::mem::size_of::<i8>()
389        }
390    }
391
392    /// Mean squared error between the original data and the dequantized values.
393    pub fn quantization_error(&self, original: &[f32]) -> f32 {
394        if original.is_empty() {
395            return 0.0;
396        }
397
398        let dequantized = self.to_f32();
399
400        let sum: f32 = original.iter()
401            .zip(dequantized.iter())
402            .map(|(a, b)| (a - b).powi(2))
403            .sum();
404
405        sum / original.len() as f32
406    }
407}
408
409// ---------------------------------------------------------------------------
410// INT4-specific methods
411// ---------------------------------------------------------------------------
412
413impl QuantizedTensorGeneric<Int4Range> {
414    /// Pack two INT4 values per byte for 2× compression.
415    pub fn pack(&mut self) {
416        self.packed_data = Some(pack_int4(&self.data));
417    }
418
419    /// Return unpacked i8 data, decompressing from packed storage if needed.
420    pub fn ensure_unpacked(&self) -> Vec<i8> {
421        if let Some(ref packed) = self.packed_data {
422            unpack_int4(packed, self.data.len())
423        } else {
424            self.data.clone()
425        }
426    }
427
428    /// Whether the data is currently bit-packed.
429    pub fn is_packed(&self) -> bool {
430        self.packed_data.is_some()
431    }
432
433    /// Size that the packed representation would occupy (or already occupies).
434    pub fn packed_size_bytes(&self) -> usize {
435        if let Some(ref packed) = self.packed_data {
436            packed.len()
437        } else {
438            self.data.len().div_ceil(2)
439        }
440    }
441
442    /// Size of the unpacked representation in bytes.
443    pub fn unpacked_size_bytes(&self) -> usize {
444        self.data.len() * std::mem::size_of::<i8>()
445    }
446}
447
448// ---------------------------------------------------------------------------
449// INT4 bit-packing helpers
450// ---------------------------------------------------------------------------
451
452fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
453    debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
454    debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
455
456    // Convert to 4-bit representation
457    let nibble1 = (val1 & 0x0F) as u8;
458    let nibble2 = (val2 & 0x0F) as u8;
459
460    // Pack: high 4 bits = val1, low 4 bits = val2
461    (nibble1 << 4) | nibble2
462}
463
464fn unpack_int4_pair(byte: u8) -> (i8, i8) {
465    let nibble1 = (byte >> 4) & 0x0F;
466    let nibble2 = byte & 0x0F;
467
468    // Convert from 4-bit to signed i8
469    let val1 = if nibble1 >= 8 {
470        (nibble1 as i8) | !0x0F
471    } else {
472        nibble1 as i8
473    };
474
475    let val2 = if nibble2 >= 8 {
476        (nibble2 as i8) | !0x0F
477    } else {
478        nibble2 as i8
479    };
480
481    (val1, val2)
482}
483
484/// Pack a slice of INT4 values (two per byte, high nibble first).
485pub fn pack_int4(values: &[i8]) -> Vec<u8> {
486    let mut packed = Vec::with_capacity(values.len().div_ceil(2));
487
488    for chunk in values.chunks(2) {
489        let val1 = chunk[0];
490        let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
491
492        packed.push(pack_int4_pair(val1, val2));
493    }
494
495    packed
496}
497
498/// Unpack INT4 values from packed bytes, returning exactly `num_values` i8s.
499pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
500    let mut values = Vec::with_capacity(num_values);
501
502    for &byte in packed {
503        let (val1, val2) = unpack_int4_pair(byte);
504        values.push(val1);
505        if values.len() < num_values {
506            values.push(val2);
507        }
508    }
509
510    // Truncate to exact size (removes padding)
511    values.truncate(num_values);
512    values
513}
514
515/// Extract contiguous data for a single channel along axis 0.
516///
517/// Only correct for axis 0 (the leading dimension), which is the standard
518/// layout for weight tensors (e.g. [out_channels, in_channels, H, W]).
519fn extract_channel(data: &[f32], shape: &[usize], channel_idx: usize) -> Result<Vec<f32>> {
520    if shape.is_empty() {
521        return Err(QuantizeError::InvalidTensor { reason: "Cannot extract channel from empty shape".into() });
522    }
523    let num_channels = shape[0];
524    if num_channels == 0 {
525        return Err(QuantizeError::InvalidTensor { reason: "Number of channels is 0".into() });
526    }
527    if channel_idx >= num_channels {
528        return Err(QuantizeError::InvalidTensor { reason: format!("Channel index {} out of bounds for {} channels", channel_idx, num_channels) });
529    }
530    if data.len() % num_channels != 0 {
531        return Err(QuantizeError::InvalidTensor { reason: format!("Data length {} not evenly divisible by {} channels", data.len(), num_channels) });
532    }
533    let elements_per_channel = data.len() / num_channels;
534    let start = channel_idx * elements_per_channel;
535    let end = start + elements_per_channel;
536    Ok(data[start..end].to_vec())
537}
538
539// ---------------------------------------------------------------------------
540// QuantizedTensorType
541// ---------------------------------------------------------------------------
542
543/// Type-erased wrapper over [`QuantizedTensor`] (INT8) and [`QuantizedTensorInt4`] (INT4).
544#[derive(Debug, Clone)]
545pub enum QuantizedTensorType {
546    Int8(QuantizedTensor),
547    Int4(QuantizedTensorInt4),
548}
549
550impl QuantizedTensorType {
551    /// Dequantize all values back to FP32.
552    pub fn to_f32(&self) -> Vec<f32> {
553        match self {
554            QuantizedTensorType::Int8(t) => t.to_f32(),
555            QuantizedTensorType::Int4(t) => t.to_f32(),
556        }
557    }
558
559    /// Size of the quantized data in bytes.
560    pub fn size_bytes(&self) -> usize {
561        match self {
562            QuantizedTensorType::Int8(t) => t.size_bytes(),
563            QuantizedTensorType::Int4(t) => t.size_bytes(),
564        }
565    }
566
567    #[must_use]
568    pub fn quantization_error(&self, original: &[f32]) -> f32 {
569        match self {
570            QuantizedTensorType::Int8(t) => t.quantization_error(original),
571            QuantizedTensorType::Int4(t) => t.quantization_error(original),
572        }
573    }
574
575    #[must_use]
576    pub fn data(&self) -> Vec<i8> {
577        match self {
578            QuantizedTensorType::Int8(t) => t.data.clone(),
579            QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
580        }
581    }
582
583    /// Per-tensor scale and zero-point.
584    pub fn get_scale_zero_point(&self) -> (f32, i8) {
585        match self {
586            QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
587            QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
588        }
589    }
590
591    /// Return all per-channel scales and zero-points.
592    ///
593    /// For per-tensor quantization, returns single-element vectors.
594    /// For per-channel, returns one entry per channel.
595    pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
596        match self {
597            QuantizedTensorType::Int8(t) => {
598                if let Some(ref cp) = t.channel_params {
599                    (
600                        cp.iter().map(|p| p.scale).collect(),
601                        cp.iter().map(|p| p.zero_point).collect(),
602                    )
603                } else {
604                    (vec![t.params.scale], vec![t.params.zero_point])
605                }
606            }
607            QuantizedTensorType::Int4(t) => {
608                if let Some(ref cp) = t.channel_params {
609                    (
610                        cp.iter().map(|p| p.scale).collect(),
611                        cp.iter().map(|p| p.zero_point).collect(),
612                    )
613                } else {
614                    (vec![t.params.scale], vec![t.params.zero_point])
615                }
616            }
617        }
618    }
619
620    /// Whether per-channel quantization was used.
621    pub fn is_per_channel(&self) -> bool {
622        match self {
623            QuantizedTensorType::Int8(t) => t.per_channel,
624            QuantizedTensorType::Int4(t) => t.per_channel,
625        }
626    }
627
628    #[must_use]
629    pub fn bits(&self) -> u8 {
630        match self {
631            QuantizedTensorType::Int8(_) => 8,
632            QuantizedTensorType::Int4(_) => 4,
633        }
634    }
635
636    /// `true` if this is an INT8 tensor.
637    pub fn is_int8(&self) -> bool {
638        matches!(self, QuantizedTensorType::Int8(_))
639    }
640
641    /// `true` if this is an INT4 tensor.
642    pub fn is_int4(&self) -> bool {
643        matches!(self, QuantizedTensorType::Int4(_))
644    }
645
646    /// Borrow quantized data without cloning.
647    ///
648    /// Returns `None` for packed INT4 tensors (must use `data()` which unpacks).
649    pub fn data_ref(&self) -> Option<&[i8]> {
650        match self {
651            QuantizedTensorType::Int8(t) => Some(&t.data),
652            QuantizedTensorType::Int4(t) => {
653                if t.packed_data.is_some() {
654                    None // packed: caller must use data() to unpack
655                } else {
656                    Some(&t.data)
657                }
658            }
659        }
660    }
661}
662
663// ---------------------------------------------------------------------------
664// Quantizer
665// ---------------------------------------------------------------------------
666
667/// High-level quantizer that combines configuration with optional calibration.
668pub struct Quantizer {
669    config: QuantConfig,
670    calibration_stats: Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
671}
672
673impl std::fmt::Debug for Quantizer {
674    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
675        let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
676        f.debug_struct("Quantizer")
677            .field("config", &self.config)
678            .field("calibration_stats_count", &stats_count)
679            .finish()
680    }
681}
682
683impl Quantizer {
684    /// Create a quantizer with the given configuration (no calibration).
685    pub fn new(config: QuantConfig) -> Self {
686        Self {
687            config,
688            calibration_stats: None,
689        }
690    }
691
692    /// Create a quantizer with configuration and pre-collected activation statistics.
693    pub fn with_calibration(
694        config: QuantConfig,
695        stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
696    ) -> Self {
697        Self {
698            config,
699            calibration_stats: Some(stats),
700        }
701    }
702
703    /// Quantize a tensor with optional calibration.
704    pub fn quantize_tensor_with_name(
705        &self,
706        name: &str,
707        data: &[f32],
708        shape: Vec<usize>,
709    ) -> Result<QuantizedTensorType> {
710        let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
711            if let Some(stats) = stats_map.get(name) {
712                if let Some(method) = self.config.calibration_method {
713                    use crate::calibration::stats::calculate_optimal_range;
714
715                    let sample_data = sample_from_activation_stats(stats, 1000);
716                    calculate_optimal_range(&sample_data, method)
717                } else {
718                    (stats.min(), stats.max())
719                }
720            } else {
721                // No stats for this layer, use data min/max
722                let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
723                let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
724                if !min.is_finite() || !max.is_finite() {
725                    return Err(QuantizeError::InvalidTensor {
726                        reason: format!("Tensor '{}' contains only non-finite values (NaN/Inf)", name),
727                    });
728                }
729                (min, max)
730            }
731        } else {
732            // No calibration, use data min/max
733            let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
734            let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
735            if !min.is_finite() || !max.is_finite() {
736                return Err(QuantizeError::InvalidTensor {
737                    reason: format!("Tensor '{}' contains only non-finite values (NaN/Inf)", name),
738                });
739            }
740            (min, max)
741        };
742
743        self.quantize_with_range(data, shape, min, max)
744    }
745
746    /// Quantize a tensor using the configured bit width and per-channel setting.
747    ///
748    /// # Errors
749    ///
750    /// Returns [`QuantizeError::InvalidTensor`] or [`QuantizeError::UnsupportedConfig`].
751    pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
752        self.build_tensor_with_optional_range(data, shape, None)
753    }
754
755    /// Quantize with specific range (for calibration).
756    ///
757    /// When `per_channel` is enabled, the provided `min`/`max` are ignored
758    /// because per-channel quantization computes separate ranges from the
759    /// weight data for each channel.  The calibration range (derived from
760    /// activation statistics) applies to per-tensor mode only.
761    fn quantize_with_range(
762        &self,
763        data: &[f32],
764        shape: Vec<usize>,
765        min: f32,
766        max: f32,
767    ) -> Result<QuantizedTensorType> {
768        self.build_tensor_with_optional_range(data, shape, Some((min, max)))
769    }
770
771    /// Shared core: build a [`QuantizedTensorType`] for any bit-width and range mode.
772    fn build_tensor_with_optional_range(
773        &self,
774        data: &[f32],
775        shape: Vec<usize>,
776        range: Option<(f32, f32)>,
777    ) -> Result<QuantizedTensorType> {
778        let pc = self.config.per_channel && shape.len() >= 2;
779        match self.config.bits {
780            8 => {
781                let t = match (pc, range) {
782                    (true, _) => QuantizedTensor::from_f32_per_channel(data, shape)?,
783                    (false, Some((min, max))) => QuantizedTensor::from_f32_with_range(data, shape, min, max)?,
784                    (false, None) => QuantizedTensor::from_f32(data, shape)?,
785                };
786                Ok(QuantizedTensorType::Int8(t))
787            }
788            4 => {
789                let mut t = match (pc, range) {
790                    (true, _) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
791                    (false, Some((min, max))) => QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?,
792                    (false, None) => QuantizedTensorInt4::from_f32(data, shape)?,
793                };
794                t.pack();
795                Ok(QuantizedTensorType::Int4(t))
796            }
797            b => Err(QuantizeError::UnsupportedConfig {
798                reason: format!("bits must be 4 or 8, got {b}"),
799            }),
800        }
801    }
802}
803
804// ---------------------------------------------------------------------------
805// Calibration helper
806// ---------------------------------------------------------------------------
807
808/// Sample synthetic data from the observed activation histogram distribution.
809fn sample_from_activation_stats(stats: &crate::calibration::stats::ActivationStats, n: usize) -> Vec<f32> {
810    use rand::Rng;
811
812    let histogram = stats.histogram_data();
813    if histogram.is_empty() {
814        // Fallback to uniform
815        let mut rng = rand::thread_rng();
816        let range = stats.max() - stats.min();
817        if !range.is_finite() || range.abs() < 1e-8 {
818            return vec![stats.mean(); n];
819        }
820        return (0..n).map(|_| rng.gen::<f32>() * range + stats.min()).collect();
821    }
822
823    let total_count: usize = histogram.iter().map(|&(_, c)| c).sum();
824    if total_count == 0 {
825        let mut rng = rand::thread_rng();
826        let range = stats.max() - stats.min();
827        if !range.is_finite() || range.abs() < 1e-8 {
828            return vec![stats.mean(); n];
829        }
830        return (0..n).map(|_| rng.gen::<f32>() * range + stats.min()).collect();
831    }
832
833    let mut samples = Vec::with_capacity(n);
834    for &(value, count) in &histogram {
835        let num_samples = ((count as f64 / total_count as f64) * n as f64).round() as usize;
836        for _ in 0..num_samples {
837            samples.push(value);
838        }
839    }
840
841    // Trim or pad to exactly n
842    samples.truncate(n);
843    while samples.len() < n {
844        samples.push(stats.mean());
845    }
846
847    samples
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    // -----------------------------------------------------------------------
855    // QuantConfig per-layer selection
856    // -----------------------------------------------------------------------
857
858    #[test]
859    fn test_should_quantize_no_restrictions() {
860        let config = QuantConfig::default();
861        assert!(config.should_quantize("any.layer", 1));
862        assert!(config.should_quantize("any.layer", 1_000_000));
863    }
864
865    #[test]
866    fn test_should_quantize_excluded_layer() {
867        let config = QuantConfig {
868            excluded_layers: vec!["head.weight".to_string()],
869            ..Default::default()
870        };
871        assert!(!config.should_quantize("head.weight", 1024));
872        assert!(config.should_quantize("body.weight", 1024));
873    }
874
875    #[test]
876    fn test_should_quantize_min_elements() {
877        let config = QuantConfig {
878            min_elements: 512,
879            ..Default::default()
880        };
881        assert!(!config.should_quantize("small.bias", 4));
882        assert!(!config.should_quantize("small.bias", 511));
883        assert!(config.should_quantize("large.weight", 512));
884        assert!(config.should_quantize("large.weight", 1024));
885    }
886
887    #[test]
888    fn test_should_quantize_excluded_takes_priority_over_min_elements() {
889        let config = QuantConfig {
890            excluded_layers: vec!["head.weight".to_string()],
891            min_elements: 1,
892            ..Default::default()
893        };
894        // excluded → skipped regardless of size
895        assert!(!config.should_quantize("head.weight", 1_000_000));
896    }
897
898    #[test]
899    fn test_bits_for_layer_default() {
900        let config = QuantConfig { bits: 8, ..Default::default() };
901        assert_eq!(config.bits_for_layer("any.weight"), 8);
902    }
903
904    #[test]
905    fn test_bits_for_layer_override() {
906        let mut layer_bits = std::collections::HashMap::new();
907        layer_bits.insert("head.weight".to_string(), 4u8);
908        let config = QuantConfig {
909            bits: 8,
910            layer_bits,
911            ..Default::default()
912        };
913        assert_eq!(config.bits_for_layer("head.weight"), 4);
914        assert_eq!(config.bits_for_layer("body.weight"), 8);
915    }
916
917    // -----------------------------------------------------------------------
918    // Existing tests below
919    // -----------------------------------------------------------------------
920
921    #[test]
922    fn test_quant_params() {
923        let params = QuantParams::from_range(-1.0, 1.0);
924
925        assert_eq!(params.quantize(0.0), params.zero_point);
926
927        let original = 0.5;
928        let quantized = params.quantize(original);
929        let dequantized = params.dequantize(quantized);
930
931        assert!((original - dequantized).abs() < 0.01);
932    }
933
934    #[test]
935    fn test_quantize_tensor() {
936        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
937        let shape = vec![5];
938
939        let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
940
941        assert_eq!(quantized.data.len(), 5);
942        assert_eq!(quantized.size_bytes(), 5);
943    }
944
945    #[test]
946    fn test_per_channel_quantization() {
947        let mut data = vec![];
948        for _ in 0..100 {
949            data.push(0.5); // Channel 0
950        }
951        for _ in 0..100 {
952            data.push(5.0); // Channel 1
953        }
954
955        let shape = vec![2, 100];
956
957        let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
958
959        assert!(quantized.per_channel);
960        assert!(quantized.channel_params.is_some());
961        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
962
963        let dequantized = quantized.to_f32();
964        let error: f32 = data.iter()
965            .zip(dequantized.iter())
966            .map(|(a, b)| (a - b).powi(2))
967            .sum::<f32>() / data.len() as f32;
968
969        println!("Per-channel MSE: {}", error);
970        assert!(error < 0.1);
971    }
972
973    #[test]
974    fn test_per_channel_vs_per_tensor() {
975        let mut data = vec![];
976
977        for _ in 0..1000 {
978            data.push(0.01);
979        }
980
981        for _ in 0..1000 {
982            data.push(10.0);
983        }
984
985        let shape = vec![2, 1000];
986
987        // Per-tensor quantization
988        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
989        let per_tensor_error = per_tensor.quantization_error(&data);
990
991        // Per-channel quantization
992        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
993        let per_channel_error = per_channel.quantization_error(&data);
994
995        println!("Per-tensor error:  {:.8}", per_tensor_error);
996        println!("Per-channel error: {:.8}", per_channel_error);
997
998        // Per-channel
999        assert!(per_channel_error < per_tensor_error);
1000        assert!(per_channel_error < per_tensor_error * 0.5);
1001    }
1002
1003    #[test]
1004    fn test_per_channel_benefit() {
1005        let mut data = vec![];
1006
1007        for i in 0..1000 {
1008            data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1009        }
1010
1011        for i in 0..1000 {
1012            data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1013        }
1014
1015        let shape = vec![2, 1000];
1016
1017        let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1018        let per_tensor_error = per_tensor.quantization_error(&data);
1019
1020        let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1021        let per_channel_error = per_channel.quantization_error(&data);
1022
1023        println!("Per-tensor MSE:  {:.8}", per_tensor_error);
1024        println!("Per-channel MSE: {:.8}", per_channel_error);
1025
1026        assert!(per_channel_error < per_tensor_error,
1027                "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1028                per_channel_error, per_tensor_error);
1029    }
1030
1031    #[test]
1032    fn test_int4_quant_params() {
1033        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1034
1035        assert!(params.quantize(-10.0) >= -8);
1036        assert!(params.quantize(-10.0) <= 7);
1037        assert!(params.quantize(10.0) >= -8);
1038        assert!(params.quantize(10.0) <= 7);
1039
1040        let zero_quant = params.quantize(0.0);
1041        assert!(zero_quant >= -8 && zero_quant <= 7);
1042
1043        for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1044            let quantized = params.quantize(original);
1045            let dequantized = params.dequantize(quantized);
1046
1047            println!("Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1048                     original, quantized, dequantized, (original - dequantized).abs());
1049
1050            assert!((original - dequantized).abs() < params.scale * 2.0);
1051        }
1052    }
1053
1054    #[test]
1055    fn test_int4_extreme_values() {
1056        // Test with extreme value ranges
1057        let params = QuantParamsInt4::from_range(-100.0, 100.0);
1058
1059        let q_neg = params.quantize(-100.0);
1060        let q_pos = params.quantize(100.0);
1061
1062        assert_eq!(q_neg, -8);
1063        assert_eq!(q_pos, 7);
1064    }
1065
1066    #[test]
1067    fn test_int4_vs_int8_error() {
1068        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1069
1070        let params_int8 = QuantParams::from_range(-1.0, 1.0);
1071        let error_int8: f32 = data.iter()
1072            .map(|&v| {
1073                let q = params_int8.quantize(v);
1074                let dq = params_int8.dequantize(q);
1075                (v - dq).powi(2)
1076            })
1077            .sum::<f32>() / data.len() as f32;
1078
1079        let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1080        let error_int4: f32 = data.iter()
1081            .map(|&v| {
1082                let q = params_int4.quantize(v);
1083                let dq = params_int4.dequantize(q);
1084                (v - dq).powi(2)
1085            })
1086            .sum::<f32>() / data.len() as f32;
1087
1088        println!("INT8 MSE: {:.8}", error_int8);
1089        println!("INT4 MSE: {:.8}", error_int4);
1090
1091        assert!(error_int4 > error_int8);
1092
1093        assert!(error_int4 < error_int8 * 500.0,
1094                "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1095                error_int4, error_int8);
1096
1097        assert!(error_int4.is_finite());
1098        assert!(error_int4 < 0.01);
1099
1100    }
1101
1102    #[test]
1103    fn test_int4_range() {
1104        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1105
1106        assert!(params.quantize(-10.0) == -8);
1107        assert!(params.quantize(10.0) == 7);
1108
1109        // Test quantization within range
1110        for i in -8..=7 {
1111            let value = i as f32 * params.scale;
1112            let quantized = params.quantize(value);
1113            assert!(quantized >= -8 && quantized <= 7);
1114        }
1115    }
1116
1117    #[test]
1118    fn test_int4_optimal_precision() {
1119        let params = QuantParamsInt4::from_range(-1.0, 1.0);
1120
1121        let mut unique_values = std::collections::HashSet::new();
1122
1123        // Sample across the range
1124        for i in 0..1000 {
1125            let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1126            unique_values.insert(params.quantize(value));
1127        }
1128
1129        println!("Unique quantized values: {}", unique_values.len());
1130        assert!(unique_values.len() >= 14);
1131    }
1132
1133    #[test]
1134    fn test_int4_tensor_quantization() {
1135        let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1136        let shape = vec![5];
1137
1138        let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1139
1140        assert_eq!(quantized.data.len(), 5);
1141        assert_eq!(quantized.size_bytes(), 5);
1142        assert_eq!(quantized.packed_size_bytes(), 3);
1143
1144        for &val in &quantized.data {
1145            assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1146        }
1147    }
1148
1149    #[test]
1150    fn test_int4_round_trip() {
1151        let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1152        let shape = vec![5];
1153
1154        let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1155        let dequantized = quantized.to_f32();
1156
1157        println!("Original:    {:?}", original);
1158        println!("Quantized:   {:?}", quantized.data);
1159        println!("Dequantized: {:?}", dequantized);
1160
1161        for (orig, deq) in original.iter().zip(dequantized.iter()) {
1162            let error = (orig - deq).abs();
1163            println!("  {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1164            assert!(error < 0.15, "Error too large: {}", error);
1165        }
1166    }
1167
1168    #[test]
1169    fn test_int4_per_channel() {
1170        let mut data = vec![];
1171
1172        // Channel 0: small range [-0.1, 0.1]
1173        for i in 0..100 {
1174            data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1175        }
1176
1177        // Channel 1: large range [-10.0, 10.0]
1178        for i in 0..100 {
1179            data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1180        }
1181
1182        let shape = vec![2, 100];
1183
1184        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1185
1186        assert!(quantized.per_channel);
1187        assert!(quantized.channel_params.is_some());
1188        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1189
1190        let error = quantized.quantization_error(&data);
1191        println!("INT4 per-channel MSE: {:.8}", error);
1192
1193        assert!(error < 1.0, "Error too high: {}", error);
1194    }
1195
1196    #[test]
1197    fn test_int4_vs_int8_compression() {
1198        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1199        let shape = vec![1000];
1200
1201        let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1202        let int8_size = int8_quantized.size_bytes();
1203        let int8_error = int8_quantized.quantization_error(&data);
1204
1205        let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1206        let int4_size = int4_quantized.size_bytes();
1207        let int4_packed_size = int4_quantized.packed_size_bytes();
1208        let int4_error = int4_quantized.quantization_error(&data);
1209
1210        println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1211        println!("INT4 (unpacked): {} bytes, MSE: {:.8}", int4_size, int4_error);
1212        println!("INT4 (packed): {} bytes, MSE: {:.8}", int4_packed_size, int4_error);
1213
1214        assert_eq!(int4_size, int8_size);
1215
1216        assert!(int4_packed_size <= int8_size / 2 + 1);
1217
1218        assert!(int4_error > int8_error);
1219
1220        assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1221    }
1222
1223    #[test]
1224    fn test_int4_large_tensor() {
1225        let size = 64 * 3 * 3 * 3; // 64 filters, 3x3x3 kernels
1226        let data: Vec<f32> = (0..size).map(|i| {
1227            ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5
1228        }).collect();
1229
1230        let shape = vec![64, 3, 3, 3];
1231
1232        let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1233
1234        assert_eq!(quantized.data.len(), size);
1235        assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1236
1237        let error = quantized.quantization_error(&data);
1238        println!("Large tensor INT4 error: {:.8}", error);
1239
1240        assert!(error < 0.01, "Error too high for large tensor: {}", error);
1241    }
1242
1243    #[test]
1244    fn test_int4_extreme_ranges() {
1245        let test_cases = vec![
1246            (vec![-0.001, 0.0, 0.001], "tiny range"),
1247            (vec![-100.0, 0.0, 100.0], "large range"),
1248            (vec![0.0, 0.0, 0.0], "all zeros"),
1249            (vec![1.0, 1.0, 1.0], "all same"),
1250        ];
1251
1252        for (data, desc) in test_cases {
1253            println!("\nTesting: {}", desc);
1254            let shape = vec![data.len()];
1255
1256            let result = QuantizedTensorInt4::from_f32(&data, shape);
1257            assert!(result.is_ok(), "Failed on {}", desc);
1258
1259            let quantized = result.unwrap();
1260            let dequantized = quantized.to_f32();
1261
1262            println!("  Original:    {:?}", data);
1263            println!("  Dequantized: {:?}", dequantized);
1264
1265            for &val in &quantized.data {
1266                assert!(val >= -8 && val <= 7, "Value {} out of range for {}", val, desc);
1267            }
1268        }
1269    }
1270
1271    #[test]
1272    fn test_int4_pack_unpack_pair() {
1273        let test_cases = vec![
1274            (-8, 7),
1275            (-8, -8),
1276            (7, 7),
1277            (0, 0),
1278            (-1, 0),
1279            (0, -1),
1280            (-5, 3),
1281            (6, -4),
1282        ];
1283
1284        for (val1, val2) in test_cases {
1285            println!("\nTesting: ({}, {})", val1, val2);
1286
1287            let packed = pack_int4_pair(val1, val2);
1288            let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1289
1290            println!("  Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1291            println!("  Unpacked: ({}, {})", unpacked1, unpacked2);
1292
1293            assert_eq!(val1, unpacked1, "First value mismatch");
1294            assert_eq!(val2, unpacked2, "Second value mismatch");
1295        }
1296    }
1297
1298    #[test]
1299    fn test_int4_pack_unpack_vector() {
1300        let values = vec![-8, -7, -1, 0, 1, 7];
1301        let packed = pack_int4(&values);
1302        let unpacked = unpack_int4(&packed, values.len());
1303
1304        println!("\nEven length:");
1305        println!("  Original: {:?}", values);
1306        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1307        println!("  Unpacked: {:?}", unpacked);
1308
1309        assert_eq!(values, unpacked);
1310        assert_eq!(packed.len(), (values.len() + 1) / 2);
1311    }
1312
1313    #[test]
1314    fn test_int4_pack_unpack_odd_length() {
1315        let values = vec![-8, -5, 0, 5, 7];
1316        let packed = pack_int4(&values);
1317        let unpacked = unpack_int4(&packed, values.len());
1318
1319        println!("\nOdd length:");
1320        println!("  Original: {:?}", values);
1321        println!("  Packed:   {:?} ({} bytes)", packed, packed.len());
1322        println!("  Unpacked: {:?}", unpacked);
1323
1324        assert_eq!(values, unpacked);
1325        assert_eq!(packed.len(), (values.len() + 1) / 2);
1326    }
1327
1328    #[test]
1329    fn test_int4_pack_all_values() {
1330        let values: Vec<i8> = (-8..=7).collect();
1331        let packed = pack_int4(&values);
1332        let unpacked = unpack_int4(&packed, values.len());
1333
1334        println!("\nAll INT4 values:");
1335        println!("  Original: {:?}", values);
1336        println!("  Packed:   {} bytes", packed.len());
1337        println!("  Unpacked: {:?}", unpacked);
1338
1339        assert_eq!(values, unpacked);
1340        assert_eq!(packed.len(), 8);
1341    }
1342
1343    #[test]
1344    fn test_int4_pack_large_vector() {
1345        let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1346        let packed = pack_int4(&values);
1347        let unpacked = unpack_int4(&packed, values.len());
1348
1349        assert_eq!(values, unpacked);
1350        assert_eq!(packed.len(), 500);
1351
1352        println!("\nLarge vector:");
1353        println!("  Original: {} values", values.len());
1354        println!("  Packed:   {} bytes ({}x compression)", packed.len(),
1355                values.len() / packed.len());
1356        println!("  Unpacked: {} values", unpacked.len());
1357    }
1358
1359    #[test]
1360    fn test_int4_compression_ratio() {
1361        let size = 10000;
1362        let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1363
1364        let unpacked_size = values.len() * std::mem::size_of::<i8>();
1365
1366        let packed = pack_int4(&values);
1367        let packed_size = packed.len();
1368
1369        let compression_ratio = unpacked_size as f32 / packed_size as f32;
1370
1371        println!("\nCompression test:");
1372        println!("  Values:      {}", size);
1373        println!("  Unpacked:    {} bytes", unpacked_size);
1374        println!("  Packed:      {} bytes", packed_size);
1375        println!("  Compression: {:.2}x", compression_ratio);
1376
1377        assert!((compression_ratio - 2.0).abs() < 0.01,
1378                "Expected ~2x compression, got {:.2}x", compression_ratio);
1379    }
1380
1381    #[test]
1382    fn test_int4_tensor_packing() {
1383        let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1384        let shape = vec![1000];
1385
1386        let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1387
1388        println!("Before packing:");
1389        println!("  Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1390        println!("  Is packed: {}", quantized.is_packed());
1391
1392        assert!(!quantized.is_packed());
1393        assert_eq!(quantized.size_bytes(), 1000);
1394
1395        quantized.pack();
1396
1397        println!("\nAfter packing:");
1398        println!("  Packed size: {} bytes", quantized.size_bytes());
1399        println!("  Is packed: {}", quantized.is_packed());
1400        println!("  Compression: {}x", quantized.unpacked_size_bytes() / quantized.size_bytes());
1401
1402        assert!(quantized.is_packed());
1403        assert_eq!(quantized.size_bytes(), 500);
1404
1405        let dequantized = quantized.to_f32();
1406        assert_eq!(dequantized.len(), 1000);
1407
1408        let error = quantized.quantization_error(&data);
1409        println!("  MSE after packing: {:.8}", error);
1410        assert!(error < 0.01);
1411    }
1412
1413    #[test]
1414    fn test_int4_packed_vs_unpacked_error() {
1415        let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1416        let shape = vec![100];
1417
1418        let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1419        let error_unpacked = unpacked.quantization_error(&data);
1420
1421        let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1422        packed.pack();
1423        let error_packed = packed.quantization_error(&data);
1424
1425        println!("Unpacked error: {:.8}", error_unpacked);
1426        println!("Packed error:   {:.8}", error_packed);
1427
1428        assert!((error_unpacked - error_packed).abs() < 1e-6);
1429    }
1430
1431    #[test]
1432    fn test_int4_per_channel_packing() {
1433        let mut data = vec![];
1434        for i in 0..500 {
1435            data.push((i as f32 / 500.0) * 0.2 - 0.1); // Channel 0
1436        }
1437        for i in 0..500 {
1438            data.push((i as f32 / 500.0) * 20.0 - 10.0); // Channel 1
1439        }
1440
1441        let shape = vec![2, 500];
1442
1443        let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1444
1445        let error_before = quantized.quantization_error(&data);
1446        println!("Error before packing: {:.8}", error_before);
1447
1448        quantized.pack();
1449
1450        let error_after = quantized.quantization_error(&data);
1451        println!("Error after packing:  {:.8}", error_after);
1452        println!("Size: {} bytes (packed from {} bytes)",
1453                quantized.size_bytes(),
1454                quantized.unpacked_size_bytes());
1455
1456        assert!((error_before - error_after).abs() < 1e-6);
1457
1458        assert_eq!(quantized.size_bytes(), 500);
1459    }
1460
1461    #[test]
1462    fn test_int4_compression_comparison() {
1463        let size = 10000;
1464        let data: Vec<f32> = (0..size).map(|i| {
1465            ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5
1466        }).collect();
1467        let shape = vec![size];
1468
1469        let fp32_size = size * std::mem::size_of::<f32>();
1470
1471        let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1472        let int8_size = int8.size_bytes();
1473
1474        let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1475        let int4_unpacked_size = int4_unpacked.size_bytes();
1476
1477        let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1478        int4_packed.pack();
1479        let int4_packed_size = int4_packed.size_bytes();
1480
1481        println!("\nCompression Comparison:");
1482        println!("  FP32:          {} bytes", fp32_size);
1483        println!("  INT8:          {} bytes ({:.1}x)", int8_size, fp32_size as f32 / int8_size as f32);
1484        println!("  INT4 unpacked: {} bytes ({:.1}x)", int4_unpacked_size, fp32_size as f32 / int4_unpacked_size as f32);
1485        println!("  INT4 packed:   {} bytes ({:.1}x)", int4_packed_size, fp32_size as f32 / int4_packed_size as f32);
1486
1487        assert_eq!(fp32_size / int8_size, 4); // 4x compression
1488        assert_eq!(fp32_size / int4_packed_size, 8); // 8x compression!
1489    }
1490
1491    #[test]
1492    #[ignore] // Run manually with: cargo test test_int4_real_model -- --ignored --nocapture
1493    fn test_int4_real_model() {
1494        use crate::onnx_utils::OnnxModel;
1495
1496        println!("\n{}", "=".repeat(60));
1497        println!("INT4 Real Model Test");
1498        println!("\n{}", "=".repeat(60));
1499
1500        let model_paths = vec![
1501            "test_models/mnist.onnx",
1502            "mnist.onnx",
1503            "test_models/resnet18-v1-7.onnx",
1504            "resnet18-v1-7.onnx",
1505        ];
1506
1507        let mut model = None;
1508        for path in &model_paths {
1509            if std::path::Path::new(path).exists() {
1510                println!("Loading model: {}", path);
1511                match OnnxModel::load(path) {
1512                    Ok(m) => {
1513                        model = Some(m);
1514                        break;
1515                    }
1516                    Err(e) => println!("  Failed: {}", e),
1517                }
1518            }
1519        }
1520
1521        let model = match model {
1522            Some(m) => m,
1523            None => {
1524                println!("No test models found. Skipping test.");
1525                println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1526                return;
1527            }
1528        };
1529
1530        let info = model.info();
1531        println!("✓ Model loaded: {}", info.name);
1532        println!("  Nodes: {}", info.num_nodes);
1533        println!();
1534
1535        println!("Extracting weights...");
1536        let weights = model.extract_weights();
1537        println!("✓ Found {} weight tensors", weights.len());
1538
1539        if weights.is_empty() {
1540            println!("No weights to quantize!");
1541            return;
1542        }
1543
1544        println!();
1545        println!("\n{}", "=".repeat(60));
1546        println!("Testing Per-Tensor Quantization");
1547        println!("\n{}", "=".repeat(60));
1548
1549        let test_weights: Vec<_> = weights.iter()
1550            .filter(|w| w.data.len() > 1000)
1551            .take(5)
1552            .collect();
1553
1554        println!("Testing {} large layers:\n", test_weights.len());
1555
1556        for (idx, weight) in test_weights.iter().enumerate() {
1557            let name = if weight.name.len() > 40 {
1558                format!("{}...", &weight.name[..37])
1559            } else {
1560                weight.name.clone()
1561            };
1562
1563            println!("[{}] {}", idx + 1, name);
1564            println!("    Shape: {:?}, Elements: {}", weight.shape, weight.data.len());
1565
1566            let fp32_size = weight.data.len() * 4;
1567
1568            let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1569            let (int8_size, int8_error) = if let Ok(q) = int8_result {
1570                (q.size_bytes(), q.quantization_error(&weight.data))
1571            } else {
1572                println!("    INT8 failed!");
1573                continue;
1574            };
1575
1576            let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1577            let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1578                (q.size_bytes(), q.quantization_error(&weight.data))
1579            } else {
1580                println!("    INT4 failed!");
1581                continue;
1582            };
1583
1584            let mut int4_packed = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1585            int4_packed.pack();
1586            let int4_packed_size = int4_packed.size_bytes();
1587            let int4_packed_error = int4_packed.quantization_error(&weight.data);
1588
1589            println!("    FP32:          {:7} bytes", fp32_size);
1590            println!("    INT8:          {:7} bytes ({:.1}x) MSE: {:.8}",
1591                    int8_size, fp32_size as f32 / int8_size as f32, int8_error);
1592            println!("    INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1593                    int4_unpacked_size, fp32_size as f32 / int4_unpacked_size as f32, int4_error);
1594            println!("    INT4 packed:   {:7} bytes ({:.1}x) MSE: {:.8}",
1595                    int4_packed_size, fp32_size as f32 / int4_packed_size as f32, int4_packed_error);
1596
1597            assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1598
1599            let int8_ratio = fp32_size as f32 / int8_size as f32;
1600            let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1601
1602            assert!((int8_ratio - 4.0).abs() < 0.1, "INT8 compression should be ~4x");
1603            assert!((int4_ratio - 8.0).abs() < 0.1, "INT4 compression should be ~8x");
1604
1605            println!();
1606        }
1607
1608        println!("\n{}", "=".repeat(60));
1609        println!("Testing Per-Channel Quantization");
1610        println!("\n{}", "=".repeat(60));
1611
1612        // Test per-channel on Conv layers (multi-dimensional)
1613        let conv_weights: Vec<_> = weights.iter()
1614            .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1615            .take(3)
1616            .collect();
1617
1618        if conv_weights.is_empty() {
1619            println!("No multi-channel layers found for per-channel test.");
1620        } else {
1621            println!("Testing {} conv layers:\n", conv_weights.len());
1622
1623            for (idx, weight) in conv_weights.iter().enumerate() {
1624                let name = if weight.name.len() > 40 {
1625                    format!("{}...", &weight.name[..37])
1626                } else {
1627                    weight.name.clone()
1628                };
1629
1630                println!("[{}] {}", idx + 1, name);
1631                println!("    Shape: {:?}, Channels: {}", weight.shape, weight.shape[0]);
1632
1633                let per_tensor = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1634                let per_tensor_error = per_tensor.quantization_error(&weight.data);
1635
1636                let per_channel_result = QuantizedTensorInt4::from_f32_per_channel(
1637                    &weight.data,
1638                    weight.shape.clone(),
1639                );
1640
1641                if let Ok(per_channel) = per_channel_result {
1642                    let per_channel_error = per_channel.quantization_error(&weight.data);
1643
1644                    let improvement = ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1645
1646                    println!("    Per-tensor:  MSE: {:.8}", per_tensor_error);
1647                    println!("    Per-channel: MSE: {:.8} ({:.1}% better)",
1648                            per_channel_error, improvement);
1649
1650                    assert!(per_channel_error <= per_tensor_error * 1.1,
1651                        "Per-channel should not be significantly worse");
1652                } else {
1653                    println!("    Per-channel failed!");
1654                }
1655
1656                println!();
1657            }
1658        }
1659
1660        println!("\n{}", "=".repeat(60));
1661        println!("Summary");
1662        println!("\n{}", "=".repeat(60));
1663
1664        println!("✓ INT4 quantization works on real model weights");
1665        println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1666        println!("✓ Bit packing is lossless");
1667        println!("✓ Per-channel quantization works");
1668        println!("\nINT4 implementation is ready for CLI integration!");
1669    }
1670
1671    // -----------------------------------------------------------------------
1672    // All-NaN / all-Inf edge cases
1673    // -----------------------------------------------------------------------
1674
1675    #[test]
1676    fn test_all_nan_returns_error() {
1677        let data = vec![f32::NAN, f32::NAN, f32::NAN];
1678        let result = QuantizedTensor::from_f32(&data, vec![3]);
1679        assert!(result.is_err());
1680        let err = result.unwrap_err().to_string();
1681        assert!(err.contains("non-finite"), "error should mention non-finite: {}", err);
1682    }
1683
1684    #[test]
1685    fn test_all_inf_returns_error() {
1686        let data = vec![f32::INFINITY, f32::NEG_INFINITY];
1687        let result = QuantizedTensor::from_f32(&data, vec![2]);
1688        assert!(result.is_err());
1689    }
1690
1691    #[test]
1692    fn test_all_nan_int4_returns_error() {
1693        let data = vec![f32::NAN; 4];
1694        let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
1695        assert!(result.is_err());
1696    }
1697
1698    #[test]
1699    fn test_all_nan_per_channel_returns_error() {
1700        let data = vec![f32::NAN; 6];
1701        let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
1702        assert!(result.is_err());
1703        let err = result.unwrap_err().to_string();
1704        assert!(err.contains("Channel 0"), "error should mention channel: {}", err);
1705    }
1706
1707    #[test]
1708    fn test_mixed_nan_finite_succeeds() {
1709        // Some NaN, some finite — should succeed using finite range
1710        let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
1711        let result = QuantizedTensor::from_f32(&data, vec![4]);
1712        assert!(result.is_ok());
1713    }
1714}