1use burn_backend::{
2 ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive,
3 ops::QTensorOps,
4 quantization::QuantizationParametersPrimitive,
5 tensor::{FloatTensor, IntTensor, QuantizedTensor},
6};
7use burn_std::{QuantPropagation, Shape, Slice};
8
9use crate::backends::*;
10use crate::{Dispatch, DispatchDevice};
11
12impl QTensorOps<Self> for Dispatch {
13 fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
14 creation_op!(Quantized, device, |device| B::q_from_data(data, device))
15 }
16
17 fn quantize(
18 tensor: FloatTensor<Self>,
19 scheme: &burn_std::QuantScheme,
20 qparams: QuantizationParametersPrimitive<Self>,
21 ) -> QuantizedTensor<Self> {
22 binary_op!(
23 (tensor, float),
24 (qparams.scales, float),
25 |tensor, scales| {
26 B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
27 } => Quantized
28 )
29 }
30
31 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
32 unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float)
33 }
34
35 fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
36 tensor.device()
37 }
38
39 fn q_to_device(
40 tensor: QuantizedTensor<Self>,
41 device: &DispatchDevice,
42 ) -> QuantizedTensor<Self> {
43 to_device!(
44 Quantized,
45 quantized,
46 tensor,
47 device,
48 q_to_device,
49 |inner, device| {
50 let data =
51 burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
52 B2::q_from_data(data, device)
53 }
54 )
55 }
56
57 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
58 unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
59 }
60
61 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
62 unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
63 }
64
65 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
66 unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
67 }
68
69 fn q_swap_dims(
70 tensor: QuantizedTensor<Self>,
71 dim1: usize,
72 dim2: usize,
73 ) -> QuantizedTensor<Self> {
74 unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
75 }
76
77 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
78 unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
79 }
80
81 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
82 unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
83 }
84
85 fn q_select(
86 tensor: QuantizedTensor<Self>,
87 dim: usize,
88 indices: IntTensor<Self>,
89 ) -> QuantizedTensor<Self> {
90 binary_op!(
91 (tensor, quantized),
92 (indices, int),
93 |tensor, indices| B::q_select(tensor, dim, indices) => Quantized
94 )
95 }
96
97 fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
98 unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
99 }
100
101 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
102 match (lhs, rhs) {
104 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
105 if matches!(lhs.propagation(), QuantPropagation::Propagate) {
106 let out = binary_op!(
107 (lhs, quantized),
108 (rhs, quantized),
109 |lhs, rhs| {
110 if let TensorPrimitive::QFloat(out) = B::q_matmul(
111 TensorPrimitive::QFloat(lhs),
112 TensorPrimitive::QFloat(rhs),
113 ) {
114 out
115 } else {
116 unreachable!()
117 }
118 } => Quantized
119 );
120 TensorPrimitive::QFloat(out)
121 } else {
122 let out = binary_op!(
123 (lhs, quantized),
124 (rhs, quantized),
125 |lhs, rhs| {
126 if let TensorPrimitive::Float(out) = B::q_matmul(
127 TensorPrimitive::QFloat(lhs),
128 TensorPrimitive::QFloat(rhs),
129 ) {
130 out
131 } else {
132 unreachable!()
133 }
134 } => Float
135 );
136 TensorPrimitive::Float(out)
137 }
138 }
139 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
140 if matches!(rhs.propagation(), QuantPropagation::Propagate) {
141 let out = binary_op!(
142 (lhs, float),
143 (rhs, quantized),
144 |lhs, rhs| {
145 if let TensorPrimitive::QFloat(out) = B::q_matmul(
146 TensorPrimitive::Float(lhs),
147 TensorPrimitive::QFloat(rhs),
148 ) {
149 out
150 } else {
151 unreachable!()
152 }
153 } => Quantized
154 );
155 TensorPrimitive::QFloat(out)
156 } else {
157 let out = binary_op!(
158 (lhs, float),
159 (rhs, quantized),
160 |lhs, rhs| {
161 if let TensorPrimitive::Float(out) = B::q_matmul(
162 TensorPrimitive::Float(lhs),
163 TensorPrimitive::QFloat(rhs),
164 ) {
165 out
166 } else {
167 unreachable!()
168 }
169 } => Float
170 );
171 TensorPrimitive::Float(out)
172 }
173 }
174 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
175 if matches!(lhs.propagation(), QuantPropagation::Propagate) {
176 let out = binary_op!(
177 (lhs, quantized),
178 (rhs, float),
179 |lhs, rhs| {
180 if let TensorPrimitive::QFloat(out) = B::q_matmul(
181 TensorPrimitive::QFloat(lhs),
182 TensorPrimitive::Float(rhs),
183 ) {
184 out
185 } else {
186 unreachable!()
187 }
188 } => Quantized
189 );
190 TensorPrimitive::QFloat(out)
191 } else {
192 let out = binary_op!(
193 (lhs, quantized),
194 (rhs, float),
195 |lhs, rhs| {
196 if let TensorPrimitive::Float(out) = B::q_matmul(
197 TensorPrimitive::QFloat(lhs),
198 TensorPrimitive::Float(rhs),
199 ) {
200 out
201 } else {
202 unreachable!()
203 }
204 } => Float
205 );
206 TensorPrimitive::Float(out)
207 }
208 }
209 _ => unreachable!(),
210 }
211 }
212}