1use alloc::vec;
2use alloc::vec::Vec;
3use core::{default::Default, ops::Deref};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
8pub struct QuantScheme {
9 pub value: QuantValue,
13 pub param: QuantParam,
15 pub store: QuantStore,
17 pub level: QuantLevel,
19 pub mode: QuantMode,
21}
22
23impl Default for QuantScheme {
24 fn default() -> Self {
25 Self {
26 value: QuantValue::Q8F,
27 param: QuantParam::F32,
28 store: QuantStore::U32,
29 level: QuantLevel::Tensor,
30 mode: QuantMode::Symmetric,
31 }
32 }
33}
34
35impl QuantScheme {
36 pub fn with_level(mut self, level: QuantLevel) -> Self {
38 self.level = level;
39 self
40 }
41
42 pub fn with_mode(mut self, mode: QuantMode) -> Self {
44 self.mode = mode;
45 self
46 }
47
48 pub fn with_value(mut self, value: QuantValue) -> Self {
50 self.value = value;
51 self
52 }
53
54 pub fn with_store(mut self, store: QuantStore) -> Self {
56 self.store = store;
57 self
58 }
59
60 pub fn with_param(mut self, param: QuantParam) -> Self {
62 self.param = param;
63 self
64 }
65
66 pub fn size_bits_stored(&self) -> usize {
68 self.store.size_bits(&self.value).max(8)
70 }
71
72 pub fn size_bits_value(&self) -> usize {
74 self.value.size_bits()
75 }
76
77 pub fn num_quants(&self) -> usize {
79 self.size_bits_stored() / self.value.size_bits()
80 }
81
82 pub fn native_packing(&self) -> usize {
85 self.value.native_packing()
86 }
87}
88
89#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
91pub enum QuantLevel {
92 Tensor,
94 Block(BlockSize),
96}
97
98impl QuantLevel {
99 pub fn block(values: impl AsRef<[u8]>) -> Self {
101 QuantLevel::Block(BlockSize::new(values))
102 }
103}
104
105#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
107pub enum QuantValue {
108 Q8F,
110 E5M2,
112 E4M3,
114 Q4F,
116 E2M1,
118 Q2F,
120 Q8S,
122 Q4S,
124 Q2S,
126}
127
128impl QuantValue {
129 pub fn size_bits(&self) -> usize {
131 match self {
132 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
133 QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
134 QuantValue::Q2F | QuantValue::Q2S => 2,
135 }
136 }
137
138 pub fn native_packing(&self) -> usize {
141 match self {
142 QuantValue::E2M1 => 2,
143 _ => 1,
144 }
145 }
146
147 pub fn range(&self) -> (f32, f32) {
149 match self {
150 QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
151 QuantValue::Q4F => (-8.0, 7.0),
152 QuantValue::Q2F => (-2.0, 1.0),
153 QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
154 QuantValue::Q4S => (-7.0, 7.0),
155 QuantValue::Q2S => (-1.0, 1.0),
156 QuantValue::E4M3 => (-448.0, 448.0),
157 QuantValue::E5M2 => (-57344.0, 57344.0),
158 QuantValue::E2M1 => (-6.0, 6.0), }
160 }
161
162 pub fn is_symmetric(&self) -> bool {
164 match self {
165 Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
166 Self::Q8S | Self::Q4S | Self::Q2S => true,
167 }
168 }
169}
170
171impl QuantStore {
172 pub fn size_bits(&self, value: &QuantValue) -> usize {
174 match self {
175 QuantStore::Native => value.size_bits(),
176 QuantStore::U32 => 32,
177 }
178 }
179}
180
181#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
183pub enum QuantStore {
184 Native,
186 U32,
188 }
191
192#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
194pub enum QuantMode {
195 Symmetric,
197}
198
199#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
204pub enum QuantParam {
205 F32,
207 F16,
209 BF16,
211 UE8M0,
213 UE4M3,
215}
216
217const MAX_DIMS: usize = 5;
218
219#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
221pub struct BlockSize {
222 storage: [u8; MAX_DIMS],
223 len: u8,
224}
225
226impl BlockSize {
227 pub const MAX_DIMS: usize = MAX_DIMS;
229
230 pub fn new(values: impl AsRef<[u8]>) -> Self {
232 let values = values.as_ref();
233 debug_assert!(
234 values.len() <= MAX_DIMS,
235 "Tried creating a block size larger than the cap"
236 );
237 let len = values.len().min(MAX_DIMS);
238 let mut storage = [1; MAX_DIMS];
239 storage[..len].copy_from_slice(&values[..len]);
240 Self {
241 storage,
242 len: len as u8,
243 }
244 }
245
246 pub fn as_slice(&self) -> &[u8] {
248 &self.storage[..self.len as usize]
249 }
250
251 pub fn to_vec(&self) -> Vec<u8> {
253 self.storage[..self.len as usize].to_vec()
254 }
255
256 pub fn as_dim<const N: usize>(&self) -> [u8; N] {
258 let data_len = N.min(self.len as usize);
259 let data_start = N - data_len;
260 let mut out = [1; N];
261 out[data_start..].copy_from_slice(&self.storage[..data_len]);
262 out
263 }
264
265 pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
267 let data_len = len.min(self.len as usize);
268 let data_start = len - data_len;
269 let mut out = vec![1; len];
270 out[data_start..].copy_from_slice(&self.storage[..data_len]);
271 out
272 }
273
274 pub fn iter(&self) -> impl Iterator<Item = &u8> {
276 self.as_slice().iter()
277 }
278
279 pub fn num_elements(&self) -> usize {
281 self.iter().map(|it| *it as usize).product()
282 }
283}
284
285impl Deref for BlockSize {
286 type Target = [u8];
287
288 fn deref(&self) -> &Self::Target {
289 self.as_slice()
290 }
291}
292
293impl<T: AsRef<[u8]>> From<T> for BlockSize {
294 fn from(value: T) -> Self {
295 BlockSize::new(value)
296 }
297}