burn_ndarray/ops/
qtensor.rs

1use core::ops::Range;
2
3use burn_tensor::{
4    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5    quantization::{
6        AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme,
7        QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization,
8    },
9    DType, ElementConversion, Shape, TensorData, TensorMetadata,
10};
11
12use crate::{
13    element::{IntNdArrayElement, NdArrayElement, QuantElement},
14    new_tensor_float, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
15    NdArrayTensorFloat,
16};
17
18use super::{NdArrayMathOps, NdArrayOps};
19
20fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
21    let shape = tensor.shape();
22    let values = tensor.array.into_iter().collect();
23    TensorData::new(values, shape)
24}
25
26fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
27    match tensor {
28        NdArrayTensorFloat::F32(tensor) => into_data(tensor),
29        NdArrayTensorFloat::F64(tensor) => into_data(tensor),
30    }
31}
32
33impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
34    for NdArray<E, I, Q>
35{
36    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
37        match data.dtype {
38            DType::QFloat(scheme) => {
39                let shape = data.shape.clone();
40                let num_elements = data.num_elements();
41                let q_bytes = QuantizedBytes {
42                    bytes: data.into_bytes(),
43                    scheme,
44                    num_elements,
45                };
46
47                match scheme {
48                    QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
49                    | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
50                        let (values, qparams) = q_bytes.into_vec_i8();
51
52                        let data = TensorData::new(values, shape).convert::<Q>();
53                        let qparams = QParams {
54                            scale: qparams.scale,
55                            offset: qparams.offset.map(|x| x.elem::<Q>()),
56                        };
57
58                        NdArrayQTensor {
59                            qtensor: NdArrayTensor::<Q>::from_data(data),
60                            scheme,
61                            qparams,
62                        }
63                    }
64                }
65            }
66            _ => panic!(
67                "Invalid dtype (expected DType::QFloat, got {:?})",
68                data.dtype
69            ),
70        }
71    }
72
73    fn quantize(
74        tensor: FloatTensor<Self>,
75        scheme: &QuantizationScheme,
76        qparams: QuantizationParametersPrimitive<Self>,
77    ) -> QuantizedTensor<Self> {
78        let (strategy, qparams) = match scheme {
79            QuantizationScheme::PerTensorAffine(dtype) => match dtype {
80                QuantizationType::QInt8 => {
81                    let scale = into_data_f(qparams.scale).iter().next().unwrap();
82                    let offset = into_data(qparams.offset.unwrap())
83                        .iter::<Q>()
84                        .next()
85                        .unwrap();
86                    (
87                        QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
88                            scale,
89                            offset.elem(),
90                        )),
91                        QParams {
92                            scale,
93                            offset: Some(offset),
94                        },
95                    )
96                }
97            },
98            QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
99                QuantizationType::QInt8 => {
100                    let scale = into_data_f(qparams.scale).iter().next().unwrap();
101                    (
102                        QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
103                            scale,
104                        )),
105                        QParams {
106                            scale,
107                            offset: None,
108                        },
109                    )
110                }
111            },
112        };
113
114        let shape = tensor.shape();
115        let data = into_data_f(tensor).with_quantization(strategy);
116        let num_elements = data.num_elements();
117        let q_bytes = QuantizedBytes {
118            bytes: data.into_bytes(),
119            scheme: *scheme,
120            num_elements,
121        };
122        let (values, _) = q_bytes.into_vec_i8();
123        let data = TensorData::new(values, shape).convert::<Q>();
124
125        NdArrayQTensor {
126            qtensor: NdArrayTensor::<Q>::from_data(data),
127            scheme: *scheme,
128            qparams,
129        }
130    }
131
132    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
133        let shape = tensor.qtensor.shape();
134        let strategy = tensor.strategy();
135        let values = tensor.qtensor.array.into_iter().collect();
136        let data = TensorData::quantized(values, shape, strategy);
137        new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
138    }
139
140    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
141        NdArrayDevice::Cpu
142    }
143
144    fn q_to_device(
145        tensor: QuantizedTensor<Self>,
146        _device: &NdArrayDevice,
147    ) -> QuantizedTensor<Self> {
148        tensor
149    }
150
151    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
152        NdArrayQTensor {
153            qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
154            scheme: tensor.scheme,
155            qparams: tensor.qparams,
156        }
157    }
158
159    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
160        let strategy = tensor.strategy();
161        let shape = tensor.qtensor.shape();
162        let values = tensor.qtensor.array.into_iter().collect();
163        TensorData::quantized(values, shape, strategy)
164    }
165
166    fn q_swap_dims(
167        tensor: QuantizedTensor<Self>,
168        dim1: usize,
169        dim2: usize,
170    ) -> QuantizedTensor<Self> {
171        NdArrayQTensor {
172            qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
173            scheme: tensor.scheme,
174            qparams: tensor.qparams,
175        }
176    }
177
178    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
179        NdArrayQTensor {
180            qtensor: NdArrayOps::permute(tensor.qtensor, axes),
181            scheme: tensor.scheme,
182            qparams: tensor.qparams,
183        }
184    }
185
186    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
187        NdArrayQTensor {
188            qtensor: NdArrayOps::flip(tensor.qtensor, axes),
189            scheme: tensor.scheme,
190            qparams: tensor.qparams,
191        }
192    }
193
194    fn q_gather(
195        dim: usize,
196        tensor: QuantizedTensor<Self>,
197        indices: IntTensor<Self>,
198    ) -> QuantizedTensor<Self> {
199        NdArrayQTensor {
200            qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
201            scheme: tensor.scheme,
202            qparams: tensor.qparams,
203        }
204    }
205
206    fn q_select(
207        tensor: QuantizedTensor<Self>,
208        dim: usize,
209        indices: IntTensor<Self>,
210    ) -> QuantizedTensor<Self> {
211        NdArrayQTensor {
212            qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
213            scheme: tensor.scheme,
214            qparams: tensor.qparams,
215        }
216    }
217
218    fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
219        NdArrayQTensor {
220            qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
221            scheme: tensor.scheme,
222            qparams: tensor.qparams,
223        }
224    }
225
226    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
227        NdArrayMathOps::argmax(tensor.qtensor, dim)
228    }
229
230    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
231        NdArrayMathOps::argmin(tensor.qtensor, dim)
232    }
233
234    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
235        NdArrayQTensor {
236            qtensor: NdArrayOps::expand(tensor.qtensor, shape),
237            scheme: tensor.scheme,
238            qparams: tensor.qparams,
239        }
240    }
241}