Skip to main content

burn_ndarray/ops/
qtensor.rs

1use alloc::{vec, vec::Vec};
2
3use burn_backend::{
4    DType, ExecutionError, Shape, TensorData, TensorMetadata,
5    ops::QTensorOps,
6    quantization::{
7        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8        QuantizationParametersPrimitive, QuantizedBytes,
9    },
10    tensor::{FloatTensor, IntTensor, QuantizedTensor},
11};
12use burn_std::{FloatDType, IntDType};
13
14use crate::{
15    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
16    element::{IntNdArrayElement, QuantElement},
17    execute_with_dtype, execute_with_int_dtype, execute_with_int_out_dtype,
18    execute_with_numeric_dtype, slice,
19};
20
21use super::quantization::{QuantizationStrategy, SymmetricQuantization};
22use super::{NdArrayMathOps, NdArrayOps};
23
24impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
25    for NdArray<E, I, Q>
26where
27    NdArrayTensor: From<SharedArray<E>>,
28    NdArrayTensor: From<SharedArray<I>>,
29{
30    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
31        match data.dtype {
32            DType::QFloat(scheme) => {
33                let shape = data.shape.clone();
34                let num_elements = data.num_elements();
35                let q_bytes = QuantizedBytes {
36                    bytes: data.into_bytes(),
37                    scheme,
38                    num_elements,
39                };
40
41                match scheme {
42                    QuantScheme {
43                        level: QuantLevel::Tensor | QuantLevel::Block(_),
44                        mode: QuantMode::Symmetric,
45                        value: QuantValue::Q8F | QuantValue::Q8S,
46                        ..
47                    } => {
48                        // We can load QuantStore::U32 w/ QuantizedBytes impl
49                        let (values, qparams) = q_bytes.into_vec_i8();
50                        let data = TensorData::new(values, shape);
51                        // Overwrite storage
52                        let scheme = scheme.with_store(QuantStore::Native);
53
54                        let qparams = qparams
55                            .scales
56                            .into_iter()
57                            .map(|scales| QParams { scales })
58                            .collect();
59
60                        NdArrayQTensor {
61                            qtensor: NdArrayTensor::from_data(data),
62                            scheme,
63                            qparams,
64                        }
65                    }
66                    QuantScheme {
67                        value:
68                            QuantValue::Q4F
69                            | QuantValue::Q4S
70                            | QuantValue::Q2F
71                            | QuantValue::Q2S
72                            | QuantValue::E2M1
73                            | QuantValue::E4M3
74                            | QuantValue::E5M2,
75                        ..
76                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
77                }
78            }
79            _ => panic!(
80                "Invalid dtype (expected DType::QFloat, got {:?})",
81                data.dtype
82            ),
83        }
84    }
85
86    fn quantize(
87        tensor: FloatTensor<Self>,
88        scheme: &QuantScheme,
89        qparams: QuantizationParametersPrimitive<Self>,
90    ) -> QuantizedTensor<Self> {
91        let shape = tensor.shape();
92        let data_f = tensor.into_data();
93        let scales = qparams.scales.into_data().convert::<f32>();
94
95        // Implement with ndarray instead of QuantizationStrategy?
96        let (data, qparams) = match scheme {
97            QuantScheme {
98                level: QuantLevel::Tensor,
99                mode: QuantMode::Symmetric,
100                #[cfg(not(feature = "export_tests"))]
101                    value: QuantValue::Q8F | QuantValue::Q8S,
102                // For tests, "native" sub-byte quant serves as a reference for value equality.
103                // Values are stored as i8 regardless.
104                #[cfg(feature = "export_tests")]
105                    value:
106                    QuantValue::Q8F
107                    | QuantValue::Q8S
108                    | QuantValue::Q4F
109                    | QuantValue::Q4S
110                    | QuantValue::Q2F
111                    | QuantValue::Q2S,
112                store: QuantStore::Native,
113                ..
114            } => {
115                let scales = scales.iter().next().unwrap();
116                let strategy = QuantizationStrategy::PerTensorSymmetric(
117                    SymmetricQuantization::init(scales, scheme.value),
118                );
119                let values = strategy.quantize(data_f.as_slice().unwrap());
120                (
121                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),
122                    vec![QParams { scales }],
123                )
124            }
125            QuantScheme {
126                level: QuantLevel::Block(block_size),
127                mode: QuantMode::Symmetric,
128                #[cfg(not(feature = "export_tests"))]
129                    value: QuantValue::Q8F | QuantValue::Q8S,
130                #[cfg(feature = "export_tests")]
131                    value:
132                    QuantValue::Q8F
133                    | QuantValue::Q8S
134                    | QuantValue::Q4F
135                    | QuantValue::Q4S
136                    | QuantValue::Q2F
137                    | QuantValue::Q2S,
138                store: QuantStore::Native,
139                ..
140            } => {
141                let scales = scales.as_slice().unwrap();
142                let (strategy, qparams) = scales
143                    .iter()
144                    .map(|&s| {
145                        (
146                            SymmetricQuantization::init(s, scheme.value),
147                            QParams { scales: s },
148                        )
149                    })
150                    .unzip();
151                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);
152                let values = strategy.quantize(data_f.as_slice().unwrap());
153                (
154                    TensorData::quantized(values, shape.clone(), *scheme, scales),
155                    qparams,
156                )
157            }
158            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
159        };
160
161        let num_elements = data.num_elements();
162        let q_bytes = QuantizedBytes {
163            bytes: data.into_bytes(),
164            scheme: *scheme,
165            num_elements,
166        };
167        let (values, _) = q_bytes.into_vec_i8();
168        let data = TensorData::new(values, shape).convert::<Q>();
169
170        NdArrayQTensor {
171            qtensor: NdArrayTensor::from_data(data),
172            scheme: *scheme,
173            qparams,
174        }
175    }
176
177    fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
178        let strategy = tensor.strategy();
179        let scheme = tensor.scheme;
180        let shape = tensor.shape();
181        let data = match tensor.qtensor {
182            NdArrayTensor::I8(storage) => {
183                let data = storage.into_shared().into_iter().collect();
184                dequantize(data, shape, scheme, &strategy, dtype.into())
185            }
186            _ => unreachable!(),
187        };
188        NdArrayTensor::from_data(data)
189    }
190
191    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
192        NdArrayDevice::Cpu
193    }
194
195    fn q_to_device(
196        tensor: QuantizedTensor<Self>,
197        _device: &NdArrayDevice,
198    ) -> QuantizedTensor<Self> {
199        tensor
200    }
201
202    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
203        NdArrayQTensor {
204            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
205                NdArrayOps::reshape(array, shape)
206            }),
207            scheme: tensor.scheme,
208            qparams: tensor.qparams,
209        }
210    }
211
212    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
213        let shape = tensor.qtensor.shape();
214        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();
215        Ok(execute_with_numeric_dtype!(
216            tensor.qtensor,
217            E,
218            |array: SharedArray<E>| {
219                let values = array.into_iter().collect();
220                TensorData::quantized(values, shape, tensor.scheme, &scales)
221            }
222        ))
223    }
224
225    fn q_swap_dims(
226        tensor: QuantizedTensor<Self>,
227        dim1: usize,
228        dim2: usize,
229    ) -> QuantizedTensor<Self> {
230        NdArrayQTensor {
231            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
232                NdArrayOps::swap_dims(array, dim1, dim2)
233            }),
234            scheme: tensor.scheme,
235            qparams: tensor.qparams,
236        }
237    }
238
239    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
240        NdArrayQTensor {
241            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
242                NdArrayOps::permute(array, axes)
243            }),
244            scheme: tensor.scheme,
245            qparams: tensor.qparams,
246        }
247    }
248
249    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
250        NdArrayQTensor {
251            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
252                NdArrayOps::flip(array, axes)
253            }),
254            scheme: tensor.scheme,
255            qparams: tensor.qparams,
256        }
257    }
258
259    fn q_gather(
260        dim: usize,
261        tensor: QuantizedTensor<Self>,
262        indices: IntTensor<Self>,
263    ) -> QuantizedTensor<Self> {
264        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
265            IntElem,
266        >|
267         -> NdArrayTensor {
268            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
269                NdArrayOps::gather(dim, array, idx_array)
270            })
271        });
272        NdArrayQTensor {
273            qtensor,
274            scheme: tensor.scheme,
275            qparams: tensor.qparams,
276        }
277    }
278
279    fn q_select(
280        tensor: QuantizedTensor<Self>,
281        dim: usize,
282        indices: IntTensor<Self>,
283    ) -> QuantizedTensor<Self> {
284        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<
285            IntElem,
286        >|
287         -> NdArrayTensor {
288            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
289                NdArrayMathOps::select(array, dim, idx_array)
290            })
291        });
292        NdArrayQTensor {
293            qtensor,
294            scheme: tensor.scheme,
295            qparams: tensor.qparams,
296        }
297    }
298
299    fn q_slice(
300        tensor: QuantizedTensor<Self>,
301        slices: &[burn_backend::Slice],
302    ) -> QuantizedTensor<Self> {
303        NdArrayQTensor {
304            qtensor: slice!(tensor.qtensor, slices),
305            scheme: tensor.scheme,
306            qparams: tensor.qparams,
307        }
308    }
309
310    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
311        execute_with_int_out_dtype!(out_dtype, I, {
312            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
313                NdArrayMathOps::argmax::<I>(array, dim)
314            })
315        })
316    }
317
318    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
319        execute_with_int_out_dtype!(out_dtype, I, {
320            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
321                NdArrayMathOps::argmin::<I>(array, dim)
322            })
323        })
324    }
325
326    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
327        NdArrayQTensor {
328            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {
329                NdArrayOps::expand(array, shape)
330            }),
331            scheme: tensor.scheme,
332            qparams: tensor.qparams,
333        }
334    }
335}
336
337fn dequantize<Q: QuantElement>(
338    data: Vec<Q>,
339    shape: Shape,
340    scheme: QuantScheme,
341    strategy: &QuantizationStrategy,
342    dtype: DType,
343) -> TensorData {
344    let qparams = match strategy {
345        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],
346        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {
347            quant.iter().map(|q| q.scale).collect()
348        }
349    };
350    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);
351    let (values, _qparams) = q_bytes.into_vec_i8();
352    TensorData::new(strategy.dequantize(&values), shape).convert_dtype(dtype)
353}