1use std::ops::Range;
2
3use burn_tensor::{
4 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5 quantization::{
6 QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType,
7 QuantizedBytes,
8 },
9 DType, Shape, TensorData, TensorMetadata,
10};
11
12use crate::{LibTorch, LibTorchDevice, QuantElement, TchElement, TchQTensor, TchShape, TchTensor};
13
14use super::TchOps;
15
16fn quantize<E: TchElement, Q: QuantElement>(
17 tensor: tch::Tensor,
18 scheme: &QuantizationScheme,
19 qparams: &QParams<E, Q>,
20) -> tch::Tensor {
21 let mut tensor = tensor;
22 if tensor.kind() == tch::Kind::Half {
24 tensor = tensor.to_kind(tch::Kind::Float);
25 }
26
27 match scheme {
28 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => tensor.quantize_per_tensor(
29 qparams.scale.elem(),
30 qparams.offset.unwrap().elem(),
31 tch::Kind::QInt8,
32 ),
33 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
34 tensor.quantize_per_tensor(qparams.scale.elem(), 0, tch::Kind::QInt8)
35 }
36 }
37}
38
39impl<E: TchElement, Q: QuantElement> QTensorOps<Self> for LibTorch<E, Q> {
40 fn q_from_data(data: TensorData, device: &LibTorchDevice) -> QuantizedTensor<Self> {
41 let shape_tch = TchShape::from(data.shape.as_slice());
42 let device = (*device).into();
43
44 match data.dtype {
49 DType::QFloat(scheme) => {
50 let num_elements = data.num_elements();
51 let q_bytes = QuantizedBytes {
52 bytes: data.into_bytes(),
53 scheme,
54 num_elements,
55 };
56
57 let (values, qparams) = q_bytes.dequantize();
58 let tensor = tch::Tensor::from_slice(&values).to(device);
59 let tensor = quantize(tensor.reshape(shape_tch.dims), &scheme, &qparams);
60
61 TchQTensor {
62 qtensor: TchTensor::new(tensor),
63 scheme,
64 }
65 }
66 _ => panic!(
67 "Invalid dtype (expected DType::QFloat, got {:?})",
68 data.dtype
69 ),
70 }
71 }
72
73 fn quantize(
74 tensor: FloatTensor<Self>,
75 scheme: &QuantizationScheme,
76 qparams: QuantizationParametersPrimitive<Self>,
77 ) -> QuantizedTensor<Self> {
78 let mut tensor = tensor;
79 if E::dtype() == DType::F16 {
81 tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
82 }
83
84 let qtensor = match scheme {
85 QuantizationScheme::PerTensorAffine(dtype) => match dtype {
86 QuantizationType::QInt8 => tensor.tensor.quantize_per_tensor_tensor_qparams(
87 &qparams.scale.tensor,
88 &qparams.offset.unwrap().tensor,
89 tch::Kind::QInt8,
90 ),
91 },
92 QuantizationScheme::PerTensorSymmetric(_) => {
93 tensor.tensor.quantize_per_tensor_tensor_qparams(
94 &qparams.scale.tensor,
95 &tch::Tensor::zeros_like(&qparams.scale.tensor),
96 tch::Kind::QInt8,
97 )
98 }
99 };
100
101 TchQTensor {
102 qtensor: TchTensor::new(qtensor),
103 scheme: *scheme,
104 }
105 }
106
107 fn quantize_dynamic(
108 tensor: FloatTensor<Self>,
109 scheme: &QuantizationScheme,
110 ) -> QuantizedTensor<Self> {
111 let qtensor = match &scheme {
112 QuantizationScheme::PerTensorAffine(dtype) => match dtype {
113 QuantizationType::QInt8 => tensor
117 .tensor
118 .quantize_per_tensor_dynamic(tch::Kind::QInt8, false),
119 },
120 QuantizationScheme::PerTensorSymmetric(dtype) => {
121 log::warn!("LibTorch backend does not support symmetric per-tensor scheme for dynamic quantization, reverting to the default per-tensor affine quantization");
122 match dtype {
123 QuantizationType::QInt8 => tensor
124 .tensor
125 .quantize_per_tensor_dynamic(tch::Kind::QInt8, false),
126 }
127 }
128 };
129
130 TchQTensor {
131 qtensor: TchTensor::new(qtensor),
132 scheme: *scheme,
133 }
134 }
135
136 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
137 TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND))
138 }
139
140 fn q_device(tensor: &QuantizedTensor<Self>) -> LibTorchDevice {
141 tensor.qtensor.tensor.device().into()
142 }
143
144 fn q_to_device(
145 tensor: QuantizedTensor<Self>,
146 device: &burn_tensor::Device<Self>,
147 ) -> QuantizedTensor<Self> {
148 let mut tensor = tensor;
149 tensor.qtensor = TchOps::to_device(tensor.qtensor, device);
150 tensor
151 }
152
153 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
154 TchQTensor {
155 qtensor: TchOps::reshape(tensor.qtensor, shape),
156 scheme: tensor.scheme,
157 }
158 }
159
160 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
161 let shape = tensor.shape();
162 let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
163 let strategy = tensor.strategy();
164
165 let values: Result<Vec<i8>, tch::TchError> = tensor.qtensor.tensor.int_repr().try_into();
167
168 TensorData::quantized(values.unwrap(), shape, strategy)
169 }
170
171 fn q_swap_dims(
172 tensor: QuantizedTensor<Self>,
173 dim1: usize,
174 dim2: usize,
175 ) -> QuantizedTensor<Self> {
176 let mut tensor = tensor;
178 tensor.qtensor = TchOps::swap_dims(tensor.qtensor, dim1, dim2);
179 tensor
180 }
181
182 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
183 let mut tensor = tensor;
185 tensor.qtensor = TchOps::permute(tensor.qtensor, axes);
186 tensor
187 }
188
189 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
190 let mut tensor = tensor;
191 tensor.qtensor = TchOps::flip(tensor.qtensor, axes);
192 tensor
193 }
194
195 fn q_select(
196 tensor: QuantizedTensor<Self>,
197 dim: usize,
198 indices: IntTensor<Self>,
199 ) -> QuantizedTensor<Self> {
200 let mut tensor = tensor;
201 tensor.qtensor = TchOps::index_select_dim(tensor.qtensor, dim, indices);
202 tensor
203 }
204
205 fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
206 let mut tensor = tensor;
207 tensor.qtensor = TchOps::slice(tensor.qtensor, ranges);
208 tensor
209 }
210
211 fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
212 TchOps::argmax(TchTensor::new(tensor.qtensor.tensor.int_repr()), dim)
213 }
214
215 fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
216 TchOps::argmin(TchTensor::new(tensor.qtensor.tensor.int_repr()), dim)
217 }
218
219 fn q_max_dim_with_indices(
220 tensor: QuantizedTensor<Self>,
221 dim: usize,
222 ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
223 let (qtensor, indices) = TchOps::max_dim_with_indices(tensor.qtensor, dim);
224 let values = TchQTensor {
225 qtensor,
226 scheme: tensor.scheme,
227 };
228 (values, indices)
229 }
230
231 fn q_max_dim(tensor: QuantizedTensor<Self>, dim: usize) -> QuantizedTensor<Self> {
232 TchQTensor {
233 qtensor: TchOps::max_dim(tensor.qtensor, dim),
234 scheme: tensor.scheme,
235 }
236 }
237
238 fn q_min_dim(tensor: QuantizedTensor<Self>, dim: usize) -> QuantizedTensor<Self> {
239 TchQTensor {
240 qtensor: TchOps::min_dim(tensor.qtensor, dim),
241 scheme: tensor.scheme,
242 }
243 }
244
245 fn q_min_dim_with_indices(
246 tensor: QuantizedTensor<Self>,
247 dim: usize,
248 ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
249 let (qtensor, indices) = TchOps::min_dim_with_indices(tensor.qtensor, dim);
250 let values = TchQTensor {
251 qtensor,
252 scheme: tensor.scheme,
253 };
254 (values, indices)
255 }
256
257 fn q_narrow(
258 tensor: QuantizedTensor<Self>,
259 dim: usize,
260 start: usize,
261 length: usize,
262 ) -> QuantizedTensor<Self> {
263 TchQTensor {
264 qtensor: TchOps::narrow(tensor.qtensor, dim, start, length),
265 scheme: tensor.scheme,
266 }
267 }
268
269 fn q_chunk(
270 tensor: QuantizedTensor<Self>,
271 chunks: usize,
272 dim: usize,
273 ) -> Vec<QuantizedTensor<Self>> {
274 TchOps::chunk(tensor.qtensor, chunks, dim)
275 .into_iter()
276 .map(|x| TchQTensor {
277 qtensor: x,
278 scheme: tensor.scheme,
279 })
280 .collect()
281 }
282
283 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
284 TchQTensor {
286 qtensor: TchOps::expand(tensor.qtensor, shape),
287 scheme: tensor.scheme,
288 }
289 }
290
291 fn q_sort(
292 tensor: QuantizedTensor<Self>,
293 dim: usize,
294 descending: bool,
295 ) -> QuantizedTensor<Self> {
296 TchQTensor {
297 qtensor: TchOps::sort(tensor.qtensor, dim, descending),
298 scheme: tensor.scheme,
299 }
300 }
301
302 fn q_sort_with_indices(
303 tensor: QuantizedTensor<Self>,
304 dim: usize,
305 descending: bool,
306 ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
307 let (qtensor, indices) = TchOps::sort_with_indices(tensor.qtensor, dim, descending);
308 let tensor = TchQTensor {
309 qtensor,
310 scheme: tensor.scheme,
311 };
312 (tensor, indices)
313 }
314
315 fn q_argsort(tensor: QuantizedTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
316 TchOps::argsort(tensor.qtensor, dim, descending)
317 }
318}