burn_ndarray/ops/
qtensor.rs

1use alloc::{vec, vec::Vec};
2
3use burn_tensor::{
4    DType, Shape, TensorData, TensorMetadata,
5    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
6    quantization::{
7        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8        QuantizationParametersPrimitive, QuantizedBytes,
9    },
10};
11
12use crate::{
13    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
14    element::{IntNdArrayElement, QuantElement},
15    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
16};
17
18use super::quantization::{QuantizationStrategy, SymmetricQuantization};
19use super::{NdArrayMathOps, NdArrayOps};
20
21impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
22    for NdArray<E, I, Q>
23where
24    NdArrayTensor: From<SharedArray<E>>,
25    NdArrayTensor: From<SharedArray<I>>,
26{
27    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
28        match data.dtype {
29            DType::QFloat(scheme) => {
30                let shape = data.shape.clone();
31                let num_elements = data.num_elements();
32                let q_bytes = QuantizedBytes {
33                    bytes: data.into_bytes(),
34                    scheme,
35                    num_elements,
36                };
37
38                match scheme {
39                    QuantScheme {
40                        level: QuantLevel::Tensor | QuantLevel::Block(_),
41                        mode: QuantMode::Symmetric,
42                        value: QuantValue::Q8F | QuantValue::Q8S,
43                        store: QuantStore::Native | QuantStore::U32,
44                        ..
45                    } => {
46                        // We can load QuantStore::U32 w/ QuantizedBytes impl
47                        let (values, qparams) = q_bytes.into_vec_i8();
48                        let data = TensorData::new(values, shape);
49                        // Overwrite storage
50                        let scheme = scheme.with_store(QuantStore::Native);
51
52                        let qparams = qparams
53                            .scales
54                            .into_iter()
55                            .map(|scales| QParams { scales })
56                            .collect();
57
58                        NdArrayQTensor {
59                            qtensor: NdArrayTensor::from_data(data),
60                            scheme,
61                            qparams,
62                        }
63                    }
64                    QuantScheme {
65                        value:
66                            QuantValue::Q4F
67                            | QuantValue::Q4S
68                            | QuantValue::Q2F
69                            | QuantValue::Q2S
70                            | QuantValue::E2M1
71                            | QuantValue::E4M3
72                            | QuantValue::E5M2,
73                        ..
74                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75                }
76            }
77            _ => panic!(
78                "Invalid dtype (expected DType::QFloat, got {:?})",
79                data.dtype
80            ),
81        }
82    }
83
84    fn quantize(
85        tensor: FloatTensor<Self>,
86        scheme: &QuantScheme,
87        qparams: QuantizationParametersPrimitive<Self>,
88    ) -> QuantizedTensor<Self> {
89        let shape = tensor.shape();
90        let data_f = tensor.into_data();
91        let scales = qparams.scales.into_data().convert::<f32>();
92
93        // Implement with ndarray instead of QuantizationStrategy?
94        let (data, qparams) = match scheme {
95            QuantScheme {
96                level: QuantLevel::Tensor,
97                mode: QuantMode::Symmetric,
98                #[cfg(not(feature = "export_tests"))]
99                    value: QuantValue::Q8F | QuantValue::Q8S,
100                // For tests, "native" sub-byte quant serves as a reference for value equality.
101                // Values are stored as i8 regardless.
102                #[cfg(feature = "export_tests")]
103                    value:
104                    QuantValue::Q8F
105                    | QuantValue::Q8S
106                    | QuantValue::Q4F
107                    | QuantValue::Q4S
108                    | QuantValue::Q2F
109                    | QuantValue::Q2S,
110                store: QuantStore::Native,
111                ..
112            } => {
113                let scales = scales.iter().next().unwrap();
114                let strategy = QuantizationStrategy::PerTensorSymmetric(
115                    SymmetricQuantization::init(scales, scheme.value),
116                );
117                let values = strategy.quantize(data_f.as_slice().unwrap());
118                (
119                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
120                    vec![QParams { scales }],
121                )
122            }
123            QuantScheme {
124                level: QuantLevel::Block(block_size),
125                mode: QuantMode::Symmetric,
126                #[cfg(not(feature = "export_tests"))]
127                    value: QuantValue::Q8F | QuantValue::Q8S,
128                #[cfg(feature = "export_tests")]
129                    value:
130                    QuantValue::Q8F
131                    | QuantValue::Q8S
132                    | QuantValue::Q4F
133                    | QuantValue::Q4S
134                    | QuantValue::Q2F
135                    | QuantValue::Q2S,
136                store: QuantStore::Native,
137                ..
138            } => {
139                let scales = scales.as_slice().unwrap();
140                let (strategy, qparams) = scales
141                    .iter()
142                    .map(|&s| {
143                        (
144                            SymmetricQuantization::init(s, scheme.value),
145                            QParams { scales: s },
146                        )
147                    })
148                    .unzip();
149                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
150                let values = strategy.quantize(data_f.as_slice().unwrap());
151                (
152                    TensorData::quantized(values, shape.clone(), *scheme, scales),
153                    qparams,
154                )
155            }
156            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
157        };
158
159        let num_elements = data.num_elements();
160        let q_bytes = QuantizedBytes {
161            bytes: data.into_bytes(),
162            scheme: *scheme,
163            num_elements,
164        };
165        let (values, _) = q_bytes.into_vec_i8();
166        let data = TensorData::new(values, shape).convert::<Q>();
167
168        NdArrayQTensor {
169            qtensor: NdArrayTensor::from_data(data),
170            scheme: *scheme,
171            qparams,
172        }
173    }
174
175    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
176        let strategy = tensor.strategy();
177        let scheme = tensor.scheme;
178        let shape = tensor.shape();
179        let data = match tensor.qtensor {
180            NdArrayTensor::I8(qtensor) => {
181                let data = qtensor.into_iter().collect();
182                dequantize(data, shape, scheme, &strategy)
183            }
184            _ => unreachable!(),
185        };
186        NdArrayTensor::from_data(data)
187    }
188
189    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
190        NdArrayDevice::Cpu
191    }
192
193    fn q_to_device(
194        tensor: QuantizedTensor<Self>,
195        _device: &NdArrayDevice,
196    ) -> QuantizedTensor<Self> {
197        tensor
198    }
199
200    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
201        NdArrayQTensor {
202            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
203                qtensor, shape
204            )),
205            scheme: tensor.scheme,
206            qparams: tensor.qparams,
207        }
208    }
209
210    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
211        let shape = tensor.qtensor.shape();
212        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
213        execute_with_numeric_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
214            let values = qtensor.into_iter().collect();
215            TensorData::quantized(values, shape, tensor.scheme, &scales)
216        })
217    }
218
219    fn q_swap_dims(
220        tensor: QuantizedTensor<Self>,
221        dim1: usize,
222        dim2: usize,
223    ) -> QuantizedTensor<Self> {
224        NdArrayQTensor {
225            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
226                qtensor, dim1, dim2
227            )),
228            scheme: tensor.scheme,
229            qparams: tensor.qparams,
230        }
231    }
232
233    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
234        NdArrayQTensor {
235            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
236                qtensor, axes
237            )),
238            scheme: tensor.scheme,
239            qparams: tensor.qparams,
240        }
241    }
242
243    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
244        NdArrayQTensor {
245            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
246            scheme: tensor.scheme,
247            qparams: tensor.qparams,
248        }
249    }
250
251    fn q_gather(
252        dim: usize,
253        tensor: QuantizedTensor<Self>,
254        indices: IntTensor<Self>,
255    ) -> QuantizedTensor<Self> {
256        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
257            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
258                NdArrayMathOps::gather(dim, qtensor, indices)
259            })
260        });
261        NdArrayQTensor {
262            qtensor,
263            scheme: tensor.scheme,
264            qparams: tensor.qparams,
265        }
266    }
267
268    fn q_select(
269        tensor: QuantizedTensor<Self>,
270        dim: usize,
271        indices: IntTensor<Self>,
272    ) -> QuantizedTensor<Self> {
273        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
274            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
275                NdArrayMathOps::select(qtensor, dim, indices)
276            })
277        });
278        NdArrayQTensor {
279            qtensor,
280            scheme: tensor.scheme,
281            qparams: tensor.qparams,
282        }
283    }
284
285    fn q_slice(
286        tensor: QuantizedTensor<Self>,
287        slices: &[burn_tensor::Slice],
288    ) -> QuantizedTensor<Self> {
289        NdArrayQTensor {
290            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
291                qtensor, slices
292            )),
293            scheme: tensor.scheme,
294            qparams: tensor.qparams,
295        }
296    }
297
298    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
299        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
300            qtensor, dim
301        ))
302    }
303
304    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
305        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
306            qtensor, dim
307        ))
308    }
309
310    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
311        NdArrayQTensor {
312            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
313                qtensor, shape
314            )),
315            scheme: tensor.scheme,
316            qparams: tensor.qparams,
317        }
318    }
319}
320
321fn dequantize<Q: QuantElement>(
322    data: Vec<Q>,
323    shape: Shape,
324    scheme: QuantScheme,
325    strategy: &QuantizationStrategy,
326) -> TensorData {
327    let qparams = match strategy {
328        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
329        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
330            quant.iter().map(|q| q.scale).collect()
331        }
332    };
333    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
334    let (values, _qparams) = q_bytes.into_vec_i8();
335    TensorData::new(strategy.dequantize(&values), shape)
336}