burn_tensor/tensor/quantization/
scheme.rs1#![allow(missing_docs)] use serde::{Deserialize, Serialize};
4
5use crate::{Tensor, TensorPrimitive, backend::Backend};
6
7use super::{
8 Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive,
9};
10
11#[cfg(feature = "cubecl")]
12use cubecl::prelude::*;
13
14#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))]
17pub enum QuantizationType {
18 QInt8,
20}
21
22#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
24#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
25pub enum QuantizationMode {
26 Symmetric,
28}
29
30#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
32#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
33pub enum QuantizationScheme {
34 PerTensor(QuantizationMode, QuantizationType),
36}
37
38#[cfg(feature = "cubecl")]
39impl CubeType for QuantizationScheme {
40 type ExpandType = Self;
41}
42
43#[cfg(feature = "cubecl")]
44impl CubeDebug for QuantizationScheme {}
45
46#[cfg(feature = "cubecl")]
47impl cubecl::frontend::Init for QuantizationScheme {
48 fn init(self, _scope: &mut cubecl::ir::Scope) -> Self {
49 self
50 }
51}
52
53impl QuantizationScheme {
54 pub fn mode(&self) -> QuantizationMode {
56 match self {
57 QuantizationScheme::PerTensor(mode, ..) => *mode,
58 }
59 }
60
61 pub fn compute_range<B: Backend, const D: usize>(
63 &self,
64 tensor: &Tensor<B, D>,
65 calibration: &Calibration,
66 ) -> CalibrationRange<B> {
67 let (min, max) = match &tensor.primitive {
68 TensorPrimitive::Float(tensor) => {
69 self.compute_range_primitive::<B>(tensor.clone(), calibration)
70 }
71 TensorPrimitive::QFloat(_) => unreachable!(),
72 };
73
74 CalibrationRange {
75 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
76 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
77 }
78 }
79
80 pub(crate) fn compute_range_primitive<B: Backend>(
81 &self,
82 tensor: B::FloatTensorPrimitive,
83 calibration: &Calibration,
84 ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {
85 match calibration {
86 Calibration::MinMax => match self {
87 QuantizationScheme::PerTensor(_, _) => {
88 (B::float_min(tensor.clone()), B::float_max(tensor))
89 }
90 },
91 }
92 }
93
94 pub fn compute_q_params<B: Backend>(
96 &self,
97 range: CalibrationRange<B>,
98 ) -> QuantizationParameters<B> {
99 match self {
100 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
101 let b = i8::MAX as i32;
103 let a = -b;
104
105 let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
107
108 QuantizationParameters {
109 scale: values_range.div_scalar(b - a),
110 offset: None,
111 }
112 }
113 }
114 }
115
116 pub(crate) fn compute_q_params_primitive<B: Backend>(
118 &self,
119 min: B::FloatTensorPrimitive,
120 max: B::FloatTensorPrimitive,
121 ) -> QuantizationParametersPrimitive<B> {
122 let range = CalibrationRange {
123 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
124 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
125 };
126 self.compute_q_params(range).into()
127 }
128
129 pub fn q_type(&self) -> QuantizationType {
130 match self {
131 QuantizationScheme::PerTensor(_, quantization_type) => *quantization_type,
132 }
133 }
134}