1use core::ops::Range;
2
3use burn_tensor::{
4 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5 quantization::{
6 AffineQuantization, QParams, QuantizationParametersPrimitive, QuantizationScheme,
7 QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization,
8 },
9 DType, ElementConversion, Shape, TensorData, TensorMetadata,
10};
11
12use crate::{
13 element::{IntNdArrayElement, NdArrayElement, QuantElement},
14 new_tensor_float, FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor,
15 NdArrayTensorFloat,
16};
17
18use super::{NdArrayMathOps, NdArrayOps};
19
20fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
21 let shape = tensor.shape();
22 let values = tensor.array.into_iter().collect();
23 TensorData::new(values, shape)
24}
25
26fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
27 match tensor {
28 NdArrayTensorFloat::F32(tensor) => into_data(tensor),
29 NdArrayTensorFloat::F64(tensor) => into_data(tensor),
30 }
31}
32
33impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
34 for NdArray<E, I, Q>
35{
36 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
37 match data.dtype {
38 DType::QFloat(scheme) => {
39 let shape = data.shape.clone();
40 let num_elements = data.num_elements();
41 let q_bytes = QuantizedBytes {
42 bytes: data.into_bytes(),
43 scheme,
44 num_elements,
45 };
46
47 match scheme {
48 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
49 | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
50 let (values, qparams) = q_bytes.into_vec_i8();
51
52 let data = TensorData::new(values, shape).convert::<Q>();
53 let qparams = QParams {
54 scale: qparams.scale,
55 offset: qparams.offset.map(|x| x.elem::<Q>()),
56 };
57
58 NdArrayQTensor {
59 qtensor: NdArrayTensor::<Q>::from_data(data),
60 scheme,
61 qparams,
62 }
63 }
64 }
65 }
66 _ => panic!(
67 "Invalid dtype (expected DType::QFloat, got {:?})",
68 data.dtype
69 ),
70 }
71 }
72
73 fn quantize(
74 tensor: FloatTensor<Self>,
75 scheme: &QuantizationScheme,
76 qparams: QuantizationParametersPrimitive<Self>,
77 ) -> QuantizedTensor<Self> {
78 let (strategy, qparams) = match scheme {
79 QuantizationScheme::PerTensorAffine(dtype) => match dtype {
80 QuantizationType::QInt8 => {
81 let scale = into_data_f(qparams.scale).iter().next().unwrap();
82 let offset = into_data(qparams.offset.unwrap())
83 .iter::<Q>()
84 .next()
85 .unwrap();
86 (
87 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
88 scale,
89 offset.elem(),
90 )),
91 QParams {
92 scale,
93 offset: Some(offset),
94 },
95 )
96 }
97 },
98 QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
99 QuantizationType::QInt8 => {
100 let scale = into_data_f(qparams.scale).iter().next().unwrap();
101 (
102 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
103 scale,
104 )),
105 QParams {
106 scale,
107 offset: None,
108 },
109 )
110 }
111 },
112 };
113
114 let shape = tensor.shape();
115 let data = into_data_f(tensor).with_quantization(strategy);
116 let num_elements = data.num_elements();
117 let q_bytes = QuantizedBytes {
118 bytes: data.into_bytes(),
119 scheme: *scheme,
120 num_elements,
121 };
122 let (values, _) = q_bytes.into_vec_i8();
123 let data = TensorData::new(values, shape).convert::<Q>();
124
125 NdArrayQTensor {
126 qtensor: NdArrayTensor::<Q>::from_data(data),
127 scheme: *scheme,
128 qparams,
129 }
130 }
131
132 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
133 let shape = tensor.qtensor.shape();
134 let strategy = tensor.strategy();
135 let values = tensor.qtensor.array.into_iter().collect();
136 let data = TensorData::quantized(values, shape, strategy);
137 new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
138 }
139
140 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
141 NdArrayDevice::Cpu
142 }
143
144 fn q_to_device(
145 tensor: QuantizedTensor<Self>,
146 _device: &NdArrayDevice,
147 ) -> QuantizedTensor<Self> {
148 tensor
149 }
150
151 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
152 NdArrayQTensor {
153 qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
154 scheme: tensor.scheme,
155 qparams: tensor.qparams,
156 }
157 }
158
159 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
160 let strategy = tensor.strategy();
161 let shape = tensor.qtensor.shape();
162 let values = tensor.qtensor.array.into_iter().collect();
163 TensorData::quantized(values, shape, strategy)
164 }
165
166 fn q_swap_dims(
167 tensor: QuantizedTensor<Self>,
168 dim1: usize,
169 dim2: usize,
170 ) -> QuantizedTensor<Self> {
171 NdArrayQTensor {
172 qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
173 scheme: tensor.scheme,
174 qparams: tensor.qparams,
175 }
176 }
177
178 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
179 NdArrayQTensor {
180 qtensor: NdArrayOps::permute(tensor.qtensor, axes),
181 scheme: tensor.scheme,
182 qparams: tensor.qparams,
183 }
184 }
185
186 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
187 NdArrayQTensor {
188 qtensor: NdArrayOps::flip(tensor.qtensor, axes),
189 scheme: tensor.scheme,
190 qparams: tensor.qparams,
191 }
192 }
193
194 fn q_gather(
195 dim: usize,
196 tensor: QuantizedTensor<Self>,
197 indices: IntTensor<Self>,
198 ) -> QuantizedTensor<Self> {
199 NdArrayQTensor {
200 qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
201 scheme: tensor.scheme,
202 qparams: tensor.qparams,
203 }
204 }
205
206 fn q_select(
207 tensor: QuantizedTensor<Self>,
208 dim: usize,
209 indices: IntTensor<Self>,
210 ) -> QuantizedTensor<Self> {
211 NdArrayQTensor {
212 qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
213 scheme: tensor.scheme,
214 qparams: tensor.qparams,
215 }
216 }
217
218 fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
219 NdArrayQTensor {
220 qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
221 scheme: tensor.scheme,
222 qparams: tensor.qparams,
223 }
224 }
225
226 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
227 NdArrayMathOps::argmax(tensor.qtensor, dim)
228 }
229
230 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
231 NdArrayMathOps::argmin(tensor.qtensor, dim)
232 }
233
234 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
235 NdArrayQTensor {
236 qtensor: NdArrayOps::expand(tensor.qtensor, shape),
237 scheme: tensor.scheme,
238 qparams: tensor.qparams,
239 }
240 }
241}