use alloc::vec;
use alloc::vec::Vec;
use core::{default::Default, ops::Deref};
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct QuantScheme {
pub value: QuantValue,
pub param: QuantParam,
pub store: QuantStore,
pub level: QuantLevel,
pub mode: QuantMode,
}
impl Default for QuantScheme {
fn default() -> Self {
Self {
value: QuantValue::Q8F,
param: QuantParam::F32,
store: QuantStore::U32,
level: QuantLevel::Tensor,
mode: QuantMode::Symmetric,
}
}
}
impl QuantScheme {
pub fn with_level(mut self, level: QuantLevel) -> Self {
self.level = level;
self
}
pub fn with_mode(mut self, mode: QuantMode) -> Self {
self.mode = mode;
self
}
pub fn with_value(mut self, value: QuantValue) -> Self {
self.value = value;
self
}
pub fn with_store(mut self, store: QuantStore) -> Self {
self.store = store;
self
}
pub fn with_param(mut self, param: QuantParam) -> Self {
self.param = param;
self
}
pub fn size_bits_stored(&self) -> usize {
self.store.size_bits(&self.value).max(8)
}
pub fn size_bits_value(&self) -> usize {
self.value.size_bits()
}
pub fn num_quants(&self) -> usize {
self.size_bits_stored() / self.value.size_bits()
}
pub fn native_packing(&self) -> usize {
self.value.native_packing()
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantLevel {
Tensor,
Block(BlockSize),
}
impl QuantLevel {
pub fn block(values: impl AsRef<[u8]>) -> Self {
QuantLevel::Block(BlockSize::new(values))
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantValue {
Q8F,
E5M2,
E4M3,
Q4F,
E2M1,
Q2F,
Q8S,
Q4S,
Q2S,
}
impl QuantValue {
pub fn size_bits(&self) -> usize {
match self {
QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
QuantValue::Q2F | QuantValue::Q2S => 2,
}
}
pub fn native_packing(&self) -> usize {
match self {
QuantValue::E2M1 => 2,
_ => 1,
}
}
pub fn range(&self) -> (f32, f32) {
match self {
QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
QuantValue::Q4F => (-8.0, 7.0),
QuantValue::Q2F => (-2.0, 1.0),
QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
QuantValue::Q4S => (-7.0, 7.0),
QuantValue::Q2S => (-1.0, 1.0),
QuantValue::E4M3 => (-448.0, 448.0),
QuantValue::E5M2 => (-57344.0, 57344.0),
QuantValue::E2M1 => (-6.0, 6.0), }
}
pub fn is_symmetric(&self) -> bool {
match self {
Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
Self::Q8S | Self::Q4S | Self::Q2S => true,
}
}
}
impl QuantStore {
pub fn size_bits(&self, value: &QuantValue) -> usize {
match self {
QuantStore::Native => value.size_bits(),
QuantStore::U32 => 32,
}
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantStore {
Native,
U32,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantMode {
Symmetric,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantParam {
F32,
F16,
BF16,
UE8M0,
UE4M3,
}
const MAX_DIMS: usize = 5;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct BlockSize {
storage: [u8; MAX_DIMS],
len: u8,
}
impl BlockSize {
pub const MAX_DIMS: usize = MAX_DIMS;
pub fn new(values: impl AsRef<[u8]>) -> Self {
let values = values.as_ref();
debug_assert!(
values.len() <= MAX_DIMS,
"Tried creating a block size larger than the cap"
);
let len = values.len().min(MAX_DIMS);
let mut storage = [1; MAX_DIMS];
storage[..len].copy_from_slice(&values[..len]);
Self {
storage,
len: len as u8,
}
}
pub fn as_slice(&self) -> &[u8] {
&self.storage[..self.len as usize]
}
pub fn to_vec(&self) -> Vec<u8> {
self.storage[..self.len as usize].to_vec()
}
pub fn as_dim<const N: usize>(&self) -> [u8; N] {
let data_len = N.min(self.len as usize);
let data_start = N - data_len;
let mut out = [1; N];
out[data_start..].copy_from_slice(&self.storage[..data_len]);
out
}
pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
let data_len = len.min(self.len as usize);
let data_start = len - data_len;
let mut out = vec![1; len];
out[data_start..].copy_from_slice(&self.storage[..data_len]);
out
}
pub fn iter(&self) -> impl Iterator<Item = &u8> {
self.as_slice().iter()
}
pub fn num_elements(&self) -> usize {
self.iter().map(|it| *it as usize).product()
}
}
impl Deref for BlockSize {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl<T: AsRef<[u8]>> From<T> for BlockSize {
fn from(value: T) -> Self {
BlockSize::new(value)
}
}