burn_ndarray/ops/
qtensor.rs

1use alloc::vec;
2
3use burn_tensor::{
4    DType, Shape, TensorData, TensorMetadata,
5    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
6    quantization::{
7        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,
8        QuantizationParametersPrimitive, QuantizationStrategy, QuantizedBytes,
9        SymmetricQuantization,
10    },
11};
12
13use crate::{
14    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,
15    element::{IntNdArrayElement, QuantElement},
16    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype,
17};
18
19use super::{NdArrayMathOps, NdArrayOps};
20
21impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
22    for NdArray<E, I, Q>
23where
24    NdArrayTensor: From<SharedArray<E>>,
25    NdArrayTensor: From<SharedArray<I>>,
26{
27    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
28        match data.dtype {
29            DType::QFloat(scheme) => {
30                let shape = data.shape.clone();
31                let num_elements = data.num_elements();
32                let q_bytes = QuantizedBytes {
33                    bytes: data.into_bytes(),
34                    scheme,
35                    num_elements,
36                };
37
38                match scheme {
39                    QuantScheme {
40                        level: QuantLevel::Tensor | QuantLevel::Block(_),
41                        mode: QuantMode::Symmetric,
42                        value: QuantValue::Q8F | QuantValue::Q8S,
43                        store: QuantStore::Native | QuantStore::U32,
44                        ..
45                    } => {
46                        // We can load QuantStore::U32 w/ QuantizedBytes impl
47                        let (values, qparams) = q_bytes.into_vec_i8();
48                        let data = TensorData::new(values, shape);
49                        // Overwrite storage
50                        let scheme = scheme.with_store(QuantStore::Native);
51
52                        let qparams = qparams
53                            .scales
54                            .into_iter()
55                            .map(|scales| QParams { scales })
56                            .collect();
57
58                        NdArrayQTensor {
59                            qtensor: NdArrayTensor::from_data(data),
60                            scheme,
61                            qparams,
62                        }
63                    }
64                    QuantScheme {
65                        value:
66                            QuantValue::Q4F
67                            | QuantValue::Q4S
68                            | QuantValue::Q2F
69                            | QuantValue::Q2S
70                            | QuantValue::E2M1
71                            | QuantValue::E4M3
72                            | QuantValue::E5M2,
73                        ..
74                    } => unimplemented!("from_data not supported for scheme {scheme:?}"),
75                }
76            }
77            _ => panic!(
78                "Invalid dtype (expected DType::QFloat, got {:?})",
79                data.dtype
80            ),
81        }
82    }
83
84    fn quantize(
85        tensor: FloatTensor<Self>,
86        scheme: &QuantScheme,
87        qparams: QuantizationParametersPrimitive<Self>,
88    ) -> QuantizedTensor<Self> {
89        // Implement with ndarray instead of QuantizationStrategy?
90        let (strategy, qparams) = match scheme {
91            QuantScheme {
92                level: QuantLevel::Tensor,
93                mode: QuantMode::Symmetric,
94                #[cfg(not(feature = "export_tests"))]
95                    value: QuantValue::Q8F | QuantValue::Q8S,
96                // For tests, "native" sub-byte quant serves as a reference for value equality.
97                // Values are stored as i8 regardless.
98                #[cfg(feature = "export_tests")]
99                    value:
100                    QuantValue::Q8F
101                    | QuantValue::Q8S
102                    | QuantValue::Q4F
103                    | QuantValue::Q4S
104                    | QuantValue::Q2F
105                    | QuantValue::Q2S,
106                store: QuantStore::Native,
107                ..
108            } => {
109                let scales = qparams.scales.into_data().iter().next().unwrap();
110                (
111                    QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(
112                        scales,
113                        scheme.value,
114                    )),
115                    vec![QParams { scales }],
116                )
117            }
118            QuantScheme {
119                level: QuantLevel::Block(block_size),
120                mode: QuantMode::Symmetric,
121                #[cfg(not(feature = "export_tests"))]
122                    value: QuantValue::Q8F | QuantValue::Q8S,
123                #[cfg(feature = "export_tests")]
124                    value:
125                    QuantValue::Q8F
126                    | QuantValue::Q8S
127                    | QuantValue::Q4F
128                    | QuantValue::Q4S
129                    | QuantValue::Q2F
130                    | QuantValue::Q2S,
131                store: QuantStore::Native,
132                ..
133            } => {
134                let (strategy, qparams) = qparams
135                    .scales
136                    .into_data()
137                    .iter()
138                    .map(|s| {
139                        (
140                            SymmetricQuantization::init(s, scheme.value),
141                            QParams { scales: s },
142                        )
143                    })
144                    .unzip();
145                (
146                    QuantizationStrategy::PerBlockSymmetric(strategy, *block_size),
147                    qparams,
148                )
149            }
150            scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"),
151        };
152
153        let shape = tensor.shape();
154        let data_f = tensor.into_data();
155        let values = strategy.quantize(data_f.as_slice().unwrap());
156        let data = TensorData::quantized(values, shape.clone(), strategy, *scheme);
157        let num_elements = data.num_elements();
158        let q_bytes = QuantizedBytes {
159            bytes: data.into_bytes(),
160            scheme: *scheme,
161            num_elements,
162        };
163        let (values, _) = q_bytes.into_vec_i8();
164        let data = TensorData::new(values, shape).convert::<Q>();
165
166        NdArrayQTensor {
167            qtensor: NdArrayTensor::from_data(data),
168            scheme: *scheme,
169            qparams,
170        }
171    }
172
173    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
174        let shape = tensor.qtensor.shape();
175        let strategy = tensor.strategy();
176        let data: TensorData = execute_with_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
177            let values = qtensor.into_iter().collect();
178            TensorData::quantized(values, shape, strategy, tensor.scheme)
179        });
180
181        NdArrayTensor::from_data(data.dequantize().unwrap())
182    }
183
184    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
185        NdArrayDevice::Cpu
186    }
187
188    fn q_to_device(
189        tensor: QuantizedTensor<Self>,
190        _device: &NdArrayDevice,
191    ) -> QuantizedTensor<Self> {
192        tensor
193    }
194
195    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
196        NdArrayQTensor {
197            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::reshape(
198                qtensor, shape
199            )),
200            scheme: tensor.scheme,
201            qparams: tensor.qparams,
202        }
203    }
204
205    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
206        let strategy = tensor.strategy();
207        let shape = tensor.qtensor.shape();
208        execute_with_numeric_dtype!(tensor.qtensor, E, |qtensor: SharedArray<E>| {
209            let values = qtensor.into_iter().collect();
210            TensorData::quantized(values, shape, strategy, tensor.scheme)
211        })
212    }
213
214    fn q_swap_dims(
215        tensor: QuantizedTensor<Self>,
216        dim1: usize,
217        dim2: usize,
218    ) -> QuantizedTensor<Self> {
219        NdArrayQTensor {
220            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::swap_dims(
221                qtensor, dim1, dim2
222            )),
223            scheme: tensor.scheme,
224            qparams: tensor.qparams,
225        }
226    }
227
228    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
229        NdArrayQTensor {
230            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::permute(
231                qtensor, axes
232            )),
233            scheme: tensor.scheme,
234            qparams: tensor.qparams,
235        }
236    }
237
238    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
239        NdArrayQTensor {
240            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::flip(qtensor, axes)),
241            scheme: tensor.scheme,
242            qparams: tensor.qparams,
243        }
244    }
245
246    fn q_gather(
247        dim: usize,
248        tensor: QuantizedTensor<Self>,
249        indices: IntTensor<Self>,
250    ) -> QuantizedTensor<Self> {
251        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
252            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
253                NdArrayMathOps::gather(dim, qtensor, indices)
254            })
255        });
256        NdArrayQTensor {
257            qtensor,
258            scheme: tensor.scheme,
259            qparams: tensor.qparams,
260        }
261    }
262
263    fn q_select(
264        tensor: QuantizedTensor<Self>,
265        dim: usize,
266        indices: IntTensor<Self>,
267    ) -> QuantizedTensor<Self> {
268        let qtensor = execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
269            execute_with_numeric_dtype!(tensor.qtensor, |qtensor| {
270                NdArrayMathOps::select(qtensor, dim, indices)
271            })
272        });
273        NdArrayQTensor {
274            qtensor,
275            scheme: tensor.scheme,
276            qparams: tensor.qparams,
277        }
278    }
279
280    fn q_slice(
281        tensor: QuantizedTensor<Self>,
282        slices: &[burn_tensor::Slice],
283    ) -> QuantizedTensor<Self> {
284        NdArrayQTensor {
285            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::slice(
286                qtensor, slices
287            )),
288            scheme: tensor.scheme,
289            qparams: tensor.qparams,
290        }
291    }
292
293    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
294        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmax::<I>(
295            qtensor, dim
296        ))
297    }
298
299    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
300        execute_with_numeric_dtype!(tensor.qtensor, |qtensor| NdArrayMathOps::argmin::<I>(
301            qtensor, dim
302        ))
303    }
304
305    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
306        NdArrayQTensor {
307            qtensor: execute_with_dtype!(tensor.qtensor, |qtensor| NdArrayOps::expand(
308                qtensor, shape
309            )),
310            scheme: tensor.scheme,
311            qparams: tensor.qparams,
312        }
313    }
314}