Skip to main content

ferrotorch_core/
quantize.rs

1//! Post-training quantization (PTQ) for ferrotorch tensors.
2//!
3//! Provides symmetric and asymmetric quantization to INT8, INT4, and UINT8,
4//! with per-tensor or per-channel granularity. Designed for inference-time
5//! model compression — quantize once after training, then run forward passes
6//! with reduced memory and (on supported hardware) faster matmul.
7
8use std::collections::HashMap;
9
10use crate::dtype::Float;
11use crate::error::{FerrotorchError, FerrotorchResult};
12use crate::storage::TensorStorage;
13use crate::tensor::Tensor;
14
15// ---------------------------------------------------------------------------
16// Enums
17// ---------------------------------------------------------------------------
18
19/// Granularity of quantization parameters (scale / zero_point).
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum QuantScheme {
22    /// One scale and zero_point for the entire tensor.
23    PerTensor,
24    /// One scale and zero_point per slice along the given axis.
25    PerChannel(usize),
26}
27
28/// Target integer dtype for quantized storage.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QuantDtype {
31    /// Signed 8-bit: [-128, 127].
32    Int8,
33    /// Signed 4-bit: [-8, 7].  Stored packed in `i8` values.
34    Int4,
35    /// Unsigned 8-bit: [0, 255].
36    Uint8,
37}
38
39impl QuantDtype {
40    /// Minimum representable value.
41    #[inline]
42    fn qmin(self) -> i32 {
43        match self {
44            QuantDtype::Int8 => -128,
45            QuantDtype::Int4 => -8,
46            QuantDtype::Uint8 => 0,
47        }
48    }
49
50    /// Maximum representable value.
51    #[inline]
52    fn qmax(self) -> i32 {
53        match self {
54            QuantDtype::Int8 => 127,
55            QuantDtype::Int4 => 7,
56            QuantDtype::Uint8 => 255,
57        }
58    }
59}
60
61// ---------------------------------------------------------------------------
62// QuantizedTensor
63// ---------------------------------------------------------------------------
64
65/// A tensor stored in quantized (integer) representation.
66///
67/// The real value is recovered by `x = (q - zero_point) * scale`.
68///
69/// `scale` and `zero_point` are vectors whose length equals:
70/// * 1 for `PerTensor`
71/// * `shape[axis]` for `PerChannel(axis)`
72#[derive(Debug, Clone)]
73pub struct QuantizedTensor {
74    /// Quantized values stored as `i8` regardless of logical dtype.
75    /// For `Uint8`, the stored `i8` is reinterpreted as `u8` via
76    /// wrapping cast; for `Int4` only the low 4 bits are significant.
77    data: Vec<i8>,
78    /// Per-tensor or per-channel scales.
79    scale: Vec<f32>,
80    /// Per-tensor or per-channel zero points (in quantized domain).
81    zero_point: Vec<i32>,
82    /// Original tensor shape.
83    shape: Vec<usize>,
84    /// Quantization granularity.
85    scheme: QuantScheme,
86    /// Target quantized dtype.
87    dtype: QuantDtype,
88}
89
90impl QuantizedTensor {
91    /// Number of elements.
92    #[inline]
93    pub fn numel(&self) -> usize {
94        self.shape.iter().product()
95    }
96
97    /// Borrow the shape.
98    #[inline]
99    pub fn shape(&self) -> &[usize] {
100        &self.shape
101    }
102
103    /// Borrow the quantized data.
104    #[inline]
105    pub fn data(&self) -> &[i8] {
106        &self.data
107    }
108
109    /// Borrow the scale vector.
110    #[inline]
111    pub fn scale(&self) -> &[f32] {
112        &self.scale
113    }
114
115    /// Borrow the zero-point vector.
116    #[inline]
117    pub fn zero_point(&self) -> &[i32] {
118        &self.zero_point
119    }
120
121    /// The quantization scheme used.
122    #[inline]
123    pub fn scheme(&self) -> QuantScheme {
124        self.scheme
125    }
126
127    /// The quantized dtype.
128    #[inline]
129    pub fn qdtype(&self) -> QuantDtype {
130        self.dtype
131    }
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Compute scale and zero_point for a given (min, max) range and target dtype.
139///
140/// Uses the standard asymmetric affine quantization formula:
141///   scale = (max - min) / (qmax - qmin)
142///   zero_point = round(qmin - min / scale)
143///
144/// The range is always expanded to include zero so that `0.0` maps exactly
145/// to an integer quantized value (important for zero-padding and ReLU outputs).
146/// When min == max the range would collapse to zero, so this expansion also
147/// prevents division-by-zero.
148fn compute_scale_zp(min_val: f32, max_val: f32, dtype: QuantDtype) -> (f32, i32) {
149    let qmin = dtype.qmin();
150    let qmax = dtype.qmax();
151
152    // Ensure the range includes zero (standard PyTorch behaviour).
153    let min_val = min_val.min(0.0);
154    let max_val = max_val.max(0.0);
155
156    // After including zero the range is at least max(|min|, |max|) > 0,
157    // but guard against the degenerate all-zeros case.
158    let range = (max_val - min_val).max(f32::EPSILON);
159    let scale = range / (qmax - qmin) as f32;
160
161    // zero_point is intentionally NOT clamped to [qmin, qmax]. It is stored
162    // as i32 and may lie outside the quantized integer range. This is correct
163    // for asymmetric affine quantization — clamping the zero_point distorts
164    // the mapping when the float range doesn't straddle zero.
165    let zp = (qmin as f32 - min_val / scale).round() as i32;
166
167    (scale, zp)
168}
169
170/// Clamp and round a float to the quantized integer range.
171///
172/// Returns the result as `i8`. For `Uint8` the caller passes `qmin=0`,
173/// `qmax=255`; the clamped i32 is cast to `u8` first then transmuted to `i8`
174/// so that values 128..=255 are preserved through the bit pattern.
175#[inline]
176fn quantize_val(x: f32, scale: f32, zp: i32, qmin: i32, qmax: i32, is_unsigned: bool) -> i8 {
177    let q = (x / scale + zp as f32).round() as i32;
178    let clamped = q.clamp(qmin, qmax);
179    if is_unsigned {
180        (clamped as u8) as i8
181    } else {
182        clamped as i8
183    }
184}
185
186/// Recover the i32 quantized value from the stored `i8`, accounting for
187/// unsigned dtypes where the bit pattern represents a `u8`.
188#[inline]
189fn stored_to_i32(val: i8, is_unsigned: bool) -> i32 {
190    if is_unsigned {
191        (val as u8) as i32
192    } else {
193        val as i32
194    }
195}
196
197/// Map a linear flat index to per-channel parameters.
198///
199/// For a tensor of shape `[d0, d1, ..., dn]` with channel axis `axis`,
200/// returns the channel index for the element at `flat_index`.
201#[inline]
202fn channel_index(flat_index: usize, shape: &[usize], axis: usize) -> usize {
203    // stride of the channel axis = product of dims after axis.
204    let stride: usize = shape[axis + 1..].iter().product();
205    (flat_index / stride) % shape[axis]
206}
207
208// ---------------------------------------------------------------------------
209// Quantize
210// ---------------------------------------------------------------------------
211
212/// Quantize a floating-point tensor.
213///
214/// # Per-tensor
215///
216/// Computes a single (scale, zero_point) pair from the global min/max.
217///
218/// # Per-channel
219///
220/// Computes one (scale, zero_point) per slice along the given axis. This is
221/// common for weight tensors where each output channel has its own range.
222pub fn quantize<T: Float>(
223    tensor: &Tensor<T>,
224    scheme: QuantScheme,
225    dtype: QuantDtype,
226) -> FerrotorchResult<QuantizedTensor> {
227    let data = tensor.data()?;
228    let shape = tensor.shape().to_vec();
229    let numel = tensor.numel();
230    let qmin = dtype.qmin();
231    let qmax = dtype.qmax();
232
233    let is_unsigned = dtype == QuantDtype::Uint8;
234
235    match scheme {
236        QuantScheme::PerTensor => {
237            // Global min/max.
238            let mut min_val = f32::INFINITY;
239            let mut max_val = f32::NEG_INFINITY;
240            for &v in data {
241                let f = v.to_f32().unwrap();
242                if f < min_val {
243                    min_val = f;
244                }
245                if f > max_val {
246                    max_val = f;
247                }
248            }
249
250            let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
251
252            let qdata: Vec<i8> = data
253                .iter()
254                .map(|&v| quantize_val(v.to_f32().unwrap(), scale, zp, qmin, qmax, is_unsigned))
255                .collect();
256
257            Ok(QuantizedTensor {
258                data: qdata,
259                scale: vec![scale],
260                zero_point: vec![zp],
261                shape,
262                scheme,
263                dtype,
264            })
265        }
266
267        QuantScheme::PerChannel(axis) => {
268            if axis >= shape.len() {
269                return Err(FerrotorchError::InvalidArgument {
270                    message: format!(
271                        "PerChannel axis {axis} out of range for {}-d tensor",
272                        shape.len()
273                    ),
274                });
275            }
276
277            let num_channels = shape[axis];
278            let mut mins = vec![f32::INFINITY; num_channels];
279            let mut maxs = vec![f32::NEG_INFINITY; num_channels];
280
281            for (i, &v) in data.iter().enumerate() {
282                let ch = channel_index(i, &shape, axis);
283                let f = v.to_f32().unwrap();
284                if f < mins[ch] {
285                    mins[ch] = f;
286                }
287                if f > maxs[ch] {
288                    maxs[ch] = f;
289                }
290            }
291
292            let params: Vec<(f32, i32)> = mins
293                .iter()
294                .zip(maxs.iter())
295                .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
296                .collect();
297
298            let scales: Vec<f32> = params.iter().map(|&(s, _)| s).collect();
299            let zps: Vec<i32> = params.iter().map(|&(_, z)| z).collect();
300
301            let mut qdata = Vec::with_capacity(numel);
302            for (i, &v) in data.iter().enumerate() {
303                let ch = channel_index(i, &shape, axis);
304                qdata.push(quantize_val(
305                    v.to_f32().unwrap(),
306                    scales[ch],
307                    zps[ch],
308                    qmin,
309                    qmax,
310                    is_unsigned,
311                ));
312            }
313
314            Ok(QuantizedTensor {
315                data: qdata,
316                scale: scales,
317                zero_point: zps,
318                shape,
319                scheme,
320                dtype,
321            })
322        }
323    }
324}
325
326// ---------------------------------------------------------------------------
327// Dequantize
328// ---------------------------------------------------------------------------
329
330/// Dequantize back to a floating-point tensor.
331///
332/// Applies the inverse mapping: `x = (q - zero_point) * scale`.
333pub fn dequantize<T: Float>(qtensor: &QuantizedTensor) -> FerrotorchResult<Tensor<T>> {
334    let numel = qtensor.numel();
335    let mut result = Vec::with_capacity(numel);
336    let is_unsigned = qtensor.dtype == QuantDtype::Uint8;
337
338    match qtensor.scheme {
339        QuantScheme::PerTensor => {
340            let scale = qtensor.scale[0];
341            let zp = qtensor.zero_point[0];
342            for &q in &qtensor.data {
343                let val = (stored_to_i32(q, is_unsigned) - zp) as f32 * scale;
344                result.push(T::from(val).unwrap());
345            }
346        }
347        QuantScheme::PerChannel(axis) => {
348            for (i, &q) in qtensor.data.iter().enumerate() {
349                let ch = channel_index(i, &qtensor.shape, axis);
350                let val = (stored_to_i32(q, is_unsigned) - qtensor.zero_point[ch]) as f32
351                    * qtensor.scale[ch];
352                result.push(T::from(val).unwrap());
353            }
354        }
355    }
356
357    Tensor::from_storage(TensorStorage::cpu(result), qtensor.shape.clone(), false)
358}
359
360// ---------------------------------------------------------------------------
361// Quantized matmul
362// ---------------------------------------------------------------------------
363
364/// Multiply two quantized 2-D matrices and return a quantized result.
365///
366/// Strategy: accumulate in `i32` to avoid overflow, then rescale to the output
367/// quantized domain. This avoids a full dequantize-matmul-requantize round-trip
368/// while remaining numerically correct for INT8.
369///
370/// Both inputs must be 2-D, with compatible inner dimensions (standard matmul
371/// rules: `[M, K] x [K, N] -> [M, N]`).
372pub fn quantized_matmul(
373    a: &QuantizedTensor,
374    b: &QuantizedTensor,
375) -> FerrotorchResult<QuantizedTensor> {
376    // Validate shapes.
377    if a.shape.len() != 2 || b.shape.len() != 2 {
378        return Err(FerrotorchError::InvalidArgument {
379            message: format!(
380                "quantized_matmul requires 2-D tensors, got shapes {:?} and {:?}",
381                a.shape, b.shape
382            ),
383        });
384    }
385
386    let m = a.shape[0];
387    let k = a.shape[1];
388    let k2 = b.shape[0];
389    let n = b.shape[1];
390
391    if k != k2 {
392        return Err(FerrotorchError::ShapeMismatch {
393            message: format!(
394                "quantized_matmul inner dimensions mismatch: [{m}, {k}] x [{k2}, {n}]"
395            ),
396        });
397    }
398
399    // Both inputs must be PerTensor for the fast path.
400    if a.scale.len() != 1 || b.scale.len() != 1 {
401        return Err(FerrotorchError::InvalidArgument {
402            message: "quantized_matmul currently requires PerTensor-quantized inputs".into(),
403        });
404    }
405
406    let a_scale = a.scale[0];
407    let a_zp = a.zero_point[0];
408    let b_scale = b.scale[0];
409    let b_zp = b.zero_point[0];
410
411    let a_unsigned = a.dtype == QuantDtype::Uint8;
412    let b_unsigned = b.dtype == QuantDtype::Uint8;
413
414    // Accumulate in i32.
415    let mut acc = vec![0i32; m * n];
416    for i in 0..m {
417        for j in 0..n {
418            let mut sum = 0i32;
419            for p in 0..k {
420                let qa = stored_to_i32(a.data[i * k + p], a_unsigned) - a_zp;
421                let qb = stored_to_i32(b.data[p * n + j], b_unsigned) - b_zp;
422                sum += qa * qb;
423            }
424            acc[i * n + j] = sum;
425        }
426    }
427
428    // The real-valued result element is: acc[i,j] * a_scale * b_scale.
429    // Requantize: pick INT8 output with its own scale/zp.
430    let combined_scale = a_scale * b_scale;
431
432    // Find the real-valued min/max of the output.
433    let mut out_min = f32::INFINITY;
434    let mut out_max = f32::NEG_INFINITY;
435    for &a_val in &acc {
436        let real = a_val as f32 * combined_scale;
437        if real < out_min {
438            out_min = real;
439        }
440        if real > out_max {
441            out_max = real;
442        }
443    }
444
445    let out_dtype = QuantDtype::Int8;
446    let (out_scale, out_zp) = compute_scale_zp(out_min, out_max, out_dtype);
447    let qmin = out_dtype.qmin();
448    let qmax = out_dtype.qmax();
449
450    let qdata: Vec<i8> = acc
451        .iter()
452        .map(|&a_val| {
453            let real = a_val as f32 * combined_scale;
454            quantize_val(real, out_scale, out_zp, qmin, qmax, false)
455        })
456        .collect();
457
458    Ok(QuantizedTensor {
459        data: qdata,
460        scale: vec![out_scale],
461        zero_point: vec![out_zp],
462        shape: vec![m, n],
463        scheme: QuantScheme::PerTensor,
464        dtype: out_dtype,
465    })
466}
467
468// ---------------------------------------------------------------------------
469// Module-level quantization utility
470// ---------------------------------------------------------------------------
471
472/// Quantize every weight tensor in a module, returning a name -> QuantizedTensor
473/// map suitable for serialization or quantized inference.
474///
475/// This accepts any type implementing the `Module` trait from `ferrotorch-nn`.
476/// Because `ferrotorch-core` does not depend on `ferrotorch-nn`, we accept a
477/// generic iterator of named tensors instead.
478pub fn quantize_named_tensors<T: Float>(
479    named_tensors: impl IntoIterator<Item = (String, Tensor<T>)>,
480    scheme: QuantScheme,
481    dtype: QuantDtype,
482) -> FerrotorchResult<HashMap<String, QuantizedTensor>> {
483    let mut result = HashMap::new();
484    for (name, tensor) in named_tensors {
485        let qtensor = quantize(&tensor, scheme, dtype)?;
486        result.insert(name, qtensor);
487    }
488    Ok(result)
489}
490
491// ===========================================================================
492// QParams — quantization parameters
493// ===========================================================================
494
495/// Computed quantization parameters (scale and zero_point).
496#[derive(Debug, Clone)]
497pub struct QParams {
498    /// Per-tensor or per-channel scales.
499    pub scale: Vec<f32>,
500    /// Per-tensor or per-channel zero points.
501    pub zero_point: Vec<i32>,
502}
503
504impl QParams {
505    /// Compute symmetric quantization parameters.
506    ///
507    /// For symmetric quantization the range is `[-max_abs, max_abs]` and:
508    /// - INT8: `zero_point = 0`, `scale = max_abs / 127`
509    /// - INT4: `zero_point = 0`, `scale = max_abs / 7`
510    /// - UINT8: `zero_point = 128`, `scale = max_abs / 128`
511    pub fn symmetric(max_abs: f32, dtype: QuantDtype) -> Self {
512        let max_abs = max_abs.max(f32::EPSILON);
513        match dtype {
514            QuantDtype::Int8 => QParams {
515                scale: vec![max_abs / 127.0],
516                zero_point: vec![0],
517            },
518            QuantDtype::Int4 => QParams {
519                scale: vec![max_abs / 7.0],
520                zero_point: vec![0],
521            },
522            QuantDtype::Uint8 => QParams {
523                scale: vec![max_abs / 128.0],
524                zero_point: vec![128],
525            },
526        }
527    }
528
529    /// Compute asymmetric quantization parameters from observed min/max.
530    pub fn asymmetric(min_val: f32, max_val: f32, dtype: QuantDtype) -> Self {
531        let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
532        QParams {
533            scale: vec![scale],
534            zero_point: vec![zp],
535        }
536    }
537}
538
539// ===========================================================================
540// Observers — collect statistics for quantization calibration
541// ===========================================================================
542
543/// Trait for quantization observers that collect data statistics.
544pub trait Observer {
545    /// Update the observer with a batch of floating-point values.
546    fn observe(&mut self, data: &[f32]);
547    /// Calculate quantization parameters from collected statistics.
548    fn calculate_qparams(&self, dtype: QuantDtype) -> QParams;
549    /// Reset the observer state.
550    fn reset(&mut self);
551}
552
553// ---------------------------------------------------------------------------
554// MinMaxObserver
555// ---------------------------------------------------------------------------
556
557/// Tracks the running min/max of observed values.
558///
559/// Filters out NaN and Inf values before updating min/max.
560#[derive(Debug, Clone)]
561pub struct MinMaxObserver {
562    min_val: f32,
563    max_val: f32,
564}
565
566impl MinMaxObserver {
567    pub fn new() -> Self {
568        Self {
569            min_val: f32::INFINITY,
570            max_val: f32::NEG_INFINITY,
571        }
572    }
573}
574
575impl Default for MinMaxObserver {
576    fn default() -> Self {
577        Self::new()
578    }
579}
580
581impl Observer for MinMaxObserver {
582    fn observe(&mut self, data: &[f32]) {
583        for &x in data {
584            if !x.is_finite() {
585                continue;
586            }
587            if x < self.min_val {
588                self.min_val = x;
589            }
590            if x > self.max_val {
591                self.max_val = x;
592            }
593        }
594    }
595
596    fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
597        QParams::asymmetric(self.min_val, self.max_val, dtype)
598    }
599
600    fn reset(&mut self) {
601        self.min_val = f32::INFINITY;
602        self.max_val = f32::NEG_INFINITY;
603    }
604}
605
606// ---------------------------------------------------------------------------
607// PerChannelMinMaxObserver
608// ---------------------------------------------------------------------------
609
610/// Tracks per-channel running min/max of observed values.
611///
612/// Filters out NaN and Inf values before updating min/max.
613/// Logs a warning and returns an error when the channel count of incoming
614/// data doesn't match the configured number of channels.
615#[derive(Debug, Clone)]
616pub struct PerChannelMinMaxObserver {
617    num_channels: usize,
618    axis: usize,
619    min_vals: Vec<f32>,
620    max_vals: Vec<f32>,
621}
622
623impl PerChannelMinMaxObserver {
624    /// Create a new per-channel observer.
625    ///
626    /// * `num_channels` — expected number of channels.
627    /// * `axis` — the axis along which channels are sliced.
628    pub fn new(num_channels: usize, axis: usize) -> Self {
629        Self {
630            num_channels,
631            axis,
632            min_vals: vec![f32::INFINITY; num_channels],
633            max_vals: vec![f32::NEG_INFINITY; num_channels],
634        }
635    }
636
637    /// Observe a tensor's data with the given shape.
638    ///
639    /// Returns `Err` if the channel count along `self.axis` doesn't match.
640    pub fn observe_with_shape(&mut self, data: &[f32], shape: &[usize]) -> FerrotorchResult<()> {
641        if self.axis >= shape.len() {
642            return Err(FerrotorchError::InvalidArgument {
643                message: format!(
644                    "PerChannelMinMaxObserver axis {} out of range for {}-d tensor",
645                    self.axis,
646                    shape.len()
647                ),
648            });
649        }
650        let actual_channels = shape[self.axis];
651        if actual_channels != self.num_channels {
652            eprintln!(
653                "WARNING: PerChannelMinMaxObserver expected {} channels on axis {}, got {}",
654                self.num_channels, self.axis, actual_channels
655            );
656            return Err(FerrotorchError::InvalidArgument {
657                message: format!(
658                    "PerChannelMinMaxObserver expected {} channels on axis {}, got {}",
659                    self.num_channels, self.axis, actual_channels
660                ),
661            });
662        }
663
664        for (i, &x) in data.iter().enumerate() {
665            if !x.is_finite() {
666                continue;
667            }
668            let ch = channel_index(i, shape, self.axis);
669            if x < self.min_vals[ch] {
670                self.min_vals[ch] = x;
671            }
672            if x > self.max_vals[ch] {
673                self.max_vals[ch] = x;
674            }
675        }
676        Ok(())
677    }
678}
679
680impl Observer for PerChannelMinMaxObserver {
681    fn observe(&mut self, data: &[f32]) {
682        // Without shape info, we treat data as [num_channels, N] where N = len / num_channels.
683        if data.len() % self.num_channels != 0 {
684            eprintln!(
685                "WARNING: PerChannelMinMaxObserver data length {} not divisible by {} channels",
686                data.len(),
687                self.num_channels
688            );
689            return;
690        }
691        let per_channel = data.len() / self.num_channels;
692        for (i, &x) in data.iter().enumerate() {
693            if !x.is_finite() {
694                continue;
695            }
696            let ch = i / per_channel;
697            if ch >= self.num_channels {
698                continue;
699            }
700            if x < self.min_vals[ch] {
701                self.min_vals[ch] = x;
702            }
703            if x > self.max_vals[ch] {
704                self.max_vals[ch] = x;
705            }
706        }
707    }
708
709    fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
710        let params: Vec<(f32, i32)> = self
711            .min_vals
712            .iter()
713            .zip(self.max_vals.iter())
714            .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
715            .collect();
716        QParams {
717            scale: params.iter().map(|&(s, _)| s).collect(),
718            zero_point: params.iter().map(|&(_, z)| z).collect(),
719        }
720    }
721
722    fn reset(&mut self) {
723        self.min_vals.fill(f32::INFINITY);
724        self.max_vals.fill(f32::NEG_INFINITY);
725    }
726}
727
728// ---------------------------------------------------------------------------
729// HistogramObserver
730// ---------------------------------------------------------------------------
731
732/// Histogram-based observer that collects a distribution of values.
733///
734/// When the observed range expands, existing bin counts are redistributed
735/// into the new bin layout via linear interpolation rather than being zeroed.
736#[derive(Debug, Clone)]
737pub struct HistogramObserver {
738    num_bins: usize,
739    bins: Vec<u64>,
740    min_val: f32,
741    max_val: f32,
742    /// Whether we've seen any data yet.
743    initialized: bool,
744}
745
746impl HistogramObserver {
747    pub fn new(num_bins: usize) -> Self {
748        Self {
749            num_bins,
750            bins: vec![0u64; num_bins],
751            min_val: f32::INFINITY,
752            max_val: f32::NEG_INFINITY,
753            initialized: false,
754        }
755    }
756
757    /// Redistribute old bins into a new range via linear interpolation.
758    fn redistribute(&mut self, new_min: f32, new_max: f32) {
759        if !self.initialized || self.bins.iter().all(|&c| c == 0) {
760            self.min_val = new_min;
761            self.max_val = new_max;
762            return;
763        }
764
765        let old_min = self.min_val;
766        let old_max = self.max_val;
767        let old_range = old_max - old_min;
768        let new_range = new_max - new_min;
769
770        if old_range <= 0.0 || new_range <= 0.0 {
771            self.min_val = new_min;
772            self.max_val = new_max;
773            return;
774        }
775
776        let n = self.num_bins;
777        let old_bins = self.bins.clone();
778        self.bins.fill(0);
779
780        let old_bin_width = old_range / n as f32;
781        let new_bin_width = new_range / n as f32;
782
783        for (old_idx, &old_count) in old_bins.iter().enumerate().take(n) {
784            if old_count == 0 {
785                continue;
786            }
787            // Center of the old bin in value space.
788            let old_center = old_min + (old_idx as f32 + 0.5) * old_bin_width;
789            // Map to new bin index.
790            let new_frac = (old_center - new_min) / new_bin_width;
791            let new_idx = (new_frac as usize).min(n - 1);
792            self.bins[new_idx] += old_count;
793        }
794
795        self.min_val = new_min;
796        self.max_val = new_max;
797    }
798}
799
800impl Observer for HistogramObserver {
801    fn observe(&mut self, data: &[f32]) {
802        // First pass: find min/max of new data, filtering NaN/Inf.
803        let mut batch_min = f32::INFINITY;
804        let mut batch_max = f32::NEG_INFINITY;
805        for &x in data {
806            if !x.is_finite() {
807                continue;
808            }
809            if x < batch_min {
810                batch_min = x;
811            }
812            if x > batch_max {
813                batch_max = x;
814            }
815        }
816
817        if batch_min > batch_max {
818            // No finite values in this batch.
819            return;
820        }
821
822        // Check if range needs expanding.
823        let new_min = if self.initialized {
824            self.min_val.min(batch_min)
825        } else {
826            batch_min
827        };
828        let new_max = if self.initialized {
829            self.max_val.max(batch_max)
830        } else {
831            batch_max
832        };
833
834        if self.initialized && (new_min < self.min_val || new_max > self.max_val) {
835            // Range expanded — redistribute existing counts into new layout.
836            self.redistribute(new_min, new_max);
837        } else if !self.initialized {
838            self.min_val = new_min;
839            self.max_val = new_max;
840            self.initialized = true;
841        }
842
843        // Insert new data into bins.
844        let range = (self.max_val - self.min_val).max(f32::EPSILON);
845        let n = self.num_bins;
846        for &x in data {
847            if !x.is_finite() {
848                continue;
849            }
850            let frac = (x - self.min_val) / range;
851            let idx = ((frac * n as f32) as usize).min(n - 1);
852            self.bins[idx] += 1;
853        }
854    }
855
856    fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
857        QParams::asymmetric(self.min_val, self.max_val, dtype)
858    }
859
860    fn reset(&mut self) {
861        self.bins.fill(0);
862        self.min_val = f32::INFINITY;
863        self.max_val = f32::NEG_INFINITY;
864        self.initialized = false;
865    }
866}
867
868// ===========================================================================
869// FakeQuantize — differentiable quantize/dequantize for QAT
870// ===========================================================================
871
872/// Simulates quantization during training by quantizing and immediately
873/// dequantizing values, while allowing gradients to flow through via the
874/// straight-through estimator (STE).
875///
876/// Implements clipped STE: gradients are passed through unchanged for
877/// values within the quantization range `[dequantize(qmin), dequantize(qmax)]`,
878/// and zeroed for out-of-range values.
879#[derive(Debug, Clone)]
880pub struct FakeQuantize {
881    /// Target quantized dtype.
882    pub dtype: QuantDtype,
883    /// Cached quantization parameters.
884    pub qparams: Option<QParams>,
885    /// Whether the observer is enabled (collects statistics).
886    pub observer_enabled: bool,
887    /// Whether fake quantization is enabled.
888    pub fake_quant_enabled: bool,
889    /// The observer used to compute qparams.
890    observer: MinMaxObserver,
891}
892
893impl FakeQuantize {
894    /// Create a new FakeQuantize module.
895    pub fn new(dtype: QuantDtype) -> Self {
896        Self {
897            dtype,
898            qparams: None,
899            observer_enabled: true,
900            fake_quant_enabled: true,
901            observer: MinMaxObserver::new(),
902        }
903    }
904
905    /// Forward pass: observe data, fake-quantize, and return the result.
906    ///
907    /// Returns the fake-quantized data and a gradient mask for clipped STE.
908    /// The mask is 1.0 for in-range values and 0.0 for out-of-range values.
909    pub fn forward(&mut self, data: &[f32]) -> (Vec<f32>, Vec<f32>) {
910        if !self.fake_quant_enabled {
911            let ones = vec![1.0f32; data.len()];
912            return (data.to_vec(), ones);
913        }
914
915        // Observe if enabled.
916        if self.observer_enabled {
917            self.observer.observe(data);
918        }
919
920        // Calculate or use cached qparams.
921        // When observer is disabled and we have cached params, skip recalculation.
922        let qparams = if let Some(cached) = self.qparams.as_ref().filter(|_| !self.observer_enabled)
923        {
924            cached.clone()
925        } else {
926            let qp = self.observer.calculate_qparams(self.dtype);
927            self.qparams = Some(qp.clone());
928            qp
929        };
930
931        let scale = qparams.scale[0];
932        let zp = qparams.zero_point[0];
933        let qmin = self.dtype.qmin();
934        let qmax = self.dtype.qmax();
935
936        // Compute the dequantized range boundaries for clipped STE.
937        let range_min = (qmin as f32 - zp as f32) * scale;
938        let range_max = (qmax as f32 - zp as f32) * scale;
939
940        let mut output = Vec::with_capacity(data.len());
941        let mut grad_mask = Vec::with_capacity(data.len());
942
943        for &x in data {
944            // Fake quantize: quantize then dequantize.
945            let q = (x / scale + zp as f32)
946                .round()
947                .clamp(qmin as f32, qmax as f32);
948            let dq = (q - zp as f32) * scale;
949            output.push(dq);
950
951            // Clipped STE: zero gradient for out-of-range inputs.
952            if x >= range_min && x <= range_max {
953                grad_mask.push(1.0);
954            } else {
955                grad_mask.push(0.0);
956            }
957        }
958
959        (output, grad_mask)
960    }
961}
962
963// ===========================================================================
964// QatModel — quantization-aware training wrapper
965// ===========================================================================
966
967/// A layer with associated FakeQuantize modules for QAT.
968#[derive(Debug, Clone)]
969pub struct QatLayer {
970    /// FakeQuantize for this layer's weights.
971    pub weight_fq: FakeQuantize,
972    /// FakeQuantize for this layer's activations (applied after forward).
973    pub activation_fq: FakeQuantize,
974}
975
976/// Wraps a collection of named weight tensors for quantization-aware training.
977///
978/// Applies `FakeQuantize` to weights before forward and to activations after
979/// each layer's forward pass. Original weights are saved before fake-quantization
980/// and restored after forward to preserve full-precision values for gradient
981/// updates.
982#[derive(Debug)]
983pub struct QatModel {
984    /// Per-layer FakeQuantize state, keyed by layer name.
985    pub layers: HashMap<String, QatLayer>,
986    /// Target quantized dtype.
987    pub dtype: QuantDtype,
988}
989
990impl QatModel {
991    /// Create a new QAT model wrapper.
992    pub fn new(dtype: QuantDtype) -> Self {
993        Self {
994            layers: HashMap::new(),
995            dtype,
996        }
997    }
998
999    /// Register a layer for QAT.
1000    pub fn register_layer(&mut self, name: &str) {
1001        self.layers.insert(
1002            name.to_string(),
1003            QatLayer {
1004                weight_fq: FakeQuantize::new(self.dtype),
1005                activation_fq: FakeQuantize::new(self.dtype),
1006            },
1007        );
1008    }
1009
1010    /// Fake-quantize weights for a named layer.
1011    ///
1012    /// Returns `(fake_quantized_weights, original_weights)` so the caller
1013    /// can restore originals after the forward pass.
1014    pub fn fake_quantize_weights(
1015        &mut self,
1016        layer_name: &str,
1017        weights: &[f32],
1018    ) -> FerrotorchResult<(Vec<f32>, Vec<f32>)> {
1019        let layer =
1020            self.layers
1021                .get_mut(layer_name)
1022                .ok_or_else(|| FerrotorchError::InvalidArgument {
1023                    message: format!("layer '{layer_name}' not registered for QAT"),
1024                })?;
1025
1026        // Save original weights.
1027        let originals = weights.to_vec();
1028
1029        // Fake-quantize (gradient mask is used during backward, not here).
1030        let (fq_weights, _mask) = layer.weight_fq.forward(weights);
1031
1032        Ok((fq_weights, originals))
1033    }
1034
1035    /// Fake-quantize activations for a named layer.
1036    ///
1037    /// Applied after each layer's forward output, not just the last layer.
1038    pub fn fake_quantize_activations(
1039        &mut self,
1040        layer_name: &str,
1041        activations: &[f32],
1042    ) -> FerrotorchResult<(Vec<f32>, Vec<f32>)> {
1043        let layer =
1044            self.layers
1045                .get_mut(layer_name)
1046                .ok_or_else(|| FerrotorchError::InvalidArgument {
1047                    message: format!("layer '{layer_name}' not registered for QAT"),
1048                })?;
1049
1050        let (fq_activations, grad_mask) = layer.activation_fq.forward(activations);
1051        Ok((fq_activations, grad_mask))
1052    }
1053}
1054
1055/// Prepare a set of named parameters for quantization-aware training.
1056///
1057/// Creates a `QatModel` and registers layers. Only parameters whose name
1058/// contains "weight" get weight FakeQuantize; bias parameters are skipped.
1059pub fn prepare_qat(param_names: &[&str], dtype: QuantDtype) -> QatModel {
1060    let mut model = QatModel::new(dtype);
1061
1062    for &name in param_names {
1063        // Extract the layer name (everything before the last `.weight` or `.bias`).
1064        let layer_name = if let Some(prefix) = name.strip_suffix(".weight") {
1065            prefix
1066        } else if let Some(prefix) = name.strip_suffix(".bias") {
1067            // Only register the layer if not already registered — don't apply
1068            // weight FakeQuantize to bias parameters.
1069            if !model.layers.contains_key(prefix) {
1070                model.register_layer(prefix);
1071            }
1072            continue;
1073        } else {
1074            name
1075        };
1076
1077        model.register_layer(layer_name);
1078    }
1079
1080    model
1081}
1082
1083// ===========================================================================
1084// CUDA RNG — fork/join for reproducible GPU random state
1085// ===========================================================================
1086
1087/// Thread-safe GPU RNG state for fork/join semantics.
1088///
1089/// Uses `Mutex` with graceful poisoning recovery to avoid panics
1090/// when a thread panics while holding the lock.
1091pub mod cuda_rng {
1092    use std::sync::Mutex;
1093
1094    /// Global RNG state — a simple seed counter.
1095    static RNG_STATE: Mutex<u64> = Mutex::new(0xdeadbeef_cafebabe);
1096
1097    /// Saved RNG states for fork/join.
1098    static RNG_STACK: Mutex<Vec<u64>> = Mutex::new(Vec::new());
1099
1100    /// Get the current RNG state, recovering gracefully from mutex poisoning.
1101    pub fn get_state() -> u64 {
1102        let guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1103        *guard
1104    }
1105
1106    /// Set the RNG state.
1107    pub fn set_state(state: u64) {
1108        let mut guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1109        *guard = state;
1110    }
1111
1112    /// Save the current RNG state to a stack and set a new state.
1113    ///
1114    /// Uses `unwrap_or_else(|e| e.into_inner())` to handle poisoned mutexes
1115    /// gracefully instead of panicking.
1116    pub fn fork_rng(new_seed: u64) {
1117        let current = {
1118            let guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1119            *guard
1120        };
1121
1122        {
1123            let mut stack = RNG_STACK.lock().unwrap_or_else(|e| e.into_inner());
1124            stack.push(current);
1125        }
1126
1127        set_state(new_seed);
1128    }
1129
1130    /// Restore the previously saved RNG state from the stack.
1131    ///
1132    /// Uses `unwrap_or_else(|e| e.into_inner())` to handle poisoned mutexes
1133    /// gracefully instead of panicking.
1134    pub fn join_rng() {
1135        let saved = {
1136            let mut stack = RNG_STACK.lock().unwrap_or_else(|e| e.into_inner());
1137            stack.pop()
1138        };
1139
1140        if let Some(state) = saved {
1141            set_state(state);
1142        }
1143    }
1144
1145    /// Advance the RNG state and return the new value.
1146    pub fn next_seed() -> u64 {
1147        let mut guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1148        // Simple splitmix64 step.
1149        *guard = guard.wrapping_add(0x9e3779b97f4a7c15);
1150        let mut z = *guard;
1151        z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
1152        z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
1153        z ^ (z >> 31)
1154    }
1155}
1156
1157// ---------------------------------------------------------------------------
1158// Tests
1159// ---------------------------------------------------------------------------
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164
1165    /// Helper: create a tensor from f32 data.
1166    fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
1167        crate::from_slice(data, shape).unwrap()
1168    }
1169
1170    // ----- Round-trip quantize/dequantize -----
1171
1172    #[test]
1173    fn test_per_tensor_int8_roundtrip() {
1174        let data: Vec<f32> = (-10..=10).map(|x| x as f32 * 0.5).collect();
1175        let t = make_tensor(&data, &[data.len()]);
1176        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1177        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1178
1179        assert_eq!(rt.shape(), t.shape());
1180        let orig = t.data().unwrap();
1181        let recovered = rt.data().unwrap();
1182        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1183            let err = (o - r).abs();
1184            // INT8 over [-5, 5]: step ≈ 10/255 ≈ 0.04, max error ≈ half step ≈ 0.02
1185            assert!(
1186                err < 0.05,
1187                "element {i}: original={o}, recovered={r}, error={err}"
1188            );
1189        }
1190    }
1191
1192    #[test]
1193    fn test_per_tensor_uint8_roundtrip() {
1194        let data: Vec<f32> = (0..=20).map(|x| x as f32 * 0.1).collect();
1195        let t = make_tensor(&data, &[data.len()]);
1196        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Uint8).unwrap();
1197        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1198
1199        let orig = t.data().unwrap();
1200        let recovered = rt.data().unwrap();
1201        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1202            let err = (o - r).abs();
1203            // UINT8 over [0, 2]: step ≈ 2/255 ≈ 0.008
1204            assert!(
1205                err < 0.02,
1206                "element {i}: original={o}, recovered={r}, error={err}"
1207            );
1208        }
1209    }
1210
1211    #[test]
1212    fn test_per_tensor_int4_roundtrip() {
1213        // INT4 has only 16 levels, so larger quantization error is expected.
1214        let data: Vec<f32> = (-8..=7).map(|x| x as f32).collect();
1215        let t = make_tensor(&data, &[data.len()]);
1216        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int4).unwrap();
1217        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1218
1219        let orig = t.data().unwrap();
1220        let recovered = rt.data().unwrap();
1221        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1222            let err = (o - r).abs();
1223            // INT4 over [-8, 7]: step = 15/15 = 1.0, max error ≈ 0.5
1224            assert!(
1225                err < 1.01,
1226                "element {i}: original={o}, recovered={r}, error={err}"
1227            );
1228        }
1229    }
1230
1231    // ----- Per-channel -----
1232
1233    #[test]
1234    fn test_per_channel_int8_roundtrip() {
1235        // Shape [3, 4]: 3 channels along axis 0, each with different ranges.
1236        #[rustfmt::skip]
1237        let data: Vec<f32> = vec![
1238            // channel 0: range [0, 3]
1239            0.0, 1.0, 2.0, 3.0,
1240            // channel 1: range [-10, 10]
1241            -10.0, -5.0, 5.0, 10.0,
1242            // channel 2: range [100, 200]
1243            100.0, 130.0, 170.0, 200.0,
1244        ];
1245        let t = make_tensor(&data, &[3, 4]);
1246        let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int8).unwrap();
1247        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1248
1249        assert_eq!(qt.scale.len(), 3);
1250        assert_eq!(qt.zero_point.len(), 3);
1251
1252        let orig = t.data().unwrap();
1253        let recovered = rt.data().unwrap();
1254        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1255            let err = (o - r).abs();
1256            // Each channel has its own scale, so error is relative to the
1257            // channel's range. Worst case channel 2: 100/255 ≈ 0.39.
1258            assert!(
1259                err < 0.5,
1260                "element {i}: original={o}, recovered={r}, error={err}"
1261            );
1262        }
1263    }
1264
1265    #[test]
1266    fn test_per_channel_axis_out_of_bounds() {
1267        let t = make_tensor(&[1.0, 2.0, 3.0], &[3]);
1268        let result = quantize(&t, QuantScheme::PerChannel(5), QuantDtype::Int8);
1269        assert!(result.is_err());
1270    }
1271
1272    // ----- Quantized matmul -----
1273
1274    #[test]
1275    fn test_quantized_matmul_identity() {
1276        // A * I should ≈ A after quantize -> matmul -> dequantize.
1277        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
1278        let a = make_tensor(&a_data, &[2, 2]);
1279        let eye = make_tensor(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
1280
1281        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1282        let qi = quantize(&eye, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1283        let qc = quantized_matmul(&qa, &qi).unwrap();
1284        let c: Tensor<f32> = dequantize(&qc).unwrap();
1285
1286        assert_eq!(c.shape(), &[2, 2]);
1287        let c_data = c.data().unwrap();
1288        for (i, (&expected, &got)) in a_data.iter().zip(c_data.iter()).enumerate() {
1289            let err = (expected - got).abs();
1290            assert!(
1291                err < 0.5,
1292                "element {i}: expected={expected}, got={got}, error={err}"
1293            );
1294        }
1295    }
1296
1297    #[test]
1298    fn test_quantized_matmul_correctness() {
1299        // [2,3] x [3,2] -> [2,2]
1300        // A = [[1, 2, 3],
1301        //      [4, 5, 6]]
1302        // B = [[7,  8],
1303        //      [9, 10],
1304        //      [11, 12]]
1305        // A @ B = [[ 58,  64],
1306        //          [139, 154]]
1307        let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1308        let b = make_tensor(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
1309
1310        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1311        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1312        let qc = quantized_matmul(&qa, &qb).unwrap();
1313        let c: Tensor<f32> = dequantize(&qc).unwrap();
1314
1315        let expected = [58.0f32, 64.0, 139.0, 154.0];
1316        let c_data = c.data().unwrap();
1317        assert_eq!(c.shape(), &[2, 2]);
1318        for (i, (&e, &g)) in expected.iter().zip(c_data.iter()).enumerate() {
1319            let err = (e - g).abs();
1320            // Quantization introduces some error; for small integers in INT8
1321            // the error should be small relative to the values.
1322            assert!(err < 3.0, "element {i}: expected={e}, got={g}, error={err}");
1323        }
1324    }
1325
1326    #[test]
1327    fn test_quantized_matmul_shape_mismatch() {
1328        let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1329        let b = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1330
1331        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1332        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1333        let result = quantized_matmul(&qa, &qb);
1334        assert!(result.is_err());
1335    }
1336
1337    #[test]
1338    fn test_quantized_matmul_non_2d() {
1339        let a = make_tensor(&[1.0, 2.0, 3.0], &[3]);
1340        let b = make_tensor(&[4.0, 5.0, 6.0], &[3]);
1341
1342        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1343        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1344        let result = quantized_matmul(&qa, &qb);
1345        assert!(result.is_err());
1346    }
1347
1348    // ----- Module quantization utility -----
1349
1350    #[test]
1351    fn test_quantize_named_tensors() {
1352        let w1 = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1353        let w2 = make_tensor(&[-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], &[3, 2]);
1354
1355        let named = vec![
1356            ("layer.weight".to_string(), w1),
1357            ("layer2.weight".to_string(), w2),
1358        ];
1359
1360        let qmap = quantize_named_tensors(named, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1361
1362        assert_eq!(qmap.len(), 2);
1363        assert!(qmap.contains_key("layer.weight"));
1364        assert!(qmap.contains_key("layer2.weight"));
1365        assert_eq!(qmap["layer.weight"].shape(), &[2, 2]);
1366        assert_eq!(qmap["layer2.weight"].shape(), &[3, 2]);
1367    }
1368
1369    // ----- Constant values / edge cases -----
1370
1371    #[test]
1372    fn test_quantize_constant_tensor() {
1373        // All values identical — scale should not be zero.
1374        let t = make_tensor(&[5.0, 5.0, 5.0, 5.0], &[4]);
1375        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1376        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1377
1378        let recovered = rt.data().unwrap();
1379        for &r in recovered {
1380            assert!(
1381                (r - 5.0).abs() < 0.1,
1382                "constant tensor dequantized to {r}, expected 5.0"
1383            );
1384        }
1385    }
1386
1387    #[test]
1388    fn test_quantize_single_element() {
1389        let t = make_tensor(&[42.0], &[1]);
1390        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1391        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1392        assert!((rt.data().unwrap()[0] - 42.0).abs() < 0.5);
1393    }
1394
1395    #[test]
1396    fn test_per_channel_int4() {
1397        // 2 channels, 3 elements each.
1398        let data = vec![0.0, 1.0, 2.0, -4.0, 0.0, 4.0];
1399        let t = make_tensor(&data, &[2, 3]);
1400        let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int4).unwrap();
1401
1402        assert_eq!(qt.scale.len(), 2);
1403        assert_eq!(qt.zero_point.len(), 2);
1404
1405        let rt: Tensor<f32> = dequantize(&qt).unwrap();
1406        let orig = t.data().unwrap();
1407        let recovered = rt.data().unwrap();
1408        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1409            let err = (o - r).abs();
1410            // INT4 has coarse resolution, but channel-level ranges are small.
1411            assert!(
1412                err < 1.0,
1413                "element {i}: original={o}, recovered={r}, error={err}"
1414            );
1415        }
1416    }
1417
1418    #[test]
1419    fn test_dequantize_f64() {
1420        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1421        let t = crate::from_slice(&data, &[4]).unwrap();
1422        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1423        let rt: Tensor<f64> = dequantize(&qt).unwrap();
1424
1425        assert_eq!(rt.shape(), &[4]);
1426        let recovered = rt.data().unwrap();
1427        for (i, &r) in recovered.iter().enumerate() {
1428            let expected = data[i] as f64;
1429            let err = (expected - r).abs();
1430            assert!(
1431                err < 0.05,
1432                "element {i}: expected={expected}, recovered={r}, error={err}"
1433            );
1434        }
1435    }
1436
1437    #[test]
1438    fn test_quantized_tensor_accessors() {
1439        let t = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1440        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1441
1442        assert_eq!(qt.numel(), 6);
1443        assert_eq!(qt.shape(), &[2, 3]);
1444        assert_eq!(qt.data().len(), 6);
1445        assert_eq!(qt.scale().len(), 1);
1446        assert_eq!(qt.zero_point().len(), 1);
1447        assert_eq!(qt.scheme(), QuantScheme::PerTensor);
1448        assert_eq!(qt.qdtype(), QuantDtype::Int8);
1449    }
1450
1451    // ----- QParams -----
1452
1453    #[test]
1454    fn test_qparams_symmetric_int8() {
1455        let qp = QParams::symmetric(5.0, QuantDtype::Int8);
1456        assert_eq!(qp.zero_point, vec![0]);
1457        assert!((qp.scale[0] - 5.0 / 127.0).abs() < 1e-7);
1458    }
1459
1460    #[test]
1461    fn test_qparams_symmetric_uint8() {
1462        let qp = QParams::symmetric(5.0, QuantDtype::Uint8);
1463        assert_eq!(qp.zero_point, vec![128]);
1464        assert!((qp.scale[0] - 5.0 / 128.0).abs() < 1e-7);
1465    }
1466
1467    #[test]
1468    fn test_qparams_symmetric_int4() {
1469        let qp = QParams::symmetric(7.0, QuantDtype::Int4);
1470        assert_eq!(qp.zero_point, vec![0]);
1471        assert!((qp.scale[0] - 1.0).abs() < 1e-7);
1472    }
1473
1474    // ----- MinMaxObserver -----
1475
1476    #[test]
1477    fn test_minmax_observer() {
1478        let mut obs = MinMaxObserver::new();
1479        obs.observe(&[1.0, 2.0, 3.0]);
1480        obs.observe(&[-1.0, 5.0]);
1481        let qp = obs.calculate_qparams(QuantDtype::Int8);
1482        // Range includes zero: min=-1, max=5.
1483        assert_eq!(qp.scale.len(), 1);
1484        assert_eq!(qp.zero_point.len(), 1);
1485    }
1486
1487    #[test]
1488    fn test_minmax_observer_filters_nan_inf() {
1489        let mut obs = MinMaxObserver::new();
1490        obs.observe(&[1.0, f32::NAN, 2.0, f32::INFINITY, -1.0, f32::NEG_INFINITY]);
1491        let qp = obs.calculate_qparams(QuantDtype::Int8);
1492        // Should only see range [-1, 2], NaN/Inf filtered.
1493        let expected_range = 2.0 - (-1.0); // = 3.0
1494        let expected_scale = expected_range / 255.0;
1495        assert!((qp.scale[0] - expected_scale).abs() < 1e-5);
1496    }
1497
1498    // ----- PerChannelMinMaxObserver -----
1499
1500    #[test]
1501    fn test_per_channel_observer_with_shape() {
1502        let mut obs = PerChannelMinMaxObserver::new(2, 0);
1503        // Shape [2, 3]: channel 0 = [0, 1, 2], channel 1 = [10, 20, 30]
1504        obs.observe_with_shape(&[0.0, 1.0, 2.0, 10.0, 20.0, 30.0], &[2, 3])
1505            .unwrap();
1506        let qp = obs.calculate_qparams(QuantDtype::Int8);
1507        assert_eq!(qp.scale.len(), 2);
1508        assert_eq!(qp.zero_point.len(), 2);
1509    }
1510
1511    #[test]
1512    fn test_per_channel_observer_shape_mismatch() {
1513        let mut obs = PerChannelMinMaxObserver::new(3, 0);
1514        // Shape [2, 3] has 2 channels on axis 0, but observer expects 3.
1515        let result = obs.observe_with_shape(&[1.0; 6], &[2, 3]);
1516        assert!(result.is_err());
1517    }
1518
1519    #[test]
1520    fn test_per_channel_observer_axis() {
1521        let mut obs = PerChannelMinMaxObserver::new(3, 1);
1522        // Shape [2, 3]: axis 1 has 3 channels.
1523        obs.observe_with_shape(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1524            .unwrap();
1525        let qp = obs.calculate_qparams(QuantDtype::Int8);
1526        assert_eq!(qp.scale.len(), 3);
1527    }
1528
1529    #[test]
1530    fn test_per_channel_observer_filters_nan_inf() {
1531        let mut obs = PerChannelMinMaxObserver::new(2, 0);
1532        obs.observe_with_shape(&[f32::NAN, 1.0, 2.0, 10.0, f32::INFINITY, 30.0], &[2, 3])
1533            .unwrap();
1534        // Channel 0 should only see [1, 2], channel 1 should only see [10, 30].
1535        let qp = obs.calculate_qparams(QuantDtype::Int8);
1536        assert_eq!(qp.scale.len(), 2);
1537    }
1538
1539    // ----- HistogramObserver -----
1540
1541    #[test]
1542    fn test_histogram_observer_basic() {
1543        let mut obs = HistogramObserver::new(100);
1544        obs.observe(&[0.0, 0.5, 1.0]);
1545        let qp = obs.calculate_qparams(QuantDtype::Int8);
1546        assert_eq!(qp.scale.len(), 1);
1547    }
1548
1549    #[test]
1550    fn test_histogram_observer_range_expansion() {
1551        let mut obs = HistogramObserver::new(100);
1552        obs.observe(&[0.0, 1.0]);
1553        // Initial range is [0, 1].
1554        let bins_after_first = obs.bins.clone();
1555        let total_first: u64 = bins_after_first.iter().sum();
1556        assert_eq!(total_first, 2);
1557
1558        obs.observe(&[-1.0, 2.0]);
1559        // Range expanded to [-1, 2]. Old counts should be redistributed, not zeroed.
1560        let total_second: u64 = obs.bins.iter().sum();
1561        // Should have 4 total counts (2 original redistributed + 2 new).
1562        assert_eq!(total_second, 4);
1563    }
1564
1565    #[test]
1566    fn test_histogram_observer_filters_nan_inf() {
1567        let mut obs = HistogramObserver::new(50);
1568        obs.observe(&[f32::NAN, 1.0, f32::INFINITY, 2.0]);
1569        let total: u64 = obs.bins.iter().sum();
1570        // Only 2 finite values should be counted.
1571        assert_eq!(total, 2);
1572    }
1573
1574    // ----- FakeQuantize -----
1575
1576    #[test]
1577    fn test_fake_quantize_roundtrip() {
1578        let mut fq = FakeQuantize::new(QuantDtype::Int8);
1579        let data = vec![0.0, 0.5, 1.0, 1.5, 2.0];
1580        let (output, mask) = fq.forward(&data);
1581        assert_eq!(output.len(), 5);
1582        assert_eq!(mask.len(), 5);
1583
1584        // Output should be close to input (quantize then dequantize).
1585        for (i, (&o, &d)) in output.iter().zip(data.iter()).enumerate() {
1586            assert!((o - d).abs() < 0.1, "element {i}: output={o}, data={d}");
1587        }
1588    }
1589
1590    #[test]
1591    fn test_fake_quantize_ste_clipping() {
1592        let mut fq = FakeQuantize::new(QuantDtype::Int8);
1593        // First, observe a range [0, 2].
1594        let (_, _) = fq.forward(&[0.0, 1.0, 2.0]);
1595
1596        // Disable observer so range stays locked at [0, 2].
1597        fq.observer_enabled = false;
1598
1599        // Now forward with values outside the observed range.
1600        let (_, mask) = fq.forward(&[0.5, 1.0, 100.0, -100.0]);
1601        // In-range values should have mask = 1.0.
1602        assert_eq!(mask[0], 1.0);
1603        assert_eq!(mask[1], 1.0);
1604        // Out-of-range values should have mask = 0.0.
1605        assert_eq!(mask[2], 0.0);
1606        assert_eq!(mask[3], 0.0);
1607    }
1608
1609    #[test]
1610    fn test_fake_quantize_observer_disabled_uses_cached() {
1611        let mut fq = FakeQuantize::new(QuantDtype::Int8);
1612        // Observe initial range.
1613        let (_, _) = fq.forward(&[0.0, 10.0]);
1614        let cached_scale = fq.qparams.as_ref().unwrap().scale[0];
1615
1616        // Disable observer.
1617        fq.observer_enabled = false;
1618
1619        // Forward with a much larger range — should NOT update qparams.
1620        let (_, _) = fq.forward(&[0.0, 1000.0]);
1621        let scale_after = fq.qparams.as_ref().unwrap().scale[0];
1622        assert!(
1623            (scale_after - cached_scale).abs() < 1e-10,
1624            "scale should not change when observer is disabled"
1625        );
1626    }
1627
1628    #[test]
1629    fn test_fake_quantize_disabled_is_identity() {
1630        let mut fq = FakeQuantize::new(QuantDtype::Int8);
1631        fq.fake_quant_enabled = false;
1632        let data = vec![1.234, 5.678, -9.012];
1633        let (output, mask) = fq.forward(&data);
1634        assert_eq!(output, data);
1635        assert!(mask.iter().all(|&m| m == 1.0));
1636    }
1637
1638    // ----- QatModel -----
1639
1640    #[test]
1641    fn test_qat_model_register_and_fq_weights() {
1642        let mut model = QatModel::new(QuantDtype::Int8);
1643        model.register_layer("fc1");
1644
1645        let weights = vec![0.1, 0.2, 0.3, 0.4];
1646        let (fq_weights, originals) = model.fake_quantize_weights("fc1", &weights).unwrap();
1647
1648        // Originals should be exact copies.
1649        assert_eq!(originals, weights);
1650        // Fake-quantized weights should be close to originals.
1651        for (i, (&fq, &orig)) in fq_weights.iter().zip(weights.iter()).enumerate() {
1652            assert!((fq - orig).abs() < 0.1, "weight {i}: fq={fq}, orig={orig}");
1653        }
1654    }
1655
1656    #[test]
1657    fn test_qat_model_activation_fq_per_layer() {
1658        let mut model = QatModel::new(QuantDtype::Int8);
1659        model.register_layer("layer1");
1660        model.register_layer("layer2");
1661
1662        // Both layers should have independent activation FakeQuantize.
1663        let (act1, _) = model
1664            .fake_quantize_activations("layer1", &[1.0, 2.0])
1665            .unwrap();
1666        let (act2, _) = model
1667            .fake_quantize_activations("layer2", &[10.0, 20.0])
1668            .unwrap();
1669        assert_eq!(act1.len(), 2);
1670        assert_eq!(act2.len(), 2);
1671    }
1672
1673    #[test]
1674    fn test_qat_model_unregistered_layer_errors() {
1675        let mut model = QatModel::new(QuantDtype::Int8);
1676        let result = model.fake_quantize_weights("nonexistent", &[1.0]);
1677        assert!(result.is_err());
1678    }
1679
1680    // ----- prepare_qat -----
1681
1682    #[test]
1683    fn test_prepare_qat_skips_bias() {
1684        let names = &["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"];
1685        let model = prepare_qat(names, QuantDtype::Int8);
1686
1687        assert!(model.layers.contains_key("fc1"));
1688        assert!(model.layers.contains_key("fc2"));
1689        assert_eq!(model.layers.len(), 2);
1690    }
1691
1692    #[test]
1693    fn test_prepare_qat_bias_only_still_registers() {
1694        let names = &["fc1.bias"];
1695        let model = prepare_qat(names, QuantDtype::Int8);
1696        // Even bias-only parameters should get a layer registered.
1697        assert!(model.layers.contains_key("fc1"));
1698    }
1699
1700    // ----- cuda_rng -----
1701
1702    #[test]
1703    fn test_cuda_rng_fork_join() {
1704        let initial = cuda_rng::get_state();
1705        cuda_rng::fork_rng(0x12345678);
1706        assert_eq!(cuda_rng::get_state(), 0x12345678);
1707        cuda_rng::join_rng();
1708        assert_eq!(cuda_rng::get_state(), initial);
1709    }
1710
1711    #[test]
1712    fn test_cuda_rng_next_seed() {
1713        let s1 = cuda_rng::next_seed();
1714        let s2 = cuda_rng::next_seed();
1715        assert_ne!(s1, s2, "consecutive seeds should differ");
1716    }
1717}