1use alloc::vec;
2use core::ops::Range;
3
4use burn_tensor::{
5 DType, Shape, TensorData, TensorMetadata,
6 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7 quantization::{
8 QParams, QuantInputType, QuantLevel, QuantMode, QuantScheme,
9 QuantizationParametersPrimitive, QuantizationStrategy, QuantizedBytes,
10 SymmetricQuantization,
11 },
12};
13
14use crate::{
15 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat,
16 element::{IntNdArrayElement, NdArrayElement, QuantElement},
17 new_tensor_float,
18};
19
20use super::{NdArrayMathOps, NdArrayOps};
21
22fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
23 let shape = tensor.shape();
24 let values = tensor.array.into_iter().collect();
25 TensorData::new(values, shape)
26}
27
28fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
29 match tensor {
30 NdArrayTensorFloat::F32(tensor) => into_data(tensor),
31 NdArrayTensorFloat::F64(tensor) => into_data(tensor),
32 }
33}
34
35impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
36 for NdArray<E, I, Q>
37{
38 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
39 match data.dtype {
40 DType::QFloat(scheme) => {
41 let shape = data.shape.clone();
42 let num_elements = data.num_elements();
43 let q_bytes = QuantizedBytes {
44 bytes: data.into_bytes(),
45 scheme,
46 num_elements,
47 };
48
49 match scheme {
50 QuantScheme {
51 level: QuantLevel::Tensor,
52 mode: QuantMode::Symmetric,
53 q_type: QuantInputType::QInt8,
54 ..
55 } => {
56 let (values, qparams) = q_bytes.into_vec_i8();
58 let data = TensorData::new(values, shape);
59
60 let qparams = qparams
61 .scale
62 .into_iter()
63 .map(|scale| QParams {
64 scale,
65 offset: None,
66 })
67 .collect();
68
69 NdArrayQTensor {
70 qtensor: NdArrayTensor::<Q>::from_data(data),
71 scheme,
72 qparams,
73 }
74 }
75 }
76 }
77 _ => panic!(
78 "Invalid dtype (expected DType::QFloat, got {:?})",
79 data.dtype
80 ),
81 }
82 }
83
84 fn quantize(
85 tensor: FloatTensor<Self>,
86 scheme: &QuantScheme,
87 qparams: QuantizationParametersPrimitive<Self>,
88 ) -> QuantizedTensor<Self> {
89 let (strategy, qparams) = match scheme {
91 QuantScheme {
92 level: QuantLevel::Tensor,
93 mode: QuantMode::Symmetric,
94 q_type: QuantInputType::QInt8,
95 ..
96 } => {
97 let scale = into_data_f(qparams.scale).iter().next().unwrap();
98 (
99 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
100 scale,
101 )),
102 vec![QParams {
103 scale,
104 offset: None,
105 }],
106 )
107 }
108 };
109
110 let shape = tensor.shape();
111 let data = into_data_f(tensor).with_quantization(strategy);
112 let num_elements = data.num_elements();
113 let q_bytes = QuantizedBytes {
114 bytes: data.into_bytes(),
115 scheme: *scheme,
116 num_elements,
117 };
118 let (values, _) = q_bytes.into_vec_i8();
119 let data = TensorData::new(values, shape).convert::<Q>();
120
121 NdArrayQTensor {
122 qtensor: NdArrayTensor::<Q>::from_data(data),
123 scheme: *scheme,
124 qparams,
125 }
126 }
127
128 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
129 let shape = tensor.qtensor.shape();
130 let strategy = tensor.strategy();
131 let values = tensor.qtensor.array.into_iter().collect();
132 let data = TensorData::quantized(values, shape, strategy);
133 new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
134 }
135
136 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
137 NdArrayDevice::Cpu
138 }
139
140 fn q_to_device(
141 tensor: QuantizedTensor<Self>,
142 _device: &NdArrayDevice,
143 ) -> QuantizedTensor<Self> {
144 tensor
145 }
146
147 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
148 NdArrayQTensor {
149 qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
150 scheme: tensor.scheme,
151 qparams: tensor.qparams,
152 }
153 }
154
155 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
156 let strategy = tensor.strategy();
157 let shape = tensor.qtensor.shape();
158 let values = tensor.qtensor.array.into_iter().collect();
159 TensorData::quantized(values, shape, strategy)
160 }
161
162 fn q_swap_dims(
163 tensor: QuantizedTensor<Self>,
164 dim1: usize,
165 dim2: usize,
166 ) -> QuantizedTensor<Self> {
167 NdArrayQTensor {
168 qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
169 scheme: tensor.scheme,
170 qparams: tensor.qparams,
171 }
172 }
173
174 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
175 NdArrayQTensor {
176 qtensor: NdArrayOps::permute(tensor.qtensor, axes),
177 scheme: tensor.scheme,
178 qparams: tensor.qparams,
179 }
180 }
181
182 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
183 NdArrayQTensor {
184 qtensor: NdArrayOps::flip(tensor.qtensor, axes),
185 scheme: tensor.scheme,
186 qparams: tensor.qparams,
187 }
188 }
189
190 fn q_gather(
191 dim: usize,
192 tensor: QuantizedTensor<Self>,
193 indices: IntTensor<Self>,
194 ) -> QuantizedTensor<Self> {
195 NdArrayQTensor {
196 qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
197 scheme: tensor.scheme,
198 qparams: tensor.qparams,
199 }
200 }
201
202 fn q_select(
203 tensor: QuantizedTensor<Self>,
204 dim: usize,
205 indices: IntTensor<Self>,
206 ) -> QuantizedTensor<Self> {
207 NdArrayQTensor {
208 qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
209 scheme: tensor.scheme,
210 qparams: tensor.qparams,
211 }
212 }
213
214 fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
215 NdArrayQTensor {
216 qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
217 scheme: tensor.scheme,
218 qparams: tensor.qparams,
219 }
220 }
221
222 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
223 NdArrayMathOps::argmax(tensor.qtensor, dim)
224 }
225
226 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
227 NdArrayMathOps::argmin(tensor.qtensor, dim)
228 }
229
230 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
231 NdArrayQTensor {
232 qtensor: NdArrayOps::expand(tensor.qtensor, shape),
233 scheme: tensor.scheme,
234 qparams: tensor.qparams,
235 }
236 }
237}