burn_tensor/tensor/quantization/
scheme.rs

1#![allow(missing_docs)] // cube derive macros
2
3use serde::{Deserialize, Serialize};
4
5use crate::{Tensor, TensorPrimitive, backend::Backend};
6
7use super::{
8    Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive,
9};
10
11#[cfg(feature = "cubecl")]
12use cubecl::prelude::*;
13
14/// Quantization data type.
15#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))]
17pub enum QuantizationType {
18    /// 8-bit signed integer.
19    QInt8,
20}
21
22/// Quantization mode.
23#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
24#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
25pub enum QuantizationMode {
26    /// Symmetric or scale quantization.
27    Symmetric,
28}
29
30/// Quantization scheme.
31#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
33pub enum QuantizationScheme {
34    /// Per-tensor quantization.
35    PerTensor(QuantizationMode, QuantizationType),
36}
37
38#[cfg(feature = "cubecl")]
39impl CubeType for QuantizationScheme {
40    type ExpandType = Self;
41}
42
43#[cfg(feature = "cubecl")]
44impl CubeDebug for QuantizationScheme {}
45
46#[cfg(feature = "cubecl")]
47impl cubecl::frontend::Init for QuantizationScheme {
48    fn init(self, _scope: &mut cubecl::ir::Scope) -> Self {
49        self
50    }
51}
52
53impl QuantizationScheme {
54    /// Get the [quantization mode](QuantizationMode)
55    pub fn mode(&self) -> QuantizationMode {
56        match self {
57            QuantizationScheme::PerTensor(mode, ..) => *mode,
58        }
59    }
60
61    /// Compute the quantization range mapping.
62    pub fn compute_range<B: Backend, const D: usize>(
63        &self,
64        tensor: &Tensor<B, D>,
65        calibration: &Calibration,
66    ) -> CalibrationRange<B> {
67        let (min, max) = match &tensor.primitive {
68            TensorPrimitive::Float(tensor) => {
69                self.compute_range_primitive::<B>(tensor.clone(), calibration)
70            }
71            TensorPrimitive::QFloat(_) => unreachable!(),
72        };
73
74        CalibrationRange {
75            min: Tensor::from_primitive(TensorPrimitive::Float(min)),
76            max: Tensor::from_primitive(TensorPrimitive::Float(max)),
77        }
78    }
79
80    pub(crate) fn compute_range_primitive<B: Backend>(
81        &self,
82        tensor: B::FloatTensorPrimitive,
83        calibration: &Calibration,
84    ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
85        match calibration {
86            Calibration::MinMax => match self {
87                QuantizationScheme::PerTensor(_, _) => {
88                    (B::float_min(tensor.clone()), B::float_max(tensor))
89                }
90            },
91        }
92    }
93
94    /// Compute the quantization parameters.
95    pub fn compute_q_params<B: Backend>(
96        &self,
97        range: CalibrationRange<B>,
98    ) -> QuantizationParameters<B> {
99        match self {
100            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
101                // Quantized range `[a, b]`
102                let b = i8::MAX as i32;
103                let a = -b;
104
105                // Compute scale to convert an input value in range `[-alpha, alpha]`
106                let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
107
108                QuantizationParameters {
109                    scale: values_range.div_scalar(b - a),
110                    offset: None,
111                }
112            }
113        }
114    }
115
116    /// Compute the quantization parameters.
117    pub(crate) fn compute_q_params_primitive<B: Backend>(
118        &self,
119        min: B::FloatTensorPrimitive,
120        max: B::FloatTensorPrimitive,
121    ) -> QuantizationParametersPrimitive<B> {
122        let range = CalibrationRange {
123            min: Tensor::from_primitive(TensorPrimitive::Float(min)),
124            max: Tensor::from_primitive(TensorPrimitive::Float(max)),
125        };
126        self.compute_q_params(range).into()
127    }
128
129    pub fn q_type(&self) -> QuantizationType {
130        match self {
131            QuantizationScheme::PerTensor(_, quantization_type) => *quantization_type,
132        }
133    }
134}