1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4 DType, ExecutionError, Shape, TensorData, TensorMetadata,
5 ops::QTensorOps,
6 quantization::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8 QuantizationParametersPrimitive, QuantizedBytes,
9 },
10 tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12
13use crate::{
14 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15 element::{IntNdArrayElement, QuantElement},
16 execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23 for NdArray<E, I, Q>
24where
25 NdArrayTensor: From<SharedArray<E>>,
26 NdArrayTensor: From<SharedArray<I>>,
27{
28 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29 match data.dtype {
30 DType::QFloat(scheme) => {
31 let shape = data.shape.clone();
32 let num_elements = data.num_elements();
33 let q_bytes = QuantizedBytes {
34 bytes: data.into_bytes(),
35 scheme,
36 num_elements,
37 };
38
39 match scheme {
40 QuantScheme {
41 level: QuantLevel::Tensor | QuantLevel::Block(_),
42 mode: QuantMode::Symmetric,
43 value: QuantValue::Q8F | QuantValue::Q8S,
44 store: QuantStore::Native | QuantStore::U32,
45 ..
46 } => {
47 let (values, qparams) = q_bytes.into_vec_i8();
49 let data = TensorData::new(values, shape);
50 let scheme = scheme.with_store(QuantStore::Native);
52
53 let qparams = qparams
54 .scales
55 .into_iter()
56 .map(|scales| QParams { scales })
57 .collect();
58
59 NdArrayQTensor {
60 qtensor: NdArrayTensor::from_data(data),
61 scheme,
62 qparams,
63 }
64 }
65 QuantScheme {
66 value:
67 QuantValue::Q4F
68 | QuantValue::Q4S
69 | QuantValue::Q2F
70 | QuantValue::Q2S
71 | QuantValue::E2M1
72 | QuantValue::E4M3
73 | QuantValue::E5M2,
74 ..
75 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
76 }
77 }
78 _ => panic!(
79 "Invalid dtype (expected DType::QFloat, got {:?})",
80 data.dtype
81 ),
82 }
83 }
84
85 fn quantize(
86 tensor: FloatTensor<Self>,
87 scheme: &QuantScheme,
88 qparams: QuantizationParametersPrimitive<Self>,
89 ) -> QuantizedTensor<Self> {
90 let shape = tensor.shape();
91 let data_f = tensor.into_data();
92 let scales = qparams.scales.into_data().convert::<f32>();
93
94 let (data, qparams) = match scheme {
96 QuantScheme {
97 level: QuantLevel::Tensor,
98 mode: QuantMode::Symmetric,
99 #[cfg(not(feature = "export_tests"))]
100 value: QuantValue::Q8F | QuantValue::Q8S,
101 #[cfg(feature = "export_tests")]
104 value:
105 QuantValue::Q8F
106 | QuantValue::Q8S
107 | QuantValue::Q4F
108 | QuantValue::Q4S
109 | QuantValue::Q2F
110 | QuantValue::Q2S,
111 store: QuantStore::Native,
112 ..
113 } => {
114 let scales = scales.iter().next().unwrap();
115 let strategy = QuantizationStrategy::PerTensorSymmetric(
116 SymmetricQuantization::init(scales, scheme.value),
117 );
118 let values = strategy.quantize(data_f.as_slice().unwrap());
119 (
120 TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
121 vec![QParams { scales }],
122 )
123 }
124 QuantScheme {
125 level: QuantLevel::Block(block_size),
126 mode: QuantMode::Symmetric,
127 #[cfg(not(feature = "export_tests"))]
128 value: QuantValue::Q8F | QuantValue::Q8S,
129 #[cfg(feature = "export_tests")]
130 value:
131 QuantValue::Q8F
132 | QuantValue::Q8S
133 | QuantValue::Q4F
134 | QuantValue::Q4S
135 | QuantValue::Q2F
136 | QuantValue::Q2S,
137 store: QuantStore::Native,
138 ..
139 } => {
140 let scales = scales.as_slice().unwrap();
141 let (strategy, qparams) = scales
142 .iter()
143 .map(|&s| {
144 (
145 SymmetricQuantization::init(s, scheme.value),
146 QParams { scales: s },
147 )
148 })
149 .unzip();
150 let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
151 let values = strategy.quantize(data_f.as_slice().unwrap());
152 (
153 TensorData::quantized(values, shape.clone(), *scheme, scales),
154 qparams,
155 )
156 }
157 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
158 };
159
160 let num_elements = data.num_elements();
161 let q_bytes = QuantizedBytes {
162 bytes: data.into_bytes(),
163 scheme: *scheme,
164 num_elements,
165 };
166 let (values, _) = q_bytes.into_vec_i8();
167 let data = TensorData::new(values, shape).convert::<Q>();
168
169 NdArrayQTensor {
170 qtensor: NdArrayTensor::from_data(data),
171 scheme: *scheme,
172 qparams,
173 }
174 }
175
176 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
177 let strategy = tensor.strategy();
178 let scheme = tensor.scheme;
179 let shape = tensor.shape();
180 let data = match tensor.qtensor {
181 NdArrayTensor::I8(storage) => {
182 let data = storage.into_shared().into_iter().collect();
183 dequantize(data, shape, scheme, &strategy)
184 }
185 _ => unreachable!(),
186 };
187 NdArrayTensor::from_data(data)
188 }
189
190 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
191 NdArrayDevice::Cpu
192 }
193
194 fn q_to_device(
195 tensor: QuantizedTensor<Self>,
196 _device: &NdArrayDevice,
197 ) -> QuantizedTensor<Self> {
198 tensor
199 }
200
201 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
202 NdArrayQTensor {
203 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
204 NdArrayOps::reshape(array, shape)
205 }),
206 scheme: tensor.scheme,
207 qparams: tensor.qparams,
208 }
209 }
210
211 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
212 let shape = tensor.qtensor.shape();
213 let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
214 Ok(execute_with_numeric_dtype!(
215 tensor.qtensor,
216 E,
217 |array: SharedArray<E>| {
218 let values = array.into_iter().collect();
219 TensorData::quantized(values, shape, tensor.scheme, &scales)
220 }
221 ))
222 }
223
224 fn q_swap_dims(
225 tensor: QuantizedTensor<Self>,
226 dim1: usize,
227 dim2: usize,
228 ) -> QuantizedTensor<Self> {
229 NdArrayQTensor {
230 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
231 NdArrayOps::swap_dims(array, dim1, dim2)
232 }),
233 scheme: tensor.scheme,
234 qparams: tensor.qparams,
235 }
236 }
237
238 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239 NdArrayQTensor {
240 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
241 NdArrayOps::permute(array, axes)
242 }),
243 scheme: tensor.scheme,
244 qparams: tensor.qparams,
245 }
246 }
247
248 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249 NdArrayQTensor {
250 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
251 NdArrayOps::flip(array, axes)
252 }),
253 scheme: tensor.scheme,
254 qparams: tensor.qparams,
255 }
256 }
257
258 fn q_gather(
259 dim: usize,
260 tensor: QuantizedTensor<Self>,
261 indices: IntTensor<Self>,
262 ) -> QuantizedTensor<Self> {
263 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
264 IntElem,
265 >|
266 -> NdArrayTensor {
267 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
268 NdArrayOps::gather(dim, array, idx_array)
269 })
270 });
271 NdArrayQTensor {
272 qtensor,
273 scheme: tensor.scheme,
274 qparams: tensor.qparams,
275 }
276 }
277
278 fn q_select(
279 tensor: QuantizedTensor<Self>,
280 dim: usize,
281 indices: IntTensor<Self>,
282 ) -> QuantizedTensor<Self> {
283 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
284 IntElem,
285 >|
286 -> NdArrayTensor {
287 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
288 NdArrayMathOps::select(array, dim, idx_array)
289 })
290 });
291 NdArrayQTensor {
292 qtensor,
293 scheme: tensor.scheme,
294 qparams: tensor.qparams,
295 }
296 }
297
298 fn q_slice(
299 tensor: QuantizedTensor<Self>,
300 slices: &[burn_backend::Slice],
301 ) -> QuantizedTensor<Self> {
302 NdArrayQTensor {
303 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
304 NdArrayOps::slice(array, slices)
305 }),
306 scheme: tensor.scheme,
307 qparams: tensor.qparams,
308 }
309 }
310
311 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
312 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
313 NdArrayMathOps::argmax::<I>(array, dim)
314 })
315 }
316
317 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
318 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
319 NdArrayMathOps::argmin::<I>(array, dim)
320 })
321 }
322
323 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
324 NdArrayQTensor {
325 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
326 NdArrayOps::expand(array, shape)
327 }),
328 scheme: tensor.scheme,
329 qparams: tensor.qparams,
330 }
331 }
332}
333
334fn dequantize<Q: QuantElement>(
335 data: Vec<Q>,
336 shape: Shape,
337 scheme: QuantScheme,
338 strategy: &QuantizationStrategy,
339) -> TensorData {
340 let qparams = match strategy {
341 QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
342 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
343 quant.iter().map(|q| q.scale).collect()
344 }
345 };
346 let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
347 let (values, _qparams) = q_bytes.into_vec_i8();
348 TensorData::new(strategy.dequantize(&values), shape)
349}