burn_ndarray/ops/
qtensor.rs

1use alloc::{vec, vec::Vec};
2
3use burn_tensor::{
4    DType, Shape, TensorData, TensorMetadata,
5    backend::ExecutionError,
6    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7    quantization::{
8        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
9        QuantizationParametersPrimitive, QuantizedBytes,
10    },
11};
12
13use crate::{
14    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15    element::{IntNdArrayElement, QuantElement},
16    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23    for NdArray<E, I, Q>
24where
25    NdArrayTensor: From<SharedArray<E>>,
26    NdArrayTensor: From<SharedArray<I>>,
27{
28    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29        match data.dtype {
30            DType::QFloat(scheme) => {
31                let shape = data.shape.clone();
32                let num_elements = data.num_elements();
33                let q_bytes = QuantizedBytes {
34                    bytes: data.into_bytes(),
35                    scheme,
36                    num_elements,
37                };
38
39                match scheme {
40                    QuantScheme {
41                        level: QuantLevel::Tensor | QuantLevel::Block(_),
42                        mode: QuantMode::Symmetric,
43                        value: QuantValue::Q8F | QuantValue::Q8S,
44                        store: QuantStore::Native | QuantStore::U32,
45                        ..
46                    } => {
47                        // We can load QuantStore::U32 w/ QuantizedBytes impl
48                        let (values, qparams) = q_bytes.into_vec_i8();
49                        let data = TensorData::new(values, shape);
50                        // Overwrite storage
51                        let scheme = scheme.with_store(QuantStore::Native);
52
53                        let qparams = qparams
54                            .scales
55                            .into_iter()
56                            .map(|scales| QParams { scales })
57                            .collect();
58
59                        NdArrayQTensor {
60                            qtensor: NdArrayTensor::from_data(data),
61                            scheme,
62                            qparams,
63                        }
64                    }
65                    QuantScheme {
66                        value:
67                            QuantValue::Q4F
68                            | QuantValue::Q4S
69                            | QuantValue::Q2F
70                            | QuantValue::Q2S
71                            | QuantValue::E2M1
72                            | QuantValue::E4M3
73                            | QuantValue::E5M2,
74                        ..
75                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
76                }
77            }
78            _ => panic!(
79                "Invalid dtype (expected DType::QFloat, got {:?})",
80                data.dtype
81            ),
82        }
83    }
84
85    fn quantize(
86        tensor: FloatTensor<Self>,
87        scheme: &QuantScheme,
88        qparams: QuantizationParametersPrimitive<Self>,
89    ) -> QuantizedTensor<Self> {
90        let shape = tensor.shape();
91        let data_f = tensor.into_data();
92        let scales = qparams.scales.into_data().convert::<f32>();
93
94        // Implement with ndarray instead of QuantizationStrategy?
95        let (data, qparams) = match scheme {
96            QuantScheme {
97                level: QuantLevel::Tensor,
98                mode: QuantMode::Symmetric,
99                #[cfg(not(feature = "export_tests"))]
100                    value: QuantValue::Q8F | QuantValue::Q8S,
101                // For tests, "native" sub-byte quant serves as a reference for value equality.
102                // Values are stored as i8 regardless.
103                #[cfg(feature = "export_tests")]
104                    value:
105                    QuantValue::Q8F
106                    | QuantValue::Q8S
107                    | QuantValue::Q4F
108                    | QuantValue::Q4S
109                    | QuantValue::Q2F
110                    | QuantValue::Q2S,
111                store: QuantStore::Native,
112                ..
113            } => {
114                let scales = scales.iter().next().unwrap();
115                let strategy = QuantizationStrategy::PerTensorSymmetric(
116                    SymmetricQuantization::init(scales, scheme.value),
117                );
118                let values = strategy.quantize(data_f.as_slice().unwrap());
119                (
120                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
121                    vec![QParams { scales }],
122                )
123            }
124            QuantScheme {
125                level: QuantLevel::Block(block_size),
126                mode: QuantMode::Symmetric,
127                #[cfg(not(feature = "export_tests"))]
128                    value: QuantValue::Q8F | QuantValue::Q8S,
129                #[cfg(feature = "export_tests")]
130                    value:
131                    QuantValue::Q8F
132                    | QuantValue::Q8S
133                    | QuantValue::Q4F
134                    | QuantValue::Q4S
135                    | QuantValue::Q2F
136                    | QuantValue::Q2S,
137                store: QuantStore::Native,
138                ..
139            } => {
140                let scales = scales.as_slice().unwrap();
141                let (strategy, qparams) = scales
142                    .iter()
143                    .map(|&s| {
144                        (
145                            SymmetricQuantization::init(s, scheme.value),
146                            QParams { scales: s },
147                        )
148                    })
149                    .unzip();
150                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
151                let values = strategy.quantize(data_f.as_slice().unwrap());
152                (
153                    TensorData::quantized(values, shape.clone(), *scheme, scales),
154                    qparams,
155                )
156            }
157            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
158        };
159
160        let num_elements = data.num_elements();
161        let q_bytes = QuantizedBytes {
162            bytes: data.into_bytes(),
163            scheme: *scheme,
164            num_elements,
165        };
166        let (values, _) = q_bytes.into_vec_i8();
167        let data = TensorData::new(values, shape).convert::<Q>();
168
169        NdArrayQTensor {
170            qtensor: NdArrayTensor::from_data(data),
171            scheme: *scheme,
172            qparams,
173        }
174    }
175
176    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
177        let strategy = tensor.strategy();
178        let scheme = tensor.scheme;
179        let shape = tensor.shape();
180        let data = match tensor.qtensor {
181            NdArrayTensor::I8(qtensor) => {
182                let data = qtensor.into_iter().collect();
183                dequantize(data, shape, scheme, &strategy)
184            }
185            _ => unreachable!(),
186        };
187        NdArrayTensor::from_data(data)
188    }
189
190    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
191        NdArrayDevice::Cpu
192    }
193
194    fn q_to_device(
195        tensor: QuantizedTensor<Self>,
196        _device: &NdArrayDevice,
197    ) -> QuantizedTensor<Self> {
198        tensor
199    }
200
201    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
202        NdArrayQTensor {
203            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
204                qtensor, shape
205            )),
206            scheme: tensor.scheme,
207            qparams: tensor.qparams,
208        }
209    }
210
211    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
212        let shape = tensor.qtensor.shape();
213        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
214        Ok(execute_with_numeric_dtype!(
215            tensor.qtensor,
216            E,
217            |qtensor: SharedArray<E>| {
218                let values = qtensor.into_iter().collect();
219                TensorData::quantized(values, shape, tensor.scheme, &scales)
220            }
221        ))
222    }
223
224    fn q_swap_dims(
225        tensor: QuantizedTensor<Self>,
226        dim1: usize,
227        dim2: usize,
228    ) -> QuantizedTensor<Self> {
229        NdArrayQTensor {
230            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
231                qtensor, dim1, dim2
232            )),
233            scheme: tensor.scheme,
234            qparams: tensor.qparams,
235        }
236    }
237
238    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239        NdArrayQTensor {
240            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
241                qtensor, axes
242            )),
243            scheme: tensor.scheme,
244            qparams: tensor.qparams,
245        }
246    }
247
248    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249        NdArrayQTensor {
250            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
251            scheme: tensor.scheme,
252            qparams: tensor.qparams,
253        }
254    }
255
256    fn q_gather(
257        dim: usize,
258        tensor: QuantizedTensor<Self>,
259        indices: IntTensor<Self>,
260    ) -> QuantizedTensor<Self> {
261        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
262            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
263                NdArrayMathOps::gather(dim, qtensor, indices)
264            })
265        });
266        NdArrayQTensor {
267            qtensor,
268            scheme: tensor.scheme,
269            qparams: tensor.qparams,
270        }
271    }
272
273    fn q_select(
274        tensor: QuantizedTensor<Self>,
275        dim: usize,
276        indices: IntTensor<Self>,
277    ) -> QuantizedTensor<Self> {
278        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
279            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
280                NdArrayMathOps::select(qtensor, dim, indices)
281            })
282        });
283        NdArrayQTensor {
284            qtensor,
285            scheme: tensor.scheme,
286            qparams: tensor.qparams,
287        }
288    }
289
290    fn q_slice(
291        tensor: QuantizedTensor<Self>,
292        slices: &[burn_tensor::Slice],
293    ) -> QuantizedTensor<Self> {
294        NdArrayQTensor {
295            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
296                qtensor, slices
297            )),
298            scheme: tensor.scheme,
299            qparams: tensor.qparams,
300        }
301    }
302
303    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
304        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
305            qtensor, dim
306        ))
307    }
308
309    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
310        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
311            qtensor, dim
312        ))
313    }
314
315    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
316        NdArrayQTensor {
317            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
318                qtensor, shape
319            )),
320            scheme: tensor.scheme,
321            qparams: tensor.qparams,
322        }
323    }
324}
325
326fn dequantize<Q: QuantElement>(
327    data: Vec<Q>,
328    shape: Shape,
329    scheme: QuantScheme,
330    strategy: &QuantizationStrategy,
331) -> TensorData {
332    let qparams = match strategy {
333        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
334        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
335            quant.iter().map(|q| q.scale).collect()
336        }
337    };
338    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
339    let (values, _qparams) = q_bytes.into_vec_i8();
340    TensorData::new(strategy.dequantize(&values), shape)
341}