cubecl_common/quant/
scheme.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::{default::Default, ops::Deref};
4use serde::{Deserialize, Serialize};
5
6/// Describes a quantization scheme/configuration.
7#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
8pub struct QuantScheme {
9    /// The logical data type of quantized input values (e.g., QInt8).
10    ///
11    /// This defines how values are interpreted during computation, independent of how they're stored.
12    pub value: QuantValue,
13    /// Precision used for quantization parameters (e.g., scale and biases).
14    pub param: QuantParam,
15    /// Data type used for storing quantized values.
16    pub store: QuantStore,
17    /// Granularity level of quantization (e.g., per-tensor).
18    pub level: QuantLevel,
19    /// Quantization mode (e.g., symmetric).
20    pub mode: QuantMode,
21}
22
23impl Default for QuantScheme {
24    fn default() -> Self {
25        Self {
26            value: QuantValue::Q8F,
27            param: QuantParam::F32,
28            store: QuantStore::U32,
29            level: QuantLevel::Tensor,
30            mode: QuantMode::Symmetric,
31        }
32    }
33}
34
35impl QuantScheme {
36    /// Set the quantization level.
37    pub fn with_level(mut self, level: QuantLevel) -> Self {
38        self.level = level;
39        self
40    }
41
42    /// Set the quantization mode.
43    pub fn with_mode(mut self, mode: QuantMode) -> Self {
44        self.mode = mode;
45        self
46    }
47
48    /// Set the data type used for quantized values.
49    pub fn with_value(mut self, value: QuantValue) -> Self {
50        self.value = value;
51        self
52    }
53
54    /// Set the data type used to store quantized values.
55    pub fn with_store(mut self, store: QuantStore) -> Self {
56        self.store = store;
57        self
58    }
59
60    /// Set the precision used for quantization parameters
61    pub fn with_param(mut self, param: QuantParam) -> Self {
62        self.param = param;
63        self
64    }
65
66    /// Returns the size of the quantization storage type in bits.
67    pub fn size_bits_stored(&self) -> usize {
68        // Assume native packing if store type is < 8 bits
69        self.store.size_bits(&self.value).max(8)
70    }
71
72    /// Returns the size of the quantization storage type in bits.
73    pub fn size_bits_value(&self) -> usize {
74        self.value.size_bits()
75    }
76
77    /// Returns the number of quantized values stored in a single element.
78    pub fn num_quants(&self) -> usize {
79        self.size_bits_stored() / self.value.size_bits()
80    }
81
82    /// Returns the native packing factor for the values. When native packing > 1, the packed
83    /// representation stores `num_quants` elements grouped into packs of `native_packing` size.
84    pub fn native_packing(&self) -> usize {
85        self.value.native_packing()
86    }
87}
88
89/// Level or granularity of quantization.
90#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
91pub enum QuantLevel {
92    /// Quantize the whole tensor using a single tensor.
93    Tensor,
94    /// Quantize a tensor using multiple blocks.
95    Block(BlockSize),
96}
97
98impl QuantLevel {
99    /// Converting constructor for [`QuantLevel::Block`]
100    pub fn block(values: impl AsRef<[u8]>) -> Self {
101        QuantLevel::Block(BlockSize::new(values))
102    }
103}
104
105/// Data type used to represent quantized values.
106#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
107pub enum QuantValue {
108    /// 8-bit quantization with full range.
109    Q8F,
110    /// 8-bit floating point, e5m2 format.
111    E5M2,
112    /// 8-bit floating point, e4m3 format.
113    E4M3,
114    /// 4-bit quantization with full range.
115    Q4F,
116    /// 4-bit floating point, e2m1 format.
117    E2M1,
118    /// 2-bit quantization with full range.
119    Q2F,
120    /// 8-bit quantization with symmetric range.
121    Q8S,
122    /// 4-bit quantization with symmetric range.
123    Q4S,
124    /// 2-bit quantization with symmetric range.
125    Q2S,
126}
127
128impl QuantValue {
129    /// Returns the size of the quantization input type in bits.
130    pub fn size_bits(&self) -> usize {
131        match self {
132            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => 8,
133            QuantValue::Q4F | QuantValue::Q4S | QuantValue::E2M1 => 4,
134            QuantValue::Q2F | QuantValue::Q2S => 2,
135        }
136    }
137
138    /// Packing factor for the native representation used for intermediate values. If > 1, values
139    /// should always be processed in `native_packing` sized chunks.
140    pub fn native_packing(&self) -> usize {
141        match self {
142            QuantValue::E2M1 => 2,
143            _ => 1,
144        }
145    }
146
147    /// The possible range of values allowed by the quant value.
148    pub fn range(&self) -> (f32, f32) {
149        match self {
150            QuantValue::Q8F => (i8::MIN as f32, i8::MAX as f32),
151            QuantValue::Q4F => (-8.0, 7.0),
152            QuantValue::Q2F => (-2.0, 1.0),
153            QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32),
154            QuantValue::Q4S => (-7.0, 7.0),
155            QuantValue::Q2S => (-1.0, 1.0),
156            QuantValue::E4M3 => (-448.0, 448.0),
157            QuantValue::E5M2 => (-57344.0, 57344.0),
158            QuantValue::E2M1 => (-6.0, 6.0), // Hardcoded because of no-std
159        }
160    }
161
162    /// If the range of values is symmetric around zero.
163    pub fn is_symmetric(&self) -> bool {
164        match self {
165            Self::Q8F | Self::Q4F | Self::Q2F | Self::E4M3 | Self::E5M2 | Self::E2M1 => false,
166            Self::Q8S | Self::Q4S | Self::Q2S => true,
167        }
168    }
169}
170
171impl QuantStore {
172    /// Returns the size of the quantization input type in bits.
173    pub fn size_bits(&self, value: &QuantValue) -> usize {
174        match self {
175            QuantStore::Native => value.size_bits(),
176            QuantStore::U32 => 32,
177        }
178    }
179}
180
181/// Data type used to stored quantized values.
182#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
183pub enum QuantStore {
184    /// Native quantization doesn't require packing and unpacking.
185    Native,
186    /// Store packed quantized values in a 4-byte unsigned integer.
187    U32,
188    // /// Store packed quantized values in a 8-bit unsigned integer.
189    // U8,
190}
191
192/// Strategy used to quantize values.
193#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
194pub enum QuantMode {
195    /// Symmetric or scale quantization.
196    Symmetric,
197}
198
199/// Quantization floating-point precision.
200///
201/// This is used to represent the floating-point precision of quantization parameters like the scale(s)
202/// or the accumulation precision used during operations like matrix multiplication.
203#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
204pub enum QuantParam {
205    /// Full precision.
206    F32,
207    /// Half precision.
208    F16,
209    /// bfloat16 precision.
210    BF16,
211    /// unsigned floating point, e8m0 format.
212    UE8M0,
213    /// unsigned floating point, e4m3 format.
214    UE4M3,
215}
216
217const MAX_DIMS: usize = 5;
218
219/// Copyable block size, specialized version of `SmallVec`.
220#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
221pub struct BlockSize {
222    storage: [u8; MAX_DIMS],
223    len: u8,
224}
225
226impl BlockSize {
227    /// Max number of dimensions for block size
228    pub const MAX_DIMS: usize = MAX_DIMS;
229
230    /// Create a new blocksize from a set of values. The number of values must be `<= MAX_DIMS`.
231    pub fn new(values: impl AsRef<[u8]>) -> Self {
232        let values = values.as_ref();
233        debug_assert!(
234            values.len() <= MAX_DIMS,
235            "Tried creating a block size larger than the cap"
236        );
237        let len = values.len().min(MAX_DIMS);
238        let mut storage = [1; MAX_DIMS];
239        storage[..len].copy_from_slice(&values[..len]);
240        Self {
241            storage,
242            len: len as u8,
243        }
244    }
245
246    /// Return a slice of only the initialized valeus
247    pub fn as_slice(&self) -> &[u8] {
248        &self.storage[..self.len as usize]
249    }
250
251    /// Return a vec of only the initialized values
252    pub fn to_vec(&self) -> Vec<u8> {
253        self.storage[..self.len as usize].to_vec()
254    }
255
256    /// Returns `N` dimensions, unsqueezing if necessary.
257    pub fn as_dim<const N: usize>(&self) -> [u8; N] {
258        let data_len = N.min(self.len as usize);
259        let data_start = N - data_len;
260        let mut out = [1; N];
261        out[data_start..].copy_from_slice(&self.storage[..data_len]);
262        out
263    }
264
265    /// Returns a vector of `len` dimensions, unsqueezing if necessary.
266    pub fn to_dim_vec(&self, len: usize) -> Vec<u8> {
267        let data_len = len.min(self.len as usize);
268        let data_start = len - data_len;
269        let mut out = vec![1; len];
270        out[data_start..].copy_from_slice(&self.storage[..data_len]);
271        out
272    }
273
274    /// Create an iterator over all stored dimensions
275    pub fn iter(&self) -> impl Iterator<Item = &u8> {
276        self.as_slice().iter()
277    }
278
279    /// Returns the total number of elements in each block
280    pub fn num_elements(&self) -> usize {
281        self.iter().map(|it| *it as usize).product()
282    }
283}
284
285impl Deref for BlockSize {
286    type Target = [u8];
287
288    fn deref(&self) -> &Self::Target {
289        self.as_slice()
290    }
291}
292
293impl<T: AsRef<[u8]>> From<T> for BlockSize {
294    fn from(value: T) -> Self {
295        BlockSize::new(value)
296    }
297}