Skip to main content

ferrotorch_core/
quantize.rs

1//! Post-training quantization (PTQ) for ferrotorch tensors.
2//!
3//! Provides symmetric and asymmetric quantization to INT8, INT4, and UINT8,
4//! with per-tensor or per-channel granularity. Designed for inference-time
5//! model compression — quantize once after training, then run forward passes
6//! with reduced memory and (on supported hardware) faster matmul.
7
8use std::collections::HashMap;
9
10use crate::dtype::Float;
11use crate::error::{FerrotorchError, FerrotorchResult};
12use crate::storage::TensorStorage;
13use crate::tensor::Tensor;
14
15// ---------------------------------------------------------------------------
16// Enums
17// ---------------------------------------------------------------------------
18
19/// Granularity of quantization parameters (scale / zero_point).
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum QuantScheme {
22    /// One scale and zero_point for the entire tensor.
23    PerTensor,
24    /// One scale and zero_point per slice along the given axis.
25    PerChannel(usize),
26}
27
28/// Target integer dtype for quantized storage.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QuantDtype {
31    /// Signed 8-bit: [-128, 127].
32    Int8,
33    /// Signed 4-bit: [-8, 7].  Stored packed in `i8` values.
34    Int4,
35    /// Unsigned 8-bit: [0, 255].
36    Uint8,
37}
38
39impl QuantDtype {
40    /// Minimum representable value.
41    #[inline]
42    fn qmin(self) -> i32 {
43        match self {
44            QuantDtype::Int8 => -128,
45            QuantDtype::Int4 => -8,
46            QuantDtype::Uint8 => 0,
47        }
48    }
49
50    /// Maximum representable value.
51    #[inline]
52    fn qmax(self) -> i32 {
53        match self {
54            QuantDtype::Int8 => 127,
55            QuantDtype::Int4 => 7,
56            QuantDtype::Uint8 => 255,
57        }
58    }
59}
60
61// ---------------------------------------------------------------------------
62// QuantizedTensor
63// ---------------------------------------------------------------------------
64
65/// A tensor stored in quantized (integer) representation.
66///
67/// The real value is recovered by `x = (q - zero_point) * scale`.
68///
69/// `scale` and `zero_point` are vectors whose length equals:
70/// * 1 for `PerTensor`
71/// * `shape[axis]` for `PerChannel(axis)`
72#[derive(Debug, Clone)]
73pub struct QuantizedTensor {
74    /// Quantized values stored as `i8` regardless of logical dtype.
75    /// For `Uint8`, the stored `i8` is reinterpreted as `u8` via
76    /// wrapping cast; for `Int4` only the low 4 bits are significant.
77    data: Vec<i8>,
78    /// Per-tensor or per-channel scales.
79    scale: Vec<f32>,
80    /// Per-tensor or per-channel zero points (in quantized domain).
81    zero_point: Vec<i32>,
82    /// Original tensor shape.
83    shape: Vec<usize>,
84    /// Quantization granularity.
85    scheme: QuantScheme,
86    /// Target quantized dtype.
87    dtype: QuantDtype,
88}
89
90impl QuantizedTensor {
91    /// Number of elements.
92    #[inline]
93    pub fn numel(&self) -> usize {
94        self.shape.iter().product()
95    }
96
97    /// Borrow the shape.
98    #[inline]
99    pub fn shape(&self) -> &[usize] {
100        &self.shape
101    }
102
103    /// Borrow the quantized data.
104    #[inline]
105    pub fn data(&self) -> &[i8] {
106        &self.data
107    }
108
109    /// Borrow the scale vector.
110    #[inline]
111    pub fn scale(&self) -> &[f32] {
112        &self.scale
113    }
114
115    /// Borrow the zero-point vector.
116    #[inline]
117    pub fn zero_point(&self) -> &[i32] {
118        &self.zero_point
119    }
120
121    /// The quantization scheme used.
122    #[inline]
123    pub fn scheme(&self) -> QuantScheme {
124        self.scheme
125    }
126
127    /// The quantized dtype.
128    #[inline]
129    pub fn qdtype(&self) -> QuantDtype {
130        self.dtype
131    }
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Compute scale and zero_point for a given (min, max) range and target dtype.
139///
140/// Uses the standard asymmetric affine quantization formula:
141///   scale = (max - min) / (qmax - qmin)
142///   zero_point = round(qmin - min / scale)
143///
144/// The range is always expanded to include zero so that `0.0` maps exactly
145/// to an integer quantized value (important for zero-padding and ReLU outputs).
146/// When min == max the range would collapse to zero, so this expansion also
147/// prevents division-by-zero.
148fn compute_scale_zp(min_val: f32, max_val: f32, dtype: QuantDtype) -> (f32, i32) {
149    let qmin = dtype.qmin();
150    let qmax = dtype.qmax();
151
152    // Ensure the range includes zero (standard PyTorch behaviour).
153    let min_val = min_val.min(0.0);
154    let max_val = max_val.max(0.0);
155
156    // After including zero the range is at least max(|min|, |max|) > 0,
157    // but guard against the degenerate all-zeros case.
158    let range = (max_val - min_val).max(f32::EPSILON);
159    let scale = range / (qmax - qmin) as f32;
160
161    // zero_point is intentionally NOT clamped to [qmin, qmax]. It is stored
162    // as i32 and may lie outside the quantized integer range. This is correct
163    // for asymmetric affine quantization — clamping the zero_point distorts
164    // the mapping when the float range doesn't straddle zero.
165    let zp = (qmin as f32 - min_val / scale).round() as i32;
166
167    (scale, zp)
168}
169
170/// Clamp and round a float to the quantized integer range.
171///
172/// Returns the result as `i8`. For `Uint8` the caller passes `qmin=0`,
173/// `qmax=255`; the clamped i32 is cast to `u8` first then transmuted to `i8`
174/// so that values 128..=255 are preserved through the bit pattern.
175#[inline]
176fn quantize_val(x: f32, scale: f32, zp: i32, qmin: i32, qmax: i32, is_unsigned: bool) -> i8 {
177    let q = (x / scale + zp as f32).round() as i32;
178    let clamped = q.clamp(qmin, qmax);
179    if is_unsigned {
180        (clamped as u8) as i8
181    } else {
182        clamped as i8
183    }
184}
185
186/// Recover the i32 quantized value from the stored `i8`, accounting for
187/// unsigned dtypes where the bit pattern represents a `u8`.
188#[inline]
189fn stored_to_i32(val: i8, is_unsigned: bool) -> i32 {
190    if is_unsigned {
191        (val as u8) as i32
192    } else {
193        val as i32
194    }
195}
196
197/// Map a linear flat index to per-channel parameters.
198///
199/// For a tensor of shape `[d0, d1, ..., dn]` with channel axis `axis`,
200/// returns the channel index for the element at `flat_index`.
201#[inline]
202fn channel_index(flat_index: usize, shape: &[usize], axis: usize) -> usize {
203    // stride of the channel axis = product of dims after axis.
204    let stride: usize = shape[axis + 1..].iter().product();
205    (flat_index / stride) % shape[axis]
206}
207
208// ---------------------------------------------------------------------------
209// Quantize
210// ---------------------------------------------------------------------------
211
212/// Quantize a floating-point tensor.
213///
214/// # Per-tensor
215///
216/// Computes a single (scale, zero_point) pair from the global min/max.
217///
218/// # Per-channel
219///
220/// Computes one (scale, zero_point) per slice along the given axis. This is
221/// common for weight tensors where each output channel has its own range.
222pub fn quantize<T: Float>(
223    tensor: &Tensor<T>,
224    scheme: QuantScheme,
225    dtype: QuantDtype,
226) -> FerrotorchResult<QuantizedTensor> {
227    let data = tensor.data()?;
228    let shape = tensor.shape().to_vec();
229    let numel = tensor.numel();
230    let qmin = dtype.qmin();
231    let qmax = dtype.qmax();
232
233    let is_unsigned = dtype == QuantDtype::Uint8;
234
235    match scheme {
236        QuantScheme::PerTensor => {
237            // Global min/max.
238            let mut min_val = f32::INFINITY;
239            let mut max_val = f32::NEG_INFINITY;
240            for &v in data {
241                let f = v.to_f32().unwrap();
242                if f < min_val {
243                    min_val = f;
244                }
245                if f > max_val {
246                    max_val = f;
247                }
248            }
249
250            let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
251
252            let qdata: Vec<i8> = data
253                .iter()
254                .map(|&v| {
255                    quantize_val(v.to_f32().unwrap(), scale, zp, qmin, qmax, is_unsigned)
256                })
257                .collect();
258
259            Ok(QuantizedTensor {
260                data: qdata,
261                scale: vec![scale],
262                zero_point: vec![zp],
263                shape,
264                scheme,
265                dtype,
266            })
267        }
268
269        QuantScheme::PerChannel(axis) => {
270            if axis >= shape.len() {
271                return Err(FerrotorchError::InvalidArgument {
272                    message: format!(
273                        "PerChannel axis {axis} out of range for {}-d tensor",
274                        shape.len()
275                    ),
276                });
277            }
278
279            let num_channels = shape[axis];
280            let mut mins = vec![f32::INFINITY; num_channels];
281            let mut maxs = vec![f32::NEG_INFINITY; num_channels];
282
283            for (i, &v) in data.iter().enumerate() {
284                let ch = channel_index(i, &shape, axis);
285                let f = v.to_f32().unwrap();
286                if f < mins[ch] {
287                    mins[ch] = f;
288                }
289                if f > maxs[ch] {
290                    maxs[ch] = f;
291                }
292            }
293
294            let params: Vec<(f32, i32)> = mins
295                .iter()
296                .zip(maxs.iter())
297                .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
298                .collect();
299
300            let scales: Vec<f32> = params.iter().map(|&(s, _)| s).collect();
301            let zps: Vec<i32> = params.iter().map(|&(_, z)| z).collect();
302
303            let mut qdata = Vec::with_capacity(numel);
304            for (i, &v) in data.iter().enumerate() {
305                let ch = channel_index(i, &shape, axis);
306                qdata.push(quantize_val(
307                    v.to_f32().unwrap(),
308                    scales[ch],
309                    zps[ch],
310                    qmin,
311                    qmax,
312                    is_unsigned,
313                ));
314            }
315
316            Ok(QuantizedTensor {
317                data: qdata,
318                scale: scales,
319                zero_point: zps,
320                shape,
321                scheme,
322                dtype,
323            })
324        }
325    }
326}
327
328// ---------------------------------------------------------------------------
329// Dequantize
330// ---------------------------------------------------------------------------
331
332/// Dequantize back to a floating-point tensor.
333///
334/// Applies the inverse mapping: `x = (q - zero_point) * scale`.
335pub fn dequantize<T: Float>(qtensor: &QuantizedTensor) -> FerrotorchResult<Tensor<T>> {
336    let numel = qtensor.numel();
337    let mut result = Vec::with_capacity(numel);
338    let is_unsigned = qtensor.dtype == QuantDtype::Uint8;
339
340    match qtensor.scheme {
341        QuantScheme::PerTensor => {
342            let scale = qtensor.scale[0];
343            let zp = qtensor.zero_point[0];
344            for &q in &qtensor.data {
345                let val = (stored_to_i32(q, is_unsigned) - zp) as f32 * scale;
346                result.push(T::from(val).unwrap());
347            }
348        }
349        QuantScheme::PerChannel(axis) => {
350            for (i, &q) in qtensor.data.iter().enumerate() {
351                let ch = channel_index(i, &qtensor.shape, axis);
352                let val = (stored_to_i32(q, is_unsigned) - qtensor.zero_point[ch]) as f32
353                    * qtensor.scale[ch];
354                result.push(T::from(val).unwrap());
355            }
356        }
357    }
358
359    Tensor::from_storage(TensorStorage::cpu(result), qtensor.shape.clone(), false)
360}
361
362// ---------------------------------------------------------------------------
363// Quantized matmul
364// ---------------------------------------------------------------------------
365
366/// Multiply two quantized 2-D matrices and return a quantized result.
367///
368/// Strategy: accumulate in `i32` to avoid overflow, then rescale to the output
369/// quantized domain. This avoids a full dequantize-matmul-requantize round-trip
370/// while remaining numerically correct for INT8.
371///
372/// Both inputs must be 2-D, with compatible inner dimensions (standard matmul
373/// rules: `[M, K] x [K, N] -> [M, N]`).
374pub fn quantized_matmul(
375    a: &QuantizedTensor,
376    b: &QuantizedTensor,
377) -> FerrotorchResult<QuantizedTensor> {
378    // Validate shapes.
379    if a.shape.len() != 2 || b.shape.len() != 2 {
380        return Err(FerrotorchError::InvalidArgument {
381            message: format!(
382                "quantized_matmul requires 2-D tensors, got shapes {:?} and {:?}",
383                a.shape, b.shape
384            ),
385        });
386    }
387
388    let m = a.shape[0];
389    let k = a.shape[1];
390    let k2 = b.shape[0];
391    let n = b.shape[1];
392
393    if k != k2 {
394        return Err(FerrotorchError::ShapeMismatch {
395            message: format!(
396                "quantized_matmul inner dimensions mismatch: [{m}, {k}] x [{k2}, {n}]"
397            ),
398        });
399    }
400
401    // Both inputs must be PerTensor for the fast path.
402    if a.scale.len() != 1 || b.scale.len() != 1 {
403        return Err(FerrotorchError::InvalidArgument {
404            message: "quantized_matmul currently requires PerTensor-quantized inputs".into(),
405        });
406    }
407
408    let a_scale = a.scale[0];
409    let a_zp = a.zero_point[0];
410    let b_scale = b.scale[0];
411    let b_zp = b.zero_point[0];
412
413    let a_unsigned = a.dtype == QuantDtype::Uint8;
414    let b_unsigned = b.dtype == QuantDtype::Uint8;
415
416    // Accumulate in i32.
417    let mut acc = vec![0i32; m * n];
418    for i in 0..m {
419        for j in 0..n {
420            let mut sum = 0i32;
421            for p in 0..k {
422                let qa = stored_to_i32(a.data[i * k + p], a_unsigned) - a_zp;
423                let qb = stored_to_i32(b.data[p * n + j], b_unsigned) - b_zp;
424                sum += qa * qb;
425            }
426            acc[i * n + j] = sum;
427        }
428    }
429
430    // The real-valued result element is: acc[i,j] * a_scale * b_scale.
431    // Requantize: pick INT8 output with its own scale/zp.
432    let combined_scale = a_scale * b_scale;
433
434    // Find the real-valued min/max of the output.
435    let mut out_min = f32::INFINITY;
436    let mut out_max = f32::NEG_INFINITY;
437    for &a_val in &acc {
438        let real = a_val as f32 * combined_scale;
439        if real < out_min {
440            out_min = real;
441        }
442        if real > out_max {
443            out_max = real;
444        }
445    }
446
447    let out_dtype = QuantDtype::Int8;
448    let (out_scale, out_zp) = compute_scale_zp(out_min, out_max, out_dtype);
449    let qmin = out_dtype.qmin();
450    let qmax = out_dtype.qmax();
451
452    let qdata: Vec<i8> = acc
453        .iter()
454        .map(|&a_val| {
455            let real = a_val as f32 * combined_scale;
456            quantize_val(real, out_scale, out_zp, qmin, qmax, false)
457        })
458        .collect();
459
460    Ok(QuantizedTensor {
461        data: qdata,
462        scale: vec![out_scale],
463        zero_point: vec![out_zp],
464        shape: vec![m, n],
465        scheme: QuantScheme::PerTensor,
466        dtype: out_dtype,
467    })
468}
469
470// ---------------------------------------------------------------------------
471// Module-level quantization utility
472// ---------------------------------------------------------------------------
473
474/// Quantize every weight tensor in a module, returning a name -> QuantizedTensor
475/// map suitable for serialization or quantized inference.
476///
477/// This accepts any type implementing the `Module` trait from `ferrotorch-nn`.
478/// Because `ferrotorch-core` does not depend on `ferrotorch-nn`, we accept a
479/// generic iterator of named tensors instead.
480pub fn quantize_named_tensors<T: Float>(
481    named_tensors: impl IntoIterator<Item = (String, Tensor<T>)>,
482    scheme: QuantScheme,
483    dtype: QuantDtype,
484) -> FerrotorchResult<HashMap<String, QuantizedTensor>> {
485    let mut result = HashMap::new();
486    for (name, tensor) in named_tensors {
487        let qtensor = quantize(&tensor, scheme, dtype)?;
488        result.insert(name, qtensor);
489    }
490    Ok(result)
491}
492
493// ---------------------------------------------------------------------------
494// Tests
495// ---------------------------------------------------------------------------
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    /// Helper: create a tensor from f32 data.
502    fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
503        crate::from_slice(data, shape).unwrap()
504    }
505
506    // ----- Round-trip quantize/dequantize -----
507
508    #[test]
509    fn test_per_tensor_int8_roundtrip() {
510        let data: Vec<f32> = (-10..=10).map(|x| x as f32 * 0.5).collect();
511        let t = make_tensor(&data, &[data.len()]);
512        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
513        let rt: Tensor<f32> = dequantize(&qt).unwrap();
514
515        assert_eq!(rt.shape(), t.shape());
516        let orig = t.data().unwrap();
517        let recovered = rt.data().unwrap();
518        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
519            let err = (o - r).abs();
520            // INT8 over [-5, 5]: step ≈ 10/255 ≈ 0.04, max error ≈ half step ≈ 0.02
521            assert!(
522                err < 0.05,
523                "element {i}: original={o}, recovered={r}, error={err}"
524            );
525        }
526    }
527
528    #[test]
529    fn test_per_tensor_uint8_roundtrip() {
530        let data: Vec<f32> = (0..=20).map(|x| x as f32 * 0.1).collect();
531        let t = make_tensor(&data, &[data.len()]);
532        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Uint8).unwrap();
533        let rt: Tensor<f32> = dequantize(&qt).unwrap();
534
535        let orig = t.data().unwrap();
536        let recovered = rt.data().unwrap();
537        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
538            let err = (o - r).abs();
539            // UINT8 over [0, 2]: step ≈ 2/255 ≈ 0.008
540            assert!(
541                err < 0.02,
542                "element {i}: original={o}, recovered={r}, error={err}"
543            );
544        }
545    }
546
547    #[test]
548    fn test_per_tensor_int4_roundtrip() {
549        // INT4 has only 16 levels, so larger quantization error is expected.
550        let data: Vec<f32> = (-8..=7).map(|x| x as f32).collect();
551        let t = make_tensor(&data, &[data.len()]);
552        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int4).unwrap();
553        let rt: Tensor<f32> = dequantize(&qt).unwrap();
554
555        let orig = t.data().unwrap();
556        let recovered = rt.data().unwrap();
557        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
558            let err = (o - r).abs();
559            // INT4 over [-8, 7]: step = 15/15 = 1.0, max error ≈ 0.5
560            assert!(
561                err < 1.01,
562                "element {i}: original={o}, recovered={r}, error={err}"
563            );
564        }
565    }
566
567    // ----- Per-channel -----
568
569    #[test]
570    fn test_per_channel_int8_roundtrip() {
571        // Shape [3, 4]: 3 channels along axis 0, each with different ranges.
572        #[rustfmt::skip]
573        let data: Vec<f32> = vec![
574            // channel 0: range [0, 3]
575            0.0, 1.0, 2.0, 3.0,
576            // channel 1: range [-10, 10]
577            -10.0, -5.0, 5.0, 10.0,
578            // channel 2: range [100, 200]
579            100.0, 130.0, 170.0, 200.0,
580        ];
581        let t = make_tensor(&data, &[3, 4]);
582        let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int8).unwrap();
583        let rt: Tensor<f32> = dequantize(&qt).unwrap();
584
585        assert_eq!(qt.scale.len(), 3);
586        assert_eq!(qt.zero_point.len(), 3);
587
588        let orig = t.data().unwrap();
589        let recovered = rt.data().unwrap();
590        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
591            let err = (o - r).abs();
592            // Each channel has its own scale, so error is relative to the
593            // channel's range. Worst case channel 2: 100/255 ≈ 0.39.
594            assert!(
595                err < 0.5,
596                "element {i}: original={o}, recovered={r}, error={err}"
597            );
598        }
599    }
600
601    #[test]
602    fn test_per_channel_axis_out_of_bounds() {
603        let t = make_tensor(&[1.0, 2.0, 3.0], &[3]);
604        let result = quantize(&t, QuantScheme::PerChannel(5), QuantDtype::Int8);
605        assert!(result.is_err());
606    }
607
608    // ----- Quantized matmul -----
609
610    #[test]
611    fn test_quantized_matmul_identity() {
612        // A * I should ≈ A after quantize -> matmul -> dequantize.
613        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
614        let a = make_tensor(&a_data, &[2, 2]);
615        let eye = make_tensor(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
616
617        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
618        let qi = quantize(&eye, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
619        let qc = quantized_matmul(&qa, &qi).unwrap();
620        let c: Tensor<f32> = dequantize(&qc).unwrap();
621
622        assert_eq!(c.shape(), &[2, 2]);
623        let c_data = c.data().unwrap();
624        for (i, (&expected, &got)) in a_data.iter().zip(c_data.iter()).enumerate() {
625            let err = (expected - got).abs();
626            assert!(
627                err < 0.5,
628                "element {i}: expected={expected}, got={got}, error={err}"
629            );
630        }
631    }
632
633    #[test]
634    fn test_quantized_matmul_correctness() {
635        // [2,3] x [3,2] -> [2,2]
636        // A = [[1, 2, 3],
637        //      [4, 5, 6]]
638        // B = [[7,  8],
639        //      [9, 10],
640        //      [11, 12]]
641        // A @ B = [[ 58,  64],
642        //          [139, 154]]
643        let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
644        let b = make_tensor(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
645
646        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
647        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
648        let qc = quantized_matmul(&qa, &qb).unwrap();
649        let c: Tensor<f32> = dequantize(&qc).unwrap();
650
651        let expected = [58.0f32, 64.0, 139.0, 154.0];
652        let c_data = c.data().unwrap();
653        assert_eq!(c.shape(), &[2, 2]);
654        for (i, (&e, &g)) in expected.iter().zip(c_data.iter()).enumerate() {
655            let err = (e - g).abs();
656            // Quantization introduces some error; for small integers in INT8
657            // the error should be small relative to the values.
658            assert!(
659                err < 3.0,
660                "element {i}: expected={e}, got={g}, error={err}"
661            );
662        }
663    }
664
665    #[test]
666    fn test_quantized_matmul_shape_mismatch() {
667        let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
668        let b = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
669
670        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
671        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
672        let result = quantized_matmul(&qa, &qb);
673        assert!(result.is_err());
674    }
675
676    #[test]
677    fn test_quantized_matmul_non_2d() {
678        let a = make_tensor(&[1.0, 2.0, 3.0], &[3]);
679        let b = make_tensor(&[4.0, 5.0, 6.0], &[3]);
680
681        let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
682        let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
683        let result = quantized_matmul(&qa, &qb);
684        assert!(result.is_err());
685    }
686
687    // ----- Module quantization utility -----
688
689    #[test]
690    fn test_quantize_named_tensors() {
691        let w1 = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
692        let w2 = make_tensor(&[-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], &[3, 2]);
693
694        let named = vec![
695            ("layer.weight".to_string(), w1),
696            ("layer2.weight".to_string(), w2),
697        ];
698
699        let qmap =
700            quantize_named_tensors(named, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
701
702        assert_eq!(qmap.len(), 2);
703        assert!(qmap.contains_key("layer.weight"));
704        assert!(qmap.contains_key("layer2.weight"));
705        assert_eq!(qmap["layer.weight"].shape(), &[2, 2]);
706        assert_eq!(qmap["layer2.weight"].shape(), &[3, 2]);
707    }
708
709    // ----- Constant values / edge cases -----
710
711    #[test]
712    fn test_quantize_constant_tensor() {
713        // All values identical — scale should not be zero.
714        let t = make_tensor(&[5.0, 5.0, 5.0, 5.0], &[4]);
715        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
716        let rt: Tensor<f32> = dequantize(&qt).unwrap();
717
718        let recovered = rt.data().unwrap();
719        for &r in recovered {
720            assert!(
721                (r - 5.0).abs() < 0.1,
722                "constant tensor dequantized to {r}, expected 5.0"
723            );
724        }
725    }
726
727    #[test]
728    fn test_quantize_single_element() {
729        let t = make_tensor(&[42.0], &[1]);
730        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
731        let rt: Tensor<f32> = dequantize(&qt).unwrap();
732        assert!((rt.data().unwrap()[0] - 42.0).abs() < 0.5);
733    }
734
735    #[test]
736    fn test_per_channel_int4() {
737        // 2 channels, 3 elements each.
738        let data = vec![0.0, 1.0, 2.0, -4.0, 0.0, 4.0];
739        let t = make_tensor(&data, &[2, 3]);
740        let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int4).unwrap();
741
742        assert_eq!(qt.scale.len(), 2);
743        assert_eq!(qt.zero_point.len(), 2);
744
745        let rt: Tensor<f32> = dequantize(&qt).unwrap();
746        let orig = t.data().unwrap();
747        let recovered = rt.data().unwrap();
748        for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
749            let err = (o - r).abs();
750            // INT4 has coarse resolution, but channel-level ranges are small.
751            assert!(
752                err < 1.0,
753                "element {i}: original={o}, recovered={r}, error={err}"
754            );
755        }
756    }
757
758    #[test]
759    fn test_dequantize_f64() {
760        let data = vec![1.0f32, 2.0, 3.0, 4.0];
761        let t = crate::from_slice(&data, &[4]).unwrap();
762        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
763        let rt: Tensor<f64> = dequantize(&qt).unwrap();
764
765        assert_eq!(rt.shape(), &[4]);
766        let recovered = rt.data().unwrap();
767        for (i, &r) in recovered.iter().enumerate() {
768            let expected = data[i] as f64;
769            let err = (expected - r).abs();
770            assert!(
771                err < 0.05,
772                "element {i}: expected={expected}, recovered={r}, error={err}"
773            );
774        }
775    }
776
777    #[test]
778    fn test_quantized_tensor_accessors() {
779        let t = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
780        let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
781
782        assert_eq!(qt.numel(), 6);
783        assert_eq!(qt.shape(), &[2, 3]);
784        assert_eq!(qt.data().len(), 6);
785        assert_eq!(qt.scale().len(), 1);
786        assert_eq!(qt.zero_point().len(), 1);
787        assert_eq!(qt.scheme(), QuantScheme::PerTensor);
788        assert_eq!(qt.qdtype(), QuantDtype::Int8);
789    }
790}