1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4 DType, ExecutionError, Shape, TensorData, TensorMetadata,
5 ops::QTensorOps,
6 quantization::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8 QuantizationParametersPrimitive, QuantizedBytes,
9 },
10 tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12
13use crate::{
14 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15 element::{IntNdArrayElement, QuantElement},
16 execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23 for NdArray<E, I, Q>
24where
25 NdArrayTensor: From<SharedArray<E>>,
26 NdArrayTensor: From<SharedArray<I>>,
27{
28 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29 match data.dtype {
30 DType::QFloat(scheme) => {
31 let shape = data.shape.clone();
32 let num_elements = data.num_elements();
33 let q_bytes = QuantizedBytes {
34 bytes: data.into_bytes(),
35 scheme,
36 num_elements,
37 };
38
39 match scheme {
40 QuantScheme {
41 level: QuantLevel::Tensor | QuantLevel::Block(_),
42 mode: QuantMode::Symmetric,
43 value: QuantValue::Q8F | QuantValue::Q8S,
44 ..
45 } => {
46 let (values, qparams) = q_bytes.into_vec_i8();
48 let data = TensorData::new(values, shape);
49 let scheme = scheme.with_store(QuantStore::Native);
51
52 let qparams = qparams
53 .scales
54 .into_iter()
55 .map(|scales| QParams { scales })
56 .collect();
57
58 NdArrayQTensor {
59 qtensor: NdArrayTensor::from_data(data),
60 scheme,
61 qparams,
62 }
63 }
64 QuantScheme {
65 value:
66 QuantValue::Q4F
67 | QuantValue::Q4S
68 | QuantValue::Q2F
69 | QuantValue::Q2S
70 | QuantValue::E2M1
71 | QuantValue::E4M3
72 | QuantValue::E5M2,
73 ..
74 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75 }
76 }
77 _ => panic!(
78 "Invalid dtype (expected DType::QFloat, got {:?})",
79 data.dtype
80 ),
81 }
82 }
83
84 fn quantize(
85 tensor: FloatTensor<Self>,
86 scheme: &QuantScheme,
87 qparams: QuantizationParametersPrimitive<Self>,
88 ) -> QuantizedTensor<Self> {
89 let shape = tensor.shape();
90 let data_f = tensor.into_data();
91 let scales = qparams.scales.into_data().convert::<f32>();
92
93 let (data, qparams) = match scheme {
95 QuantScheme {
96 level: QuantLevel::Tensor,
97 mode: QuantMode::Symmetric,
98 #[cfg(not(feature = "export_tests"))]
99 value: QuantValue::Q8F | QuantValue::Q8S,
100 #[cfg(feature = "export_tests")]
103 value:
104 QuantValue::Q8F
105 | QuantValue::Q8S
106 | QuantValue::Q4F
107 | QuantValue::Q4S
108 | QuantValue::Q2F
109 | QuantValue::Q2S,
110 store: QuantStore::Native,
111 ..
112 } => {
113 let scales = scales.iter().next().unwrap();
114 let strategy = QuantizationStrategy::PerTensorSymmetric(
115 SymmetricQuantization::init(scales, scheme.value),
116 );
117 let values = strategy.quantize(data_f.as_slice().unwrap());
118 (
119 TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
120 vec![QParams { scales }],
121 )
122 }
123 QuantScheme {
124 level: QuantLevel::Block(block_size),
125 mode: QuantMode::Symmetric,
126 #[cfg(not(feature = "export_tests"))]
127 value: QuantValue::Q8F | QuantValue::Q8S,
128 #[cfg(feature = "export_tests")]
129 value:
130 QuantValue::Q8F
131 | QuantValue::Q8S
132 | QuantValue::Q4F
133 | QuantValue::Q4S
134 | QuantValue::Q2F
135 | QuantValue::Q2S,
136 store: QuantStore::Native,
137 ..
138 } => {
139 let scales = scales.as_slice().unwrap();
140 let (strategy, qparams) = scales
141 .iter()
142 .map(|&s| {
143 (
144 SymmetricQuantization::init(s, scheme.value),
145 QParams { scales: s },
146 )
147 })
148 .unzip();
149 let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
150 let values = strategy.quantize(data_f.as_slice().unwrap());
151 (
152 TensorData::quantized(values, shape.clone(), *scheme, scales),
153 qparams,
154 )
155 }
156 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
157 };
158
159 let num_elements = data.num_elements();
160 let q_bytes = QuantizedBytes {
161 bytes: data.into_bytes(),
162 scheme: *scheme,
163 num_elements,
164 };
165 let (values, _) = q_bytes.into_vec_i8();
166 let data = TensorData::new(values, shape).convert::<Q>();
167
168 NdArrayQTensor {
169 qtensor: NdArrayTensor::from_data(data),
170 scheme: *scheme,
171 qparams,
172 }
173 }
174
175 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
176 let strategy = tensor.strategy();
177 let scheme = tensor.scheme;
178 let shape = tensor.shape();
179 let data = match tensor.qtensor {
180 NdArrayTensor::I8(storage) => {
181 let data = storage.into_shared().into_iter().collect();
182 dequantize(data, shape, scheme, &strategy)
183 }
184 _ => unreachable!(),
185 };
186 NdArrayTensor::from_data(data)
187 }
188
189 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
190 NdArrayDevice::Cpu
191 }
192
193 fn q_to_device(
194 tensor: QuantizedTensor<Self>,
195 _device: &NdArrayDevice,
196 ) -> QuantizedTensor<Self> {
197 tensor
198 }
199
200 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
201 NdArrayQTensor {
202 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
203 NdArrayOps::reshape(array, shape)
204 }),
205 scheme: tensor.scheme,
206 qparams: tensor.qparams,
207 }
208 }
209
210 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
211 let shape = tensor.qtensor.shape();
212 let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
213 Ok(execute_with_numeric_dtype!(
214 tensor.qtensor,
215 E,
216 |array: SharedArray<E>| {
217 let values = array.into_iter().collect();
218 TensorData::quantized(values, shape, tensor.scheme, &scales)
219 }
220 ))
221 }
222
223 fn q_swap_dims(
224 tensor: QuantizedTensor<Self>,
225 dim1: usize,
226 dim2: usize,
227 ) -> QuantizedTensor<Self> {
228 NdArrayQTensor {
229 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
230 NdArrayOps::swap_dims(array, dim1, dim2)
231 }),
232 scheme: tensor.scheme,
233 qparams: tensor.qparams,
234 }
235 }
236
237 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
238 NdArrayQTensor {
239 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
240 NdArrayOps::permute(array, axes)
241 }),
242 scheme: tensor.scheme,
243 qparams: tensor.qparams,
244 }
245 }
246
247 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
248 NdArrayQTensor {
249 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
250 NdArrayOps::flip(array, axes)
251 }),
252 scheme: tensor.scheme,
253 qparams: tensor.qparams,
254 }
255 }
256
257 fn q_gather(
258 dim: usize,
259 tensor: QuantizedTensor<Self>,
260 indices: IntTensor<Self>,
261 ) -> QuantizedTensor<Self> {
262 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
263 IntElem,
264 >|
265 -> NdArrayTensor {
266 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
267 NdArrayOps::gather(dim, array, idx_array)
268 })
269 });
270 NdArrayQTensor {
271 qtensor,
272 scheme: tensor.scheme,
273 qparams: tensor.qparams,
274 }
275 }
276
277 fn q_select(
278 tensor: QuantizedTensor<Self>,
279 dim: usize,
280 indices: IntTensor<Self>,
281 ) -> QuantizedTensor<Self> {
282 let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
283 IntElem,
284 >|
285 -> NdArrayTensor {
286 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
287 NdArrayMathOps::select(array, dim, idx_array)
288 })
289 });
290 NdArrayQTensor {
291 qtensor,
292 scheme: tensor.scheme,
293 qparams: tensor.qparams,
294 }
295 }
296
297 fn q_slice(
298 tensor: QuantizedTensor<Self>,
299 slices: &[burn_backend::Slice],
300 ) -> QuantizedTensor<Self> {
301 NdArrayQTensor {
302 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
303 NdArrayOps::slice(array, slices)
304 }),
305 scheme: tensor.scheme,
306 qparams: tensor.qparams,
307 }
308 }
309
310 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
311 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
312 NdArrayMathOps::argmax::<I>(array, dim)
313 })
314 }
315
316 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
317 execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
318 NdArrayMathOps::argmin::<I>(array, dim)
319 })
320 }
321
322 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
323 NdArrayQTensor {
324 qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
325 NdArrayOps::expand(array, shape)
326 }),
327 scheme: tensor.scheme,
328 qparams: tensor.qparams,
329 }
330 }
331}
332
333fn dequantize<Q: QuantElement>(
334 data: Vec<Q>,
335 shape: Shape,
336 scheme: QuantScheme,
337 strategy: &QuantizationStrategy,
338) -> TensorData {
339 let qparams = match strategy {
340 QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
341 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
342 quant.iter().map(|q| q.scale).collect()
343 }
344 };
345 let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
346 let (values, _qparams) = q_bytes.into_vec_i8();
347 TensorData::new(strategy.dequantize(&values), shape)
348}