1pub use cubecl_common::quant::scheme::{
5 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Shape, bytes::Bytes};
21
22#[derive(
23 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25pub enum QuantAcc {
27 #[default]
29 F32,
30 F16,
32 BF16,
34}
35
36#[derive(
39 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42 Propagate,
44 #[default]
46 Inhibit,
47}
48
49#[derive(Clone, Debug)]
51pub struct QParams<S> {
52 pub scales: S,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59 pub offset_start: usize,
61 pub offset_end: usize,
63 pub shape: Shape,
65 pub strides: Vec<usize>,
67 pub dtype: DType,
69}
70
71pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
73 match level {
74 QuantLevel::Tensor => Shape::new([1]),
75 QuantLevel::Block(block_size) => {
76 let mut params_shape = data_shape.clone();
77 let block_size = block_size.to_dim_vec(data_shape.num_dims());
78
79 for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
80 *shape = (*shape).div_ceil(block_size as usize);
81 }
82
83 params_shape
84 }
85 }
86}
87
88pub struct QuantizedBytes {
98 pub bytes: Bytes,
100 pub scheme: QuantScheme,
102 pub num_elements: usize,
104}
105
106impl QuantizedBytes {
107 pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
109 value: Vec<E>,
110 scheme: QuantScheme,
111 scales: &[f32],
112 ) -> Self {
113 let num_elements = value.len();
114 if TypeId::of::<E>() != TypeId::of::<i8>() {
116 panic!("Invalid quantized type");
117 }
118
119 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
121 let mut bytes = Bytes::from_elems(i8s);
122
123 match scheme.level {
124 QuantLevel::Tensor => {
125 let scale_bytes = bytemuck::bytes_of(&scales[0]);
126 bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
127 }
128 QuantLevel::Block(_block_size) => {
129 let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
130 for scale in scales {
131 scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
132 }
133 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
134 }
135 }
136
137 Self {
138 bytes,
139 scheme,
140 num_elements,
141 }
142 }
143
144 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
146 let (values, (qparams, num_params)) = self.split_values_off();
147
148 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
154 let total_bytes = qparams_bytes.len();
155
156 let scales_size = scale_size * num_params;
157
158 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
159
160 (values, QParams { scales })
161 }
162
163 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
164 let mut values = read_bytes_to_i8(self.bytes);
165
166 let scale_size = num_params * size_of::<f32>();
167 let values_end = values.len() - scale_size;
168
169 let qparams = values.split_off(values_end);
170
171 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
172 let mut qparams = core::mem::ManuallyDrop::new(qparams);
173 unsafe {
174 Vec::<u32>::from_raw_parts(
175 qparams.as_mut_ptr() as _,
176 qparams.len() / 4,
177 qparams.capacity() / 4,
178 )
179 }
180 } else {
181 #[cfg(target_endian = "little")]
182 {
183 bytemuck::cast_vec(qparams)
185 }
186 #[cfg(target_endian = "big")]
187 {
188 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
189 }
190 };
191 (values, qparams)
192 }
193
194 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
198 let num_params = match self.scheme.level {
199 QuantLevel::Tensor => 1,
200 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
201 };
202
203 if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
204 assert_eq!(
205 packed_dim, 0,
206 "Packing must be on innermost dimension for splitting off values"
207 );
208 }
209
210 let (values, qparams) = match self.scheme.store {
211 QuantStore::Native => self.split_i8_values(num_params),
212 QuantStore::PackedU32(_) => match self.scheme.value {
213 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
214 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
215 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
216 let scale_size = num_params; let values_end = values.len() - scale_size;
218
219 let qparams = values.split_off(values_end);
220 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
222 (values, qparams)
223 }
224 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
225 unimplemented!("Not yet supported")
226 }
227 },
228 QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
229 };
230
231 (values, (qparams, num_params))
232 }
233}
234
235fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
236 match bytes.try_into_vec::<i8>() {
237 Ok(val) => val,
238 Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
242 }
243}
244
245pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
247 #[cfg(target_endian = "big")]
251 {
252 values
253 .chunks(4)
254 .map(|x| {
255 x.iter()
256 .enumerate()
257 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
258 })
259 .collect()
260 }
261
262 #[cfg(target_endian = "little")]
265 {
266 let mut values = values;
267 let remainder = values.len() % 4;
268 if remainder != 0 {
269 values.extend(core::iter::repeat_n(0, 4 - remainder));
271 }
272
273 let len = values.len() / 4;
274 let capacity = values.capacity() / 4;
275
276 let mut values = core::mem::ManuallyDrop::new(values);
278 let ptr = values.as_mut_ptr() as *mut u32;
279
280 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
281 }
282}
283
284pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
286 values: &[Q],
287 numel: usize,
288 value: &QuantValue,
289) -> Vec<i8> {
290 let size_store = size_of::<Q>() * 8;
291 let size_quant = value.size_bits();
292 let num_quants = size_store / size_quant;
293 let mask = Q::from((1 << size_quant) - 1).unwrap();
294 let sign_shift = 8 - size_quant; values
296 .iter()
297 .enumerate()
298 .flat_map(|(i, &packed)| {
299 let n = core::cmp::min(num_quants, numel - i * num_quants);
301 (0..n).map(move |i| {
308 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
309 ((raw << sign_shift) as i8) >> sign_shift
310 })
311 })
312 .collect()
313}
314
315#[cfg(test)]
316mod tests {
317
318 use super::*;
319 use alloc::vec;
320
321 #[test]
322 fn should_pack_i8s_to_u32() {
323 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
324
325 assert_eq!(packed, vec![2147287680]);
326 }
327
328 #[test]
329 fn should_pack_i8s_to_u32_padded() {
330 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
331 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
332
333 assert_eq!(packed, vec![2147287680, 55]);
334 assert_eq!(packed, packed_padded);
335 }
336
337 #[test]
338 fn should_unpack_u32s_to_i8s() {
339 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
340
341 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
342 }
343
344 #[test]
345 fn should_unpack_u32s_to_i8s_padded() {
346 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
347
348 assert_eq!(unpacked, vec![55]);
349 }
350
351 #[test]
352 fn should_unpack_u32s_to_i8s_arange() {
353 let unpacked = unpack_q_to_i8s(
354 &[
355 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
356 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
357 2004318071,
358 ],
359 128,
360 &QuantValue::Q4S,
361 );
362
363 assert_eq!(
364 unpacked,
365 vec![
366 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
367 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
368 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
369 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
370 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
371 ]
372 );
373 }
374
375 #[test]
376 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
377 let scale = 0.03937008;
379 let values = vec![0i8, 25, 51, 76, 102, 127];
380
381 let q_bytes = QuantizedBytes::new(
382 values.clone(),
383 QuantScheme::default()
384 .with_value(QuantValue::Q8S)
385 .with_store(QuantStore::Native),
386 &[scale],
387 );
388
389 let (q_values, qparams) = q_bytes.into_vec_i8();
390
391 assert_eq!(qparams.scales, vec![scale]);
392
393 assert_eq!(q_values, values);
394 }
395}