burn_ndarray/ops/
qtensor.rs

1use alloc::vec;
2use core::ops::Range;
3
4use burn_tensor::{
5    DType, Shape, TensorData, TensorMetadata,
6    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7    quantization::{
8        QParams, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme,
9        QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization,
10    },
11};
12
13use crate::{
14    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat,
15    element::{IntNdArrayElement, NdArrayElement, QuantElement},
16    new_tensor_float,
17};
18
19use super::{NdArrayMathOps, NdArrayOps};
20
21fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
22    let shape = tensor.shape();
23    let values = tensor.array.into_iter().collect();
24    TensorData::new(values, shape)
25}
26
27fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
28    match tensor {
29        NdArrayTensorFloat::F32(tensor) => into_data(tensor),
30        NdArrayTensorFloat::F64(tensor) => into_data(tensor),
31    }
32}
33
34impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
35    for NdArray<E, I, Q>
36{
37    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
38        match data.dtype {
39            DType::QFloat(scheme) => {
40                let shape = data.shape.clone();
41                let num_elements = data.num_elements();
42                let q_bytes = QuantizedBytes {
43                    bytes: data.into_bytes(),
44                    scheme,
45                    num_elements,
46                };
47
48                match scheme {
49                    QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => {
50                        // We should probably check that `Q` matches i8.. but it's the only valid type now
51                        let (values, qparams) = q_bytes.into_vec_i8();
52                        let data = TensorData::new(values, shape);
53
54                        let qparams = match mode {
55                            QuantizationMode::Symmetric => qparams
56                                .scale
57                                .into_iter()
58                                .map(|scale| QParams {
59                                    scale,
60                                    offset: None,
61                                })
62                                .collect(),
63                        };
64
65                        NdArrayQTensor {
66                            qtensor: NdArrayTensor::<Q>::from_data(data),
67                            scheme,
68                            qparams,
69                        }
70                    }
71                }
72            }
73            _ => panic!(
74                "Invalid dtype (expected DType::QFloat, got {:?})",
75                data.dtype
76            ),
77        }
78    }
79
80    fn quantize(
81        tensor: FloatTensor<Self>,
82        scheme: &QuantizationScheme,
83        qparams: QuantizationParametersPrimitive<Self>,
84    ) -> QuantizedTensor<Self> {
85        // Implement with ndarray instead of QuantizationStrategy?
86        let (strategy, qparams) = match scheme {
87            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
88                let scale = into_data_f(qparams.scale).iter().next().unwrap();
89                (
90                    QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
91                        scale,
92                    )),
93                    vec![QParams {
94                        scale,
95                        offset: None,
96                    }],
97                )
98            }
99        };
100
101        let shape = tensor.shape();
102        let data = into_data_f(tensor).with_quantization(strategy);
103        let num_elements = data.num_elements();
104        let q_bytes = QuantizedBytes {
105            bytes: data.into_bytes(),
106            scheme: *scheme,
107            num_elements,
108        };
109        let (values, _) = q_bytes.into_vec_i8();
110        let data = TensorData::new(values, shape).convert::<Q>();
111
112        NdArrayQTensor {
113            qtensor: NdArrayTensor::<Q>::from_data(data),
114            scheme: *scheme,
115            qparams,
116        }
117    }
118
119    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
120        let shape = tensor.qtensor.shape();
121        let strategy = tensor.strategy();
122        let values = tensor.qtensor.array.into_iter().collect();
123        let data = TensorData::quantized(values, shape, strategy);
124        new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
125    }
126
127    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
128        NdArrayDevice::Cpu
129    }
130
131    fn q_to_device(
132        tensor: QuantizedTensor<Self>,
133        _device: &NdArrayDevice,
134    ) -> QuantizedTensor<Self> {
135        tensor
136    }
137
138    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
139        NdArrayQTensor {
140            qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
141            scheme: tensor.scheme,
142            qparams: tensor.qparams,
143        }
144    }
145
146    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
147        let strategy = tensor.strategy();
148        let shape = tensor.qtensor.shape();
149        let values = tensor.qtensor.array.into_iter().collect();
150        TensorData::quantized(values, shape, strategy)
151    }
152
153    fn q_swap_dims(
154        tensor: QuantizedTensor<Self>,
155        dim1: usize,
156        dim2: usize,
157    ) -> QuantizedTensor<Self> {
158        NdArrayQTensor {
159            qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
160            scheme: tensor.scheme,
161            qparams: tensor.qparams,
162        }
163    }
164
165    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
166        NdArrayQTensor {
167            qtensor: NdArrayOps::permute(tensor.qtensor, axes),
168            scheme: tensor.scheme,
169            qparams: tensor.qparams,
170        }
171    }
172
173    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
174        NdArrayQTensor {
175            qtensor: NdArrayOps::flip(tensor.qtensor, axes),
176            scheme: tensor.scheme,
177            qparams: tensor.qparams,
178        }
179    }
180
181    fn q_gather(
182        dim: usize,
183        tensor: QuantizedTensor<Self>,
184        indices: IntTensor<Self>,
185    ) -> QuantizedTensor<Self> {
186        NdArrayQTensor {
187            qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
188            scheme: tensor.scheme,
189            qparams: tensor.qparams,
190        }
191    }
192
193    fn q_select(
194        tensor: QuantizedTensor<Self>,
195        dim: usize,
196        indices: IntTensor<Self>,
197    ) -> QuantizedTensor<Self> {
198        NdArrayQTensor {
199            qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
200            scheme: tensor.scheme,
201            qparams: tensor.qparams,
202        }
203    }
204
205    fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
206        NdArrayQTensor {
207            qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
208            scheme: tensor.scheme,
209            qparams: tensor.qparams,
210        }
211    }
212
213    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
214        NdArrayMathOps::argmax(tensor.qtensor, dim)
215    }
216
217    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
218        NdArrayMathOps::argmin(tensor.qtensor, dim)
219    }
220
221    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
222        NdArrayQTensor {
223            qtensor: NdArrayOps::expand(tensor.qtensor, shape),
224            scheme: tensor.scheme,
225            qparams: tensor.qparams,
226        }
227    }
228}