1pub use cubecl_quant::scheme::{
5 BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,
6};
7
8use alloc::vec::Vec;
9use core::any::TypeId;
10use num_traits::PrimInt;
11
12use crate::bytes::Bytes;
13
14#[derive(Clone, Debug)]
16pub struct QParams<S> {
17 pub scales: S,
19}
20
21pub struct QuantizedBytes {
31 pub bytes: Bytes,
33 pub scheme: QuantScheme,
35 pub num_elements: usize,
37}
38
39impl QuantizedBytes {
40 pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
42 value: Vec<E>,
43 scheme: QuantScheme,
44 scales: &[f32],
45 ) -> Self {
46 let num_elements = value.len();
47 if TypeId::of::<E>() != TypeId::of::<i8>() {
49 panic!("Invalid quantized type");
50 }
51
52 let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);
54 let mut bytes = Bytes::from_elems(i8s);
55
56 match scheme.level {
57 QuantLevel::Tensor => {
58 let scale_bytes = bytemuck::bytes_of(&scales[0]);
59 bytes.extend_from_byte_slice_aligned(scale_bytes, align_of::<f32>());
60 }
61 QuantLevel::Block(_block_size) => {
62 let mut scale_bytes = Vec::with_capacity(size_of_val(scales));
63 for scale in scales {
64 scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));
65 }
66 bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), align_of::<f32>());
67 }
68 }
69
70 Self {
71 bytes,
72 scheme,
73 num_elements,
74 }
75 }
76
77 pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {
79 let (values, (qparams, num_params)) = self.split_values_off();
80
81 let scale_size = core::mem::size_of::<f32>(); let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);
87 let total_bytes = qparams_bytes.len();
88
89 let scales_size = scale_size * num_params;
90
91 let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();
92
93 (values, QParams { scales })
94 }
95
96 fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {
97 let mut values = read_bytes_to_i8(self.bytes);
98
99 let scale_size = num_params * size_of::<f32>();
100 let values_end = values.len() - scale_size;
101
102 let qparams = values.split_off(values_end);
103
104 let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {
105 let mut qparams = core::mem::ManuallyDrop::new(qparams);
106 unsafe {
107 Vec::<u32>::from_raw_parts(
108 qparams.as_mut_ptr() as _,
109 qparams.len() / 4,
110 qparams.capacity() / 4,
111 )
112 }
113 } else {
114 #[cfg(target_endian = "little")]
115 {
116 bytemuck::cast_vec(qparams)
118 }
119 #[cfg(target_endian = "big")]
120 {
121 crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))
122 }
123 };
124 (values, qparams)
125 }
126
127 fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {
131 let num_params = match self.scheme.level {
132 QuantLevel::Tensor => 1,
133 QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),
134 };
135
136 let (values, qparams) = match self.scheme.store {
137 QuantStore::Native => self.split_i8_values(num_params),
138 QuantStore::U32 => match self.scheme.value {
139 QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),
140 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
141 let mut values = self.bytes.try_into_vec::<u32>().unwrap();
142 let scale_size = num_params; let values_end = values.len() - scale_size;
144
145 let qparams = values.split_off(values_end);
146 let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);
148 (values, qparams)
149 }
150 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
151 unimplemented!("Not yet supported")
152 }
153 },
154 };
155
156 (values, (qparams, num_params))
157 }
158}
159
160fn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {
161 match bytes.try_into_vec::<i8>() {
162 Ok(val) => val,
163 Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },
167 }
168}
169
170pub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {
172 #[cfg(target_endian = "big")]
176 {
177 values
178 .chunks(4)
179 .map(|x| {
180 x.iter()
181 .enumerate()
182 .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))
183 })
184 .collect()
185 }
186
187 #[cfg(target_endian = "little")]
190 {
191 let mut values = values;
192 let remainder = values.len() % 4;
193 if remainder != 0 {
194 values.extend(core::iter::repeat_n(0, 4 - remainder));
196 }
197
198 let len = values.len() / 4;
199 let capacity = values.capacity() / 4;
200
201 let mut values = core::mem::ManuallyDrop::new(values);
203 let ptr = values.as_mut_ptr() as *mut u32;
204
205 unsafe { Vec::from_raw_parts(ptr, len, capacity) }
206 }
207}
208
209pub(crate) fn unpack_q_to_i8s<Q: PrimInt>(
211 values: &[Q],
212 numel: usize,
213 value: &QuantValue,
214) -> Vec<i8> {
215 let size_store = size_of::<Q>() * 8;
216 let size_quant = value.size_bits();
217 let num_quants = size_store / size_quant;
218 let mask = Q::from((1 << size_quant) - 1).unwrap();
219 let sign_shift = 8 - size_quant; values
221 .iter()
222 .enumerate()
223 .flat_map(|(i, &packed)| {
224 let n = core::cmp::min(num_quants, numel - i * num_quants);
226 (0..n).map(move |i| {
233 let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();
234 ((raw << sign_shift) as i8) >> sign_shift
235 })
236 })
237 .collect()
238}
239
240#[cfg(test)]
241mod tests {
242
243 use super::*;
244 use alloc::vec;
245
246 #[test]
247 fn should_pack_i8s_to_u32() {
248 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);
249
250 assert_eq!(packed, vec![2147287680]);
251 }
252
253 #[test]
254 fn should_pack_i8s_to_u32_padded() {
255 let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);
256 let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);
257
258 assert_eq!(packed, vec![2147287680, 55]);
259 assert_eq!(packed, packed_padded);
260 }
261
262 #[test]
263 fn should_unpack_u32s_to_i8s() {
264 let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);
265
266 assert_eq!(unpacked, vec![-128, 2, -3, 127]);
267 }
268
269 #[test]
270 fn should_unpack_u32s_to_i8s_padded() {
271 let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);
272
273 assert_eq!(unpacked, vec![55]);
274 }
275
276 #[test]
277 fn should_unpack_u32s_to_i8s_arange() {
278 let unpacked = unpack_q_to_i8s(
279 &[
280 0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,
281 1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,
282 2004318071,
283 ],
284 128,
285 &QuantValue::Q4S,
286 );
287
288 assert_eq!(
289 unpacked,
290 vec![
291 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
292 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
293 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
294 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
295 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
296 ]
297 );
298 }
299
300 #[test]
301 fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {
302 let scale = 0.03937008;
304 let values = vec![0i8, 25, 51, 76, 102, 127];
305
306 let q_bytes = QuantizedBytes::new(
307 values.clone(),
308 QuantScheme::default()
309 .with_value(QuantValue::Q8S)
310 .with_store(QuantStore::Native),
311 &[scale],
312 );
313
314 let (q_values, qparams) = q_bytes.into_vec_i8();
315
316 assert_eq!(qparams.scales, vec![scale]);
317
318 assert_eq!(q_values, values);
319 }
320}