1use alloc::vec;
2use alloc::vec::Vec;
3use core::{default::Default, ops::Deref};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
8pub struct QuantScheme {
9    pub value: QuantValue,
13    pub param: QuantParam,
15    pub store: QuantStore,
17    pub level: QuantLevel,
19    pub mode: QuantMode,
21}
22
23impl Default for QuantScheme {
24    fn default() -> Self {
25        Self {
26            value: QuantValue::Q8F,
27            param: QuantParam::F32,
28            store: QuantStore::U32,
29            level: QuantLevel::Tensor,
30            mode: QuantMode::Symmetric,
31        }
32    }
33}
34
35impl QuantScheme {
36    pub fn with_level(mut self, level: QuantLevel) -> Self {
38        self.level = level;
39        self
40    }
41
42    pub fn with_mode(mut self, mode: QuantMode) -> Self {
44        self.mode = mode;
45        self
46    }
47
48    pub fn with_value(mut self, value: QuantValue) -> Self {
50        self.value = value;
51        self
52    }
53
54    pub fn with_store(mut self, store: QuantStore) -> Self {
56        self.store = store;
57        self
58    }
59
60    pub fn with_param(mut self, param: QuantParam) -> Self {
62        self.param = param;
63        self
64    }
65
66    pub fn size_bits_stored(&self) -> usize {
68        self.store.size_bits(&self.value).max(8)
70    }
71
72    pub fn size_bits_value(&self) -> usize {
74        self.value.size_bits()
75    }
76
77    pub fn num_quants(&self) -> usize {
79        self.size_bits_stored() / self.value.size_bits()
80    }
81
82    pub fn native_packing(&self) -> usize {
85        self.value.native_packing()
86    }
87}
88
89#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
91pub enum QuantLevel {
92    Tensor,
94    Block(BlockSize),
96}
97
98impl QuantLevel {
99    pub fn block(values: impl AsRef<[u8]>) -> Self {
101        QuantLevel::Block(BlockSize::new(values))
102    }
103}
104
105#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
107pub enum QuantValue {
108    Q8F,
110    E5M2,
112    E4M3,
114    Q4F,
116    E2M1,
118    Q2F,
120    Q8S,
122    Q4S,
124    Q2S,
126}
127
128impl QuantValue {
129    pub fn size_bits(&self) -> usize {
131        match self {
132            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
133            QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
134            QuantValue::Q2F | QuantValue::Q2S => 2,
135        }
136    }
137
138    pub fn native_packing(&self) -> usize {
141        match self {
142            QuantValue::E2M1 => 2,
143            _ => 1,
144        }
145    }
146
147    pub fn range(&self) -> (f32, f32) {
149        match self {
150            QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
151            QuantValue::Q4F => (-8.0, 7.0),
152            QuantValue::Q2F => (-2.0, 1.0),
153            QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
154            QuantValue::Q4S => (-7.0, 7.0),
155            QuantValue::Q2S => (-1.0, 1.0),
156            QuantValue::E4M3 => (-448.0, 448.0),
157            QuantValue::E5M2 => (-57344.0, 57344.0),
158            QuantValue::E2M1 => (-6.0, 6.0), }
160    }
161
162    pub fn is_symmetric(&self) -> bool {
164        match self {
165            Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
166            Self::Q8S | Self::Q4S | Self::Q2S => true,
167        }
168    }
169}
170
171impl QuantStore {
172    pub fn size_bits(&self, value: &QuantValue) -> usize {
174        match self {
175            QuantStore::Native => value.size_bits(),
176            QuantStore::U32 => 32,
177        }
178    }
179}
180
181#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
183pub enum QuantStore {
184    Native,
186    U32,
188    }
191
192#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
194pub enum QuantMode {
195    Symmetric,
197}
198
199#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
204pub enum QuantParam {
205    F32,
207    F16,
209    BF16,
211    UE8M0,
213    UE4M3,
215}
216
217const MAX_DIMS: usize = 5;
218
219#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
221pub struct BlockSize {
222    storage: [u8; MAX_DIMS],
223    len: u8,
224}
225
226impl BlockSize {
227    pub const MAX_DIMS: usize = MAX_DIMS;
229
230    pub fn new(values: impl AsRef<[u8]>) -> Self {
232        let values = values.as_ref();
233        debug_assert!(
234            values.len() <= MAX_DIMS,
235            "Tried creating a block size larger than the cap"
236        );
237        let len = values.len().min(MAX_DIMS);
238        let mut storage = [1; MAX_DIMS];
239        storage[..len].copy_from_slice(&values[..len]);
240        Self {
241            storage,
242            len: len as u8,
243        }
244    }
245
246    pub fn as_slice(&self) -> &[u8] {
248        &self.storage[..self.len as usize]
249    }
250
251    pub fn to_vec(&self) -> Vec<u8> {
253        self.storage[..self.len as usize].to_vec()
254    }
255
256    pub fn as_dim<const N: usize>(&self) -> [u8; N] {
258        let data_len = N.min(self.len as usize);
259        let data_start = N - data_len;
260        let mut out = [1; N];
261        out[data_start..].copy_from_slice(&self.storage[..data_len]);
262        out
263    }
264
265    pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
267        let data_len = len.min(self.len as usize);
268        let data_start = len - data_len;
269        let mut out = vec![1; len];
270        out[data_start..].copy_from_slice(&self.storage[..data_len]);
271        out
272    }
273
274    pub fn iter(&self) -> impl Iterator<Item = &u8> {
276        self.as_slice().iter()
277    }
278
279    pub fn num_elements(&self) -> usize {
281        self.iter().map(|it| *it as usize).product()
282    }
283}
284
285impl Deref for BlockSize {
286    type Target = [u8];
287
288    fn deref(&self) -> &Self::Target {
289        self.as_slice()
290    }
291}
292
293impl<T: AsRef<[u8]>> From<T> for BlockSize {
294    fn from(value: T) -> Self {
295        BlockSize::new(value)
296    }
297}