1use std::marker::PhantomData;
2
3use burn_backend::{
4 DType, Element, ExecutionError, FloatDType, QTensorPrimitive, Shape, Slice, TensorData,
5 TensorPrimitive,
6 ops::QTensorOps,
7 quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
8 tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
9};
10use burn_ir::{
11 BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer,
12 InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr,
13 QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr,
14};
15
16use crate::{
17 Fusion, FusionBackend,
18 client::GlobalFusionClient,
19 get_client,
20 stream::{OperationStreams, execution::Operation},
21};
22
23use super::NoOp;
24
25impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
26 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
27 let client = get_client::<B>(device);
28 let dtype = data.dtype;
29 let tensor = B::q_from_data(data, device);
30 let shape = burn_backend::TensorMetadata::shape(&tensor);
31
32 let handle = B::quantized_tensor_handle(tensor);
33 let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
34
35 client
36 .register(
37 OperationStreams::default(),
38 OperationIr::Init(desc),
39 NoOp::<B>::new(),
40 )
41 .output()
42 }
43
44 fn quantize(
45 tensor: FloatTensor<Self>,
46 scheme: &QuantScheme,
47 qparams: QuantizationParametersPrimitive<Self>,
48 ) -> QuantizedTensor<Self> {
49 #[derive(new, Debug)]
50 struct QuantizeOp<B: FusionBackend> {
51 desc: QuantizeOpIr,
52 _b: PhantomData<B>,
53 }
54
55 impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
56 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
58 let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
59
60 let qparams = QuantizationParametersPrimitive { scales };
61 let output = B::quantize(tensor, &self.desc.scheme, qparams);
62 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
63 }
64 }
65
66 let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]);
67
68 let client = tensor.client.clone();
69 let qparams = QuantizationParametersIr {
70 scales: qparams.scales.into_ir(),
71 };
72 let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || {
73 client.create_empty_handle()
74 });
75
76 client
77 .register(
78 streams,
79 OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())),
80 QuantizeOp::<B>::new(desc),
81 )
82 .output()
83 }
84
85 fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
86 #[derive(new, Debug)]
87 struct DequantizeOp<B: FusionBackend> {
88 desc: DequantizeOpIr,
89 _b: PhantomData<B>,
90 }
91
92 impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
93 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
94 let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
95
96 let output = B::dequantize(tensor, self.desc.out.dtype.into());
97 handles.register_float_tensor::<B>(&self.desc.out.id, output);
98 }
99 }
100
101 let streams = OperationStreams::with_inputs([&tensor]);
102
103 let client = tensor.client.clone();
104 let dtype = dtype.into();
105 let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle());
106
107 client
108 .register(
109 streams,
110 OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
111 DequantizeOp::<B>::new(desc),
112 )
113 .output()
114 }
115
116 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
117 tensor.client.device().clone()
118 }
119
120 fn q_to_device(
121 tensor: QuantizedTensor<Self>,
122 device_dst: &Device<Self>,
123 ) -> QuantizedTensor<Self> {
124 let device_src: &B::Device = tensor.client.device();
125 let device_dst: B::Device = device_dst.clone();
126
127 if device_src == &device_dst {
128 return tensor;
129 }
130
131 let id = tensor.stream;
132 let client_dst = get_client::<B>(&device_dst);
133 let client_src = tensor.client.clone();
134
135 GlobalFusionClient::change_client_quantized::<B>(
136 tensor.into_ir(),
137 client_src,
138 client_dst,
139 id,
140 )
141 }
142
143 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
144 if tensor.shape == shape {
145 return tensor;
146 }
147
148 #[derive(new, Debug)]
149 struct ReshapeDimsOps<B: FusionBackend> {
150 desc: ShapeOpIr,
151 _b: PhantomData<B>,
152 }
153
154 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
155 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
156 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
157 let output = B::q_reshape(input, self.desc.out.shape.clone());
158 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
159 }
160 }
161
162 let streams = OperationStreams::with_inputs([&tensor]);
163
164 let client = tensor.client.clone();
165 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
166
167 client
168 .register(
169 streams,
170 OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
171 ReshapeDimsOps::<B>::new(desc),
172 )
173 .output()
174 }
175
176 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
177 tensor.q_into_data::<B>().await
178 }
179
180 fn q_swap_dims(
181 tensor: QuantizedTensor<Self>,
182 dim1: usize,
183 dim2: usize,
184 ) -> QuantizedTensor<Self> {
185 #[derive(new, Debug)]
186 struct SwapDimsOps<B: FusionBackend> {
187 desc: SwapDimsOpIr,
188 _b: PhantomData<B>,
189 }
190
191 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
192 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
193 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
194 let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
195 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
196 }
197 }
198
199 let streams = OperationStreams::with_inputs([&tensor]);
200
201 let client = tensor.client.clone();
202 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
203 client.create_empty_handle()
204 });
205
206 client
207 .register(
208 streams,
209 OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
210 SwapDimsOps::<B>::new(desc),
211 )
212 .output()
213 }
214
215 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
216 #[derive(new, Debug)]
217 struct PermuteDimsOps<B: FusionBackend> {
218 desc: PermuteOpIr,
219 _b: PhantomData<B>,
220 }
221
222 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
223 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
224 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
225 let output = B::q_permute(input, self.desc.axes.as_slice());
226 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
227 }
228 }
229
230 let streams = OperationStreams::with_inputs([&tensor]);
231
232 let client = tensor.client.clone();
233 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
234 client.create_empty_handle()
235 });
236
237 client
238 .register(
239 streams,
240 OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())),
241 PermuteDimsOps::<B>::new(desc),
242 )
243 .output()
244 }
245
246 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
247 #[derive(new, Debug)]
248 struct FlipOps<B: FusionBackend> {
249 desc: FlipOpIr,
250 _b: PhantomData<B>,
251 }
252
253 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
254 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
255 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
256 let output = B::q_flip(input, &self.desc.axes);
257 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
258 }
259 }
260
261 let streams = OperationStreams::with_inputs([&tensor]);
262
263 let client = tensor.client.clone();
264 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
265 client.create_empty_handle()
266 });
267
268 client
269 .register(
270 streams,
271 OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())),
272 FlipOps::<B>::new(desc),
273 )
274 .output()
275 }
276
277 fn q_gather(
278 dim: usize,
279 tensor: QuantizedTensor<Self>,
280 indices: IntTensor<Self>,
281 ) -> QuantizedTensor<Self> {
282 #[derive(new, Debug)]
283 struct GatherOps<B: FusionBackend> {
284 desc: GatherOpIr,
285 _b: PhantomData<B>,
286 }
287
288 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
289 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
290 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
291 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
292
293 let output = B::q_gather(self.desc.dim, tensor, indices);
294 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
295 }
296 }
297
298 let streams = OperationStreams::with_inputs([&tensor]);
299
300 let client = tensor.client.clone();
301 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
302 client.create_empty_handle()
303 });
304
305 client
306 .register(
307 streams,
308 OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
309 GatherOps::<B>::new(desc),
310 )
311 .output()
312 }
313
314 fn q_select(
315 tensor: QuantizedTensor<Self>,
316 dim: usize,
317 indices: IntTensor<Self>,
318 ) -> QuantizedTensor<Self> {
319 #[derive(new, Debug)]
320 struct SelectOps<B: FusionBackend> {
321 desc: SelectOpIr,
322 _b: PhantomData<B>,
323 }
324
325 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
326 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
327 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
328 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
329
330 let output = B::q_select(tensor, self.desc.dim, indices);
331
332 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
333 }
334 }
335
336 let streams = OperationStreams::with_inputs([&tensor]);
337
338 let client = tensor.client.clone();
339 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
340 client.create_empty_handle()
341 });
342
343 client
344 .register(
345 streams,
346 OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
347 SelectOps::<B>::new(desc),
348 )
349 .output()
350 }
351
352 fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
353 #[derive(new, Debug)]
354 struct SliceOps<B: FusionBackend> {
355 desc: SliceOpIr,
356 _b: PhantomData<B>,
357 }
358
359 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
360 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
361 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
362
363 let output = B::q_slice(tensor, self.desc.ranges.as_slice());
364
365 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
366 }
367 }
368
369 let streams = OperationStreams::with_inputs([&tensor]);
370
371 let client = tensor.client.clone();
372 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
373 client.create_empty_handle()
374 });
375
376 client
377 .register(
378 streams,
379 OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
380 SliceOps::<B>::new(desc),
381 )
382 .output()
383 }
384
385 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
386 #[derive(new, Debug)]
387 struct ExpandOps<B: FusionBackend> {
388 desc: ShapeOpIr,
389 _b: PhantomData<B>,
390 }
391
392 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
393 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
394 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
395 let output = B::q_expand(input, self.desc.out.shape.clone());
396
397 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
398 }
399 }
400
401 let streams = OperationStreams::with_inputs([&tensor]);
402
403 let client = tensor.client.clone();
404 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
405
406 client
407 .register(
408 streams,
409 OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
410 ExpandOps::<B>::new(desc),
411 )
412 .output()
413 }
414
415 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
416 #[derive(new, Debug)]
417 struct MatmulOps<B: FusionBackend> {
418 desc: MatmulOpIr,
419 lhs_quantized: bool,
420 rhs_quantized: bool,
421 _b: PhantomData<B>,
422 }
423
424 impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
425 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
426 let lhs = match self.lhs_quantized {
427 true => {
428 TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
429 }
430 false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
431 };
432 let rhs = match self.rhs_quantized {
433 true => {
434 TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
435 }
436 false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
437 };
438 let output = B::q_matmul(lhs, rhs);
439 match output {
440 TensorPrimitive::Float(output) => {
441 handles.register_float_tensor::<B>(&self.desc.out.id, output);
442 }
443 TensorPrimitive::QFloat(output) => {
444 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
445 }
446 }
447 }
448 }
449
450 let mut propagation = QuantPropagation::Inhibit;
451 let mut scheme = QuantScheme::default();
452 let mut streams = OperationStreams::default();
453 let mut lhs_quantized = false;
454 let mut rhs_quantized = false;
455 match &lhs {
456 TensorPrimitive::QFloat(lhs) => {
457 propagation = lhs.propagation();
458 scheme = *lhs.scheme();
459 lhs_quantized = true;
460 streams.tensor(lhs);
461 }
462 TensorPrimitive::Float(lhs) => {
463 streams.tensor(lhs);
464 }
465 }
466 match &rhs {
467 TensorPrimitive::QFloat(rhs) => {
468 propagation = rhs.propagation();
469 scheme = *rhs.scheme();
470 rhs_quantized = true;
471 streams.tensor(rhs);
472 }
473 TensorPrimitive::Float(rhs) => {
474 streams.tensor(rhs);
475 }
476 }
477
478 let dtype = match propagation {
479 QuantPropagation::Propagate => DType::QFloat(scheme),
480 QuantPropagation::Inhibit => B::FloatElem::dtype(),
481 };
482
483 let client = match &lhs {
484 TensorPrimitive::Float(lhs) => lhs.client.clone(),
485 TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
486 };
487
488 let lhs = match lhs {
489 TensorPrimitive::Float(lhs) => lhs.into_ir(),
490 TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
491 };
492 let rhs = match rhs {
493 TensorPrimitive::Float(rhs) => rhs.into_ir(),
494 TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
495 };
496
497 let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle());
498
499 let out = client
500 .register(
501 streams,
502 OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
503 MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
504 )
505 .output();
506
507 match propagation {
508 QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
509 QuantPropagation::Inhibit => TensorPrimitive::Float(out),
510 }
511 }
512}