burn_tensor/tensor/quantization/
scheme.rs

1#![allow(missing_docs)] // cube derive macros
2
3use serde::{Deserialize, Serialize};
4
5use crate::{backend::Backend, Tensor, TensorPrimitive};
6
7use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive};
8
9#[cfg(feature = "cubecl")]
10use cubecl::prelude::*;
11
12/// Quantization data type.
13#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
14#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))]
15pub enum QuantizationType {
16    /// 8-bit signed integer.
17    QInt8,
18}
19
20/// Quantization scheme.
21#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
22#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
23pub enum QuantizationScheme {
24    /// Per-tensor affine/asymmetric quantization.
25    PerTensorAffine(QuantizationType),
26    /// Per-tensor symmetric quantization.
27    PerTensorSymmetric(QuantizationType),
28    // /// Per-channel affine/asymmetric quantization.
29    // PerChannelAffine,
30    // /// Per-channel symmetric quantization.
31    // PerChannelSymmetric,
32}
33
34#[cfg(feature = "cubecl")]
35impl CubeType for QuantizationScheme {
36    type ExpandType = Self;
37}
38#[cfg(feature = "cubecl")]
39impl cubecl::frontend::Init for QuantizationScheme {
40    fn init(self, _context: &mut CubeContext) -> Self {
41        self
42    }
43}
44
45impl QuantizationScheme {
46    /// Compute the quantization parameters.
47    pub fn compute_q_params<B: Backend>(
48        &self,
49        range: CalibrationRange<B>,
50    ) -> QuantizationParameters<B> {
51        match self {
52            QuantizationScheme::PerTensorAffine(dtype) => match dtype {
53                QuantizationType::QInt8 => {
54                    // Quantized range `[a, b]`
55                    let a = i8::MIN as i32;
56                    let b = i8::MAX as i32;
57
58                    // We extend the `[min, max]` interval to ensure that it contains 0.
59                    // Otherwise, we would not meet the requirement that 0 be an exactly
60                    // representable value (zero-point).
61                    let zero = Tensor::zeros_like(&range.min);
62                    let min = range.min.min_pair(zero);
63                    let zero = Tensor::zeros_like(&range.max);
64                    let max = range.max.max_pair(zero);
65
66                    // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the
67                    // scale to 0.1 to avoid division by zero.
68                    let scale = max.sub(min.clone()).div_scalar(b - a);
69                    let scale = scale.clone().mask_fill(scale.equal_elem(0.), 0.1);
70                    let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int());
71                    QuantizationParameters { scale, offset }
72                }
73            },
74            QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
75                QuantizationType::QInt8 => {
76                    // Quantized range `[a, b]`
77                    let b = i8::MAX as i32;
78                    let a = -b;
79
80                    // Compute scale to convert an input value in range `[-alpha, alpha]`
81                    let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
82
83                    QuantizationParameters {
84                        scale: values_range.div_scalar(b - a),
85                        offset: None,
86                    }
87                }
88            },
89        }
90    }
91
92    /// Compute the quantization parameters.
93    pub(crate) fn compute_q_params_primitive<B: Backend>(
94        &self,
95        min: B::FloatTensorPrimitive,
96        max: B::FloatTensorPrimitive,
97    ) -> QuantizationParametersPrimitive<B> {
98        let range = CalibrationRange {
99            min: Tensor::from_primitive(TensorPrimitive::Float(min)),
100            max: Tensor::from_primitive(TensorPrimitive::Float(max)),
101        };
102        self.compute_q_params(range).into()
103    }
104}