burn_tensor/tensor/quantization/
strategy.rsuse core::{
hash::{Hash, Hasher},
marker::PhantomData,
};
use alloc::vec::Vec;
use burn_common::{iter_par, run_par};
use num_traits::{Float, PrimInt};
use serde::{Deserialize, Serialize};
use super::{QuantizationScheme, QuantizationType};
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationStrategy {
PerTensorAffineInt8(AffineQuantization<f32, i8, i32>),
PerTensorSymmetricInt8(SymmetricQuantization<f32, i8>),
}
impl QuantizationStrategy {
pub fn scheme(&self) -> QuantizationScheme {
match self {
QuantizationStrategy::PerTensorAffineInt8(_) => {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8)
}
}
}
}
pub trait Quantization<E: Float, Q: PrimInt> {
fn new(alpha: E, beta: E) -> Self;
fn quantize(&self, values: &[E]) -> Vec<Q>;
fn dequantize(&self, values: &[Q]) -> Vec<E>;
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
pub scale: E,
pub offset: Q,
_a: PhantomData<A>,
}
impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
pub fn init(scale: E, offset: Q) -> Self {
let mut scale = scale;
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
Self {
scale,
offset,
_a: PhantomData,
}
}
}
impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization<E, Q, A> {
fn new(alpha: E, beta: E) -> Self {
let a = E::from(Q::min_value()).unwrap();
let b = E::from(Q::max_value()).unwrap();
let alpha = E::min(alpha, E::zero());
let beta = E::max(beta, E::zero());
let scale = (beta - alpha) / (b - a);
let z = -(alpha / scale - a);
Self::init(scale, Q::from(z).unwrap())
}
fn quantize(&self, values: &[E]) -> Vec<Q> {
let a = E::from(Q::min_value()).unwrap();
let b = E::from(Q::max_value()).unwrap();
let z = E::from(self.offset).unwrap();
run_par!(|| {
iter_par!(values.iter())
.map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap())
.collect()
})
}
fn dequantize(&self, values: &[Q]) -> Vec<E> {
run_par!(|| {
iter_par!(values.iter())
.map(|x_q| {
self.scale
* (E::from(
A::from(*x_q)
.unwrap()
.saturating_sub(A::from(self.offset).unwrap()),
)
.unwrap())
})
.collect()
})
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
pub scale: E,
_q: PhantomData<Q>,
}
impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
pub fn init(scale: E) -> Self {
let mut scale = scale;
if scale.eq(&E::zero()) {
scale = E::from(0.1).unwrap();
}
Self {
scale,
_q: PhantomData,
}
}
}
impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
fn new(alpha: E, beta: E) -> Self {
assert!(
!Q::min_value().is_zero(),
"Symmetric quantization is only valid for signed integers."
);
let b = E::from(Q::max_value()).unwrap();
let a = b.neg();
let alpha = alpha.abs().max(beta.abs());
Self::init((alpha + alpha) / (b - a))
}
fn quantize(&self, values: &[E]) -> Vec<Q> {
let b = E::from(Q::max_value()).unwrap();
let a = b.neg();
values
.iter()
.map(|x| Q::from(x.div(self.scale).round().clamp(a, b)).unwrap())
.collect()
}
fn dequantize(&self, values: &[Q]) -> Vec<E> {
values
.iter()
.map(|x_q| self.scale * E::from(*x_q).unwrap())
.collect()
}
}
const SIGN_MASK: u64 = 0x8000000000000000u64;
const EXP_MASK: u64 = 0x7ff0000000000000u64;
const MAN_MASK: u64 = 0x000fffffffffffffu64;
#[inline]
fn raw_double_bits<F: Float>(f: &F) -> u64 {
let (man, exp, sign) = f.integer_decode();
let exp_u64 = exp as u16 as u64;
let sign_u64 = (sign > 0) as u64;
(man & MAN_MASK) | ((exp_u64 << 52) & EXP_MASK) | ((sign_u64 << 63) & SIGN_MASK)
}
#[inline(always)]
fn canonicalize_signed_zero<T: Float>(x: T) -> T {
x + T::zero()
}
impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q, A> {
fn hash<H: Hasher>(&self, state: &mut H) {
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
bits.hash(state);
self.offset.hash(state);
}
}
impl<E: Float, Q: PrimInt, A: PrimInt> PartialEq for AffineQuantization<E, Q, A> {
fn eq(&self, other: &Self) -> bool {
self.scale == other.scale && self.offset == other.offset
}
}
impl<E: Float, Q: PrimInt, A: PrimInt> Eq for AffineQuantization<E, Q, A> {}
impl<E: Float, Q: PrimInt> Hash for SymmetricQuantization<E, Q> {
fn hash<H: Hasher>(&self, state: &mut H) {
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
bits.hash(state);
}
}
impl<E: Float, Q: PrimInt> PartialEq for SymmetricQuantization<E, Q> {
fn eq(&self, other: &Self) -> bool {
self.scale == other.scale
}
}
impl<E: Float, Q: PrimInt> Eq for SymmetricQuantization<E, Q> {}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_int8_affine_quantization() {
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
let expected_q = vec![-128, -40, 71, 126];
let expected_d = vec![-1.794902, -1.0011765, 0.0, 0.49607843];
let affine = AffineQuantization::<f32, i8, i32>::new(-1.8, 0.5);
let q = affine.quantize(&x);
assert_eq!(q, expected_q);
let d = affine.dequantize(&expected_q);
assert_eq!(d, expected_d);
}
#[test]
fn test_affine_should_ensure_zero_point() {
let x: [f32; 6] = [2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let expected_q = vec![-26, -77, -26, 25, 76, 127];
let expected_d = x.to_vec();
let affine = AffineQuantization::<f32, i8, i32>::new(1.0, 5.0);
assert_eq!(affine.offset, -128);
assert_eq!(affine.scale, 0.019607844);
let q = affine.quantize(&x);
assert_eq!(q, expected_q);
let d = affine.dequantize(&expected_q);
assert_eq!(d, expected_d);
}
#[test]
fn test_int8_symmetric_quantization() {
let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];
let expected_q = vec![-127, -71, 0, 35];
let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];
let symmetric = SymmetricQuantization::<f32, i8>::new(-1.8, 0.5);
let q: Vec<i8> = symmetric.quantize(&x);
assert_eq!(q, expected_q);
let d = symmetric.dequantize(&expected_q);
assert_eq!(d, expected_d);
}
}