burn_tch/ops/
qtensor.rs

1use std::ops::Range;
2
3use burn_tensor::{
4    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5    quantization::{
6        QParams, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType,
7        QuantizedBytes,
8    },
9    DType, Shape, TensorData, TensorMetadata,
10};
11
12use crate::{LibTorch, LibTorchDevice, QuantElement, TchElement, TchQTensor, TchShape, TchTensor};
13
14use super::TchOps;
15
16fn quantize<E: TchElement, Q: QuantElement>(
17    tensor: tch::Tensor,
18    scheme: &QuantizationScheme,
19    qparams: &QParams<E, Q>,
20) -> tch::Tensor {
21    let mut tensor = tensor;
22    // Quantize only works on Float Tensor
23    if tensor.kind() == tch::Kind::Half {
24        tensor = tensor.to_kind(tch::Kind::Float);
25    }
26
27    match scheme {
28        QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => tensor.quantize_per_tensor(
29            qparams.scale.elem(),
30            qparams.offset.unwrap().elem(),
31            tch::Kind::QInt8,
32        ),
33        QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
34            tensor.quantize_per_tensor(qparams.scale.elem(), 0, tch::Kind::QInt8)
35        }
36    }
37}
38
39impl<E: TchElement, Q: QuantElement> QTensorOps<Self> for LibTorch<E, Q> {
40    fn q_from_data(data: TensorData, device: &LibTorchDevice) -> QuantizedTensor<Self> {
41        let shape_tch = TchShape::from(data.shape.as_slice());
42        let device = (*device).into();
43
44        // NOTE: tch-rs doesn't have `from_blob_quantized_*` APIs
45        // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/quantized/Quantizer.cpp#L322
46        // So for now we have to load the dequantized values to quantize them back since the dequantization
47        // methods take the values provided when quantizing.
48        match data.dtype {
49            DType::QFloat(scheme) => {
50                let num_elements = data.num_elements();
51                let q_bytes = QuantizedBytes {
52                    bytes: data.into_bytes(),
53                    scheme,
54                    num_elements,
55                };
56
57                let (values, qparams) = q_bytes.dequantize();
58                let tensor = tch::Tensor::from_slice(&values).to(device);
59                let tensor = quantize(tensor.reshape(shape_tch.dims), &scheme, &qparams);
60
61                TchQTensor {
62                    qtensor: TchTensor::new(tensor),
63                    scheme,
64                }
65            }
66            _ => panic!(
67                "Invalid dtype (expected DType::QFloat, got {:?})",
68                data.dtype
69            ),
70        }
71    }
72
73    fn quantize(
74        tensor: FloatTensor<Self>,
75        scheme: &QuantizationScheme,
76        qparams: QuantizationParametersPrimitive<Self>,
77    ) -> QuantizedTensor<Self> {
78        let mut tensor = tensor;
79        // Quantize only works on Float Tensor
80        if E::dtype() == DType::F16 {
81            tensor.tensor = tensor.tensor.to_kind(tch::Kind::Float);
82        }
83
84        let qtensor = match scheme {
85            QuantizationScheme::PerTensorAffine(dtype) => match dtype {
86                QuantizationType::QInt8 => tensor.tensor.quantize_per_tensor_tensor_qparams(
87                    &qparams.scale.tensor,
88                    &qparams.offset.unwrap().tensor,
89                    tch::Kind::QInt8,
90                ),
91            },
92            QuantizationScheme::PerTensorSymmetric(_) => {
93                tensor.tensor.quantize_per_tensor_tensor_qparams(
94                    &qparams.scale.tensor,
95                    &tch::Tensor::zeros_like(&qparams.scale.tensor),
96                    tch::Kind::QInt8,
97                )
98            }
99        };
100
101        TchQTensor {
102            qtensor: TchTensor::new(qtensor),
103            scheme: *scheme,
104        }
105    }
106
107    fn quantize_dynamic(
108        tensor: FloatTensor<Self>,
109        scheme: &QuantizationScheme,
110    ) -> QuantizedTensor<Self> {
111        let qtensor = match &scheme {
112            QuantizationScheme::PerTensorAffine(dtype) => match dtype {
113                // Notes on `reduce_range`:
114                // https://github.com/pytorch/pytorch/issues/93140
115                // https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
116                QuantizationType::QInt8 => tensor
117                    .tensor
118                    .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false),
119            },
120            QuantizationScheme::PerTensorSymmetric(dtype) => {
121                log::warn!("LibTorch backend does not support symmetric per-tensor scheme for dynamic quantization, reverting to the default per-tensor affine quantization");
122                match dtype {
123                    QuantizationType::QInt8 => tensor
124                        .tensor
125                        .quantize_per_tensor_dynamic(tch::Kind::QInt8, /*reduce_range*/ false),
126                }
127            }
128        };
129
130        TchQTensor {
131            qtensor: TchTensor::new(qtensor),
132            scheme: *scheme,
133        }
134    }
135
136    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
137        TchTensor::new(tensor.qtensor.tensor.dequantize().to_kind(E::KIND))
138    }
139
140    fn q_device(tensor: &QuantizedTensor<Self>) -> LibTorchDevice {
141        tensor.qtensor.tensor.device().into()
142    }
143
144    fn q_to_device(
145        tensor: QuantizedTensor<Self>,
146        device: &burn_tensor::Device<Self>,
147    ) -> QuantizedTensor<Self> {
148        let mut tensor = tensor;
149        tensor.qtensor = TchOps::to_device(tensor.qtensor, device);
150        tensor
151    }
152
153    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
154        TchQTensor {
155            qtensor: TchOps::reshape(tensor.qtensor, shape),
156            scheme: tensor.scheme,
157        }
158    }
159
160    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
161        let shape = tensor.shape();
162        let tensor = Self::q_reshape(tensor.clone(), Shape::new([shape.num_elements()]));
163        let strategy = tensor.strategy();
164
165        // To get the integer values we have to call `int_repr()`
166        let values: Result<Vec<i8>, tch::TchError> = tensor.qtensor.tensor.int_repr().try_into();
167
168        TensorData::quantized(values.unwrap(), shape, strategy)
169    }
170
171    fn q_swap_dims(
172        tensor: QuantizedTensor<Self>,
173        dim1: usize,
174        dim2: usize,
175    ) -> QuantizedTensor<Self> {
176        // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op
177        let mut tensor = tensor;
178        tensor.qtensor = TchOps::swap_dims(tensor.qtensor, dim1, dim2);
179        tensor
180    }
181
182    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
183        // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op
184        let mut tensor = tensor;
185        tensor.qtensor = TchOps::permute(tensor.qtensor, axes);
186        tensor
187    }
188
189    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
190        let mut tensor = tensor;
191        tensor.qtensor = TchOps::flip(tensor.qtensor, axes);
192        tensor
193    }
194
195    fn q_select(
196        tensor: QuantizedTensor<Self>,
197        dim: usize,
198        indices: IntTensor<Self>,
199    ) -> QuantizedTensor<Self> {
200        let mut tensor = tensor;
201        tensor.qtensor = TchOps::index_select_dim(tensor.qtensor, dim, indices);
202        tensor
203    }
204
205    fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
206        let mut tensor = tensor;
207        tensor.qtensor = TchOps::slice(tensor.qtensor, ranges);
208        tensor
209    }
210
211    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
212        TchOps::argmax(TchTensor::new(tensor.qtensor.tensor.int_repr()), dim)
213    }
214
215    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
216        TchOps::argmin(TchTensor::new(tensor.qtensor.tensor.int_repr()), dim)
217    }
218
219    fn q_max_dim_with_indices(
220        tensor: QuantizedTensor<Self>,
221        dim: usize,
222    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
223        let (qtensor, indices) = TchOps::max_dim_with_indices(tensor.qtensor, dim);
224        let values = TchQTensor {
225            qtensor,
226            scheme: tensor.scheme,
227        };
228        (values, indices)
229    }
230
231    fn q_max_dim(tensor: QuantizedTensor<Self>, dim: usize) -> QuantizedTensor<Self> {
232        TchQTensor {
233            qtensor: TchOps::max_dim(tensor.qtensor, dim),
234            scheme: tensor.scheme,
235        }
236    }
237
238    fn q_min_dim(tensor: QuantizedTensor<Self>, dim: usize) -> QuantizedTensor<Self> {
239        TchQTensor {
240            qtensor: TchOps::min_dim(tensor.qtensor, dim),
241            scheme: tensor.scheme,
242        }
243    }
244
245    fn q_min_dim_with_indices(
246        tensor: QuantizedTensor<Self>,
247        dim: usize,
248    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
249        let (qtensor, indices) = TchOps::min_dim_with_indices(tensor.qtensor, dim);
250        let values = TchQTensor {
251            qtensor,
252            scheme: tensor.scheme,
253        };
254        (values, indices)
255    }
256
257    fn q_narrow(
258        tensor: QuantizedTensor<Self>,
259        dim: usize,
260        start: usize,
261        length: usize,
262    ) -> QuantizedTensor<Self> {
263        TchQTensor {
264            qtensor: TchOps::narrow(tensor.qtensor, dim, start, length),
265            scheme: tensor.scheme,
266        }
267    }
268
269    fn q_chunk(
270        tensor: QuantizedTensor<Self>,
271        chunks: usize,
272        dim: usize,
273    ) -> Vec<QuantizedTensor<Self>> {
274        TchOps::chunk(tensor.qtensor, chunks, dim)
275            .into_iter()
276            .map(|x| TchQTensor {
277                qtensor: x,
278                scheme: tensor.scheme,
279            })
280            .collect()
281    }
282
283    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
284        // NOTE: with per-channel quantization (future), the channel axis could be impacted by this op
285        TchQTensor {
286            qtensor: TchOps::expand(tensor.qtensor, shape),
287            scheme: tensor.scheme,
288        }
289    }
290
291    fn q_sort(
292        tensor: QuantizedTensor<Self>,
293        dim: usize,
294        descending: bool,
295    ) -> QuantizedTensor<Self> {
296        TchQTensor {
297            qtensor: TchOps::sort(tensor.qtensor, dim, descending),
298            scheme: tensor.scheme,
299        }
300    }
301
302    fn q_sort_with_indices(
303        tensor: QuantizedTensor<Self>,
304        dim: usize,
305        descending: bool,
306    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {
307        let (qtensor, indices) = TchOps::sort_with_indices(tensor.qtensor, dim, descending);
308        let tensor = TchQTensor {
309            qtensor,
310            scheme: tensor.scheme,
311        };
312        (tensor, indices)
313    }
314
315    fn q_argsort(tensor: QuantizedTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
316        TchOps::argsort(tensor.qtensor, dim, descending)
317    }
318}