1use std::marker::PhantomData;
2
3use burn_ir::{
4 BaseOperationIr, BinaryOpIr, DequantizeOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr,
5 GatherOpIr, HandleContainer, InitOperationIr, NumericOperationIr, OperationIr, PermuteOpIr,
6 QuantizationParametersIr, QuantizeOpIr, SelectOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,
7};
8use burn_tensor::{
9 DType, Device, Element, Shape, Slice, TensorData, TensorMetadata, TensorPrimitive,
10 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
11 quantization::{
12 QTensorPrimitive, QuantPropagation, QuantScheme, QuantizationParametersPrimitive,
13 },
14};
15
16use crate::{
17 Fusion, FusionBackend, get_client,
18 stream::{OperationStreams, StreamId, execution::Operation},
19};
20
21use super::NoOp;
22
23impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
24 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
25 let stream = StreamId::current();
26 let client = get_client::<B>(&device.clone());
27 let dtype = data.dtype;
28 let tensor = B::q_from_data(data, device);
29 let shape = tensor.shape();
30
31 let handle = B::quantized_tensor_handle(tensor);
32 let out = client.register_tensor(handle, shape, stream, dtype);
33 let desc = out.to_ir_out();
34
35 client.register(
36 OperationStreams::default(),
37 OperationIr::Init(InitOperationIr { out: desc }),
38 NoOp::<B>::new(),
39 );
40
41 out
42 }
43
44 fn quantize(
45 tensor: FloatTensor<Self>,
46 scheme: &QuantScheme,
47 qparams: QuantizationParametersPrimitive<Self>,
48 ) -> QuantizedTensor<Self> {
49 #[derive(new, Debug)]
50 struct QuantizeOp<B: FusionBackend> {
51 desc: QuantizeOpIr,
52 _b: PhantomData<B>,
53 }
54
55 impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
56 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57 let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
58 let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
59
60 let qparams = QuantizationParametersPrimitive { scales };
61 let output = B::quantize(tensor, &self.desc.scheme, qparams);
62 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
63 }
64 }
65
66 let shape = tensor.shape.clone();
67 let dtype = tensor.dtype;
68 let out = tensor
69 .client
70 .tensor_uninitialized(shape, DType::QFloat(*scheme));
71
72 let mut streams = OperationStreams::default();
73 streams.tensor(&tensor);
74 streams.tensor(&qparams.scales);
75
76 let desc = QuantizeOpIr {
77 tensor: tensor.into_ir(),
78 qparams: QuantizationParametersIr {
79 scales: qparams.scales.clone().into_ir(),
80 },
81 scheme: *scheme,
82 out: out.to_ir_out(),
83 };
84
85 out.client.register(
86 streams,
87 OperationIr::Float(dtype, FloatOperationIr::Quantize(desc.clone())),
88 QuantizeOp::<B>::new(desc),
89 );
90
91 out
92 }
93
94 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
95 #[derive(new, Debug)]
96 struct DequantizeOp<B: FusionBackend> {
97 desc: DequantizeOpIr,
98 _b: PhantomData<B>,
99 }
100
101 impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
102 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
103 let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
104
105 let output = B::dequantize(tensor);
106 handles.register_float_tensor::<B>(&self.desc.out.id, output);
107 }
108 }
109
110 let mut streams = OperationStreams::default();
111 streams.tensor(&tensor);
112
113 let shape = tensor.shape.clone();
114 let dtype = B::FloatElem::dtype();
115 let out = tensor.client.tensor_uninitialized(shape, dtype);
116
117 let desc = DequantizeOpIr {
118 input: tensor.into_ir(),
119 out: out.to_ir_out(),
120 };
121
122 out.client.register(
123 streams,
124 OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
125 DequantizeOp::<B>::new(desc),
126 );
127
128 out
129 }
130
131 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
132 tensor.client.device().clone()
133 }
134
135 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
136 let device_original: &B::Device = tensor.client.device();
137 let device_target: B::Device = device.clone();
138
139 if device_original == &device_target {
140 return tensor;
141 }
142
143 let id = tensor.stream;
144 let client_target = get_client::<B>(&device_target);
145 let client_original = tensor.client.clone();
146
147 client_original.change_client_quantized::<B>(tensor.into_ir(), client_target, id)
148 }
149
150 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
151 if tensor.shape == shape {
152 return tensor;
153 }
154
155 #[derive(new, Debug)]
156 struct ReshapeDimsOps<B: FusionBackend> {
157 desc: UnaryOpIr,
158 _b: PhantomData<B>,
159 }
160
161 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
162 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
163 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
164 let output = B::q_reshape(input, self.desc.out.shape.clone());
165 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
166 }
167 }
168
169 let mut streams = OperationStreams::default();
170 streams.tensor(&tensor);
171
172 let dtype = tensor.dtype;
173 let out = tensor.client.tensor_uninitialized(shape, dtype);
174
175 let desc = UnaryOpIr {
176 input: tensor.into_ir(),
177 out: out.to_ir_out(),
178 };
179 out.client.register(
180 streams,
181 OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
182 ReshapeDimsOps::<B>::new(desc),
183 );
184
185 out
186 }
187
188 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
189 tensor.q_into_data::<B>().await
190 }
191
192 fn q_swap_dims(
193 tensor: QuantizedTensor<Self>,
194 dim1: usize,
195 dim2: usize,
196 ) -> QuantizedTensor<Self> {
197 #[derive(new, Debug)]
198 struct SwapDimsOps<B: FusionBackend> {
199 desc: SwapDimsOpIr,
200 _b: PhantomData<B>,
201 }
202
203 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
204 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
205 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
206 let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
207 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
208 }
209 }
210
211 let mut streams = OperationStreams::default();
212 streams.tensor(&tensor);
213
214 let dtype = tensor.dtype;
215 let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
216
217 let mut out = tensor.client.tensor_uninitialized(shape, dtype);
218
219 let desc = SwapDimsOpIr {
220 input: tensor.into_ir(),
221 dim1,
222 dim2,
223 out: out.to_ir_out(),
224 };
225 out.client.register(
226 streams,
227 OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
228 SwapDimsOps::<B>::new(desc),
229 );
230 out.stream = StreamId::current();
231
232 out
233 }
234
235 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
236 #[derive(new, Debug)]
237 struct PermuteDimsOps<B: FusionBackend> {
238 desc: PermuteOpIr,
239 _b: PhantomData<B>,
240 }
241
242 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
243 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
244 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
245 let output = B::q_permute(input, self.desc.axes.as_slice());
246 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
247 }
248 }
249
250 let mut streams = OperationStreams::default();
251 streams.tensor(&tensor);
252
253 let shape = tensor.shape.clone().permute(axes).unwrap();
255
256 let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
257
258 let desc = PermuteOpIr {
259 input: tensor.into_ir(),
260 axes: axes.to_vec(),
261 out: out.to_ir_out(),
262 };
263
264 out.client.register(
265 streams,
266 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
267 PermuteDimsOps::<B>::new(desc),
268 );
269
270 out
271 }
272
273 fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
274 #[derive(new, Debug)]
275 struct FlipOps<B: FusionBackend> {
276 desc: FlipOpIr,
277 _b: PhantomData<B>,
278 }
279
280 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
281 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
282 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
283 let output = B::q_flip(input, &self.desc.axes);
284 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
285 }
286 }
287
288 let mut streams = OperationStreams::default();
289 streams.tensor(&tensor);
290 let out = tensor
291 .client
292 .tensor_uninitialized(tensor.shape.clone(), tensor.dtype);
293
294 let desc = FlipOpIr {
295 input: tensor.into_ir(),
296 axes: axes.to_vec(),
297 out: out.to_ir_out(),
298 };
299
300 out.client.register(
301 streams,
302 OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
303 FlipOps::<B>::new(desc),
304 );
305
306 out
307 }
308
309 fn q_gather(
310 dim: usize,
311 tensor: QuantizedTensor<Self>,
312 indices: IntTensor<Self>,
313 ) -> QuantizedTensor<Self> {
314 #[derive(new, Debug)]
315 struct GatherOps<B: FusionBackend> {
316 desc: GatherOpIr,
317 _b: PhantomData<B>,
318 }
319
320 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
321 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
322 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
323 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
324
325 let output = B::q_gather(self.desc.dim, tensor, indices);
326 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
327 }
328 }
329
330 let mut streams = OperationStreams::default();
331 streams.tensor(&tensor);
332 streams.tensor(&indices);
333
334 let dtype = tensor.dtype;
335 let shape = indices.shape.clone();
336 let out = tensor.client.tensor_uninitialized(shape, dtype);
337
338 let desc = GatherOpIr {
339 tensor: tensor.into_ir(),
340 dim,
341 indices: indices.into_ir(),
342 out: out.to_ir_out(),
343 };
344 out.client.register(
345 streams,
346 OperationIr::NumericFloat(dtype, NumericOperationIr::Gather(desc.clone())),
347 GatherOps::<B>::new(desc),
348 );
349
350 out
351 }
352
353 fn q_select(
354 tensor: QuantizedTensor<Self>,
355 dim: usize,
356 indices: IntTensor<Self>,
357 ) -> QuantizedTensor<Self> {
358 #[derive(new, Debug)]
359 struct SelectOps<B: FusionBackend> {
360 desc: SelectOpIr,
361 _b: PhantomData<B>,
362 }
363
364 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
365 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
366 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
367 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
368
369 let output = B::q_select(tensor, self.desc.dim, indices);
370
371 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
372 }
373 }
374
375 let mut streams = OperationStreams::default();
376 streams.tensor(&tensor);
377 streams.tensor(&indices);
378
379 let dtype = tensor.dtype;
380 let mut shape = tensor.shape.clone();
381 shape[dim] = indices.shape[0];
382 let out = tensor.client.tensor_uninitialized(shape, dtype);
383 let desc = SelectOpIr {
384 tensor: tensor.into_ir(),
385 dim,
386 indices: indices.into_ir(),
387 out: out.to_ir_out(),
388 };
389 out.client.register(
390 streams,
391 OperationIr::NumericFloat(dtype, NumericOperationIr::Select(desc.clone())),
392 SelectOps::<B>::new(desc),
393 );
394
395 out
396 }
397
398 fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
399 #[derive(new, Debug)]
400 struct SliceOps<B: FusionBackend> {
401 desc: SliceOpIr,
402 _b: PhantomData<B>,
403 }
404
405 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
406 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
407 let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
408
409 let output = B::q_slice(tensor, self.desc.ranges.as_slice());
410
411 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
412 }
413 }
414 let mut streams = OperationStreams::default();
415 streams.tensor(&tensor);
416 let dtype = tensor.dtype;
417 let shape = tensor.shape.clone().slice(slices).unwrap();
418
419 let out = tensor.client.tensor_uninitialized(shape, dtype);
420
421 let desc = SliceOpIr {
422 tensor: tensor.into_ir(),
423 ranges: slices.into(),
424 out: out.to_ir_out(),
425 };
426 out.client.register(
427 streams,
428 OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
429 SliceOps::<B>::new(desc),
430 );
431
432 out
433 }
434
435 fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
436 #[derive(new, Debug)]
437 struct ExpandOps<B: FusionBackend> {
438 desc: ExpandOpIr,
439 _b: PhantomData<B>,
440 }
441
442 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
443 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
444 let input = handles.get_quantized_tensor::<B>(&self.desc.input);
445 let output = B::q_expand(input, self.desc.shape.clone());
446
447 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
448 }
449 }
450
451 let mut streams = OperationStreams::default();
452 streams.tensor(&tensor);
453
454 let out = tensor
455 .client
456 .tensor_uninitialized(shape.clone(), tensor.dtype);
457
458 let desc = ExpandOpIr {
459 input: tensor.into_ir(),
460 shape,
461 out: out.to_ir_out(),
462 };
463
464 out.client.register(
465 streams,
466 OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
467 ExpandOps::<B>::new(desc),
468 );
469
470 out
471 }
472
473 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
474 #[derive(new, Debug)]
475 struct MatmulOps<B: FusionBackend> {
476 desc: BinaryOpIr,
477 lhs_quantized: bool,
478 rhs_quantized: bool,
479 _b: PhantomData<B>,
480 }
481
482 impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
483 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
484 let lhs = match self.lhs_quantized {
485 true => {
486 TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
487 }
488 false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
489 };
490 let rhs = match self.rhs_quantized {
491 true => {
492 TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
493 }
494 false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
495 };
496 let output = B::q_matmul(lhs, rhs);
497 match output {
498 TensorPrimitive::Float(output) => {
499 handles.register_float_tensor::<B>(&self.desc.out.id, output);
500 }
501 TensorPrimitive::QFloat(output) => {
502 handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
503 }
504 }
505 }
506 }
507
508 let mut propagation = QuantPropagation::Inhibit;
509 let mut scheme = QuantScheme::default();
510 let mut streams = OperationStreams::default();
511 let mut lhs_quantized = false;
512 let mut rhs_quantized = false;
513 match &lhs {
514 TensorPrimitive::QFloat(lhs) => {
515 propagation = lhs.propagation();
516 scheme = *lhs.scheme();
517 lhs_quantized = true;
518 streams.tensor(lhs);
519 }
520 TensorPrimitive::Float(lhs) => {
521 streams.tensor(lhs);
522 }
523 }
524 match &rhs {
525 TensorPrimitive::QFloat(rhs) => {
526 propagation = rhs.propagation();
527 scheme = *rhs.scheme();
528 rhs_quantized = true;
529 streams.tensor(rhs);
530 }
531 TensorPrimitive::Float(rhs) => {
532 streams.tensor(rhs);
533 }
534 }
535
536 let dtype = match propagation {
537 QuantPropagation::Propagate => DType::QFloat(scheme),
538 QuantPropagation::Inhibit => B::FloatElem::dtype(),
539 };
540 let shape = Shape::matmul(&lhs.shape(), &rhs.shape()).unwrap();
541
542 let client = match &lhs {
543 TensorPrimitive::Float(lhs) => lhs.client.clone(),
544 TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
545 };
546
547 let lhs = match lhs {
548 TensorPrimitive::Float(lhs) => lhs.into_ir(),
549 TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
550 };
551 let rhs = match rhs {
552 TensorPrimitive::Float(rhs) => rhs.into_ir(),
553 TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
554 };
555
556 let out = client.tensor_uninitialized(shape, dtype);
557 let desc = BinaryOpIr {
558 lhs,
559 rhs,
560 out: out.to_ir_out(),
561 };
562
563 out.client.register(
564 streams,
565 OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
566 MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
567 );
568
569 match propagation {
570 QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
571 QuantPropagation::Inhibit => TensorPrimitive::Float(out),
572 }
573 }
574}