1use alloc::vec;
2use core::ops::Range;
3
4use burn_tensor::{
5 DType, Shape, TensorData, TensorMetadata,
6 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7 quantization::{
8 QParams, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme,
9 QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization,
10 },
11};
12
13use crate::{
14 FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat,
15 element::{IntNdArrayElement, NdArrayElement, QuantElement},
16 new_tensor_float,
17};
18
19use super::{NdArrayMathOps, NdArrayOps};
20
21fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
22 let shape = tensor.shape();
23 let values = tensor.array.into_iter().collect();
24 TensorData::new(values, shape)
25}
26
27fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
28 match tensor {
29 NdArrayTensorFloat::F32(tensor) => into_data(tensor),
30 NdArrayTensorFloat::F64(tensor) => into_data(tensor),
31 }
32}
33
34impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
35 for NdArray<E, I, Q>
36{
37 fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
38 match data.dtype {
39 DType::QFloat(scheme) => {
40 let shape = data.shape.clone();
41 let num_elements = data.num_elements();
42 let q_bytes = QuantizedBytes {
43 bytes: data.into_bytes(),
44 scheme,
45 num_elements,
46 };
47
48 match scheme {
49 QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => {
50 let (values, qparams) = q_bytes.into_vec_i8();
52 let data = TensorData::new(values, shape);
53
54 let qparams = match mode {
55 QuantizationMode::Symmetric => qparams
56 .scale
57 .into_iter()
58 .map(|scale| QParams {
59 scale,
60 offset: None,
61 })
62 .collect(),
63 };
64
65 NdArrayQTensor {
66 qtensor: NdArrayTensor::<Q>::from_data(data),
67 scheme,
68 qparams,
69 }
70 }
71 }
72 }
73 _ => panic!(
74 "Invalid dtype (expected DType::QFloat, got {:?})",
75 data.dtype
76 ),
77 }
78 }
79
80 fn quantize(
81 tensor: FloatTensor<Self>,
82 scheme: &QuantizationScheme,
83 qparams: QuantizationParametersPrimitive<Self>,
84 ) -> QuantizedTensor<Self> {
85 let (strategy, qparams) = match scheme {
87 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
88 let scale = into_data_f(qparams.scale).iter().next().unwrap();
89 (
90 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
91 scale,
92 )),
93 vec![QParams {
94 scale,
95 offset: None,
96 }],
97 )
98 }
99 };
100
101 let shape = tensor.shape();
102 let data = into_data_f(tensor).with_quantization(strategy);
103 let num_elements = data.num_elements();
104 let q_bytes = QuantizedBytes {
105 bytes: data.into_bytes(),
106 scheme: *scheme,
107 num_elements,
108 };
109 let (values, _) = q_bytes.into_vec_i8();
110 let data = TensorData::new(values, shape).convert::<Q>();
111
112 NdArrayQTensor {
113 qtensor: NdArrayTensor::<Q>::from_data(data),
114 scheme: *scheme,
115 qparams,
116 }
117 }
118
119 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
120 let shape = tensor.qtensor.shape();
121 let strategy = tensor.strategy();
122 let values = tensor.qtensor.array.into_iter().collect();
123 let data = TensorData::quantized(values, shape, strategy);
124 new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
125 }
126
127 fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
128 NdArrayDevice::Cpu
129 }
130
131 fn q_to_device(
132 tensor: QuantizedTensor<Self>,
133 _device: &NdArrayDevice,
134 ) -> QuantizedTensor<Self> {
135 tensor
136 }
137
138 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
139 NdArrayQTensor {
140 qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
141 scheme: tensor.scheme,
142 qparams: tensor.qparams,
143 }
144 }
145
146 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
147 let strategy = tensor.strategy();
148 let shape = tensor.qtensor.shape();
149 let values = tensor.qtensor.array.into_iter().collect();
150 TensorData::quantized(values, shape, strategy)
151 }
152
153 fn q_swap_dims(
154 tensor: QuantizedTensor<Self>,
155 dim1: usize,
156 dim2: usize,
157 ) -> QuantizedTensor<Self> {
158 NdArrayQTensor {
159 qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
160 scheme: tensor.scheme,
161 qparams: tensor.qparams,
162 }
163 }
164
165 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
166 NdArrayQTensor {
167 qtensor: NdArrayOps::permute(tensor.qtensor, axes),
168 scheme: tensor.scheme,
169 qparams: tensor.qparams,
170 }
171 }
172
173 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
174 NdArrayQTensor {
175 qtensor: NdArrayOps::flip(tensor.qtensor, axes),
176 scheme: tensor.scheme,
177 qparams: tensor.qparams,
178 }
179 }
180
181 fn q_gather(
182 dim: usize,
183 tensor: QuantizedTensor<Self>,
184 indices: IntTensor<Self>,
185 ) -> QuantizedTensor<Self> {
186 NdArrayQTensor {
187 qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
188 scheme: tensor.scheme,
189 qparams: tensor.qparams,
190 }
191 }
192
193 fn q_select(
194 tensor: QuantizedTensor<Self>,
195 dim: usize,
196 indices: IntTensor<Self>,
197 ) -> QuantizedTensor<Self> {
198 NdArrayQTensor {
199 qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
200 scheme: tensor.scheme,
201 qparams: tensor.qparams,
202 }
203 }
204
205 fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
206 NdArrayQTensor {
207 qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
208 scheme: tensor.scheme,
209 qparams: tensor.qparams,
210 }
211 }
212
213 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
214 NdArrayMathOps::argmax(tensor.qtensor, dim)
215 }
216
217 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
218 NdArrayMathOps::argmin(tensor.qtensor, dim)
219 }
220
221 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
222 NdArrayQTensor {
223 qtensor: NdArrayOps::expand(tensor.qtensor, shape),
224 scheme: tensor.scheme,
225 qparams: tensor.qparams,
226 }
227 }
228}