1use burn_backend::{
2 ExecutionError, FloatDType, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
3 ops::QTensorOps,
4 quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
5 tensor::{FloatTensor, IntTensor, QuantizedTensor},
6};
7
8use crate::backends::*;
9use crate::{Dispatch, DispatchDevice};
10
11impl QTensorOps<Self> for Dispatch {
12 fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {
13 creation_op!(Quantized, device, |device| B::q_from_data(data, device))
14 }
15
16 fn quantize(
17 tensor: FloatTensor<Self>,
18 scheme: &QuantScheme,
19 qparams: QuantizationParametersPrimitive<Self>,
20 ) -> QuantizedTensor<Self> {
21 binary_op!(
22 (tensor, float),
23 (qparams.scales, float),
24 |tensor, scales| {
25 B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })
26 } => Quantized
27 )
28 }
29
30 fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
31 unary_op!(tensor, quantized, |tensor| B::dequantize(tensor, dtype) => Float)
32 }
33
34 fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {
35 tensor.device()
36 }
37
38 fn q_to_device(
39 tensor: QuantizedTensor<Self>,
40 device: &DispatchDevice,
41 ) -> QuantizedTensor<Self> {
42 to_device!(
43 Quantized,
44 quantized,
45 tensor,
46 device,
47 q_to_device,
48 |inner, device| {
49 let data =
50 burn_backend::read_sync(B1::q_into_data(inner)).expect("Should read data");
51 B2::q_from_data(data, device)
52 }
53 )
54 }
55
56 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
57 unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)
58 }
59
60 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
61 unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)
62 }
63
64 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
65 unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)
66 }
67
68 fn q_swap_dims(
69 tensor: QuantizedTensor<Self>,
70 dim1: usize,
71 dim2: usize,
72 ) -> QuantizedTensor<Self> {
73 unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)
74 }
75
76 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
77 unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)
78 }
79
80 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
81 unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)
82 }
83
84 fn q_select(
85 tensor: QuantizedTensor<Self>,
86 dim: usize,
87 indices: IntTensor<Self>,
88 ) -> QuantizedTensor<Self> {
89 binary_op!(
90 (tensor, quantized),
91 (indices, int),
92 |tensor, indices| B::q_select(tensor, dim, indices) => Quantized
93 )
94 }
95
96 fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
97 unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)
98 }
99
100 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
101 match (lhs, rhs) {
103 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
104 if matches!(lhs.propagation(), QuantPropagation::Propagate) {
105 let out = binary_op!(
106 (lhs, quantized),
107 (rhs, quantized),
108 |lhs, rhs| {
109 if let TensorPrimitive::QFloat(out) = B::q_matmul(
110 TensorPrimitive::QFloat(lhs),
111 TensorPrimitive::QFloat(rhs),
112 ) {
113 out
114 } else {
115 unreachable!()
116 }
117 } => Quantized
118 );
119 TensorPrimitive::QFloat(out)
120 } else {
121 let out = binary_op!(
122 (lhs, quantized),
123 (rhs, quantized),
124 |lhs, rhs| {
125 if let TensorPrimitive::Float(out) = B::q_matmul(
126 TensorPrimitive::QFloat(lhs),
127 TensorPrimitive::QFloat(rhs),
128 ) {
129 out
130 } else {
131 unreachable!()
132 }
133 } => Float
134 );
135 TensorPrimitive::Float(out)
136 }
137 }
138 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
139 if matches!(rhs.propagation(), QuantPropagation::Propagate) {
140 let out = binary_op!(
141 (lhs, float),
142 (rhs, quantized),
143 |lhs, rhs| {
144 if let TensorPrimitive::QFloat(out) = B::q_matmul(
145 TensorPrimitive::Float(lhs),
146 TensorPrimitive::QFloat(rhs),
147 ) {
148 out
149 } else {
150 unreachable!()
151 }
152 } => Quantized
153 );
154 TensorPrimitive::QFloat(out)
155 } else {
156 let out = binary_op!(
157 (lhs, float),
158 (rhs, quantized),
159 |lhs, rhs| {
160 if let TensorPrimitive::Float(out) = B::q_matmul(
161 TensorPrimitive::Float(lhs),
162 TensorPrimitive::QFloat(rhs),
163 ) {
164 out
165 } else {
166 unreachable!()
167 }
168 } => Float
169 );
170 TensorPrimitive::Float(out)
171 }
172 }
173 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
174 if matches!(lhs.propagation(), QuantPropagation::Propagate) {
175 let out = binary_op!(
176 (lhs, quantized),
177 (rhs, float),
178 |lhs, rhs| {
179 if let TensorPrimitive::QFloat(out) = B::q_matmul(
180 TensorPrimitive::QFloat(lhs),
181 TensorPrimitive::Float(rhs),
182 ) {
183 out
184 } else {
185 unreachable!()
186 }
187 } => Quantized
188 );
189 TensorPrimitive::QFloat(out)
190 } else {
191 let out = binary_op!(
192 (lhs, quantized),
193 (rhs, float),
194 |lhs, rhs| {
195 if let TensorPrimitive::Float(out) = B::q_matmul(
196 TensorPrimitive::QFloat(lhs),
197 TensorPrimitive::Float(rhs),
198 ) {
199 out
200 } else {
201 unreachable!()
202 }
203 } => Float
204 );
205 TensorPrimitive::Float(out)
206 }
207 }
208 _ => unreachable!(),
209 }
210 }
211}