1use alloc::vec;
2
3use burn_tensor::{
4 DType, Shape, TensorData, TensorMetadata,
5 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
6 quantization::{
7 QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8 QuantizationParametersPrimitive, QuantizationStrategy, QuantizedBytes,
9 SymmetricQuantization,
10 },
11};
12
13use crate::{
14 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15 element::{IntNdArrayElement, QuantElement},
16 execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::{NdArrayMathOps, NdArrayOps};
20
21impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
22 for NdArray<E, I, Q>
23where
24 NdArrayTensor: From<SharedArray<E>>,
25 NdArrayTensor: From<SharedArray<I>>,
26{
27 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
28 match data.dtype {
29 DType::QFloat(scheme) => {
30 let shape = data.shape.clone();
31 let num_elements = data.num_elements();
32 let q_bytes = QuantizedBytes {
33 bytes: data.into_bytes(),
34 scheme,
35 num_elements,
36 };
37
38 match scheme {
39 QuantScheme {
40 level: QuantLevel::Tensor | QuantLevel::Block(_),
41 mode: QuantMode::Symmetric,
42 value: QuantValue::Q8F | QuantValue::Q8S,
43 store: QuantStore::Native | QuantStore::U32,
44 ..
45 } => {
46 let (values, qparams) = q_bytes.into_vec_i8();
48 let data = TensorData::new(values, shape);
49 let scheme = scheme.with_store(QuantStore::Native);
51
52 let qparams = qparams
53 .scales
54 .into_iter()
55 .map(|scales| QParams { scales })
56 .collect();
57
58 NdArrayQTensor {
59 qtensor: NdArrayTensor::from_data(data),
60 scheme,
61 qparams,
62 }
63 }
64 QuantScheme {
65 value:
66 QuantValue::Q4F
67 | QuantValue::Q4S
68 | QuantValue::Q2F
69 | QuantValue::Q2S
70 | QuantValue::E2M1
71 | QuantValue::E4M3
72 | QuantValue::E5M2,
73 ..
74 } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75 }
76 }
77 _ => panic!(
78 "Invalid dtype (expected DType::QFloat, got {:?})",
79 data.dtype
80 ),
81 }
82 }
83
84 fn quantize(
85 tensor: FloatTensor<Self>,
86 scheme: &QuantScheme,
87 qparams: QuantizationParametersPrimitive<Self>,
88 ) -> QuantizedTensor<Self> {
89 let (strategy, qparams) = match scheme {
91 QuantScheme {
92 level: QuantLevel::Tensor,
93 mode: QuantMode::Symmetric,
94 #[cfg(not(feature = "export_tests"))]
95 value: QuantValue::Q8F | QuantValue::Q8S,
96 #[cfg(feature = "export_tests")]
99 value:
100 QuantValue::Q8F
101 | QuantValue::Q8S
102 | QuantValue::Q4F
103 | QuantValue::Q4S
104 | QuantValue::Q2F
105 | QuantValue::Q2S,
106 store: QuantStore::Native,
107 ..
108 } => {
109 let scales = qparams.scales.into_data().iter().next().unwrap();
110 (
111 QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
112 scales,
113 scheme.value,
114 )),
115 vec![QParams { scales }],
116 )
117 }
118 QuantScheme {
119 level: QuantLevel::Block(block_size),
120 mode: QuantMode::Symmetric,
121 #[cfg(not(feature = "export_tests"))]
122 value: QuantValue::Q8F | QuantValue::Q8S,
123 #[cfg(feature = "export_tests")]
124 value:
125 QuantValue::Q8F
126 | QuantValue::Q8S
127 | QuantValue::Q4F
128 | QuantValue::Q4S
129 | QuantValue::Q2F
130 | QuantValue::Q2S,
131 store: QuantStore::Native,
132 ..
133 } => {
134 let (strategy, qparams) = qparams
135 .scales
136 .into_data()
137 .iter()
138 .map(|s| {
139 (
140 SymmetricQuantization::init(s, scheme.value),
141 QParams { scales: s },
142 )
143 })
144 .unzip();
145 (
146 QuantizationStrategy::PerBlockSymmetric(strategy, *block_size),
147 qparams,
148 )
149 }
150 scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
151 };
152
153 let shape = tensor.shape();
154 let data_f = tensor.into_data();
155 let values = strategy.quantize(data_f.as_slice().unwrap());
156 let data = TensorData::quantized(values, shape.clone(), strategy, *scheme);
157 let num_elements = data.num_elements();
158 let q_bytes = QuantizedBytes {
159 bytes: data.into_bytes(),
160 scheme: *scheme,
161 num_elements,
162 };
163 let (values, _) = q_bytes.into_vec_i8();
164 let data = TensorData::new(values, shape).convert::<Q>();
165
166 NdArrayQTensor {
167 qtensor: NdArrayTensor::from_data(data),
168 scheme: *scheme,
169 qparams,
170 }
171 }
172
173 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
174 let shape = tensor.qtensor.shape();
175 let strategy = tensor.strategy();
176 let data: TensorData = execute_with_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
177 let values = qtensor.into_iter().collect();
178 TensorData::quantized(values, shape, strategy, tensor.scheme)
179 });
180
181 NdArrayTensor::from_data(data.dequantize().unwrap())
182 }
183
184 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
185 NdArrayDevice::Cpu
186 }
187
188 fn q_to_device(
189 tensor: QuantizedTensor<Self>,
190 _device: &NdArrayDevice,
191 ) -> QuantizedTensor<Self> {
192 tensor
193 }
194
195 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
196 NdArrayQTensor {
197 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
198 qtensor, shape
199 )),
200 scheme: tensor.scheme,
201 qparams: tensor.qparams,
202 }
203 }
204
205 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
206 let strategy = tensor.strategy();
207 let shape = tensor.qtensor.shape();
208 execute_with_numeric_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
209 let values = qtensor.into_iter().collect();
210 TensorData::quantized(values, shape, strategy, tensor.scheme)
211 })
212 }
213
214 fn q_swap_dims(
215 tensor: QuantizedTensor<Self>,
216 dim1: usize,
217 dim2: usize,
218 ) -> QuantizedTensor<Self> {
219 NdArrayQTensor {
220 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
221 qtensor, dim1, dim2
222 )),
223 scheme: tensor.scheme,
224 qparams: tensor.qparams,
225 }
226 }
227
228 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
229 NdArrayQTensor {
230 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
231 qtensor, axes
232 )),
233 scheme: tensor.scheme,
234 qparams: tensor.qparams,
235 }
236 }
237
238 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239 NdArrayQTensor {
240 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
241 scheme: tensor.scheme,
242 qparams: tensor.qparams,
243 }
244 }
245
246 fn q_gather(
247 dim: usize,
248 tensor: QuantizedTensor<Self>,
249 indices: IntTensor<Self>,
250 ) -> QuantizedTensor<Self> {
251 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
252 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
253 NdArrayMathOps::gather(dim, qtensor, indices)
254 })
255 });
256 NdArrayQTensor {
257 qtensor,
258 scheme: tensor.scheme,
259 qparams: tensor.qparams,
260 }
261 }
262
263 fn q_select(
264 tensor: QuantizedTensor<Self>,
265 dim: usize,
266 indices: IntTensor<Self>,
267 ) -> QuantizedTensor<Self> {
268 let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
269 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
270 NdArrayMathOps::select(qtensor, dim, indices)
271 })
272 });
273 NdArrayQTensor {
274 qtensor,
275 scheme: tensor.scheme,
276 qparams: tensor.qparams,
277 }
278 }
279
280 fn q_slice(
281 tensor: QuantizedTensor<Self>,
282 slices: &[burn_tensor::Slice],
283 ) -> QuantizedTensor<Self> {
284 NdArrayQTensor {
285 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
286 qtensor, slices
287 )),
288 scheme: tensor.scheme,
289 qparams: tensor.qparams,
290 }
291 }
292
293 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
294 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
295 qtensor, dim
296 ))
297 }
298
299 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
300 execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
301 qtensor, dim
302 ))
303 }
304
305 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
306 NdArrayQTensor {
307 qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
308 qtensor, shape
309 )),
310 scheme: tensor.scheme,
311 qparams: tensor.qparams,
312 }
313 }
314}