cubecl_common/quant/
scheme.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::{default::Default, ops::Deref};
4use serde::{Deserialize, Serialize};
5
6/// Describes a quantization scheme/configuration.
7#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
8pub struct QuantScheme {
9    /// The logical data type of quantized input values (e.g., QInt8).
10    ///
11    /// This defines how values are interpreted during computation, independent of how they're stored.
12    pub value: QuantValue,
13    /// Precision used for quantization parameters (e.g., scale and biases).
14    pub param: QuantParam,
15    /// Data type used for storing quantized values.
16    pub store: QuantStore,
17    /// Granularity level of quantization (e.g., per-tensor).
18    pub level: QuantLevel,
19    /// Quantization mode (e.g., symmetric).
20    pub mode: QuantMode,
21}
22
23impl Default for QuantScheme {
24    fn default() -> Self {
25        Self {
26            value: QuantValue::Q8F,
27            param: QuantParam::F32,
28            store: QuantStore::PackedU32(0),
29            level: QuantLevel::Tensor,
30            mode: QuantMode::Symmetric,
31        }
32    }
33}
34
35impl QuantScheme {
36    /// Set the quantization level.
37    pub fn with_level(mut self, level: QuantLevel) -> Self {
38        self.level = level;
39        self
40    }
41
42    /// Set the quantization mode.
43    pub fn with_mode(mut self, mode: QuantMode) -> Self {
44        self.mode = mode;
45        self
46    }
47
48    /// Set the data type used for quantized values.
49    pub fn with_value(mut self, value: QuantValue) -> Self {
50        self.value = value;
51        self
52    }
53
54    /// Set the data type used to store quantized values.
55    pub fn with_store(mut self, store: QuantStore) -> Self {
56        self.store = store;
57        self
58    }
59
60    /// Set the precision used for quantization parameters
61    pub fn with_param(mut self, param: QuantParam) -> Self {
62        self.param = param;
63        self
64    }
65
66    /// Returns the size of the quantization storage type in bits.
67    pub fn size_bits_stored(&self) -> usize {
68        self.store.size_bits(&self.value)
69    }
70
71    /// Returns the size of the quantization storage type in bits.
72    pub fn size_bits_value(&self) -> usize {
73        self.value.size_bits()
74    }
75
76    /// Returns the number of quantized values stored in a single element.
77    pub fn num_quants(&self) -> usize {
78        self.size_bits_stored() / self.value.size_bits()
79    }
80
81    /// Returns the native packing factor for the values. When native packing > 1, the packed
82    /// representation stores `num_quants` elements grouped into packs of `native_packing` size.
83    pub fn native_packing(&self) -> usize {
84        self.value.native_packing()
85    }
86
87    /// Returns the packing dim for the store.
88    pub fn packing_dim(&self) -> Option<usize> {
89        self.store.packing_dim()
90    }
91
92    /// Swaps the packing dim if it's either of `dim0` or `dim1`.
93    /// Executes the corresponding update to `shape.swap(dim0, dim1)`.
94    pub fn swap_packing_dim(&mut self, dim0: usize, dim1: usize) {
95        if let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
96            &mut self.store
97        {
98            if *packed_dim == dim0 {
99                *packed_dim = dim1;
100            } else if *packed_dim == dim1 {
101                *packed_dim = dim0;
102            }
103        }
104    }
105}
106
107/// Level or granularity of quantization.
108#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
109pub enum QuantLevel {
110    /// Quantize the whole tensor using a single tensor.
111    Tensor,
112    /// Quantize a tensor using multiple blocks.
113    Block(BlockSize),
114}
115
116impl QuantLevel {
117    /// Converting constructor for [`QuantLevel::Block`]
118    pub fn block(values: impl AsRef<[u8]>) -> Self {
119        QuantLevel::Block(BlockSize::new(values))
120    }
121}
122
123/// Data type used to represent quantized values.
124#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
125pub enum QuantValue {
126    /// 8-bit quantization with full range.
127    Q8F,
128    /// 8-bit floating point, e5m2 format.
129    E5M2,
130    /// 8-bit floating point, e4m3 format.
131    E4M3,
132    /// 4-bit quantization with full range.
133    Q4F,
134    /// 4-bit floating point, e2m1 format.
135    E2M1,
136    /// 2-bit quantization with full range.
137    Q2F,
138    /// 8-bit quantization with symmetric range.
139    Q8S,
140    /// 4-bit quantization with symmetric range.
141    Q4S,
142    /// 2-bit quantization with symmetric range.
143    Q2S,
144}
145
146impl QuantValue {
147    /// Returns the size of the quantization input type in bits.
148    pub fn size_bits(&self) -> usize {
149        match self {
150            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
151            QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
152            QuantValue::Q2F | QuantValue::Q2S => 2,
153        }
154    }
155
156    /// Packing factor for the native representation used for intermediate values. If > 1, values
157    /// should always be processed in `native_packing` sized chunks.
158    pub fn native_packing(&self) -> usize {
159        match self {
160            QuantValue::E2M1 => 2,
161            _ => 1,
162        }
163    }
164
165    /// The possible range of values allowed by the quant value.
166    pub fn range(&self) -> (f32, f32) {
167        match self {
168            QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
169            QuantValue::Q4F => (-8.0, 7.0),
170            QuantValue::Q2F => (-2.0, 1.0),
171            QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
172            QuantValue::Q4S => (-7.0, 7.0),
173            QuantValue::Q2S => (-1.0, 1.0),
174            QuantValue::E4M3 => (-448.0, 448.0),
175            QuantValue::E5M2 => (-57344.0, 57344.0),
176            QuantValue::E2M1 => (-6.0, 6.0), // Hardcoded because of no-std
177        }
178    }
179
180    /// If the range of values is symmetric around zero.
181    pub fn is_symmetric(&self) -> bool {
182        match self {
183            Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
184            Self::Q8S | Self::Q4S | Self::Q2S => true,
185        }
186    }
187}
188
189impl QuantStore {
190    /// Returns the size of the quantization input type in bits.
191    pub fn size_bits(&self, value: &QuantValue) -> usize {
192        match self {
193            QuantStore::Native => value.size_bits(),
194            QuantStore::PackedNative(_) => value.size_bits() * value.native_packing(),
195            QuantStore::PackedU32(_) => 32,
196        }
197    }
198
199    fn packing_dim(&self) -> Option<usize> {
200        match self {
201            QuantStore::Native => None,
202            QuantStore::PackedNative(packing_dim) | QuantStore::PackedU32(packing_dim) => {
203                Some(*packing_dim)
204            }
205        }
206    }
207}
208
209/// Data type used to stored quantized values.
210#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
211pub enum QuantStore {
212    /// Native quantization doesn't require packing and unpacking.
213    Native,
214    /// Store packed quantized values in a natively supported packing format (i.e. e2m1x2).
215    /// Argument is the dimension the tensor is packed on, starting from the innermost dimension.
216    PackedNative(usize),
217    /// Store packed quantized values in a 4-byte unsigned integer.
218    /// Argument is the dimension the tensor is packed on, starting from the innermost dimension.
219    PackedU32(usize),
220    // /// Store packed quantized values in a 8-bit unsigned integer.
221    // U8,
222}
223
224/// Strategy used to quantize values.
225#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
226pub enum QuantMode {
227    /// Symmetric or scale quantization.
228    Symmetric,
229}
230
231/// Quantization floating-point precision.
232///
233/// This is used to represent the floating-point precision of quantization parameters like the scale(s)
234/// or the accumulation precision used during operations like matrix multiplication.
235#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
236pub enum QuantParam {
237    /// Full precision.
238    F32,
239    /// Half precision.
240    F16,
241    /// bfloat16 precision.
242    BF16,
243    /// unsigned floating point, e8m0 format.
244    UE8M0,
245    /// unsigned floating point, e4m3 format.
246    UE4M3,
247}
248
249const MAX_DIMS: usize = 5;
250
251/// Copyable block size, specialized version of `SmallVec`.
252#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
253pub struct BlockSize {
254    storage: [u8; MAX_DIMS],
255    len: u8,
256}
257
258impl core::fmt::Debug for BlockSize {
259    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
260        write!(f, "BlockSize({:?})", self.as_slice())
261    }
262}
263
264impl BlockSize {
265    /// Max number of dimensions for block size
266    pub const MAX_DIMS: usize = MAX_DIMS;
267
268    /// Create a new blocksize from a set of values. The number of values must be `<= MAX_DIMS`.
269    pub fn new(values: impl AsRef<[u8]>) -> Self {
270        let values = values.as_ref();
271        debug_assert!(
272            values.len() <= MAX_DIMS,
273            "Tried creating a block size larger than the cap"
274        );
275        let len = values.len().min(MAX_DIMS);
276        let mut storage = [1; MAX_DIMS];
277        storage[..len].copy_from_slice(&values[..len]);
278        Self {
279            storage,
280            len: len as u8,
281        }
282    }
283
284    /// Create a new blocksize from a set of values. The number of values must be `<= MAX_DIMS`.
285    /// Trims any leading zeros.
286    pub fn new_trim(values: impl AsRef<[u8]>) -> Self {
287        let values = values.as_ref();
288        let first_value = values.iter().position(|s| *s != 1).unwrap_or(0);
289        Self::new(&values[first_value..])
290    }
291
292    /// Return a slice of only the initialized valeus
293    pub fn as_slice(&self) -> &[u8] {
294        &self.storage[..self.len as usize]
295    }
296
297    /// Return a vec of only the initialized values
298    pub fn to_vec(&self) -> Vec<u8> {
299        self.storage[..self.len as usize].to_vec()
300    }
301
302    /// Returns `N` dimensions, unsqueezing if necessary.
303    pub fn as_dim<const N: usize>(&self) -> [u8; N] {
304        let data_len = N.min(self.len as usize);
305        let data_start = N - data_len;
306        let mut out = [1; N];
307        out[data_start..].copy_from_slice(&self.storage[..data_len]);
308        out
309    }
310
311    /// Returns a vector of `len` dimensions, unsqueezing if necessary.
312    pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
313        let data_len = len.min(self.len as usize);
314        let data_start = len - data_len;
315        let mut out = vec![1; len];
316        out[data_start..].copy_from_slice(&self.storage[..data_len]);
317        out
318    }
319
320    /// Create an iterator over all stored dimensions
321    pub fn iter(&self) -> impl Iterator<Item = &u8> {
322        self.as_slice().iter()
323    }
324
325    /// Returns the total number of elements in each block
326    pub fn num_elements(&self) -> usize {
327        self.iter().map(|it| *it as usize).product()
328    }
329}
330
331impl Deref for BlockSize {
332    type Target = [u8];
333
334    fn deref(&self) -> &Self::Target {
335        self.as_slice()
336    }
337}
338
339impl<T: AsRef<[u8]>> From<T> for BlockSize {
340    fn from(value: T) -> Self {
341        BlockSize::new(value)
342    }
343}