1use alloc::{vec, vec::Vec};
2
3use burn_tensor::{
4 DType, Shape, TensorData, TensorMetadata,
5 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
6 quantization::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8 QuantizationParametersPrimitive, QuantizedBytes,
9 },
10};
11
12use crate::{
13 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
14 element::{IntNdArrayElement, QuantElement},
15 execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
16};
17
18use super::quantization::{QuantizationStrategy, SymmetricQuantization};
19use super::{NdArrayMathOps, NdArrayOps};
20
21impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
22 for NdArray<E, I, Q>
23where
24 NdArrayTensor: From<SharedArray<E>>,
25 NdArrayTensor: From<SharedArray<I>>,
26{
27 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
28 match data.dtype {
29 DType::QFloat(scheme) => {
30 let shape = data.shape.clone();
31 let num_elements = data.num_elements();
32 let q_bytes = QuantizedBytes {
33 bytes: data.into_bytes(),
34 scheme,
35 num_elements,
36 };
37
38 match scheme {
39 QuantScheme {
40 level: QuantLevel::Tensor | QuantLevel::Block(_),
41 mode: QuantMode::Symmetric,
42 value: QuantValue::Q8F | QuantValue::Q8S,
43 store: QuantStore::Native | QuantStore::U32,
44 ..
45 } => {
46 let (values, qparams) = q_bytes.into_vec_i8();
48 let data = TensorData::new(values, shape);
49 let scheme = scheme.with_store(QuantStore::Native);
51
52 let qparams = qparams
53 .scales
54 .into_iter()
55 .map(|scales| QParams { scales })
56 .collect();
57
58 NdArrayQTensor {
59 qtensor: NdArrayTensor::from_data(data),
60 scheme,
61 qparams,
62 }
63 }
64 QuantScheme {
65 value:
66 QuantValue::Q4F
67 | QuantValue::Q4S
68 | QuantValue::Q2F
69 | QuantValue::Q2S
70 | QuantValue::E2M1
71 | QuantValue::E4M3
72 | QuantValue::E5M2,
73 ..
74 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75 }
76 }
77 _ => panic!(
78 "Invalid dtype (expected DType::QFloat, got {:?})",
79 data.dtype
80 ),
81 }
82 }
83
84 fn quantize(
85 tensor: FloatTensor<Self>,
86 scheme: &QuantScheme,
87 qparams: QuantizationParametersPrimitive<Self>,
88 ) -> QuantizedTensor<Self> {
89 let shape = tensor.shape();
90 let data_f = tensor.into_data();
91 let scales = qparams.scales.into_data().convert::<f32>();
92
93 let (data, qparams) = match scheme {
95 QuantScheme {
96 level: QuantLevel::Tensor,
97 mode: QuantMode::Symmetric,
98 #[cfg(not(feature = "export_tests"))]
99 value: QuantValue::Q8F | QuantValue::Q8S,
100 #[cfg(feature = "export_tests")]
103 value:
104 QuantValue::Q8F
105 | QuantValue::Q8S
106 | QuantValue::Q4F
107 | QuantValue::Q4S
108 | QuantValue::Q2F
109 | QuantValue::Q2S,
110 store: QuantStore::Native,
111 ..
112 } => {
113 let scales = scales.iter().next().unwrap();
114 let strategy = QuantizationStrategy::PerTensorSymmetric(
115 SymmetricQuantization::init(scales, scheme.value),
116 );
117 let values = strategy.quantize(data_f.as_slice().unwrap());
118 (
119 TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
120 vec![QParams { scales }],
121 )
122 }
123 QuantScheme {
124 level: QuantLevel::Block(block_size),
125 mode: QuantMode::Symmetric,
126 #[cfg(not(feature = "export_tests"))]
127 value: QuantValue::Q8F | QuantValue::Q8S,
128 #[cfg(feature = "export_tests")]
129 value:
130 QuantValue::Q8F
131 | QuantValue::Q8S
132 | QuantValue::Q4F
133 | QuantValue::Q4S
134 | QuantValue::Q2F
135 | QuantValue::Q2S,
136 store: QuantStore::Native,
137 ..
138 } => {
139 let scales = scales.as_slice().unwrap();
140 let (strategy, qparams) = scales
141 .iter()
142 .map(|&s| {
143 (
144 SymmetricQuantization::init(s, scheme.value),
145 QParams { scales: s },
146 )
147 })
148 .unzip();
149 let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
150 let values = strategy.quantize(data_f.as_slice().unwrap());
151 (
152 TensorData::quantized(values, shape.clone(), *scheme, scales),
153 qparams,
154 )
155 }
156 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
157 };
158
159 let num_elements = data.num_elements();
160 let q_bytes = QuantizedBytes {
161 bytes: data.into_bytes(),
162 scheme: *scheme,
163 num_elements,
164 };
165 let (values, _) = q_bytes.into_vec_i8();
166 let data = TensorData::new(values, shape).convert::<Q>();
167
168 NdArrayQTensor {
169 qtensor: NdArrayTensor::from_data(data),
170 scheme: *scheme,
171 qparams,
172 }
173 }
174
175 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
176 let strategy = tensor.strategy();
177 let scheme = tensor.scheme;
178 let shape = tensor.shape();
179 let data = match tensor.qtensor {
180 NdArrayTensor::I8(qtensor) => {
181 let data = qtensor.into_iter().collect();
182 dequantize(data, shape, scheme, &strategy)
183 }
184 _ => unreachable!(),
185 };
186 NdArrayTensor::from_data(data)
187 }
188
189 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
190 NdArrayDevice::Cpu
191 }
192
193 fn q_to_device(
194 tensor: QuantizedTensor<Self>,
195 _device: &NdArrayDevice,
196 ) -> QuantizedTensor<Self> {
197 tensor
198 }
199
200 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
201 NdArrayQTensor {
202 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
203 qtensor, shape
204 )),
205 scheme: tensor.scheme,
206 qparams: tensor.qparams,
207 }
208 }
209
210 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
211 let shape = tensor.qtensor.shape();
212 let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
213 execute_with_numeric_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
214 let values = qtensor.into_iter().collect();
215 TensorData::quantized(values, shape, tensor.scheme, &scales)
216 })
217 }
218
219 fn q_swap_dims(
220 tensor: QuantizedTensor<Self>,
221 dim1: usize,
222 dim2: usize,
223 ) -> QuantizedTensor<Self> {
224 NdArrayQTensor {
225 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
226 qtensor, dim1, dim2
227 )),
228 scheme: tensor.scheme,
229 qparams: tensor.qparams,
230 }
231 }
232
233 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
234 NdArrayQTensor {
235 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
236 qtensor, axes
237 )),
238 scheme: tensor.scheme,
239 qparams: tensor.qparams,
240 }
241 }
242
243 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
244 NdArrayQTensor {
245 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
246 scheme: tensor.scheme,
247 qparams: tensor.qparams,
248 }
249 }
250
251 fn q_gather(
252 dim: usize,
253 tensor: QuantizedTensor<Self>,
254 indices: IntTensor<Self>,
255 ) -> QuantizedTensor<Self> {
256 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
257 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
258 NdArrayMathOps::gather(dim, qtensor, indices)
259 })
260 });
261 NdArrayQTensor {
262 qtensor,
263 scheme: tensor.scheme,
264 qparams: tensor.qparams,
265 }
266 }
267
268 fn q_select(
269 tensor: QuantizedTensor<Self>,
270 dim: usize,
271 indices: IntTensor<Self>,
272 ) -> QuantizedTensor<Self> {
273 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
274 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
275 NdArrayMathOps::select(qtensor, dim, indices)
276 })
277 });
278 NdArrayQTensor {
279 qtensor,
280 scheme: tensor.scheme,
281 qparams: tensor.qparams,
282 }
283 }
284
285 fn q_slice(
286 tensor: QuantizedTensor<Self>,
287 slices: &[burn_tensor::Slice],
288 ) -> QuantizedTensor<Self> {
289 NdArrayQTensor {
290 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
291 qtensor, slices
292 )),
293 scheme: tensor.scheme,
294 qparams: tensor.qparams,
295 }
296 }
297
298 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
299 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
300 qtensor, dim
301 ))
302 }
303
304 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
305 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
306 qtensor, dim
307 ))
308 }
309
310 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
311 NdArrayQTensor {
312 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
313 qtensor, shape
314 )),
315 scheme: tensor.scheme,
316 qparams: tensor.qparams,
317 }
318 }
319}
320
321fn dequantize<Q: QuantElement>(
322 data: Vec<Q>,
323 shape: Shape,
324 scheme: QuantScheme,
325 strategy: &QuantizationStrategy,
326) -> TensorData {
327 let qparams = match strategy {
328 QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
329 QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
330 quant.iter().map(|q| q.scale).collect()
331 }
332 };
333 let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
334 let (values, _qparams) = q_bytes.into_vec_i8();
335 TensorData::new(strategy.dequantize(&values), shape)
336}