1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4 DType, ExecutionError, Shape, TensorData, TensorMetadata,
5 ops::QTensorOps,
6 quantization::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8 QuantizationParametersPrimitive, QuantizedBytes,
9 },
10 tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12use burn_std::{FloatDType, IntDType};
13
14use crate::{
15 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
16 element::{IntNdArrayElement, QuantElement},
17 execute_with_dtype, execute_with_int_dtype, execute_with_int_out_dtype,
18 execute_with_numeric_dtype, slice,
19};
20
21use super::quantization::{QuantizationStrategy, SymmetricQuantization};
22use super::{NdArrayMathOps, NdArrayOps};
23
24impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
25 for NdArray<E, I, Q>
26where
27 NdArrayTensor: From<SharedArray<E>>,
28 NdArrayTensor: From<SharedArray<I>>,
29{
30 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
31 match data.dtype {
32 DType::QFloat(scheme) => {
33 let shape = data.shape.clone();
34 let num_elements = data.num_elements();
35 let q_bytes = QuantizedBytes {
36 bytes: data.into_bytes(),
37 scheme,
38 num_elements,
39 };
40
41 match scheme {
42 QuantScheme {
43 level: QuantLevel::Tensor | QuantLevel::Block(_),
44 mode: QuantMode::Symmetric,
45 value: QuantValue::Q8F | QuantValue::Q8S,
46 ..
47 } => {
48 let (values, qparams) = q_bytes.into_vec_i8();
50 let data = TensorData::new(values, shape);
51 let scheme = scheme.with_store(QuantStore::Native);
53
54 let qparams = qparams
55 .scales
56 .into_iter()
57 .map(|scales| QParams { scales })
58 .collect();
59
60 NdArrayQTensor {
61 qtensor: NdArrayTensor::from_data(data),
62 scheme,
63 qparams,
64 }
65 }
66 QuantScheme {
67 value:
68 QuantValue::Q4F
69 | QuantValue::Q4S
70 | QuantValue::Q2F
71 | QuantValue::Q2S
72 | QuantValue::E2M1
73 | QuantValue::E4M3
74 | QuantValue::E5M2,
75 ..
76 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
77 }
78 }
79 _ => panic!(
80 "Invalid dtype (expected DType::QFloat, got {:?})",
81 data.dtype
82 ),
83 }
84 }
85
86 fn quantize(
87 tensor: FloatTensor<Self>,
88 scheme: &QuantScheme,
89 qparams: QuantizationParametersPrimitive<Self>,
90 ) -> QuantizedTensor<Self> {
91 let shape = tensor.shape();
92 let data_f = tensor.into_data();
93 let scales = qparams.scales.into_data().convert::<f32>();
94
95 let (data, qparams) = match scheme {
97 QuantScheme {
98 level: QuantLevel::Tensor,
99 mode: QuantMode::Symmetric,
100 #[cfg(not(feature = "export_tests"))]
101 value: QuantValue::Q8F | QuantValue::Q8S,
102 #[cfg(feature = "export_tests")]
105 value:
106 QuantValue::Q8F
107 | QuantValue::Q8S
108 | QuantValue::Q4F
109 | QuantValue::Q4S
110 | QuantValue::Q2F
111 | QuantValue::Q2S,
112 store: QuantStore::Native,
113 ..
114 } => {
115 let scales = scales.iter().next().unwrap();
116 let strategy = QuantizationStrategy::PerTensorSymmetric(
117 SymmetricQuantization::init(scales, scheme.value),
118 );
119 let values = strategy.quantize(data_f.as_slice().unwrap());
120 (
121 TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
122 vec![QParams { scales }],
123 )
124 }
125 QuantScheme {
126 level: QuantLevel::Block(block_size),
127 mode: QuantMode::Symmetric,
128 #[cfg(not(feature = "export_tests"))]
129 value: QuantValue::Q8F | QuantValue::Q8S,
130 #[cfg(feature = "export_tests")]
131 value:
132 QuantValue::Q8F
133 | QuantValue::Q8S
134 | QuantValue::Q4F
135 | QuantValue::Q4S
136 | QuantValue::Q2F
137 | QuantValue::Q2S,
138 store: QuantStore::Native,
139 ..
140 } => {
141 let scales = scales.as_slice().unwrap();
142 let (strategy, qparams) = scales
143 .iter()
144 .map(|&s| {
145 (
146 SymmetricQuantization::init(s, scheme.value),
147 QParams { scales: s },
148 )
149 })
150 .unzip();
151 let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
152 let values = strategy.quantize(data_f.as_slice().unwrap());
153 (
154 TensorData::quantized(values, shape.clone(), *scheme, scales),
155 qparams,
156 )
157 }
158 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
159 };
160
161 let num_elements = data.num_elements();
162 let q_bytes = QuantizedBytes {
163 bytes: data.into_bytes(),
164 scheme: *scheme,
165 num_elements,
166 };
167 let (values, _) = q_bytes.into_vec_i8();
168 let data = TensorData::new(values, shape).convert::<Q>();
169
170 NdArrayQTensor {
171 qtensor: NdArrayTensor::from_data(data),
172 scheme: *scheme,
173 qparams,
174 }
175 }
176
177 fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
178 let strategy = tensor.strategy();
179 let scheme = tensor.scheme;
180 let shape = tensor.shape();
181 let data = match tensor.qtensor {
182 NdArrayTensor::I8(storage) => {
183 let data = storage.into_shared().into_iter().collect();
184 dequantize(data, shape, scheme, &strategy, dtype.into())
185 }
186 _ => unreachable!(),
187 };
188 NdArrayTensor::from_data(data)
189 }
190
191 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
192 NdArrayDevice::Cpu
193 }
194
195 fn q_to_device(
196 tensor: QuantizedTensor<Self>,
197 _device: &NdArrayDevice,
198 ) -> QuantizedTensor<Self> {
199 tensor
200 }
201
202 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
203 NdArrayQTensor {
204 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
205 NdArrayOps::reshape(array, shape)
206 }),
207 scheme: tensor.scheme,
208 qparams: tensor.qparams,
209 }
210 }
211
212 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
213 let shape = tensor.qtensor.shape();
214 let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
215 Ok(execute_with_numeric_dtype!(
216 tensor.qtensor,
217 E,
218 |array: SharedArray<E>| {
219 let values = array.into_iter().collect();
220 TensorData::quantized(values, shape, tensor.scheme, &scales)
221 }
222 ))
223 }
224
225 fn q_swap_dims(
226 tensor: QuantizedTensor<Self>,
227 dim1: usize,
228 dim2: usize,
229 ) -> QuantizedTensor<Self> {
230 NdArrayQTensor {
231 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
232 NdArrayOps::swap_dims(array, dim1, dim2)
233 }),
234 scheme: tensor.scheme,
235 qparams: tensor.qparams,
236 }
237 }
238
239 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
240 NdArrayQTensor {
241 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
242 NdArrayOps::permute(array, axes)
243 }),
244 scheme: tensor.scheme,
245 qparams: tensor.qparams,
246 }
247 }
248
249 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
250 NdArrayQTensor {
251 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
252 NdArrayOps::flip(array, axes)
253 }),
254 scheme: tensor.scheme,
255 qparams: tensor.qparams,
256 }
257 }
258
259 fn q_gather(
260 dim: usize,
261 tensor: QuantizedTensor<Self>,
262 indices: IntTensor<Self>,
263 ) -> QuantizedTensor<Self> {
264 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
265 IntElem,
266 >|
267 -> NdArrayTensor {
268 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
269 NdArrayOps::gather(dim, array, idx_array)
270 })
271 });
272 NdArrayQTensor {
273 qtensor,
274 scheme: tensor.scheme,
275 qparams: tensor.qparams,
276 }
277 }
278
279 fn q_select(
280 tensor: QuantizedTensor<Self>,
281 dim: usize,
282 indices: IntTensor<Self>,
283 ) -> QuantizedTensor<Self> {
284 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
285 IntElem,
286 >|
287 -> NdArrayTensor {
288 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
289 NdArrayMathOps::select(array, dim, idx_array)
290 })
291 });
292 NdArrayQTensor {
293 qtensor,
294 scheme: tensor.scheme,
295 qparams: tensor.qparams,
296 }
297 }
298
299 fn q_slice(
300 tensor: QuantizedTensor<Self>,
301 slices: &[burn_backend::Slice],
302 ) -> QuantizedTensor<Self> {
303 NdArrayQTensor {
304 qtensor: slice!(tensor.qtensor, slices),
305 scheme: tensor.scheme,
306 qparams: tensor.qparams,
307 }
308 }
309
310 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
311 execute_with_int_out_dtype!(out_dtype, I, {
312 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
313 NdArrayMathOps::argmax::<I>(array, dim)
314 })
315 })
316 }
317
318 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
319 execute_with_int_out_dtype!(out_dtype, I, {
320 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
321 NdArrayMathOps::argmin::<I>(array, dim)
322 })
323 })
324 }
325
326 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
327 NdArrayQTensor {
328 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
329 NdArrayOps::expand(array, shape)
330 }),
331 scheme: tensor.scheme,
332 qparams: tensor.qparams,
333 }
334 }
335}
336
337fn dequantize<Q: QuantElement>(
338 data: Vec<Q>,
339 shape: Shape,
340 scheme: QuantScheme,
341 strategy: &QuantizationStrategy,
342 dtype: DType,
343) -> TensorData {
344 let qparams = match strategy {
345 QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
346 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
347 quant.iter().map(|q| q.scale).collect()
348 }
349 };
350 let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
351 let (values, _qparams) = q_bytes.into_vec_i8();
352 TensorData::new(strategy.dequantize(&values), shape).convert_dtype(dtype)
353}