burn_ndarray/ops/
qtensor.rs

1use alloc::vec;
2use core::ops::Range;
3
4use burn_tensor::{
5    DType, Shape, TensorData, TensorMetadata,
6    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7    quantization::{
8        QParams, QuantInputType, QuantLevel, QuantMode, QuantScheme,
9        QuantizationParametersPrimitive, QuantizationStrategy, QuantizedBytes,
10        SymmetricQuantization,
11    },
12};
13
14use crate::{
15    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, NdArrayTensorFloat,
16    element::{IntNdArrayElement, NdArrayElement, QuantElement},
17    new_tensor_float,
18};
19
20use super::{NdArrayMathOps, NdArrayOps};
21
22fn into_data<E: NdArrayElement>(tensor: NdArrayTensor<E>) -> TensorData {
23    let shape = tensor.shape();
24    let values = tensor.array.into_iter().collect();
25    TensorData::new(values, shape)
26}
27
28fn into_data_f(tensor: NdArrayTensorFloat) -> TensorData {
29    match tensor {
30        NdArrayTensorFloat::F32(tensor) => into_data(tensor),
31        NdArrayTensorFloat::F64(tensor) => into_data(tensor),
32    }
33}
34
35impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>
36    for NdArray<E, I, Q>
37{
38    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {
39        match data.dtype {
40            DType::QFloat(scheme) => {
41                let shape = data.shape.clone();
42                let num_elements = data.num_elements();
43                let q_bytes = QuantizedBytes {
44                    bytes: data.into_bytes(),
45                    scheme,
46                    num_elements,
47                };
48
49                match scheme {
50                    QuantScheme {
51                        level: QuantLevel::Tensor,
52                        mode: QuantMode::Symmetric,
53                        q_type: QuantInputType::QInt8,
54                        ..
55                    } => {
56                        // We should probably check that `Q` matches i8.. but it's the only valid type now
57                        let (values, qparams) = q_bytes.into_vec_i8();
58                        let data = TensorData::new(values, shape);
59
60                        let qparams = qparams
61                            .scale
62                            .into_iter()
63                            .map(|scale| QParams {
64                                scale,
65                                offset: None,
66                            })
67                            .collect();
68
69                        NdArrayQTensor {
70                            qtensor: NdArrayTensor::<Q>::from_data(data),
71                            scheme,
72                            qparams,
73                        }
74                    }
75                }
76            }
77            _ => panic!(
78                "Invalid dtype (expected DType::QFloat, got {:?})",
79                data.dtype
80            ),
81        }
82    }
83
84    fn quantize(
85        tensor: FloatTensor<Self>,
86        scheme: &QuantScheme,
87        qparams: QuantizationParametersPrimitive<Self>,
88    ) -> QuantizedTensor<Self> {
89        // Implement with ndarray instead of QuantizationStrategy?
90        let (strategy, qparams) = match scheme {
91            QuantScheme {
92                level: QuantLevel::Tensor,
93                mode: QuantMode::Symmetric,
94                q_type: QuantInputType::QInt8,
95                ..
96            } => {
97                let scale = into_data_f(qparams.scale).iter().next().unwrap();
98                (
99                    QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
100                        scale,
101                    )),
102                    vec![QParams {
103                        scale,
104                        offset: None,
105                    }],
106                )
107            }
108        };
109
110        let shape = tensor.shape();
111        let data = into_data_f(tensor).with_quantization(strategy);
112        let num_elements = data.num_elements();
113        let q_bytes = QuantizedBytes {
114            bytes: data.into_bytes(),
115            scheme: *scheme,
116            num_elements,
117        };
118        let (values, _) = q_bytes.into_vec_i8();
119        let data = TensorData::new(values, shape).convert::<Q>();
120
121        NdArrayQTensor {
122            qtensor: NdArrayTensor::<Q>::from_data(data),
123            scheme: *scheme,
124            qparams,
125        }
126    }
127
128    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
129        let shape = tensor.qtensor.shape();
130        let strategy = tensor.strategy();
131        let values = tensor.qtensor.array.into_iter().collect();
132        let data = TensorData::quantized(values, shape, strategy);
133        new_tensor_float!(NdArrayTensor::from_data(data.dequantize().unwrap()))
134    }
135
136    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {
137        NdArrayDevice::Cpu
138    }
139
140    fn q_to_device(
141        tensor: QuantizedTensor<Self>,
142        _device: &NdArrayDevice,
143    ) -> QuantizedTensor<Self> {
144        tensor
145    }
146
147    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
148        NdArrayQTensor {
149            qtensor: NdArrayOps::reshape(tensor.qtensor, shape),
150            scheme: tensor.scheme,
151            qparams: tensor.qparams,
152        }
153    }
154
155    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
156        let strategy = tensor.strategy();
157        let shape = tensor.qtensor.shape();
158        let values = tensor.qtensor.array.into_iter().collect();
159        TensorData::quantized(values, shape, strategy)
160    }
161
162    fn q_swap_dims(
163        tensor: QuantizedTensor<Self>,
164        dim1: usize,
165        dim2: usize,
166    ) -> QuantizedTensor<Self> {
167        NdArrayQTensor {
168            qtensor: NdArrayOps::swap_dims(tensor.qtensor, dim1, dim2),
169            scheme: tensor.scheme,
170            qparams: tensor.qparams,
171        }
172    }
173
174    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
175        NdArrayQTensor {
176            qtensor: NdArrayOps::permute(tensor.qtensor, axes),
177            scheme: tensor.scheme,
178            qparams: tensor.qparams,
179        }
180    }
181
182    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
183        NdArrayQTensor {
184            qtensor: NdArrayOps::flip(tensor.qtensor, axes),
185            scheme: tensor.scheme,
186            qparams: tensor.qparams,
187        }
188    }
189
190    fn q_gather(
191        dim: usize,
192        tensor: QuantizedTensor<Self>,
193        indices: IntTensor<Self>,
194    ) -> QuantizedTensor<Self> {
195        NdArrayQTensor {
196            qtensor: NdArrayMathOps::gather(dim, tensor.qtensor, indices),
197            scheme: tensor.scheme,
198            qparams: tensor.qparams,
199        }
200    }
201
202    fn q_select(
203        tensor: QuantizedTensor<Self>,
204        dim: usize,
205        indices: IntTensor<Self>,
206    ) -> QuantizedTensor<Self> {
207        NdArrayQTensor {
208            qtensor: NdArrayMathOps::select(tensor.qtensor, dim, indices),
209            scheme: tensor.scheme,
210            qparams: tensor.qparams,
211        }
212    }
213
214    fn q_slice(tensor: QuantizedTensor<Self>, ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
215        NdArrayQTensor {
216            qtensor: NdArrayOps::slice(tensor.qtensor, ranges),
217            scheme: tensor.scheme,
218            qparams: tensor.qparams,
219        }
220    }
221
222    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
223        NdArrayMathOps::argmax(tensor.qtensor, dim)
224    }
225
226    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {
227        NdArrayMathOps::argmin(tensor.qtensor, dim)
228    }
229
230    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
231        NdArrayQTensor {
232            qtensor: NdArrayOps::expand(tensor.qtensor, shape),
233            scheme: tensor.scheme,
234            qparams: tensor.qparams,
235        }
236    }
237}