burn_ndarray/ops/
qtensor.rs

1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4    DType, ExecutionError, Shape, TensorData, TensorMetadata,
5    ops::QTensorOps,
6    quantization::{
7        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8        QuantizationParametersPrimitive, QuantizedBytes,
9    },
10    tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12
13use crate::{
14    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15    element::{IntNdArrayElement, QuantElement},
16    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23    for NdArray<E, I, Q>
24where
25    NdArrayTensor: From<SharedArray<E>>,
26    NdArrayTensor: From<SharedArray<I>>,
27{
28    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29        match data.dtype {
30            DType::QFloat(scheme) => {
31                let shape = data.shape.clone();
32                let num_elements = data.num_elements();
33                let q_bytes = QuantizedBytes {
34                    bytes: data.into_bytes(),
35                    scheme,
36                    num_elements,
37                };
38
39                match scheme {
40                    QuantScheme {
41                        level: QuantLevel::Tensor | QuantLevel::Block(_),
42                        mode: QuantMode::Symmetric,
43                        value: QuantValue::Q8F | QuantValue::Q8S,
44                        store: QuantStore::Native | QuantStore::U32,
45                        ..
46                    } => {
47                        // We can load QuantStore::U32 w/ QuantizedBytes impl
48                        let (values, qparams) = q_bytes.into_vec_i8();
49                        let data = TensorData::new(values, shape);
50                        // Overwrite storage
51                        let scheme = scheme.with_store(QuantStore::Native);
52
53                        let qparams = qparams
54                            .scales
55                            .into_iter()
56                            .map(|scales| QParams { scales })
57                            .collect();
58
59                        NdArrayQTensor {
60                            qtensor: NdArrayTensor::from_data(data),
61                            scheme,
62                            qparams,
63                        }
64                    }
65                    QuantScheme {
66                        value:
67                            QuantValue::Q4F
68                            | QuantValue::Q4S
69                            | QuantValue::Q2F
70                            | QuantValue::Q2S
71                            | QuantValue::E2M1
72                            | QuantValue::E4M3
73                            | QuantValue::E5M2,
74                        ..
75                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
76                }
77            }
78            _ => panic!(
79                "Invalid dtype (expected DType::QFloat, got {:?})",
80                data.dtype
81            ),
82        }
83    }
84
85    fn quantize(
86        tensor: FloatTensor<Self>,
87        scheme: &QuantScheme,
88        qparams: QuantizationParametersPrimitive<Self>,
89    ) -> QuantizedTensor<Self> {
90        let shape = tensor.shape();
91        let data_f = tensor.into_data();
92        let scales = qparams.scales.into_data().convert::<f32>();
93
94        // Implement with ndarray instead of QuantizationStrategy?
95        let (data, qparams) = match scheme {
96            QuantScheme {
97                level: QuantLevel::Tensor,
98                mode: QuantMode::Symmetric,
99                #[cfg(not(feature = "export_tests"))]
100                    value: QuantValue::Q8F | QuantValue::Q8S,
101                // For tests, "native" sub-byte quant serves as a reference for value equality.
102                // Values are stored as i8 regardless.
103                #[cfg(feature = "export_tests")]
104                    value:
105                    QuantValue::Q8F
106                    | QuantValue::Q8S
107                    | QuantValue::Q4F
108                    | QuantValue::Q4S
109                    | QuantValue::Q2F
110                    | QuantValue::Q2S,
111                store: QuantStore::Native,
112                ..
113            } => {
114                let scales = scales.iter().next().unwrap();
115                let strategy = QuantizationStrategy::PerTensorSymmetric(
116                    SymmetricQuantization::init(scales, scheme.value),
117                );
118                let values = strategy.quantize(data_f.as_slice().unwrap());
119                (
120                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
121                    vec![QParams { scales }],
122                )
123            }
124            QuantScheme {
125                level: QuantLevel::Block(block_size),
126                mode: QuantMode::Symmetric,
127                #[cfg(not(feature = "export_tests"))]
128                    value: QuantValue::Q8F | QuantValue::Q8S,
129                #[cfg(feature = "export_tests")]
130                    value:
131                    QuantValue::Q8F
132                    | QuantValue::Q8S
133                    | QuantValue::Q4F
134                    | QuantValue::Q4S
135                    | QuantValue::Q2F
136                    | QuantValue::Q2S,
137                store: QuantStore::Native,
138                ..
139            } => {
140                let scales = scales.as_slice().unwrap();
141                let (strategy, qparams) = scales
142                    .iter()
143                    .map(|&s| {
144                        (
145                            SymmetricQuantization::init(s, scheme.value),
146                            QParams { scales: s },
147                        )
148                    })
149                    .unzip();
150                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
151                let values = strategy.quantize(data_f.as_slice().unwrap());
152                (
153                    TensorData::quantized(values, shape.clone(), *scheme, scales),
154                    qparams,
155                )
156            }
157            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
158        };
159
160        let num_elements = data.num_elements();
161        let q_bytes = QuantizedBytes {
162            bytes: data.into_bytes(),
163            scheme: *scheme,
164            num_elements,
165        };
166        let (values, _) = q_bytes.into_vec_i8();
167        let data = TensorData::new(values, shape).convert::<Q>();
168
169        NdArrayQTensor {
170            qtensor: NdArrayTensor::from_data(data),
171            scheme: *scheme,
172            qparams,
173        }
174    }
175
176    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
177        let strategy = tensor.strategy();
178        let scheme = tensor.scheme;
179        let shape = tensor.shape();
180        let data = match tensor.qtensor {
181            NdArrayTensor::I8(storage) => {
182                let data = storage.into_shared().into_iter().collect();
183                dequantize(data, shape, scheme, &strategy)
184            }
185            _ => unreachable!(),
186        };
187        NdArrayTensor::from_data(data)
188    }
189
190    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
191        NdArrayDevice::Cpu
192    }
193
194    fn q_to_device(
195        tensor: QuantizedTensor<Self>,
196        _device: &NdArrayDevice,
197    ) -> QuantizedTensor<Self> {
198        tensor
199    }
200
201    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
202        NdArrayQTensor {
203            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
204                NdArrayOps::reshape(array, shape)
205            }),
206            scheme: tensor.scheme,
207            qparams: tensor.qparams,
208        }
209    }
210
211    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
212        let shape = tensor.qtensor.shape();
213        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
214        Ok(execute_with_numeric_dtype!(
215            tensor.qtensor,
216            E,
217            |array: SharedArray<E>| {
218                let values = array.into_iter().collect();
219                TensorData::quantized(values, shape, tensor.scheme, &scales)
220            }
221        ))
222    }
223
224    fn q_swap_dims(
225        tensor: QuantizedTensor<Self>,
226        dim1: usize,
227        dim2: usize,
228    ) -> QuantizedTensor<Self> {
229        NdArrayQTensor {
230            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
231                NdArrayOps::swap_dims(array, dim1, dim2)
232            }),
233            scheme: tensor.scheme,
234            qparams: tensor.qparams,
235        }
236    }
237
238    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239        NdArrayQTensor {
240            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
241                NdArrayOps::permute(array, axes)
242            }),
243            scheme: tensor.scheme,
244            qparams: tensor.qparams,
245        }
246    }
247
248    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249        NdArrayQTensor {
250            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
251                NdArrayOps::flip(array, axes)
252            }),
253            scheme: tensor.scheme,
254            qparams: tensor.qparams,
255        }
256    }
257
258    fn q_gather(
259        dim: usize,
260        tensor: QuantizedTensor<Self>,
261        indices: IntTensor<Self>,
262    ) -> QuantizedTensor<Self> {
263        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
264            IntElem,
265        >|
266         -> NdArrayTensor {
267            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
268                NdArrayOps::gather(dim, array, idx_array)
269            })
270        });
271        NdArrayQTensor {
272            qtensor,
273            scheme: tensor.scheme,
274            qparams: tensor.qparams,
275        }
276    }
277
278    fn q_select(
279        tensor: QuantizedTensor<Self>,
280        dim: usize,
281        indices: IntTensor<Self>,
282    ) -> QuantizedTensor<Self> {
283        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
284            IntElem,
285        >|
286         -> NdArrayTensor {
287            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
288                NdArrayMathOps::select(array, dim, idx_array)
289            })
290        });
291        NdArrayQTensor {
292            qtensor,
293            scheme: tensor.scheme,
294            qparams: tensor.qparams,
295        }
296    }
297
298    fn q_slice(
299        tensor: QuantizedTensor<Self>,
300        slices: &[burn_backend::Slice],
301    ) -> QuantizedTensor<Self> {
302        NdArrayQTensor {
303            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
304                NdArrayOps::slice(array, slices)
305            }),
306            scheme: tensor.scheme,
307            qparams: tensor.qparams,
308        }
309    }
310
311    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
312        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
313            NdArrayMathOps::argmax::<I>(array, dim)
314        })
315    }
316
317    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
318        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
319            NdArrayMathOps::argmin::<I>(array, dim)
320        })
321    }
322
323    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
324        NdArrayQTensor {
325            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
326                NdArrayOps::expand(array, shape)
327            }),
328            scheme: tensor.scheme,
329            qparams: tensor.qparams,
330        }
331    }
332}
333
334fn dequantize<Q: QuantElement>(
335    data: Vec<Q>,
336    shape: Shape,
337    scheme: QuantScheme,
338    strategy: &QuantizationStrategy,
339) -> TensorData {
340    let qparams = match strategy {
341        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
342        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
343            quant.iter().map(|q| q.scale).collect()
344        }
345    };
346    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
347    let (values, _qparams) = q_bytes.into_vec_i8();
348    TensorData::new(strategy.dequantize(&values), shape)
349}