burn_tensor/tensor/quantization/
scheme.rsuse serde::{Deserialize, Serialize};
use crate::{backend::Backend, Tensor, TensorPrimitive};
use super::{CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive};
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationType {
QInt8,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationScheme {
PerTensorAffine(QuantizationType),
PerTensorSymmetric(QuantizationType),
}
impl QuantizationScheme {
pub fn compute_q_params<B: Backend>(
&self,
range: CalibrationRange<B>,
) -> QuantizationParameters<B> {
match self {
QuantizationScheme::PerTensorAffine(dtype) => match dtype {
QuantizationType::QInt8 => {
let a = i8::MIN as i32;
let b = i8::MAX as i32;
let zero = Tensor::zeros_like(&range.min);
let min = range.min.min_pair(zero);
let zero = Tensor::zeros_like(&range.max);
let max = range.max.max_pair(zero);
let scale = max.sub(min.clone()).div_scalar(b - a);
let offset = Some(-(min.div(scale.clone()).sub_scalar(a)).int());
QuantizationParameters { scale, offset }
}
},
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => {
let b = i8::MAX as i32;
let a = -b;
let values_range = range.min.abs().max_pair(range.max.abs()).mul_scalar(2);
QuantizationParameters {
scale: values_range.div_scalar(b - a),
offset: None,
}
}
},
}
}
pub(crate) fn compute_q_params_primitive<B: Backend>(
&self,
min: B::FloatTensorPrimitive,
max: B::FloatTensorPrimitive,
) -> QuantizationParametersPrimitive<B> {
let range = CalibrationRange {
min: Tensor::from_primitive(TensorPrimitive::Float(min)),
max: Tensor::from_primitive(TensorPrimitive::Float(max)),
};
self.compute_q_params(range).into()
}
}