1pub use cubecl_common::quant::scheme::{
5 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8pub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();
14
15use alloc::vec::Vec;
16use core::any::TypeId;
17use num_traits::PrimInt;
18use serde::{Deserialize, Serialize};
19
20use crate::{DType, Metadata, Shape, bytes::Bytes};
21
22#[derive(
23 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
24)]
25pub enum QuantAcc {
27 #[default]
29 F32,
30 F16,
32 BF16,
34}
35
36#[derive(
39 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
40)]
41pub enum QuantPropagation {
42 Propagate,
44 #[default]
46 Inhibit,
47}
48
49#[derive(Clone, Debug)]
51pub struct QParams<S> {
52 pub scales: S,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QParamTensor {
59 pub offset_start: usize,
61 pub offset_end: usize,
63 pub metadata: Metadata,
65 pub dtype: DType,
67}
68
69pub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {
71 match level {
72 QuantLevel::Tensor => Shape::new([1]),
73 QuantLevel::Block(block_size) => {
74 let mut params_shape = data_shape.clone();
75 let block_size = block_size.to_dim_vec(data_shape.num_dims());
76
77 for (shape, block_size) in params_shape.iter_mut().zip(block_size) {
78 *shape = (*shape).div_ceil(block_size as usize);
79 }
80
81 params_shape
82 }
83 }
84}
85
86pub struct QuantizedBytes {
96 pub bytes: Bytes,
98 pub scheme: QuantScheme,
100 pub num_elements: usize,
102}
103
104impl QuantizedBytes {
105 pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
107 value: Vec<E>,
108 scheme: QuantScheme,
109 scales: &[f32],
110 ) -> Self {
111 let num_elements = value.len();
112 if TypeId::of::<E>() != TypeId::of::<i8>() {
114 panic!("Invalid quantized type");
115 }
116
117 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
119 let mut bytes = Bytes::from_elems(i8s);
120
121 match scheme.level {
122 QuantLevel::Tensor => {
123 let scale_bytes = bytemuck::bytes_of(&scales[0]);
124 bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);
125 }
126 QuantLevel::Block(_block_size) => {
127 let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
128 for scale in scales {
129 scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
130 }
131 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);
132 }
133 }
134
135 Self {
136 bytes,
137 scheme,
138 num_elements,
139 }
140 }
141
142 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
144 let (values, (qparams, num_params)) = self.split_values_off();
145
146 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
152 let total_bytes = qparams_bytes.len();
153
154 let scales_size = scale_size * num_params;
155
156 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
157
158 (values, QParams { scales })
159 }
160
161 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
162 let mut values = read_bytes_to_i8(self.bytes);
163
164 let scale_size = num_params * size_of::<f32>();
165 let values_end = values.len() - scale_size;
166
167 let qparams = values.split_off(values_end);
168
169 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
170 let mut qparams = core::mem::ManuallyDrop::new(qparams);
171 unsafe {
172 Vec::<u32>::from_raw_parts(
173 qparams.as_mut_ptr() as _,
174 qparams.len() / 4,
175 qparams.capacity() / 4,
176 )
177 }
178 } else {
179 #[cfg(target_endian = "little")]
180 {
181 bytemuck::cast_vec(qparams)
183 }
184 #[cfg(target_endian = "big")]
185 {
186 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
187 }
188 };
189 (values, qparams)
190 }
191
192 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
196 let num_params = match self.scheme.level {
197 QuantLevel::Tensor => 1,
198 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
199 };
200
201 if let QuantStore::PackedU32(packed_dim) = self.scheme.store {
202 assert_eq!(
203 packed_dim, 0,
204 "Packing must be on innermost dimension for splitting off values"
205 );
206 }
207
208 let (values, qparams) = match self.scheme.store {
209 QuantStore::Native => self.split_i8_values(num_params),
210 QuantStore::PackedU32(_) => match self.scheme.value {
211 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
212 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
213 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
214 let scale_size = num_params; let values_end = values.len() - scale_size;
216
217 let qparams = values.split_off(values_end);
218 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
220 (values, qparams)
221 }
222 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
223 unimplemented!("Not yet supported")
224 }
225 },
226 QuantStore::PackedNative(_) => unimplemented!("Not yet supported"),
227 };
228
229 (values, (qparams, num_params))
230 }
231}
232
233fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
234 match bytes.try_into_vec::<i8>() {
235 Ok(val) => val,
236 Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
240 }
241}
242
243pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
245 #[cfg(target_endian = "big")]
249 {
250 values
251 .chunks(4)
252 .map(|x| {
253 x.iter()
254 .enumerate()
255 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
256 })
257 .collect()
258 }
259
260 #[cfg(target_endian = "little")]
263 {
264 let mut values = values;
265 let remainder = values.len() % 4;
266 if remainder != 0 {
267 values.extend(core::iter::repeat_n(0, 4 - remainder));
269 }
270
271 let len = values.len() / 4;
272 let capacity = values.capacity() / 4;
273
274 let mut values = core::mem::ManuallyDrop::new(values);
276 let ptr = values.as_mut_ptr() as *mut u32;
277
278 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
279 }
280}
281
282pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
284 values: &[Q],
285 numel: usize,
286 value: &QuantValue,
287) -> Vec<i8> {
288 let size_store = size_of::<Q>() * 8;
289 let size_quant = value.size_bits();
290 let num_quants = size_store / size_quant;
291 let mask = Q::from((1 << size_quant) - 1).unwrap();
292 let sign_shift = 8 - size_quant; values
294 .iter()
295 .enumerate()
296 .flat_map(|(i, &packed)| {
297 let n = core::cmp::min(num_quants, numel - i * num_quants);
299 (0..n).map(move |i| {
306 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
307 ((raw << sign_shift) as i8) >> sign_shift
308 })
309 })
310 .collect()
311}
312
313#[cfg(test)]
314mod tests {
315
316 use super::*;
317 use alloc::vec;
318
319 #[test]
320 fn should_pack_i8s_to_u32() {
321 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
322
323 assert_eq!(packed, vec![2147287680]);
324 }
325
326 #[test]
327 fn should_pack_i8s_to_u32_padded() {
328 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
329 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
330
331 assert_eq!(packed, vec![2147287680, 55]);
332 assert_eq!(packed, packed_padded);
333 }
334
335 #[test]
336 fn should_unpack_u32s_to_i8s() {
337 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
338
339 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
340 }
341
342 #[test]
343 fn should_unpack_u32s_to_i8s_padded() {
344 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
345
346 assert_eq!(unpacked, vec![55]);
347 }
348
349 #[test]
350 fn should_unpack_u32s_to_i8s_arange() {
351 let unpacked = unpack_q_to_i8s(
352 &[
353 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
354 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
355 2004318071,
356 ],
357 128,
358 &QuantValue::Q4S,
359 );
360
361 assert_eq!(
362 unpacked,
363 vec![
364 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
365 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
366 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
367 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
368 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
369 ]
370 );
371 }
372
373 #[test]
374 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
375 let scale = 0.03937008;
377 let values = vec![0i8, 25, 51, 76, 102, 127];
378
379 let q_bytes = QuantizedBytes::new(
380 values.clone(),
381 QuantScheme::default()
382 .with_value(QuantValue::Q8S)
383 .with_store(QuantStore::Native),
384 &[scale],
385 );
386
387 let (q_values, qparams) = q_bytes.into_vec_i8();
388
389 assert_eq!(qparams.scales, vec![scale]);
390
391 assert_eq!(q_values, values);
392 }
393}