1use alloc::vec;
2use alloc::vec::Vec;
3use core::{default::Default, ops::Deref};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
8pub struct QuantScheme {
9 pub value: QuantValue,
13 pub param: QuantParam,
15 pub store: QuantStore,
17 pub level: QuantLevel,
19 pub mode: QuantMode,
21}
22
23impl Default for QuantScheme {
24 fn default() -> Self {
25 Self {
26 value: QuantValue::Q8F,
27 param: QuantParam::F32,
28 store: QuantStore::PackedU32(0),
29 level: QuantLevel::Tensor,
30 mode: QuantMode::Symmetric,
31 }
32 }
33}
34
35impl QuantScheme {
36 pub fn with_level(mut self, level: QuantLevel) -> Self {
38 self.level = level;
39 self
40 }
41
42 pub fn with_mode(mut self, mode: QuantMode) -> Self {
44 self.mode = mode;
45 self
46 }
47
48 pub fn with_value(mut self, value: QuantValue) -> Self {
50 self.value = value;
51 self
52 }
53
54 pub fn with_store(mut self, store: QuantStore) -> Self {
56 self.store = store;
57 self
58 }
59
60 pub fn with_param(mut self, param: QuantParam) -> Self {
62 self.param = param;
63 self
64 }
65
66 pub fn size_bits_stored(&self) -> usize {
68 self.store.size_bits(&self.value)
69 }
70
71 pub fn size_bits_value(&self) -> usize {
73 self.value.size_bits()
74 }
75
76 pub fn num_quants(&self) -> usize {
78 self.size_bits_stored() / self.value.size_bits()
79 }
80
81 pub fn native_packing(&self) -> usize {
84 self.value.native_packing()
85 }
86
87 pub fn packing_dim(&self) -> Option<usize> {
89 self.store.packing_dim()
90 }
91
92 pub fn swap_packing_dim(&mut self, dim0: usize, dim1: usize) {
95 if let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
96 &mut self.store
97 {
98 if *packed_dim == dim0 {
99 *packed_dim = dim1;
100 } else if *packed_dim == dim1 {
101 *packed_dim = dim0;
102 }
103 }
104 }
105}
106
107#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
109pub enum QuantLevel {
110 Tensor,
112 Block(BlockSize),
114}
115
116impl QuantLevel {
117 pub fn block(values: impl AsRef<[u8]>) -> Self {
119 QuantLevel::Block(BlockSize::new(values))
120 }
121}
122
123#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
125pub enum QuantValue {
126 Q8F,
128 E5M2,
130 E4M3,
132 Q4F,
134 E2M1,
136 Q2F,
138 Q8S,
140 Q4S,
142 Q2S,
144}
145
146impl QuantValue {
147 pub fn size_bits(&self) -> usize {
149 match self {
150 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
151 QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
152 QuantValue::Q2F | QuantValue::Q2S => 2,
153 }
154 }
155
156 pub fn native_packing(&self) -> usize {
159 match self {
160 QuantValue::E2M1 => 2,
161 _ => 1,
162 }
163 }
164
165 pub fn range(&self) -> (f32, f32) {
167 match self {
168 QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
169 QuantValue::Q4F => (-8.0, 7.0),
170 QuantValue::Q2F => (-2.0, 1.0),
171 QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
172 QuantValue::Q4S => (-7.0, 7.0),
173 QuantValue::Q2S => (-1.0, 1.0),
174 QuantValue::E4M3 => (-448.0, 448.0),
175 QuantValue::E5M2 => (-57344.0, 57344.0),
176 QuantValue::E2M1 => (-6.0, 6.0), }
178 }
179
180 pub fn is_symmetric(&self) -> bool {
182 match self {
183 Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
184 Self::Q8S | Self::Q4S | Self::Q2S => true,
185 }
186 }
187}
188
189impl QuantStore {
190 pub fn size_bits(&self, value: &QuantValue) -> usize {
192 match self {
193 QuantStore::Native => value.size_bits(),
194 QuantStore::PackedNative(_) => value.size_bits() * value.native_packing(),
195 QuantStore::PackedU32(_) => 32,
196 }
197 }
198
199 fn packing_dim(&self) -> Option<usize> {
200 match self {
201 QuantStore::Native => None,
202 QuantStore::PackedNative(packing_dim) | QuantStore::PackedU32(packing_dim) => {
203 Some(*packing_dim)
204 }
205 }
206 }
207}
208
209#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
211pub enum QuantStore {
212 Native,
214 PackedNative(usize),
217 PackedU32(usize),
220 }
223
224#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
226pub enum QuantMode {
227 Symmetric,
229}
230
231#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
236pub enum QuantParam {
237 F32,
239 F16,
241 BF16,
243 UE8M0,
245 UE4M3,
247}
248
249const MAX_DIMS: usize = 5;
250
251#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
253pub struct BlockSize {
254 storage: [u8; MAX_DIMS],
255 len: u8,
256}
257
258impl core::fmt::Debug for BlockSize {
259 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
260 write!(f, "BlockSize({:?})", self.as_slice())
261 }
262}
263
264impl BlockSize {
265 pub const MAX_DIMS: usize = MAX_DIMS;
267
268 pub fn new(values: impl AsRef<[u8]>) -> Self {
270 let values = values.as_ref();
271 debug_assert!(
272 values.len() <= MAX_DIMS,
273 "Tried creating a block size larger than the cap"
274 );
275 let len = values.len().min(MAX_DIMS);
276 let mut storage = [1; MAX_DIMS];
277 storage[..len].copy_from_slice(&values[..len]);
278 Self {
279 storage,
280 len: len as u8,
281 }
282 }
283
284 pub fn new_trim(values: impl AsRef<[u8]>) -> Self {
287 let values = values.as_ref();
288 let first_value = values.iter().position(|s| *s != 1).unwrap_or(0);
289 Self::new(&values[first_value..])
290 }
291
292 pub fn as_slice(&self) -> &[u8] {
294 &self.storage[..self.len as usize]
295 }
296
297 pub fn to_vec(&self) -> Vec<u8> {
299 self.storage[..self.len as usize].to_vec()
300 }
301
302 pub fn as_dim<const N: usize>(&self) -> [u8; N] {
304 let data_len = N.min(self.len as usize);
305 let data_start = N - data_len;
306 let mut out = [1; N];
307 out[data_start..].copy_from_slice(&self.storage[..data_len]);
308 out
309 }
310
311 pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
313 let data_len = len.min(self.len as usize);
314 let data_start = len - data_len;
315 let mut out = vec![1; len];
316 out[data_start..].copy_from_slice(&self.storage[..data_len]);
317 out
318 }
319
320 pub fn iter(&self) -> impl Iterator<Item = &u8> {
322 self.as_slice().iter()
323 }
324
325 pub fn num_elements(&self) -> usize {
327 self.iter().map(|it| *it as usize).product()
328 }
329}
330
331impl Deref for BlockSize {
332 type Target = [u8];
333
334 fn deref(&self) -> &Self::Target {
335 self.as_slice()
336 }
337}
338
339impl<T: AsRef<[u8]>> From<T> for BlockSize {
340 fn from(value: T) -> Self {
341 BlockSize::new(value)
342 }
343}