Skip to main content

burn_dispatch/ops/
qtensor.rs

1use burn_backend::{
2    ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive,
3    ops::QTensorOps,
4    quantization::QuantizationParametersPrimitive,
5    tensor::{FloatTensor, IntTensor, QuantizedTensor},
6};
7use burn_std::{QuantPropagation, Shape, Slice};
8
9use crate::backends::*;
10use crate::{Dispatch, DispatchDevice};
11
12impl QTensorOps<Self> for Dispatch {
13    fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
14        creation_op!(Quantized, device, |device| B::q_from_data(data, device))
15    }
16
17    fn quantize(
18        tensor: FloatTensor<Self>,
19        scheme: &burn_std::QuantScheme,
20        qparams: QuantizationParametersPrimitive<Self>,
21    ) -> QuantizedTensor<Self> {
22        binary_op!(
23            (tensor, float),
24            (qparams.scales, float),
25            |tensor, scales| {
26                B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
27            } => Quantized
28        )
29    }
30
31    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
32        unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float)
33    }
34
35    fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
36        tensor.device()
37    }
38
39    fn q_to_device(
40        tensor: QuantizedTensor<Self>,
41        device: &DispatchDevice,
42    ) -> QuantizedTensor<Self> {
43        to_device!(
44            Quantized,
45            quantized,
46            tensor,
47            device,
48            q_to_device,
49            |inner, device| {
50                let data =
51                    burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
52                B2::q_from_data(data, device)
53            }
54        )
55    }
56
57    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
58        unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
59    }
60
61    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
62        unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
63    }
64
65    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
66        unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
67    }
68
69    fn q_swap_dims(
70        tensor: QuantizedTensor<Self>,
71        dim1: usize,
72        dim2: usize,
73    ) -> QuantizedTensor<Self> {
74        unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
75    }
76
77    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
78        unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
79    }
80
81    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
82        unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
83    }
84
85    fn q_select(
86        tensor: QuantizedTensor<Self>,
87        dim: usize,
88        indices: IntTensor<Self>,
89    ) -> QuantizedTensor<Self> {
90        binary_op!(
91            (tensor, quantized),
92            (indices, int),
93            |tensor, indices| B::q_select(tensor, dim, indices) => Quantized
94        )
95    }
96
97    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
98        unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
99    }
100
101    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
102        // TODO: this would be much cleaner if we consolidated tensor primitive types
103        match (lhs, rhs) {
104            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
105                if matches!(lhs.propagation(), QuantPropagation::Propagate) {
106                    let out = binary_op!(
107                        (lhs, quantized),
108                        (rhs, quantized),
109                        |lhs, rhs| {
110                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
111                                TensorPrimitive::QFloat(lhs),
112                                TensorPrimitive::QFloat(rhs),
113                            ) {
114                                out
115                            } else {
116                                unreachable!()
117                            }
118                        } => Quantized
119                    );
120                    TensorPrimitive::QFloat(out)
121                } else {
122                    let out = binary_op!(
123                        (lhs, quantized),
124                        (rhs, quantized),
125                        |lhs, rhs| {
126                            if let TensorPrimitive::Float(out) = B::q_matmul(
127                                TensorPrimitive::QFloat(lhs),
128                                TensorPrimitive::QFloat(rhs),
129                            ) {
130                                out
131                            } else {
132                                unreachable!()
133                            }
134                        } => Float
135                    );
136                    TensorPrimitive::Float(out)
137                }
138            }
139            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
140                if matches!(rhs.propagation(), QuantPropagation::Propagate) {
141                    let out = binary_op!(
142                        (lhs, float),
143                        (rhs, quantized),
144                        |lhs, rhs| {
145                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
146                                TensorPrimitive::Float(lhs),
147                                TensorPrimitive::QFloat(rhs),
148                            ) {
149                                out
150                            } else {
151                                unreachable!()
152                            }
153                        } => Quantized
154                    );
155                    TensorPrimitive::QFloat(out)
156                } else {
157                    let out = binary_op!(
158                        (lhs, float),
159                        (rhs, quantized),
160                        |lhs, rhs| {
161                            if let TensorPrimitive::Float(out) = B::q_matmul(
162                                TensorPrimitive::Float(lhs),
163                                TensorPrimitive::QFloat(rhs),
164                            ) {
165                                out
166                            } else {
167                                unreachable!()
168                            }
169                        } => Float
170                    );
171                    TensorPrimitive::Float(out)
172                }
173            }
174            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
175                if matches!(lhs.propagation(), QuantPropagation::Propagate) {
176                    let out = binary_op!(
177                        (lhs, quantized),
178                        (rhs, float),
179                        |lhs, rhs| {
180                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
181                                TensorPrimitive::QFloat(lhs),
182                                TensorPrimitive::Float(rhs),
183                            ) {
184                                out
185                            } else {
186                                unreachable!()
187                            }
188                        } => Quantized
189                    );
190                    TensorPrimitive::QFloat(out)
191                } else {
192                    let out = binary_op!(
193                        (lhs, quantized),
194                        (rhs, float),
195                        |lhs, rhs| {
196                            if let TensorPrimitive::Float(out) = B::q_matmul(
197                                TensorPrimitive::QFloat(lhs),
198                                TensorPrimitive::Float(rhs),
199                            ) {
200                                out
201                            } else {
202                                unreachable!()
203                            }
204                        } => Float
205                    );
206                    TensorPrimitive::Float(out)
207                }
208            }
209            _ => unreachable!(),
210        }
211    }
212}