Skip to main content

burn_dispatch/ops/
qtensor.rs

1use burn_backend::{
2    ExecutionError, FloatDType, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
3    ops::QTensorOps,
4    quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
5    tensor::{FloatTensor, IntTensor, QuantizedTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl QTensorOps<Self> for Dispatch {
12    fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
13        creation_op!(Quantized, device, |device| B::q_from_data(data, device))
14    }
15
16    fn quantize(
17        tensor: FloatTensor<Self>,
18        scheme: &QuantScheme,
19        qparams: QuantizationParametersPrimitive<Self>,
20    ) -> QuantizedTensor<Self> {
21        binary_op!(
22            (tensor, float),
23            (qparams.scales, float),
24            |tensor, scales| {
25                B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
26            } => Quantized
27        )
28    }
29
30    fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
31        unary_op!(tensor, quantized, |tensor| B::dequantize(tensor, dtype) => Float)
32    }
33
34    fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
35        tensor.device()
36    }
37
38    fn q_to_device(
39        tensor: QuantizedTensor<Self>,
40        device: &DispatchDevice,
41    ) -> QuantizedTensor<Self> {
42        to_device!(
43            Quantized,
44            quantized,
45            tensor,
46            device,
47            q_to_device,
48            |inner, device| {
49                let data =
50                    burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
51                B2::q_from_data(data, device)
52            }
53        )
54    }
55
56    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
57        unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
58    }
59
60    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
61        unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
62    }
63
64    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
65        unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
66    }
67
68    fn q_swap_dims(
69        tensor: QuantizedTensor<Self>,
70        dim1: usize,
71        dim2: usize,
72    ) -> QuantizedTensor<Self> {
73        unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
74    }
75
76    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
77        unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
78    }
79
80    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
81        unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
82    }
83
84    fn q_select(
85        tensor: QuantizedTensor<Self>,
86        dim: usize,
87        indices: IntTensor<Self>,
88    ) -> QuantizedTensor<Self> {
89        binary_op!(
90            (tensor, quantized),
91            (indices, int),
92            |tensor, indices| B::q_select(tensor, dim, indices) => Quantized
93        )
94    }
95
96    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
97        unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
98    }
99
100    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
101        // TODO: this would be much cleaner if we consolidated tensor primitive types
102        match (lhs, rhs) {
103            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
104                if matches!(lhs.propagation(), QuantPropagation::Propagate) {
105                    let out = binary_op!(
106                        (lhs, quantized),
107                        (rhs, quantized),
108                        |lhs, rhs| {
109                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
110                                TensorPrimitive::QFloat(lhs),
111                                TensorPrimitive::QFloat(rhs),
112                            ) {
113                                out
114                            } else {
115                                unreachable!()
116                            }
117                        } => Quantized
118                    );
119                    TensorPrimitive::QFloat(out)
120                } else {
121                    let out = binary_op!(
122                        (lhs, quantized),
123                        (rhs, quantized),
124                        |lhs, rhs| {
125                            if let TensorPrimitive::Float(out) = B::q_matmul(
126                                TensorPrimitive::QFloat(lhs),
127                                TensorPrimitive::QFloat(rhs),
128                            ) {
129                                out
130                            } else {
131                                unreachable!()
132                            }
133                        } => Float
134                    );
135                    TensorPrimitive::Float(out)
136                }
137            }
138            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
139                if matches!(rhs.propagation(), QuantPropagation::Propagate) {
140                    let out = binary_op!(
141                        (lhs, float),
142                        (rhs, quantized),
143                        |lhs, rhs| {
144                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
145                                TensorPrimitive::Float(lhs),
146                                TensorPrimitive::QFloat(rhs),
147                            ) {
148                                out
149                            } else {
150                                unreachable!()
151                            }
152                        } => Quantized
153                    );
154                    TensorPrimitive::QFloat(out)
155                } else {
156                    let out = binary_op!(
157                        (lhs, float),
158                        (rhs, quantized),
159                        |lhs, rhs| {
160                            if let TensorPrimitive::Float(out) = B::q_matmul(
161                                TensorPrimitive::Float(lhs),
162                                TensorPrimitive::QFloat(rhs),
163                            ) {
164                                out
165                            } else {
166                                unreachable!()
167                            }
168                        } => Float
169                    );
170                    TensorPrimitive::Float(out)
171                }
172            }
173            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
174                if matches!(lhs.propagation(), QuantPropagation::Propagate) {
175                    let out = binary_op!(
176                        (lhs, quantized),
177                        (rhs, float),
178                        |lhs, rhs| {
179                            if let TensorPrimitive::QFloat(out) = B::q_matmul(
180                                TensorPrimitive::QFloat(lhs),
181                                TensorPrimitive::Float(rhs),
182                            ) {
183                                out
184                            } else {
185                                unreachable!()
186                            }
187                        } => Quantized
188                    );
189                    TensorPrimitive::QFloat(out)
190                } else {
191                    let out = binary_op!(
192                        (lhs, quantized),
193                        (rhs, float),
194                        |lhs, rhs| {
195                            if let TensorPrimitive::Float(out) = B::q_matmul(
196                                TensorPrimitive::QFloat(lhs),
197                                TensorPrimitive::Float(rhs),
198                            ) {
199                                out
200                            } else {
201                                unreachable!()
202                            }
203                        } => Float
204                    );
205                    TensorPrimitive::Float(out)
206                }
207            }
208            _ => unreachable!(),
209        }
210    }
211}