burn_fusion/ops/
boolean.rs

1use burn_ir::{
2    BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3    InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4    SwapDimsOpIr, TensorIr, UnaryOpIr,
5};
6use burn_tensor::{
7    Device, Element, Shape, TensorData, TensorMetadata,
8    ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
9};
10use std::marker::PhantomData;
11
12use crate::{
13    Fusion, FusionBackend,
14    client::FusionClient,
15    get_client,
16    stream::{OperationStreams, StreamId, execution::Operation},
17};
18
19use super::NoOp;
20
21impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
22    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
23        #[derive(new, Debug)]
24        struct EmptyOps<B: FusionBackend> {
25            desc: TensorIr,
26            device: Device<B>,
27        }
28
29        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
30            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
31                let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device);
32                handles.register_bool_tensor::<B>(&self.desc.id, output);
33            }
34        }
35
36        let client = get_client::<B>(&device.clone());
37        let out = client.tensor_uninitialized(shape.dims.clone(), B::BoolElem::dtype());
38
39        let desc = out.to_ir_out();
40
41        client.register(
42            OperationStreams::default(),
43            OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
44            EmptyOps::<B>::new(desc, device.clone()),
45        );
46
47        out
48    }
49
50    async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
51        tensor.bool_into_data::<B>().await
52    }
53
54    fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
55        let stream = StreamId::current();
56        let client = get_client::<B>(&device.clone());
57        let tensor = B::bool_from_data(data, device);
58        let shape = tensor.shape();
59
60        let handle = B::bool_tensor_handle(tensor);
61        let out = client.register_tensor(handle, shape.dims, stream, B::BoolElem::dtype());
62        let desc = out.to_ir_out();
63
64        client.register(
65            OperationStreams::default(),
66            OperationIr::Init(InitOperationIr { out: desc }),
67            NoOp::<B>::new(),
68        );
69
70        out
71    }
72
73    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
74        #[derive(new, Debug)]
75        struct IntoIntOps<B: FusionBackend> {
76            desc: UnaryOpIr,
77            _b: PhantomData<B>,
78        }
79
80        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
81            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
82                let input = handles.get_bool_tensor::<B>(&self.desc.input);
83                let output = B::bool_into_int(input);
84                handles.register_int_tensor::<B>(&self.desc.out.id, output);
85            }
86        }
87
88        let mut streams = OperationStreams::default();
89        streams.tensor(&tensor);
90
91        let out = tensor
92            .client
93            .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
94
95        let desc = UnaryOpIr {
96            input: tensor.into_ir(),
97            out: out.to_ir_out(),
98        };
99
100        out.client.register(
101            streams,
102            OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
103            IntoIntOps::<B>::new(desc),
104        );
105
106        out
107    }
108
109    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
110        #[derive(new, Debug)]
111        struct IntoFloatOps<B: FusionBackend> {
112            desc: UnaryOpIr,
113            _b: PhantomData<B>,
114        }
115
116        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
117            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
118                let input = handles.get_bool_tensor::<B>(&self.desc.input);
119                let output = B::bool_into_float(input);
120                handles.register_float_tensor::<B>(&self.desc.out.id, output);
121            }
122        }
123
124        let mut streams = OperationStreams::default();
125        streams.tensor(&tensor);
126
127        let out = tensor
128            .client
129            .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
130
131        let desc = UnaryOpIr {
132            input: tensor.into_ir(),
133            out: out.to_ir_out(),
134        };
135        out.client.register(
136            streams,
137            OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
138            IntoFloatOps::<B>::new(desc),
139        );
140
141        out
142    }
143
144    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
145        tensor.client.device().clone()
146    }
147
148    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
149        let device_original: &B::Device = tensor.client.device();
150        let device_target: B::Device = device.clone();
151
152        if device_original == &device_target {
153            return tensor;
154        }
155
156        let id = tensor.stream;
157        let client_target = get_client::<B>(&device_target);
158        let client_original = tensor.client.clone();
159
160        client_original
161            .clone()
162            .change_client_bool::<B>(tensor.into_ir(), client_target, id)
163    }
164
165    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
166        #[derive(new, Debug)]
167        struct ReshapeDimsOps<B: FusionBackend> {
168            desc: UnaryOpIr,
169            _b: PhantomData<B>,
170        }
171
172        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
173            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
174                let input = handles.get_bool_tensor::<B>(&self.desc.input);
175                let output = B::bool_reshape(input, Shape::from(&self.desc.out.shape));
176                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
177            }
178        }
179
180        let mut streams = OperationStreams::default();
181        streams.tensor(&tensor);
182
183        let out = tensor
184            .client
185            .tensor_uninitialized(shape.dims, B::BoolElem::dtype());
186
187        let desc = UnaryOpIr {
188            input: tensor.into_ir(),
189            out: out.to_ir_out(),
190        };
191        out.client.register(
192            streams,
193            OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
194            ReshapeDimsOps::<B>::new(desc),
195        );
196
197        out
198    }
199
200    fn bool_slice(tensor: BoolTensor<Self>, ranges: &[std::ops::Range<usize>]) -> BoolTensor<Self> {
201        #[derive(new, Debug)]
202        struct SliceOps<B: FusionBackend> {
203            desc: SliceOpIr,
204            _b: PhantomData<B>,
205        }
206
207        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
208            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
209                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
210
211                let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
212
213                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
214            }
215        }
216
217        let ndims = burn_tensor::TensorMetadata::shape(&tensor).num_dims();
218        let mut shape: Vec<usize> = ranges.iter().map(|range| range.end - range.start).collect();
219
220        for i in shape.len()..ndims {
221            shape.push(tensor.shape[i]);
222        }
223
224        let mut streams = OperationStreams::default();
225        streams.tensor(&tensor);
226
227        let out = tensor
228            .client
229            .tensor_uninitialized(shape, B::BoolElem::dtype());
230
231        let desc = SliceOpIr {
232            tensor: tensor.into_ir(),
233            ranges: ranges.into(),
234            out: out.to_ir_out(),
235        };
236        out.client.register(
237            streams,
238            OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
239            SliceOps::<B>::new(desc),
240        );
241
242        out
243    }
244
245    fn bool_slice_assign(
246        tensor: BoolTensor<Self>,
247        ranges: &[std::ops::Range<usize>],
248        value: BoolTensor<Self>,
249    ) -> BoolTensor<Self> {
250        #[derive(new, Debug)]
251        struct SliceAssignOps<B: FusionBackend> {
252            desc: SliceAssignOpIr,
253            _b: PhantomData<B>,
254        }
255
256        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
257            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
258                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
259                let value = handles.get_bool_tensor::<B>(&self.desc.value);
260
261                let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
262
263                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
264            }
265        }
266
267        let shape: Vec<usize> = tensor.shape.clone();
268        let mut streams = OperationStreams::default();
269        streams.tensor(&tensor);
270        streams.tensor(&value);
271
272        let out = tensor
273            .client
274            .tensor_uninitialized(shape, B::BoolElem::dtype());
275
276        let desc = SliceAssignOpIr {
277            tensor: tensor.into_ir(),
278            ranges: ranges.into(),
279            value: value.into_ir(),
280            out: out.to_ir_out(),
281        };
282
283        out.client.register(
284            streams,
285            OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
286            SliceAssignOps::<B>::new(desc),
287        );
288
289        out
290    }
291
292    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
293        #[derive(new, Debug)]
294        struct CatOps<B: FusionBackend> {
295            desc: CatOpIr,
296            _b: PhantomData<B>,
297        }
298
299        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
300            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
301                let tensors = self
302                    .desc
303                    .tensors
304                    .iter()
305                    .map(|tensor| handles.get_bool_tensor::<B>(tensor))
306                    .collect();
307
308                let output = B::bool_cat(tensors, self.desc.dim);
309
310                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
311            }
312        }
313
314        let tensor_first = tensors.first().unwrap();
315        let client = tensor_first.client.clone();
316
317        // Calculate the output shape
318        let mut shape: Vec<usize> = tensor_first.shape.clone();
319        let mut streams = OperationStreams::default();
320        tensors.iter().for_each(|t| streams.tensor(t));
321
322        shape[dim] = 0;
323        for tensor in tensors.iter() {
324            shape[dim] += tensor.shape[dim];
325        }
326
327        let out = client.tensor_uninitialized(shape, B::BoolElem::dtype());
328
329        let desc = CatOpIr {
330            tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
331            dim,
332            out: out.to_ir_out(),
333        };
334        client.register(
335            streams,
336            OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
337            CatOps::<B>::new(desc),
338        );
339
340        out
341    }
342
343    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
344        #[derive(new, Debug)]
345        struct EqualOps<B: FusionBackend> {
346            desc: BinaryOpIr,
347            _b: PhantomData<B>,
348        }
349
350        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
351            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
352                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
353                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
354                let output = B::bool_equal(lhs, rhs);
355                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
356            }
357        }
358
359        let mut streams = OperationStreams::default();
360        streams.tensor(&lhs);
361        streams.tensor(&rhs);
362
363        let out = lhs.client.tensor_uninitialized(
364            binary_ops_shape(&lhs.shape, &rhs.shape),
365            B::BoolElem::dtype(),
366        );
367
368        let desc = BinaryOpIr {
369            lhs: lhs.into_ir(),
370            rhs: rhs.into_ir(),
371            out: out.to_ir_out(),
372        };
373        out.client.register(
374            streams,
375            OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
376            EqualOps::<B>::new(desc),
377        );
378
379        out
380    }
381
382    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
383        #[derive(new, Debug)]
384        struct NotOps<B: FusionBackend> {
385            desc: UnaryOpIr,
386            _b: PhantomData<B>,
387        }
388
389        impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
390            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
391                let input = handles.get_bool_tensor::<B>(&self.desc.input);
392                let output = B::bool_not(input);
393                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
394            }
395        }
396
397        let mut streams = OperationStreams::default();
398        streams.tensor(&tensor);
399
400        let out = tensor
401            .client
402            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
403
404        let desc = UnaryOpIr {
405            input: tensor.into_ir(),
406            out: out.to_ir_out(),
407        };
408
409        out.client.register(
410            streams,
411            OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
412            NotOps::<B>::new(desc),
413        );
414
415        out
416    }
417
418    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
419        #[derive(new, Debug)]
420        struct AndOps<B: FusionBackend> {
421            desc: BinaryOpIr,
422            _b: PhantomData<B>,
423        }
424
425        impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
426            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
427                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
428                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
429                let output = B::bool_and(lhs, rhs);
430                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
431            }
432        }
433
434        let mut streams = OperationStreams::default();
435        streams.tensor(&lhs);
436        streams.tensor(&rhs);
437
438        let out = lhs.client.tensor_uninitialized(
439            binary_ops_shape(&lhs.shape, &rhs.shape),
440            B::BoolElem::dtype(),
441        );
442
443        let desc = BinaryOpIr {
444            lhs: lhs.into_ir(),
445            rhs: rhs.into_ir(),
446            out: out.to_ir_out(),
447        };
448        out.client.register(
449            streams,
450            OperationIr::Bool(BoolOperationIr::And(desc.clone())),
451            AndOps::<B>::new(desc),
452        );
453
454        out
455    }
456
457    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
458        #[derive(new, Debug)]
459        struct OrOps<B: FusionBackend> {
460            desc: BinaryOpIr,
461            _b: PhantomData<B>,
462        }
463
464        impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
465            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
466                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
467                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
468                let output = B::bool_or(lhs, rhs);
469                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
470            }
471        }
472
473        let mut streams = OperationStreams::default();
474        streams.tensor(&lhs);
475        streams.tensor(&rhs);
476
477        let out = lhs.client.tensor_uninitialized(
478            binary_ops_shape(&lhs.shape, &rhs.shape),
479            B::BoolElem::dtype(),
480        );
481
482        let desc = BinaryOpIr {
483            lhs: lhs.into_ir(),
484            rhs: rhs.into_ir(),
485            out: out.to_ir_out(),
486        };
487        out.client.register(
488            streams,
489            OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
490            OrOps::<B>::new(desc),
491        );
492
493        out
494    }
495
496    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
497        #[derive(new, Debug)]
498        struct SwapDimsOps<B: FusionBackend> {
499            desc: SwapDimsOpIr,
500            _b: PhantomData<B>,
501        }
502
503        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
504            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
505                let input = handles.get_bool_tensor::<B>(&self.desc.input);
506                let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
507                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
508            }
509        }
510
511        let mut streams = OperationStreams::default();
512        streams.tensor(&tensor);
513
514        let mut shape = tensor.shape.clone();
515        shape[dim1] = tensor.shape[dim2];
516        shape[dim2] = tensor.shape[dim1];
517
518        let out = tensor
519            .client
520            .tensor_uninitialized(shape, B::BoolElem::dtype());
521
522        let desc = SwapDimsOpIr {
523            input: tensor.into_ir(),
524            dim1,
525            dim2,
526            out: out.to_ir_out(),
527        };
528        out.client.register(
529            streams,
530            OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
531            SwapDimsOps::<B>::new(desc),
532        );
533
534        out
535    }
536
537    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
538        #[derive(new, Debug)]
539        struct PermuteDimsOps<B: FusionBackend> {
540            desc: PermuteOpIr,
541            _b: PhantomData<B>,
542        }
543
544        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
545            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
546                let input = handles.get_bool_tensor::<B>(&self.desc.input);
547                let output = B::bool_permute(input, self.desc.axes.as_slice());
548                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
549            }
550        }
551
552        let mut streams = OperationStreams::default();
553        streams.tensor(&tensor);
554
555        // Change the shape of the tensor to match the new axes
556        let shape = axes.iter().map(|x| tensor.shape[*x]).collect();
557
558        let out = tensor
559            .client
560            .tensor_uninitialized(shape, B::BoolElem::dtype());
561
562        let desc = PermuteOpIr {
563            input: tensor.into_ir(),
564            axes: axes.to_vec(),
565            out: out.to_ir_out(),
566        };
567
568        out.client.register(
569            streams,
570            OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
571            PermuteDimsOps::<B>::new(desc),
572        );
573
574        out
575    }
576
577    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
578        #[derive(new, Debug)]
579        struct ExpandOps<B: FusionBackend> {
580            desc: ExpandOpIr,
581            _b: PhantomData<B>,
582        }
583
584        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
585            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
586                let input = handles.get_bool_tensor::<B>(&self.desc.input);
587                let output = B::bool_expand(input, self.desc.shape.clone().into());
588
589                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
590            }
591        }
592
593        let mut streams = OperationStreams::default();
594        streams.tensor(&tensor);
595
596        let out = tensor
597            .client
598            .tensor_uninitialized(shape.dims.clone(), B::BoolElem::dtype());
599
600        let desc = ExpandOpIr {
601            input: tensor.into_ir(),
602            shape: shape.dims,
603            out: out.to_ir_out(),
604        };
605
606        out.client.register(
607            streams,
608            OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
609            ExpandOps::<B>::new(desc),
610        );
611
612        out
613    }
614
615    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
616        #[derive(new, Debug)]
617        struct FlipOps<B: FusionBackend> {
618            desc: FlipOpIr,
619            _b: PhantomData<B>,
620        }
621
622        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
623            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
624                let input = handles.get_bool_tensor::<B>(&self.desc.input);
625                let output = B::bool_flip(input, self.desc.axes.as_slice());
626                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
627            }
628        }
629
630        let mut streams = OperationStreams::default();
631        streams.tensor(&tensor);
632
633        let out = tensor
634            .client
635            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
636
637        let desc = FlipOpIr {
638            input: tensor.into_ir(),
639            out: out.to_ir_out(),
640            axes: axes.to_vec(),
641        };
642
643        out.client.register(
644            streams,
645            OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
646            FlipOps::<B>::new(desc),
647        );
648
649        out
650    }
651
652    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
653        #[derive(new, Debug)]
654        struct RepeatDimOps<B: FusionBackend> {
655            desc: RepeatDimOpIr,
656            _b: PhantomData<B>,
657        }
658
659        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
660            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
661                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
662
663                let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
664
665                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
666            }
667        }
668
669        let mut streams = OperationStreams::default();
670        streams.tensor(&tensor);
671
672        let mut shape = tensor.shape.clone();
673        shape[dim] *= times;
674        let out = tensor
675            .client
676            .tensor_uninitialized(shape, B::BoolElem::dtype());
677
678        let desc = RepeatDimOpIr {
679            tensor: tensor.into_ir(),
680            dim,
681            times,
682            out: out.to_ir_out(),
683        };
684        out.client.register(
685            streams,
686            OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
687            RepeatDimOps::<B>::new(desc),
688        );
689
690        out
691    }
692}