1pub use cubecl_common::quant::scheme::{
5 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Shape, bytes::Bytes};
21
22#[derive(
23 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25pub enum QuantAcc {
27 #[default]
29 F32,
30 F16,
32 BF16,
34}
35
36#[derive(
39 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42 Propagate,
44 #[default]
46 Inhibit,
47}
48
49#[derive(Clone, Debug)]
51pub struct QParams<S> {
52 pub scales: S,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59 pub offset_start: usize,
61 pub offset_end: usize,
63 pub shape: Shape,
65 pub strides: Vec<usize>,
67 pub dtype: DType,
69}
70
71pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
73 match level {
74 QuantLevel::Tensor => Shape::new([1]),
75 QuantLevel::Block(block_size) => {
76 let mut params_shape = data_shape.clone();
77 let block_size = block_size.to_dim_vec(data_shape.num_dims());
78
79 for (shape, block_size) in params_shape.dims.iter_mut().zip(block_size) {
80 *shape = (*shape).div_ceil(block_size as usize);
81 }
82
83 params_shape
84 }
85 }
86}
87
88pub struct QuantizedBytes {
98 pub bytes: Bytes,
100 pub scheme: QuantScheme,
102 pub num_elements: usize,
104}
105
106impl QuantizedBytes {
107 pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
109 value: Vec<E>,
110 scheme: QuantScheme,
111 scales: &[f32],
112 ) -> Self {
113 let num_elements = value.len();
114 if TypeId::of::<E>() != TypeId::of::<i8>() {
116 panic!("Invalid quantized type");
117 }
118
119 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
121 let mut bytes = Bytes::from_elems(i8s);
122
123 match scheme.level {
124 QuantLevel::Tensor => {
125 let scale_bytes = bytemuck::bytes_of(&scales[0]);
126 bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
127 }
128 QuantLevel::Block(_block_size) => {
129 let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
130 for scale in scales {
131 scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
132 }
133 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
134 }
135 }
136
137 Self {
138 bytes,
139 scheme,
140 num_elements,
141 }
142 }
143
144 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
146 let (values, (qparams, num_params)) = self.split_values_off();
147
148 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
154 let total_bytes = qparams_bytes.len();
155
156 let scales_size = scale_size * num_params;
157
158 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
159
160 (values, QParams { scales })
161 }
162
163 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
164 let mut values = read_bytes_to_i8(self.bytes);
165
166 let scale_size = num_params * size_of::<f32>();
167 let values_end = values.len() - scale_size;
168
169 let qparams = values.split_off(values_end);
170
171 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
172 let mut qparams = core::mem::ManuallyDrop::new(qparams);
173 unsafe {
174 Vec::<u32>::from_raw_parts(
175 qparams.as_mut_ptr() as _,
176 qparams.len() / 4,
177 qparams.capacity() / 4,
178 )
179 }
180 } else {
181 #[cfg(target_endian = "little")]
182 {
183 bytemuck::cast_vec(qparams)
185 }
186 #[cfg(target_endian = "big")]
187 {
188 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
189 }
190 };
191 (values, qparams)
192 }
193
194 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
198 let num_params = match self.scheme.level {
199 QuantLevel::Tensor => 1,
200 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
201 };
202
203 let (values, qparams) = match self.scheme.store {
204 QuantStore::Native => self.split_i8_values(num_params),
205 QuantStore::U32 => match self.scheme.value {
206 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
207 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
208 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
209 let scale_size = num_params; let values_end = values.len() - scale_size;
211
212 let qparams = values.split_off(values_end);
213 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
215 (values, qparams)
216 }
217 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
218 unimplemented!("Not yet supported")
219 }
220 },
221 };
222
223 (values, (qparams, num_params))
224 }
225}
226
227fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
228 match bytes.try_into_vec::<i8>() {
229 Ok(val) => val,
230 Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
234 }
235}
236
237pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
239 #[cfg(target_endian = "big")]
243 {
244 values
245 .chunks(4)
246 .map(|x| {
247 x.iter()
248 .enumerate()
249 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
250 })
251 .collect()
252 }
253
254 #[cfg(target_endian = "little")]
257 {
258 let mut values = values;
259 let remainder = values.len() % 4;
260 if remainder != 0 {
261 values.extend(core::iter::repeat_n(0, 4 - remainder));
263 }
264
265 let len = values.len() / 4;
266 let capacity = values.capacity() / 4;
267
268 let mut values = core::mem::ManuallyDrop::new(values);
270 let ptr = values.as_mut_ptr() as *mut u32;
271
272 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
273 }
274}
275
276pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
278 values: &[Q],
279 numel: usize,
280 value: &QuantValue,
281) -> Vec<i8> {
282 let size_store = size_of::<Q>() * 8;
283 let size_quant = value.size_bits();
284 let num_quants = size_store / size_quant;
285 let mask = Q::from((1 << size_quant) - 1).unwrap();
286 let sign_shift = 8 - size_quant; values
288 .iter()
289 .enumerate()
290 .flat_map(|(i, &packed)| {
291 let n = core::cmp::min(num_quants, numel - i * num_quants);
293 (0..n).map(move |i| {
300 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
301 ((raw << sign_shift) as i8) >> sign_shift
302 })
303 })
304 .collect()
305}
306
307#[cfg(test)]
308mod tests {
309
310 use super::*;
311 use alloc::vec;
312
313 #[test]
314 fn should_pack_i8s_to_u32() {
315 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
316
317 assert_eq!(packed, vec![2147287680]);
318 }
319
320 #[test]
321 fn should_pack_i8s_to_u32_padded() {
322 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
323 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
324
325 assert_eq!(packed, vec![2147287680, 55]);
326 assert_eq!(packed, packed_padded);
327 }
328
329 #[test]
330 fn should_unpack_u32s_to_i8s() {
331 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
332
333 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
334 }
335
336 #[test]
337 fn should_unpack_u32s_to_i8s_padded() {
338 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
339
340 assert_eq!(unpacked, vec![55]);
341 }
342
343 #[test]
344 fn should_unpack_u32s_to_i8s_arange() {
345 let unpacked = unpack_q_to_i8s(
346 &[
347 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
348 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
349 2004318071,
350 ],
351 128,
352 &QuantValue::Q4S,
353 );
354
355 assert_eq!(
356 unpacked,
357 vec![
358 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
359 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
360 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
361 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
362 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
363 ]
364 );
365 }
366
367 #[test]
368 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
369 let scale = 0.03937008;
371 let values = vec![0i8, 25, 51, 76, 102, 127];
372
373 let q_bytes = QuantizedBytes::new(
374 values.clone(),
375 QuantScheme::default()
376 .with_value(QuantValue::Q8S)
377 .with_store(QuantStore::Native),
378 &[scale],
379 );
380
381 let (q_values, qparams) = q_bytes.into_vec_i8();
382
383 assert_eq!(qparams.scales, vec![scale]);
384
385 assert_eq!(q_values, values);
386 }
387}