use core::any::TypeId;
use crate::{Bytes, Element};
use alloc::vec::Vec;
use super::{
pack_i8s_to_u32s, unpack_u32s_to_i8s, AffineQuantization, QParams, Quantization,
QuantizationScheme, QuantizationStrategy, QuantizationType, SymmetricQuantization,
};
pub struct QuantizedBytes {
pub bytes: Bytes,
pub scheme: QuantizationScheme,
pub num_elements: usize,
}
impl QuantizedBytes {
pub fn new<E: Element>(value: Vec<E>, strategy: QuantizationStrategy) -> Self {
let mut bytes: Bytes;
let num_elements = value.len();
match strategy {
QuantizationStrategy::PerTensorAffineInt8(q) => {
if TypeId::of::<E>() == TypeId::of::<i8>() {
let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
bytes = Bytes::from_elems(u32s);
} else {
panic!("Invalid quantized type");
}
let offset = q.offset as i32;
let scale_bytes = bytemuck::bytes_of(&q.scale);
let offset_bytes = bytemuck::bytes_of(&offset);
bytes.extend_from_byte_slice_aligned(offset_bytes, align_of::<i32>());
bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
}
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
if TypeId::of::<E>() == TypeId::of::<i8>() {
let u32s = pack_i8s_to_u32s(bytemuck::allocation::cast_vec(value));
bytes = Bytes::from_elems(u32s);
} else {
panic!("Invalid quantized type");
}
let scale_bytes = bytemuck::bytes_of(&q.scale);
bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
}
}
Self {
bytes,
scheme: strategy.scheme(),
num_elements,
}
}
pub fn into_vec_i8(self) -> (Vec<i8>, QParams<f32, i8>) {
let numel = self.num_elements;
let scheme = self.scheme;
let (values, qparams) = self.split_values_off();
let values = unpack_u32s_to_i8s(values, numel);
let scale_size = core::mem::size_of::<f32>(); let qparams_bytes = bytemuck::cast_slice(&qparams);
let total_bytes = qparams_bytes.len();
let scale = *bytemuck::checked::from_bytes(&qparams_bytes[total_bytes - scale_size..]);
let offset = match scheme {
QuantizationScheme::PerTensorAffine(_) => {
let offset_size = core::mem::size_of::<i32>(); Some(*bytemuck::checked::from_bytes::<i32>(
&qparams_bytes
[total_bytes - scale_size - offset_size..total_bytes - scale_size],
) as i8)
}
QuantizationScheme::PerTensorSymmetric(_) => None,
};
(values, QParams { scale, offset })
}
fn split_values_off(self) -> (Vec<u32>, Vec<u32>) {
let mut values = match self.bytes.align() {
1 => {
let bytes = self.bytes.try_into_vec::<u8>().unwrap();
#[cfg(target_endian = "little")]
{
unsafe { reinterpret_vec(bytes) }
}
#[cfg(target_endian = "big")]
{
pack_i8s_to_u32s(bytemuck::allocation::cast_vec(bytes))
}
}
4 => self.bytes.try_into_vec::<u32>().unwrap(),
_ => unreachable!(),
};
let scale_size = 1; let mut values_end = values.len() - scale_size;
if let QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) = self.scheme {
values_end -= 1; }
let qparams = values.split_off(values_end);
(values, qparams)
}
pub fn dequantize(self) -> (Vec<f32>, QParams<f32, i8>) {
match self.scheme {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
let (values, qparams) = self.into_vec_i8();
let strategy = AffineQuantization::<f32, i8, i32>::init(
qparams.scale,
qparams.offset.unwrap(),
);
(strategy.dequantize(&values), qparams)
}
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
let (values, qparams) = self.into_vec_i8();
let strategy = SymmetricQuantization::<f32, i8>::init(qparams.scale);
(strategy.dequantize(&values), qparams)
}
}
}
}
unsafe fn reinterpret_vec<T, U>(mut input: Vec<T>) -> Vec<U> {
assert!(
input.as_mut_ptr().align_offset(align_of::<U>()) == 0,
"Alignment mismatch"
);
assert!(
size_of::<T>() != 0 && size_of::<U>() != 0,
"Zero-sized types not allowed"
);
assert!(
input.len() * size_of::<T>() % size_of::<U>() == 0,
"Size mismatch"
);
let len = input.len() * size_of::<T>() / size_of::<U>();
let cap = input.capacity() * size_of::<T>() / size_of::<U>();
let ptr = input.as_mut_ptr() as *mut U;
core::mem::forget(input);
Vec::from_raw_parts(ptr, len, cap)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn should_pack_unpack_quantization_parameters_symmetric() {
let scale = 0.03937008;
let values = vec![0i8, 25, 51, 76, 102, 127];
let q_bytes = QuantizedBytes::new(
values.clone(),
QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)),
);
let (q_values, qparams) = q_bytes.into_vec_i8();
assert_eq!(qparams.scale, scale);
assert_eq!(qparams.offset, None);
assert_eq!(q_values, values);
}
#[test]
fn should_pack_unpack_quantization_parameters_affine() {
let scale = 0.019607844;
let offset = -128;
let values = vec![-128i8, -77, -26, 25, 76, 127];
let q_bytes = QuantizedBytes::new(
values.clone(),
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(scale, offset)),
);
let (q_values, qparams) = q_bytes.into_vec_i8();
assert_eq!(qparams.scale, scale);
assert_eq!(qparams.offset, Some(offset));
assert_eq!(q_values, values);
}
}