Skip to main content

burn_fusion/ops/
qtensor.rs

1use std::marker::PhantomData;
2
3use burn_backend::{
4    DType, Element, ExecutionError, FloatDType, QTensorPrimitive, Shape, Slice, TensorData,
5    TensorPrimitive,
6    ops::QTensorOps,
7    quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},
8    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
9};
10use burn_ir::{
11    BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer,
12    InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr,
13    QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr,
14};
15
16use crate::{
17    Fusion, FusionBackend,
18    client::GlobalFusionClient,
19    get_client,
20    stream::{OperationStreams, execution::Operation},
21};
22
23use super::NoOp;
24
25impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
26    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
27        let client = get_client::<B>(device);
28        let dtype = data.dtype;
29        let tensor = B::q_from_data(data, device);
30        let shape = burn_backend::TensorMetadata::shape(&tensor);
31
32        let handle = B::quantized_tensor_handle(tensor);
33        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
34
35        client
36            .register(
37                OperationStreams::default(),
38                OperationIr::Init(desc),
39                NoOp::<B>::new(),
40            )
41            .output()
42    }
43
44    fn quantize(
45        tensor: FloatTensor<Self>,
46        scheme: &QuantScheme,
47        qparams: QuantizationParametersPrimitive<Self>,
48    ) -> QuantizedTensor<Self> {
49        #[derive(new, Debug)]
50        struct QuantizeOp<B: FusionBackend> {
51            desc: QuantizeOpIr,
52            _b: PhantomData<B>,
53        }
54
55        impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {
56            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);
58                let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);
59
60                let qparams = QuantizationParametersPrimitive { scales };
61                let output = B::quantize(tensor, &self.desc.scheme, qparams);
62                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
63            }
64        }
65
66        let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]);
67
68        let client = tensor.client.clone();
69        let qparams = QuantizationParametersIr {
70            scales: qparams.scales.into_ir(),
71        };
72        let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || {
73            client.create_empty_handle()
74        });
75
76        client
77            .register(
78                streams,
79                OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())),
80                QuantizeOp::<B>::new(desc),
81            )
82            .output()
83    }
84
85    fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
86        #[derive(new, Debug)]
87        struct DequantizeOp<B: FusionBackend> {
88            desc: DequantizeOpIr,
89            _b: PhantomData<B>,
90        }
91
92        impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {
93            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
94                let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);
95
96                let output = B::dequantize(tensor, self.desc.out.dtype.into());
97                handles.register_float_tensor::<B>(&self.desc.out.id, output);
98            }
99        }
100
101        let streams = OperationStreams::with_inputs([&tensor]);
102
103        let client = tensor.client.clone();
104        let dtype = dtype.into();
105        let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle());
106
107        client
108            .register(
109                streams,
110                OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),
111                DequantizeOp::<B>::new(desc),
112            )
113            .output()
114    }
115
116    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
117        tensor.client.device().clone()
118    }
119
120    fn q_to_device(
121        tensor: QuantizedTensor<Self>,
122        device_dst: &Device<Self>,
123    ) -> QuantizedTensor<Self> {
124        let device_src: &B::Device = tensor.client.device();
125        let device_dst: B::Device = device_dst.clone();
126
127        if device_src == &device_dst {
128            return tensor;
129        }
130
131        let id = tensor.stream;
132        let client_dst = get_client::<B>(&device_dst);
133        let client_src = tensor.client.clone();
134
135        GlobalFusionClient::change_client_quantized::<B>(
136            tensor.into_ir(),
137            client_src,
138            client_dst,
139            id,
140        )
141    }
142
143    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
144        if tensor.shape == shape {
145            return tensor;
146        }
147
148        #[derive(new, Debug)]
149        struct ReshapeDimsOps<B: FusionBackend> {
150            desc: ShapeOpIr,
151            _b: PhantomData<B>,
152        }
153
154        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
155            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
156                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
157                let output = B::q_reshape(input, self.desc.out.shape.clone());
158                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
159            }
160        }
161
162        let streams = OperationStreams::with_inputs([&tensor]);
163
164        let client = tensor.client.clone();
165        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
166
167        client
168            .register(
169                streams,
170                OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),
171                ReshapeDimsOps::<B>::new(desc),
172            )
173            .output()
174    }
175
176    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
177        tensor.q_into_data::<B>().await
178    }
179
180    fn q_swap_dims(
181        tensor: QuantizedTensor<Self>,
182        dim1: usize,
183        dim2: usize,
184    ) -> QuantizedTensor<Self> {
185        #[derive(new, Debug)]
186        struct SwapDimsOps<B: FusionBackend> {
187            desc: SwapDimsOpIr,
188            _b: PhantomData<B>,
189        }
190
191        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
192            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
193                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
194                let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);
195                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
196            }
197        }
198
199        let streams = OperationStreams::with_inputs([&tensor]);
200
201        let client = tensor.client.clone();
202        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
203            client.create_empty_handle()
204        });
205
206        client
207            .register(
208                streams,
209                OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),
210                SwapDimsOps::<B>::new(desc),
211            )
212            .output()
213    }
214
215    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
216        #[derive(new, Debug)]
217        struct PermuteDimsOps<B: FusionBackend> {
218            desc: PermuteOpIr,
219            _b: PhantomData<B>,
220        }
221
222        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
223            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
224                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
225                let output = B::q_permute(input, self.desc.axes.as_slice());
226                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
227            }
228        }
229
230        let streams = OperationStreams::with_inputs([&tensor]);
231
232        let client = tensor.client.clone();
233        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
234            client.create_empty_handle()
235        });
236
237        client
238            .register(
239                streams,
240                OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())),
241                PermuteDimsOps::<B>::new(desc),
242            )
243            .output()
244    }
245
246    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
247        #[derive(new, Debug)]
248        struct FlipOps<B: FusionBackend> {
249            desc: FlipOpIr,
250            _b: PhantomData<B>,
251        }
252
253        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
254            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
255                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
256                let output = B::q_flip(input, &self.desc.axes);
257                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
258            }
259        }
260
261        let streams = OperationStreams::with_inputs([&tensor]);
262
263        let client = tensor.client.clone();
264        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
265            client.create_empty_handle()
266        });
267
268        client
269            .register(
270                streams,
271                OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())),
272                FlipOps::<B>::new(desc),
273            )
274            .output()
275    }
276
277    fn q_gather(
278        dim: usize,
279        tensor: QuantizedTensor<Self>,
280        indices: IntTensor<Self>,
281    ) -> QuantizedTensor<Self> {
282        #[derive(new, Debug)]
283        struct GatherOps<B: FusionBackend> {
284            desc: GatherOpIr,
285            _b: PhantomData<B>,
286        }
287
288        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
289            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
290                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
291                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
292
293                let output = B::q_gather(self.desc.dim, tensor, indices);
294                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
295            }
296        }
297
298        let streams = OperationStreams::with_inputs([&tensor]);
299
300        let client = tensor.client.clone();
301        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
302            client.create_empty_handle()
303        });
304
305        client
306            .register(
307                streams,
308                OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),
309                GatherOps::<B>::new(desc),
310            )
311            .output()
312    }
313
314    fn q_select(
315        tensor: QuantizedTensor<Self>,
316        dim: usize,
317        indices: IntTensor<Self>,
318    ) -> QuantizedTensor<Self> {
319        #[derive(new, Debug)]
320        struct SelectOps<B: FusionBackend> {
321            desc: SelectOpIr,
322            _b: PhantomData<B>,
323        }
324
325        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
326            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
327                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
328                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
329
330                let output = B::q_select(tensor, self.desc.dim, indices);
331
332                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
333            }
334        }
335
336        let streams = OperationStreams::with_inputs([&tensor]);
337
338        let client = tensor.client.clone();
339        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
340            client.create_empty_handle()
341        });
342
343        client
344            .register(
345                streams,
346                OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),
347                SelectOps::<B>::new(desc),
348            )
349            .output()
350    }
351
352    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {
353        #[derive(new, Debug)]
354        struct SliceOps<B: FusionBackend> {
355            desc: SliceOpIr,
356            _b: PhantomData<B>,
357        }
358
359        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
360            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
361                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);
362
363                let output = B::q_slice(tensor, self.desc.ranges.as_slice());
364
365                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
366            }
367        }
368
369        let streams = OperationStreams::with_inputs([&tensor]);
370
371        let client = tensor.client.clone();
372        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
373            client.create_empty_handle()
374        });
375
376        client
377            .register(
378                streams,
379                OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),
380                SliceOps::<B>::new(desc),
381            )
382            .output()
383    }
384
385    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
386        #[derive(new, Debug)]
387        struct ExpandOps<B: FusionBackend> {
388            desc: ShapeOpIr,
389            _b: PhantomData<B>,
390        }
391
392        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
393            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
394                let input = handles.get_quantized_tensor::<B>(&self.desc.input);
395                let output = B::q_expand(input, self.desc.out.shape.clone());
396
397                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
398            }
399        }
400
401        let streams = OperationStreams::with_inputs([&tensor]);
402
403        let client = tensor.client.clone();
404        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
405
406        client
407            .register(
408                streams,
409                OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),
410                ExpandOps::<B>::new(desc),
411            )
412            .output()
413    }
414
415    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
416        #[derive(new, Debug)]
417        struct MatmulOps<B: FusionBackend> {
418            desc: MatmulOpIr,
419            lhs_quantized: bool,
420            rhs_quantized: bool,
421            _b: PhantomData<B>,
422        }
423
424        impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {
425            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
426                let lhs = match self.lhs_quantized {
427                    true => {
428                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))
429                    }
430                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),
431                };
432                let rhs = match self.rhs_quantized {
433                    true => {
434                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))
435                    }
436                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),
437                };
438                let output = B::q_matmul(lhs, rhs);
439                match output {
440                    TensorPrimitive::Float(output) => {
441                        handles.register_float_tensor::<B>(&self.desc.out.id, output);
442                    }
443                    TensorPrimitive::QFloat(output) => {
444                        handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
445                    }
446                }
447            }
448        }
449
450        let mut propagation = QuantPropagation::Inhibit;
451        let mut scheme = QuantScheme::default();
452        let mut streams = OperationStreams::default();
453        let mut lhs_quantized = false;
454        let mut rhs_quantized = false;
455        match &lhs {
456            TensorPrimitive::QFloat(lhs) => {
457                propagation = lhs.propagation();
458                scheme = *lhs.scheme();
459                lhs_quantized = true;
460                streams.tensor(lhs);
461            }
462            TensorPrimitive::Float(lhs) => {
463                streams.tensor(lhs);
464            }
465        }
466        match &rhs {
467            TensorPrimitive::QFloat(rhs) => {
468                propagation = rhs.propagation();
469                scheme = *rhs.scheme();
470                rhs_quantized = true;
471                streams.tensor(rhs);
472            }
473            TensorPrimitive::Float(rhs) => {
474                streams.tensor(rhs);
475            }
476        }
477
478        let dtype = match propagation {
479            QuantPropagation::Propagate => DType::QFloat(scheme),
480            QuantPropagation::Inhibit => B::FloatElem::dtype(),
481        };
482
483        let client = match &lhs {
484            TensorPrimitive::Float(lhs) => lhs.client.clone(),
485            TensorPrimitive::QFloat(lhs) => lhs.client.clone(),
486        };
487
488        let lhs = match lhs {
489            TensorPrimitive::Float(lhs) => lhs.into_ir(),
490            TensorPrimitive::QFloat(lhs) => lhs.into_ir(),
491        };
492        let rhs = match rhs {
493            TensorPrimitive::Float(rhs) => rhs.into_ir(),
494            TensorPrimitive::QFloat(rhs) => rhs.into_ir(),
495        };
496
497        let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle());
498
499        let out = client
500            .register(
501                streams,
502                OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
503                MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),
504            )
505            .output();
506
507        match propagation {
508            QuantPropagation::Propagate => TensorPrimitive::QFloat(out),
509            QuantPropagation::Inhibit => TensorPrimitive::Float(out),
510        }
511    }
512}