burn_tensor/tensor/quantization/
scheme.rs1#![allow(missing_docs)] use serde::{Deserialize, Serialize};
4
5use crate::{backend::Backend, Tensor, TensorPrimitive};
6
7use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive};
8
9#[cfg(feature = "cubecl")]
10use cubecl::prelude::*;
11
12#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
14#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))]
15pub enum QuantizationType {
16 QInt8,
18}
19
20#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
22#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))]
23pub enum QuantizationScheme {
24 PerTensorAffine(QuantizationType),
26 PerTensorSymmetric(QuantizationType),
28 }
33
34#[cfg(feature = "cubecl")]
35impl CubeType for QuantizationScheme {
36 type ExpandType = Self;
37}
38#[cfg(feature = "cubecl")]
39impl cubecl::frontend::Init for QuantizationScheme {
40 fn init(self, _context: &mut CubeContext) -> Self {
41 self
42 }
43}
44
45impl QuantizationScheme {
46 pub fn compute_q_params<B: Backend>(
48 &self,
49 range: CalibrationRange<B>,
50 ) -> QuantizationParameters<B> {
51 match self {
52 QuantizationScheme::PerTensorAffine(dtype) => match dtype {
53 QuantizationType::QInt8 => {
54 let a = i8::MIN as i32;
56 let b = i8::MAX as i32;
57
58 let zero = Tensor::zeros_like(&range.min);
62 let min = range.min.min_pair(zero);
63 let zero = Tensor::zeros_like(&range.max);
64 let max = range.max.max_pair(zero);
65
66 let scale = max.sub(min.clone()).div_scalar(b - a);
69 let scale = scale.clone().mask_fill(scale.equal_elem(0.), 0.1);
70 let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int());
71 QuantizationParameters { scale, offset }
72 }
73 },
74 QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
75 QuantizationType::QInt8 => {
76 let b = i8::MAX as i32;
78 let a = -b;
79
80 let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
82
83 QuantizationParameters {
84 scale: values_range.div_scalar(b - a),
85 offset: None,
86 }
87 }
88 },
89 }
90 }
91
92 pub(crate) fn compute_q_params_primitive<B: Backend>(
94 &self,
95 min: B::FloatTensorPrimitive,
96 max: B::FloatTensorPrimitive,
97 ) -> QuantizationParametersPrimitive<B> {
98 let range = CalibrationRange {
99 min: Tensor::from_primitive(TensorPrimitive::Float(min)),
100 max: Tensor::from_primitive(TensorPrimitive::Float(max)),
101 };
102 self.compute_q_params(range).into()
103 }
104}