1pub use cubecl_quant::scheme::{
5 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8use alloc::vec::Vec;
9use core::any::TypeId;
10use num_traits::PrimInt;
11use serde::{Deserialize, Serialize};
12
13use crate::bytes::Bytes;
14
15#[derive(
16 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
17)]
18pub enum QuantAcc {
20 #[default]
22 F32,
23 F16,
25 BF16,
27}
28
29#[derive(
32 Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,
33)]
34pub enum QuantPropagation {
35 Propagate,
37 #[default]
39 Inhibit,
40}
41
42#[derive(Clone, Debug)]
44pub struct QParams<S> {
45 pub scales: S,
47}
48
49pub struct QuantizedBytes {
59 pub bytes: Bytes,
61 pub scheme: QuantScheme,
63 pub num_elements: usize,
65}
66
67impl QuantizedBytes {
68 pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
70 value: Vec<E>,
71 scheme: QuantScheme,
72 scales: &[f32],
73 ) -> Self {
74 let num_elements = value.len();
75 if TypeId::of::<E>() != TypeId::of::<i8>() {
77 panic!("Invalid quantized type");
78 }
79
80 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
82 let mut bytes = Bytes::from_elems(i8s);
83
84 match scheme.level {
85 QuantLevel::Tensor => {
86 let scale_bytes = bytemuck::bytes_of(&scales[0]);
87 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
88 }
89 QuantLevel::Block(_block_size) => {
90 let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
91 for scale in scales {
92 scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
93 }
94 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
95 }
96 }
97
98 Self {
99 bytes,
100 scheme,
101 num_elements,
102 }
103 }
104
105 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
107 let (values, (qparams, num_params)) = self.split_values_off();
108
109 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
115 let total_bytes = qparams_bytes.len();
116
117 let scales_size = scale_size * num_params;
118
119 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
120
121 (values, QParams { scales })
122 }
123
124 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
125 let mut values = read_bytes_to_i8(self.bytes);
126
127 let scale_size = num_params * size_of::<f32>();
128 let values_end = values.len() - scale_size;
129
130 let qparams = values.split_off(values_end);
131
132 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
133 let mut qparams = core::mem::ManuallyDrop::new(qparams);
134 unsafe {
135 Vec::<u32>::from_raw_parts(
136 qparams.as_mut_ptr() as _,
137 qparams.len() / 4,
138 qparams.capacity() / 4,
139 )
140 }
141 } else {
142 #[cfg(target_endian = "little")]
143 {
144 bytemuck::cast_vec(qparams)
146 }
147 #[cfg(target_endian = "big")]
148 {
149 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
150 }
151 };
152 (values, qparams)
153 }
154
155 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
159 let num_params = match self.scheme.level {
160 QuantLevel::Tensor => 1,
161 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
162 };
163
164 let (values, qparams) = match self.scheme.store {
165 QuantStore::Native => self.split_i8_values(num_params),
166 QuantStore::U32 => match self.scheme.value {
167 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
168 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
169 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
170 let scale_size = num_params; let values_end = values.len() - scale_size;
172
173 let qparams = values.split_off(values_end);
174 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
176 (values, qparams)
177 }
178 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
179 unimplemented!("Not yet supported")
180 }
181 },
182 };
183
184 (values, (qparams, num_params))
185 }
186}
187
188fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
189 match bytes.try_into_vec::<i8>() {
190 Ok(val) => val,
191 Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
195 }
196}
197
198pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
200 #[cfg(target_endian = "big")]
204 {
205 values
206 .chunks(4)
207 .map(|x| {
208 x.iter()
209 .enumerate()
210 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
211 })
212 .collect()
213 }
214
215 #[cfg(target_endian = "little")]
218 {
219 let mut values = values;
220 let remainder = values.len() % 4;
221 if remainder != 0 {
222 values.extend(core::iter::repeat_n(0, 4 - remainder));
224 }
225
226 let len = values.len() / 4;
227 let capacity = values.capacity() / 4;
228
229 let mut values = core::mem::ManuallyDrop::new(values);
231 let ptr = values.as_mut_ptr() as *mut u32;
232
233 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
234 }
235}
236
237pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
239 values: &[Q],
240 numel: usize,
241 value: &QuantValue,
242) -> Vec<i8> {
243 let size_store = size_of::<Q>() * 8;
244 let size_quant = value.size_bits();
245 let num_quants = size_store / size_quant;
246 let mask = Q::from((1 << size_quant) - 1).unwrap();
247 let sign_shift = 8 - size_quant; values
249 .iter()
250 .enumerate()
251 .flat_map(|(i, &packed)| {
252 let n = core::cmp::min(num_quants, numel - i * num_quants);
254 (0..n).map(move |i| {
261 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
262 ((raw << sign_shift) as i8) >> sign_shift
263 })
264 })
265 .collect()
266}
267
268#[cfg(test)]
269mod tests {
270
271 use super::*;
272 use alloc::vec;
273
274 #[test]
275 fn should_pack_i8s_to_u32() {
276 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
277
278 assert_eq!(packed, vec![2147287680]);
279 }
280
281 #[test]
282 fn should_pack_i8s_to_u32_padded() {
283 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
284 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
285
286 assert_eq!(packed, vec![2147287680, 55]);
287 assert_eq!(packed, packed_padded);
288 }
289
290 #[test]
291 fn should_unpack_u32s_to_i8s() {
292 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
293
294 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
295 }
296
297 #[test]
298 fn should_unpack_u32s_to_i8s_padded() {
299 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
300
301 assert_eq!(unpacked, vec![55]);
302 }
303
304 #[test]
305 fn should_unpack_u32s_to_i8s_arange() {
306 let unpacked = unpack_q_to_i8s(
307 &[
308 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
309 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
310 2004318071,
311 ],
312 128,
313 &QuantValue::Q4S,
314 );
315
316 assert_eq!(
317 unpacked,
318 vec![
319 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
320 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
321 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
322 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
323 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
324 ]
325 );
326 }
327
328 #[test]
329 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
330 let scale = 0.03937008;
332 let values = vec![0i8, 25, 51, 76, 102, 127];
333
334 let q_bytes = QuantizedBytes::new(
335 values.clone(),
336 QuantScheme::default()
337 .with_value(QuantValue::Q8S)
338 .with_store(QuantStore::Native),
339 &[scale],
340 );
341
342 let (q_values, qparams) = q_bytes.into_vec_i8();
343
344 assert_eq!(qparams.scales, vec![scale]);
345
346 assert_eq!(q_values, values);
347 }
348}