burn_ndarray/ops/
qtensor.rs

1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4    DType, ExecutionError, Shape, TensorData, TensorMetadata,
5    ops::QTensorOps,
6    quantization::{
7        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8        QuantizationParametersPrimitive, QuantizedBytes,
9    },
10    tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12
13use crate::{
14    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15    element::{IntNdArrayElement, QuantElement},
16    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::quantization::{QuantizationStrategy, SymmetricQuantization};
20use super::{NdArrayMathOps, NdArrayOps};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
23    for NdArray<E, I, Q>
24where
25    NdArrayTensor: From<SharedArray<E>>,
26    NdArrayTensor: From<SharedArray<I>>,
27{
28    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
29        match data.dtype {
30            DType::QFloat(scheme) => {
31                let shape = data.shape.clone();
32                let num_elements = data.num_elements();
33                let q_bytes = QuantizedBytes {
34                    bytes: data.into_bytes(),
35                    scheme,
36                    num_elements,
37                };
38
39                match scheme {
40                    QuantScheme {
41                        level: QuantLevel::Tensor | QuantLevel::Block(_),
42                        mode: QuantMode::Symmetric,
43                        value: QuantValue::Q8F | QuantValue::Q8S,
44                        ..
45                    } => {
46                        // We can load QuantStore::U32 w/ QuantizedBytes impl
47                        let (values, qparams) = q_bytes.into_vec_i8();
48                        let data = TensorData::new(values, shape);
49                        // Overwrite storage
50                        let scheme = scheme.with_store(QuantStore::Native);
51
52                        let qparams = qparams
53                            .scales
54                            .into_iter()
55                            .map(|scales| QParams { scales })
56                            .collect();
57
58                        NdArrayQTensor {
59                            qtensor: NdArrayTensor::from_data(data),
60                            scheme,
61                            qparams,
62                        }
63                    }
64                    QuantScheme {
65                        value:
66                            QuantValue::Q4F
67                            | QuantValue::Q4S
68                            | QuantValue::Q2F
69                            | QuantValue::Q2S
70                            | QuantValue::E2M1
71                            | QuantValue::E4M3
72                            | QuantValue::E5M2,
73                        ..
74                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75                }
76            }
77            _ => panic!(
78                "Invalid dtype (expected DType::QFloat, got {:?})",
79                data.dtype
80            ),
81        }
82    }
83
84    fn quantize(
85        tensor: FloatTensor<Self>,
86        scheme: &QuantScheme,
87        qparams: QuantizationParametersPrimitive<Self>,
88    ) -> QuantizedTensor<Self> {
89        let shape = tensor.shape();
90        let data_f = tensor.into_data();
91        let scales = qparams.scales.into_data().convert::<f32>();
92
93        // Implement with ndarray instead of QuantizationStrategy?
94        let (data, qparams) = match scheme {
95            QuantScheme {
96                level: QuantLevel::Tensor,
97                mode: QuantMode::Symmetric,
98                #[cfg(not(feature = "export_tests"))]
99                    value: QuantValue::Q8F | QuantValue::Q8S,
100                // For tests, "native" sub-byte quant serves as a reference for value equality.
101                // Values are stored as i8 regardless.
102                #[cfg(feature = "export_tests")]
103                    value:
104                    QuantValue::Q8F
105                    | QuantValue::Q8S
106                    | QuantValue::Q4F
107                    | QuantValue::Q4S
108                    | QuantValue::Q2F
109                    | QuantValue::Q2S,
110                store: QuantStore::Native,
111                ..
112            } => {
113                let scales = scales.iter().next().unwrap();
114                let strategy = QuantizationStrategy::PerTensorSymmetric(
115                    SymmetricQuantization::init(scales, scheme.value),
116                );
117                let values = strategy.quantize(data_f.as_slice().unwrap());
118                (
119                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
120                    vec![QParams { scales }],
121                )
122            }
123            QuantScheme {
124                level: QuantLevel::Block(block_size),
125                mode: QuantMode::Symmetric,
126                #[cfg(not(feature = "export_tests"))]
127                    value: QuantValue::Q8F | QuantValue::Q8S,
128                #[cfg(feature = "export_tests")]
129                    value:
130                    QuantValue::Q8F
131                    | QuantValue::Q8S
132                    | QuantValue::Q4F
133                    | QuantValue::Q4S
134                    | QuantValue::Q2F
135                    | QuantValue::Q2S,
136                store: QuantStore::Native,
137                ..
138            } => {
139                let scales = scales.as_slice().unwrap();
140                let (strategy, qparams) = scales
141                    .iter()
142                    .map(|&s| {
143                        (
144                            SymmetricQuantization::init(s, scheme.value),
145                            QParams { scales: s },
146                        )
147                    })
148                    .unzip();
149                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
150                let values = strategy.quantize(data_f.as_slice().unwrap());
151                (
152                    TensorData::quantized(values, shape.clone(), *scheme, scales),
153                    qparams,
154                )
155            }
156            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
157        };
158
159        let num_elements = data.num_elements();
160        let q_bytes = QuantizedBytes {
161            bytes: data.into_bytes(),
162            scheme: *scheme,
163            num_elements,
164        };
165        let (values, _) = q_bytes.into_vec_i8();
166        let data = TensorData::new(values, shape).convert::<Q>();
167
168        NdArrayQTensor {
169            qtensor: NdArrayTensor::from_data(data),
170            scheme: *scheme,
171            qparams,
172        }
173    }
174
175    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
176        let strategy = tensor.strategy();
177        let scheme = tensor.scheme;
178        let shape = tensor.shape();
179        let data = match tensor.qtensor {
180            NdArrayTensor::I8(storage) => {
181                let data = storage.into_shared().into_iter().collect();
182                dequantize(data, shape, scheme, &strategy)
183            }
184            _ => unreachable!(),
185        };
186        NdArrayTensor::from_data(data)
187    }
188
189    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
190        NdArrayDevice::Cpu
191    }
192
193    fn q_to_device(
194        tensor: QuantizedTensor<Self>,
195        _device: &NdArrayDevice,
196    ) -> QuantizedTensor<Self> {
197        tensor
198    }
199
200    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
201        NdArrayQTensor {
202            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
203                NdArrayOps::reshape(array, shape)
204            }),
205            scheme: tensor.scheme,
206            qparams: tensor.qparams,
207        }
208    }
209
210    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
211        let shape = tensor.qtensor.shape();
212        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
213        Ok(execute_with_numeric_dtype!(
214            tensor.qtensor,
215            E,
216            |array: SharedArray<E>| {
217                let values = array.into_iter().collect();
218                TensorData::quantized(values, shape, tensor.scheme, &scales)
219            }
220        ))
221    }
222
223    fn q_swap_dims(
224        tensor: QuantizedTensor<Self>,
225        dim1: usize,
226        dim2: usize,
227    ) -> QuantizedTensor<Self> {
228        NdArrayQTensor {
229            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
230                NdArrayOps::swap_dims(array, dim1, dim2)
231            }),
232            scheme: tensor.scheme,
233            qparams: tensor.qparams,
234        }
235    }
236
237    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
238        NdArrayQTensor {
239            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
240                NdArrayOps::permute(array, axes)
241            }),
242            scheme: tensor.scheme,
243            qparams: tensor.qparams,
244        }
245    }
246
247    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
248        NdArrayQTensor {
249            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
250                NdArrayOps::flip(array, axes)
251            }),
252            scheme: tensor.scheme,
253            qparams: tensor.qparams,
254        }
255    }
256
257    fn q_gather(
258        dim: usize,
259        tensor: QuantizedTensor<Self>,
260        indices: IntTensor<Self>,
261    ) -> QuantizedTensor<Self> {
262        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
263            IntElem,
264        >|
265         -> NdArrayTensor {
266            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
267                NdArrayOps::gather(dim, array, idx_array)
268            })
269        });
270        NdArrayQTensor {
271            qtensor,
272            scheme: tensor.scheme,
273            qparams: tensor.qparams,
274        }
275    }
276
277    fn q_select(
278        tensor: QuantizedTensor<Self>,
279        dim: usize,
280        indices: IntTensor<Self>,
281    ) -> QuantizedTensor<Self> {
282        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
283            IntElem,
284        >|
285         -> NdArrayTensor {
286            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
287                NdArrayMathOps::select(array, dim, idx_array)
288            })
289        });
290        NdArrayQTensor {
291            qtensor,
292            scheme: tensor.scheme,
293            qparams: tensor.qparams,
294        }
295    }
296
297    fn q_slice(
298        tensor: QuantizedTensor<Self>,
299        slices: &[burn_backend::Slice],
300    ) -> QuantizedTensor<Self> {
301        NdArrayQTensor {
302            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
303                NdArrayOps::slice(array, slices)
304            }),
305            scheme: tensor.scheme,
306            qparams: tensor.qparams,
307        }
308    }
309
310    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
311        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
312            NdArrayMathOps::argmax::<I>(array, dim)
313        })
314    }
315
316    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
317        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
318            NdArrayMathOps::argmin::<I>(array, dim)
319        })
320    }
321
322    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
323        NdArrayQTensor {
324            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
325                NdArrayOps::expand(array, shape)
326            }),
327            scheme: tensor.scheme,
328            qparams: tensor.qparams,
329        }
330    }
331}
332
333fn dequantize<Q: QuantElement>(
334    data: Vec<Q>,
335    shape: Shape,
336    scheme: QuantScheme,
337    strategy: &QuantizationStrategy,
338) -> TensorData {
339    let qparams = match strategy {
340        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
341        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
342            quant.iter().map(|q| q.scale).collect()
343        }
344    };
345    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
346    let (values, _qparams) = q_bytes.into_vec_i8();
347    TensorData::new(strategy.dequantize(&values), shape)
348}