burn_fusion/ops/
qtensor.rs

1use std::marker::PhantomData;
2
3use burn_ir::{
4    BaseOperationIr, BinaryOpIr, DequantizeOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr,
5    GatherOpIr, HandleContainer, InitOperationIr, NumericOperationIr, OperationIr, PermuteOpIr,
6    QuantizationParametersIr, QuantizeOpIr, SelectOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,
7};
8use burn_tensor::{
9    DType, Device, Element, Shape, Slice, TensorData, TensorMetadata, TensorPrimitive,
10    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
11    quantization::{
12        QTensorPrimitive, QuantPropagation, QuantScheme, QuantizationParametersPrimitive,
13    },
14};
15
16use crate::{
17    Fusion, FusionBackend, get_client,
18    stream::{OperationStreams, StreamId, execution::Operation},
19};
20
21use super::NoOp;
22
23impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
24    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
25        let stream = StreamId::current();
26        let client = get_client::<B>(&device.clone());
27        let dtype = data.dtype;
28        let tensor = B::q_from_data(data, device);
29        let shape = tensor.shape();
30
31        let handle = B::quantized_tensor_handle(tensor);
32        let out = client.register_tensor(handle, shape, stream, dtype);
33        let desc = out.to_ir_out();
34
35        client.register(
36            OperationStreams::default(),
37            OperationIr::Init(InitOperationIr { out: desc }),
38            NoOp::<B>::new(),
39        );
40
41        out
42    }
43
44    fn quantize(
45        tensor: FloatTensor<Self>,
46        scheme: &QuantScheme,
47        qparams: QuantizationParametersPrimitive<Self>,
48    ) -> QuantizedTensor<Self> {
49        #[derive(new, Debug)]
50        struct QuantizeOp<B: FusionBackend> {
51            desc: QuantizeOpIr,
52            _b: PhantomData<B>,
53        }
54
55        impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
56            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
58                let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
59
60                let qparams = QuantizationParametersPrimitive { scales };
61                let output = B::quantize(tensor, &self.desc.scheme, qparams);
62                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
63            }
64        }
65
66        let shape = tensor.shape.clone();
67        let dtype = tensor.dtype;
68        let out = tensor
69            .client
70            .tensor_uninitialized(shape, DType::QFloat(*scheme));
71
72        let mut streams = OperationStreams::default();
73        streams.tensor(&tensor);
74        streams.tensor(&qparams.scales);
75
76        let desc = QuantizeOpIr {
77            tensor: tensor.into_ir(),
78            qparams: QuantizationParametersIr {
79                scales: qparams.scales.clone().into_ir(),
80            },
81            scheme: *scheme,
82            out: out.to_ir_out(),
83        };
84
85        out.client.register(
86            streams,
87            OperationIr::Float(dtype, FloatOperationIr::Quantize(desc.clone())),
88            QuantizeOp::<B>::new(desc),
89        );
90
91        out
92    }
93
94    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
95        #[derive(new, Debug)]
96        struct DequantizeOp<B: FusionBackend> {
97            desc: DequantizeOpIr,
98            _b: PhantomData<B>,
99        }
100
101        impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
102            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
103                let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
104
105                let output = B::dequantize(tensor);
106                handles.register_float_tensor::<B>(&self.desc.out.id, output);
107            }
108        }
109
110        let mut streams = OperationStreams::default();
111        streams.tensor(&tensor);
112
113        let shape = tensor.shape.clone();
114        let dtype = B::FloatElem::dtype();
115        let out = tensor.client.tensor_uninitialized(shape, dtype);
116
117        let desc = DequantizeOpIr {
118            input: tensor.into_ir(),
119            out: out.to_ir_out(),
120        };
121
122        out.client.register(
123            streams,
124            OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
125            DequantizeOp::<B>::new(desc),
126        );
127
128        out
129    }
130
131    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
132        tensor.client.device().clone()
133    }
134
135    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
136        let device_original: &B::Device = tensor.client.device();
137        let device_target: B::Device = device.clone();
138
139        if device_original == &device_target {
140            return tensor;
141        }
142
143        let id = tensor.stream;
144        let client_target = get_client::<B>(&device_target);
145        let client_original = tensor.client.clone();
146
147        client_original.change_client_quantized::<B>(tensor.into_ir(), client_target, id)
148    }
149
150    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
151        if tensor.shape == shape {
152            return tensor;
153        }
154
155        #[derive(new, Debug)]
156        struct ReshapeDimsOps<B: FusionBackend> {
157            desc: UnaryOpIr,
158            _b: PhantomData<B>,
159        }
160
161        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
162            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
163                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
164                let output = B::q_reshape(input, self.desc.out.shape.clone());
165                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
166            }
167        }
168
169        let mut streams = OperationStreams::default();
170        streams.tensor(&tensor);
171
172        let dtype = tensor.dtype;
173        let out = tensor.client.tensor_uninitialized(shape, dtype);
174
175        let desc = UnaryOpIr {
176            input: tensor.into_ir(),
177            out: out.to_ir_out(),
178        };
179        out.client.register(
180            streams,
181            OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
182            ReshapeDimsOps::<B>::new(desc),
183        );
184
185        out
186    }
187
188    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
189        tensor.q_into_data::<B>().await
190    }
191
192    fn q_swap_dims(
193        tensor: QuantizedTensor<Self>,
194        dim1: usize,
195        dim2: usize,
196    ) -> QuantizedTensor<Self> {
197        #[derive(new, Debug)]
198        struct SwapDimsOps<B: FusionBackend> {
199            desc: SwapDimsOpIr,
200            _b: PhantomData<B>,
201        }
202
203        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
204            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
205                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
206                let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
207                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
208            }
209        }
210
211        let mut streams = OperationStreams::default();
212        streams.tensor(&tensor);
213
214        let dtype = tensor.dtype;
215        let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
216
217        let mut out = tensor.client.tensor_uninitialized(shape, dtype);
218
219        let desc = SwapDimsOpIr {
220            input: tensor.into_ir(),
221            dim1,
222            dim2,
223            out: out.to_ir_out(),
224        };
225        out.client.register(
226            streams,
227            OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
228            SwapDimsOps::<B>::new(desc),
229        );
230        out.stream = StreamId::current();
231
232        out
233    }
234
235    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
236        #[derive(new, Debug)]
237        struct PermuteDimsOps<B: FusionBackend> {
238            desc: PermuteOpIr,
239            _b: PhantomData<B>,
240        }
241
242        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
243            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
244                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
245                let output = B::q_permute(input, self.desc.axes.as_slice());
246                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
247            }
248        }
249
250        let mut streams = OperationStreams::default();
251        streams.tensor(&tensor);
252
253        // Change the shape of the tensor to match the new axes
254        let shape = tensor.shape.clone().permute(axes).unwrap();
255
256        let out = tensor.client.tensor_uninitialized(shape, tensor.dtype);
257
258        let desc = PermuteOpIr {
259            input: tensor.into_ir(),
260            axes: axes.to_vec(),
261            out: out.to_ir_out(),
262        };
263
264        out.client.register(
265            streams,
266            OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
267            PermuteDimsOps::<B>::new(desc),
268        );
269
270        out
271    }
272
273    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
274        #[derive(new, Debug)]
275        struct FlipOps<B: FusionBackend> {
276            desc: FlipOpIr,
277            _b: PhantomData<B>,
278        }
279
280        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
281            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
282                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
283                let output = B::q_flip(input, &self.desc.axes);
284                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
285            }
286        }
287
288        let mut streams = OperationStreams::default();
289        streams.tensor(&tensor);
290        let out = tensor
291            .client
292            .tensor_uninitialized(tensor.shape.clone(), tensor.dtype);
293
294        let desc = FlipOpIr {
295            input: tensor.into_ir(),
296            axes: axes.to_vec(),
297            out: out.to_ir_out(),
298        };
299
300        out.client.register(
301            streams,
302            OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
303            FlipOps::<B>::new(desc),
304        );
305
306        out
307    }
308
309    fn q_gather(
310        dim: usize,
311        tensor: QuantizedTensor<Self>,
312        indices: IntTensor<Self>,
313    ) -> QuantizedTensor<Self> {
314        #[derive(new, Debug)]
315        struct GatherOps<B: FusionBackend> {
316            desc: GatherOpIr,
317            _b: PhantomData<B>,
318        }
319
320        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
321            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
322                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
323                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
324
325                let output = B::q_gather(self.desc.dim, tensor, indices);
326                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
327            }
328        }
329
330        let mut streams = OperationStreams::default();
331        streams.tensor(&tensor);
332        streams.tensor(&indices);
333
334        let dtype = tensor.dtype;
335        let shape = indices.shape.clone();
336        let out = tensor.client.tensor_uninitialized(shape, dtype);
337
338        let desc = GatherOpIr {
339            tensor: tensor.into_ir(),
340            dim,
341            indices: indices.into_ir(),
342            out: out.to_ir_out(),
343        };
344        out.client.register(
345            streams,
346            OperationIr::NumericFloat(dtype, NumericOperationIr::Gather(desc.clone())),
347            GatherOps::<B>::new(desc),
348        );
349
350        out
351    }
352
353    fn q_select(
354        tensor: QuantizedTensor<Self>,
355        dim: usize,
356        indices: IntTensor<Self>,
357    ) -> QuantizedTensor<Self> {
358        #[derive(new, Debug)]
359        struct SelectOps<B: FusionBackend> {
360            desc: SelectOpIr,
361            _b: PhantomData<B>,
362        }
363
364        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
365            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
366                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
367                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
368
369                let output = B::q_select(tensor, self.desc.dim, indices);
370
371                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
372            }
373        }
374
375        let mut streams = OperationStreams::default();
376        streams.tensor(&tensor);
377        streams.tensor(&indices);
378
379        let dtype = tensor.dtype;
380        let mut shape = tensor.shape.clone();
381        shape[dim] = indices.shape[0];
382        let out = tensor.client.tensor_uninitialized(shape, dtype);
383        let desc = SelectOpIr {
384            tensor: tensor.into_ir(),
385            dim,
386            indices: indices.into_ir(),
387            out: out.to_ir_out(),
388        };
389        out.client.register(
390            streams,
391            OperationIr::NumericFloat(dtype, NumericOperationIr::Select(desc.clone())),
392            SelectOps::<B>::new(desc),
393        );
394
395        out
396    }
397
398    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
399        #[derive(new, Debug)]
400        struct SliceOps<B: FusionBackend> {
401            desc: SliceOpIr,
402            _b: PhantomData<B>,
403        }
404
405        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
406            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
407                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
408
409                let output = B::q_slice(tensor, self.desc.ranges.as_slice());
410
411                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
412            }
413        }
414        let mut streams = OperationStreams::default();
415        streams.tensor(&tensor);
416        let dtype = tensor.dtype;
417        let shape = tensor.shape.clone().slice(slices).unwrap();
418
419        let out = tensor.client.tensor_uninitialized(shape, dtype);
420
421        let desc = SliceOpIr {
422            tensor: tensor.into_ir(),
423            ranges: slices.into(),
424            out: out.to_ir_out(),
425        };
426        out.client.register(
427            streams,
428            OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
429            SliceOps::<B>::new(desc),
430        );
431
432        out
433    }
434
435    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
436        #[derive(new, Debug)]
437        struct ExpandOps<B: FusionBackend> {
438            desc: ExpandOpIr,
439            _b: PhantomData<B>,
440        }
441
442        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
443            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
444                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
445                let output = B::q_expand(input, self.desc.shape.clone());
446
447                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
448            }
449        }
450
451        let mut streams = OperationStreams::default();
452        streams.tensor(&tensor);
453
454        let out = tensor
455            .client
456            .tensor_uninitialized(shape.clone(), tensor.dtype);
457
458        let desc = ExpandOpIr {
459            input: tensor.into_ir(),
460            shape,
461            out: out.to_ir_out(),
462        };
463
464        out.client.register(
465            streams,
466            OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
467            ExpandOps::<B>::new(desc),
468        );
469
470        out
471    }
472
473    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
474        #[derive(new, Debug)]
475        struct MatmulOps<B: FusionBackend> {
476            desc: BinaryOpIr,
477            lhs_quantized: bool,
478            rhs_quantized: bool,
479            _b: PhantomData<B>,
480        }
481
482        impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
483            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
484                let lhs = match self.lhs_quantized {
485                    true => {
486                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
487                    }
488                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
489                };
490                let rhs = match self.rhs_quantized {
491                    true => {
492                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
493                    }
494                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
495                };
496                let output = B::q_matmul(lhs, rhs);
497                match output {
498                    TensorPrimitive::Float(output) => {
499                        handles.register_float_tensor::<B>(&self.desc.out.id, output);
500                    }
501                    TensorPrimitive::QFloat(output) => {
502                        handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
503                    }
504                }
505            }
506        }
507
508        let mut propagation = QuantPropagation::Inhibit;
509        let mut scheme = QuantScheme::default();
510        let mut streams = OperationStreams::default();
511        let mut lhs_quantized = false;
512        let mut rhs_quantized = false;
513        match &lhs {
514            TensorPrimitive::QFloat(lhs) => {
515                propagation = lhs.propagation();
516                scheme = *lhs.scheme();
517                lhs_quantized = true;
518                streams.tensor(lhs);
519            }
520            TensorPrimitive::Float(lhs) => {
521                streams.tensor(lhs);
522            }
523        }
524        match &rhs {
525            TensorPrimitive::QFloat(rhs) => {
526                propagation = rhs.propagation();
527                scheme = *rhs.scheme();
528                rhs_quantized = true;
529                streams.tensor(rhs);
530            }
531            TensorPrimitive::Float(rhs) => {
532                streams.tensor(rhs);
533            }
534        }
535
536        let dtype = match propagation {
537            QuantPropagation::Propagate => DType::QFloat(scheme),
538            QuantPropagation::Inhibit => B::FloatElem::dtype(),
539        };
540        let shape = Shape::matmul(&lhs.shape(), &rhs.shape()).unwrap();
541
542        let client = match &lhs {
543            TensorPrimitive::Float(lhs) => lhs.client.clone(),
544            TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
545        };
546
547        let lhs = match lhs {
548            TensorPrimitive::Float(lhs) => lhs.into_ir(),
549            TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
550        };
551        let rhs = match rhs {
552            TensorPrimitive::Float(rhs) => rhs.into_ir(),
553            TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
554        };
555
556        let out = client.tensor_uninitialized(shape, dtype);
557        let desc = BinaryOpIr {
558            lhs,
559            rhs,
560            out: out.to_ir_out(),
561        };
562
563        out.client.register(
564            streams,
565            OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
566            MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
567        );
568
569        match propagation {
570            QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
571            QuantPropagation::Inhibit => TensorPrimitive::Float(out),
572        }
573    }
574}