1use alloc::{vec, vec::Vec};
2
3use burn_tensor::{
4 DType, Shape, TensorData, TensorMetadata,
5 backend::ExecutionError,
6 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7 quantization::{
8 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
9 QuantizationParametersPrimitive, QuantizedBytes,
10 },
11};
12
13use crate::{
14 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15 element::{IntNdArrayElement, QuantElement},
16 execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23 for NdArray<E, I, Q>
24where
25 NdArrayTensor: From<SharedArray<E>>,
26 NdArrayTensor: From<SharedArray<I>>,
27{
28 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29 match data.dtype {
30 DType::QFloat(scheme) => {
31 let shape = data.shape.clone();
32 let num_elements = data.num_elements();
33 let q_bytes = QuantizedBytes {
34 bytes: data.into_bytes(),
35 scheme,
36 num_elements,
37 };
38
39 match scheme {
40 QuantScheme {
41 level: QuantLevel::Tensor | QuantLevel::Block(_),
42 mode: QuantMode::Symmetric,
43 value: QuantValue::Q8F | QuantValue::Q8S,
44 store: QuantStore::Native | QuantStore::U32,
45 ..
46 } => {
47 let (values, qparams) = q_bytes.into_vec_i8();
49 let data = TensorData::new(values, shape);
50 let scheme = scheme.with_store(QuantStore::Native);
52
53 let qparams = qparams
54 .scales
55 .into_iter()
56 .map(|scales| QParams { scales })
57 .collect();
58
59 NdArrayQTensor {
60 qtensor: NdArrayTensor::from_data(data),
61 scheme,
62 qparams,
63 }
64 }
65 QuantScheme {
66 value:
67 QuantValue::Q4F
68 | QuantValue::Q4S
69 | QuantValue::Q2F
70 | QuantValue::Q2S
71 | QuantValue::E2M1
72 | QuantValue::E4M3
73 | QuantValue::E5M2,
74 ..
75 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
76 }
77 }
78 _ => panic!(
79 "Invalid dtype (expected DType::QFloat, got {:?})",
80 data.dtype
81 ),
82 }
83 }
84
85 fn quantize(
86 tensor: FloatTensor<Self>,
87 scheme: &QuantScheme,
88 qparams: QuantizationParametersPrimitive<Self>,
89 ) -> QuantizedTensor<Self> {
90 let shape = tensor.shape();
91 let data_f = tensor.into_data();
92 let scales = qparams.scales.into_data().convert::<f32>();
93
94 let (data, qparams) = match scheme {
96 QuantScheme {
97 level: QuantLevel::Tensor,
98 mode: QuantMode::Symmetric,
99 #[cfg(not(feature = "export_tests"))]
100 value: QuantValue::Q8F | QuantValue::Q8S,
101 #[cfg(feature = "export_tests")]
104 value:
105 QuantValue::Q8F
106 | QuantValue::Q8S
107 | QuantValue::Q4F
108 | QuantValue::Q4S
109 | QuantValue::Q2F
110 | QuantValue::Q2S,
111 store: QuantStore::Native,
112 ..
113 } => {
114 let scales = scales.iter().next().unwrap();
115 let strategy = QuantizationStrategy::PerTensorSymmetric(
116 SymmetricQuantization::init(scales, scheme.value),
117 );
118 let values = strategy.quantize(data_f.as_slice().unwrap());
119 (
120 TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
121 vec![QParams { scales }],
122 )
123 }
124 QuantScheme {
125 level: QuantLevel::Block(block_size),
126 mode: QuantMode::Symmetric,
127 #[cfg(not(feature = "export_tests"))]
128 value: QuantValue::Q8F | QuantValue::Q8S,
129 #[cfg(feature = "export_tests")]
130 value:
131 QuantValue::Q8F
132 | QuantValue::Q8S
133 | QuantValue::Q4F
134 | QuantValue::Q4S
135 | QuantValue::Q2F
136 | QuantValue::Q2S,
137 store: QuantStore::Native,
138 ..
139 } => {
140 let scales = scales.as_slice().unwrap();
141 let (strategy, qparams) = scales
142 .iter()
143 .map(|&s| {
144 (
145 SymmetricQuantization::init(s, scheme.value),
146 QParams { scales: s },
147 )
148 })
149 .unzip();
150 let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
151 let values = strategy.quantize(data_f.as_slice().unwrap());
152 (
153 TensorData::quantized(values, shape.clone(), *scheme, scales),
154 qparams,
155 )
156 }
157 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
158 };
159
160 let num_elements = data.num_elements();
161 let q_bytes = QuantizedBytes {
162 bytes: data.into_bytes(),
163 scheme: *scheme,
164 num_elements,
165 };
166 let (values, _) = q_bytes.into_vec_i8();
167 let data = TensorData::new(values, shape).convert::<Q>();
168
169 NdArrayQTensor {
170 qtensor: NdArrayTensor::from_data(data),
171 scheme: *scheme,
172 qparams,
173 }
174 }
175
176 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
177 let strategy = tensor.strategy();
178 let scheme = tensor.scheme;
179 let shape = tensor.shape();
180 let data = match tensor.qtensor {
181 NdArrayTensor::I8(qtensor) => {
182 let data = qtensor.into_iter().collect();
183 dequantize(data, shape, scheme, &strategy)
184 }
185 _ => unreachable!(),
186 };
187 NdArrayTensor::from_data(data)
188 }
189
190 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
191 NdArrayDevice::Cpu
192 }
193
194 fn q_to_device(
195 tensor: QuantizedTensor<Self>,
196 _device: &NdArrayDevice,
197 ) -> QuantizedTensor<Self> {
198 tensor
199 }
200
201 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
202 NdArrayQTensor {
203 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
204 qtensor, shape
205 )),
206 scheme: tensor.scheme,
207 qparams: tensor.qparams,
208 }
209 }
210
211 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
212 let shape = tensor.qtensor.shape();
213 let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
214 Ok(execute_with_numeric_dtype!(
215 tensor.qtensor,
216 E,
217 |qtensor: SharedArray<E>| {
218 let values = qtensor.into_iter().collect();
219 TensorData::quantized(values, shape, tensor.scheme, &scales)
220 }
221 ))
222 }
223
224 fn q_swap_dims(
225 tensor: QuantizedTensor<Self>,
226 dim1: usize,
227 dim2: usize,
228 ) -> QuantizedTensor<Self> {
229 NdArrayQTensor {
230 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
231 qtensor, dim1, dim2
232 )),
233 scheme: tensor.scheme,
234 qparams: tensor.qparams,
235 }
236 }
237
238 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239 NdArrayQTensor {
240 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
241 qtensor, axes
242 )),
243 scheme: tensor.scheme,
244 qparams: tensor.qparams,
245 }
246 }
247
248 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249 NdArrayQTensor {
250 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
251 scheme: tensor.scheme,
252 qparams: tensor.qparams,
253 }
254 }
255
256 fn q_gather(
257 dim: usize,
258 tensor: QuantizedTensor<Self>,
259 indices: IntTensor<Self>,
260 ) -> QuantizedTensor<Self> {
261 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
262 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
263 NdArrayMathOps::gather(dim, qtensor, indices)
264 })
265 });
266 NdArrayQTensor {
267 qtensor,
268 scheme: tensor.scheme,
269 qparams: tensor.qparams,
270 }
271 }
272
273 fn q_select(
274 tensor: QuantizedTensor<Self>,
275 dim: usize,
276 indices: IntTensor<Self>,
277 ) -> QuantizedTensor<Self> {
278 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
279 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
280 NdArrayMathOps::select(qtensor, dim, indices)
281 })
282 });
283 NdArrayQTensor {
284 qtensor,
285 scheme: tensor.scheme,
286 qparams: tensor.qparams,
287 }
288 }
289
290 fn q_slice(
291 tensor: QuantizedTensor<Self>,
292 slices: &[burn_tensor::Slice],
293 ) -> QuantizedTensor<Self> {
294 NdArrayQTensor {
295 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
296 qtensor, slices
297 )),
298 scheme: tensor.scheme,
299 qparams: tensor.qparams,
300 }
301 }
302
303 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
304 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
305 qtensor, dim
306 ))
307 }
308
309 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
310 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
311 qtensor, dim
312 ))
313 }
314
315 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
316 NdArrayQTensor {
317 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
318 qtensor, shape
319 )),
320 scheme: tensor.scheme,
321 qparams: tensor.qparams,
322 }
323 }
324}
325
326fn dequantize<Q: QuantElement>(
327 data: Vec<Q>,
328 shape: Shape,
329 scheme: QuantScheme,
330 strategy: &QuantizationStrategy,
331) -> TensorData {
332 let qparams = match strategy {
333 QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
334 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
335 quant.iter().map(|q| q.scale).collect()
336 }
337 };
338 let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
339 let (values, _qparams) = q_bytes.into_vec_i8();
340 TensorData::new(strategy.dequantize(&values), shape)
341}