burn_fusion/ops/
boolean.rs

1use crate::{
2    Fusion, FusionBackend, get_client,
3    stream::{OperationStreams, StreamId, execution::Operation},
4};
5use burn_ir::{
6    BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
7    InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
8    SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
9};
10use burn_tensor::ops::unfold::calculate_unfold_shape;
11use burn_tensor::{
12    Device, Element, Shape, Slice, TensorData, TensorMetadata,
13    ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor},
14};
15use std::marker::PhantomData;
16
17use super::NoOp;
18
19impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
20    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
21        #[derive(new, Debug)]
22        struct EmptyOps<B: FusionBackend> {
23            desc: TensorIr,
24            device: Device<B>,
25        }
26
27        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
28            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
29                let output = B::bool_empty(self.desc.shape.clone(), &self.device);
30                handles.register_bool_tensor::<B>(&self.desc.id, output);
31            }
32        }
33
34        let client = get_client::<B>(&device.clone());
35        let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
36
37        let desc = out.to_ir_out();
38
39        client.register(
40            OperationStreams::default(),
41            OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
42            EmptyOps::<B>::new(desc, device.clone()),
43        );
44
45        out
46    }
47
48    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
49        #[derive(new, Debug)]
50        struct ZerosOps<B: FusionBackend> {
51            desc: TensorIr,
52            device: Device<B>,
53        }
54
55        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
56            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57                let output = B::bool_zeros(self.desc.shape.clone(), &self.device);
58                handles.register_bool_tensor::<B>(&self.desc.id, output);
59            }
60        }
61
62        let client = get_client::<B>(&device.clone());
63        let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
64
65        let desc = out.to_ir_out();
66
67        client.register(
68            OperationStreams::default(),
69            OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
70            ZerosOps::<B>::new(desc, device.clone()),
71        );
72
73        out
74    }
75
76    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
77        #[derive(new, Debug)]
78        struct OnesOps<B: FusionBackend> {
79            desc: TensorIr,
80            device: Device<B>,
81        }
82
83        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
84            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
85                let output = B::bool_ones(self.desc.shape.clone(), &self.device);
86                handles.register_bool_tensor::<B>(&self.desc.id, output);
87            }
88        }
89
90        let client = get_client::<B>(&device.clone());
91        let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
92
93        let desc = out.to_ir_out();
94
95        client.register(
96            OperationStreams::default(),
97            OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
98            OnesOps::<B>::new(desc, device.clone()),
99        );
100
101        out
102    }
103
104    async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
105        tensor.bool_into_data::<B>().await
106    }
107
108    fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
109        let stream = StreamId::current();
110        let client = get_client::<B>(&device.clone());
111        let tensor = B::bool_from_data(data, device);
112        let shape = tensor.shape();
113
114        let handle = B::bool_tensor_handle(tensor);
115        let out = client.register_tensor(handle, shape, stream, B::BoolElem::dtype());
116        let desc = out.to_ir_out();
117
118        client.register(
119            OperationStreams::default(),
120            OperationIr::Init(InitOperationIr { out: desc }),
121            NoOp::<B>::new(),
122        );
123
124        out
125    }
126
127    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
128        #[derive(new, Debug)]
129        struct IntoIntOps<B: FusionBackend> {
130            desc: UnaryOpIr,
131            _b: PhantomData<B>,
132        }
133
134        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
135            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
136                let input = handles.get_bool_tensor::<B>(&self.desc.input);
137                let output = B::bool_into_int(input);
138                handles.register_int_tensor::<B>(&self.desc.out.id, output);
139            }
140        }
141
142        let mut streams = OperationStreams::default();
143        streams.tensor(&tensor);
144
145        let out = tensor
146            .client
147            .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
148
149        let desc = UnaryOpIr {
150            input: tensor.into_ir(),
151            out: out.to_ir_out(),
152        };
153
154        out.client.register(
155            streams,
156            OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
157            IntoIntOps::<B>::new(desc),
158        );
159
160        out
161    }
162
163    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
164        #[derive(new, Debug)]
165        struct IntoFloatOps<B: FusionBackend> {
166            desc: UnaryOpIr,
167            _b: PhantomData<B>,
168        }
169
170        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
171            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
172                let input = handles.get_bool_tensor::<B>(&self.desc.input);
173                let output = B::bool_into_float(input);
174                handles.register_float_tensor::<B>(&self.desc.out.id, output);
175            }
176        }
177
178        let mut streams = OperationStreams::default();
179        streams.tensor(&tensor);
180
181        let out = tensor
182            .client
183            .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
184
185        let desc = UnaryOpIr {
186            input: tensor.into_ir(),
187            out: out.to_ir_out(),
188        };
189        out.client.register(
190            streams,
191            OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
192            IntoFloatOps::<B>::new(desc),
193        );
194
195        out
196    }
197
198    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
199        tensor.client.device().clone()
200    }
201
202    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
203        let device_original: &B::Device = tensor.client.device();
204        let device_target: B::Device = device.clone();
205
206        if device_original == &device_target {
207            return tensor;
208        }
209
210        let id = tensor.stream;
211        let client_target = get_client::<B>(&device_target);
212        let client_original = tensor.client.clone();
213
214        client_original
215            .clone()
216            .change_client_bool::<B>(tensor.into_ir(), client_target, id)
217    }
218
219    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
220        if tensor.shape == shape {
221            return tensor;
222        }
223
224        #[derive(new, Debug)]
225        struct ReshapeDimsOps<B: FusionBackend> {
226            desc: UnaryOpIr,
227            _b: PhantomData<B>,
228        }
229
230        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
231            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
232                let input = handles.get_bool_tensor::<B>(&self.desc.input);
233                let output = B::bool_reshape(input, self.desc.out.shape.clone());
234                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
235            }
236        }
237
238        let mut streams = OperationStreams::default();
239        streams.tensor(&tensor);
240
241        let out = tensor
242            .client
243            .tensor_uninitialized(shape, B::BoolElem::dtype());
244
245        let desc = UnaryOpIr {
246            input: tensor.into_ir(),
247            out: out.to_ir_out(),
248        };
249        out.client.register(
250            streams,
251            OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
252            ReshapeDimsOps::<B>::new(desc),
253        );
254
255        out
256    }
257
258    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
259        #[derive(new, Debug)]
260        struct SliceOps<B: FusionBackend> {
261            desc: SliceOpIr,
262            _b: PhantomData<B>,
263        }
264
265        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
266            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
267                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
268
269                let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
270
271                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
272            }
273        }
274
275        let shape = tensor.shape.clone().slice(slices).unwrap();
276
277        let mut streams = OperationStreams::default();
278        streams.tensor(&tensor);
279
280        let out = tensor
281            .client
282            .tensor_uninitialized(shape, B::BoolElem::dtype());
283
284        let desc = SliceOpIr {
285            tensor: tensor.into_ir(),
286            ranges: slices.to_vec(),
287            out: out.to_ir_out(),
288        };
289        out.client.register(
290            streams,
291            OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
292            SliceOps::<B>::new(desc),
293        );
294
295        out
296    }
297
298    fn bool_slice_assign(
299        tensor: BoolTensor<Self>,
300        ranges: &[burn_tensor::Slice],
301        value: BoolTensor<Self>,
302    ) -> BoolTensor<Self> {
303        #[derive(new, Debug)]
304        struct SliceAssignOps<B: FusionBackend> {
305            desc: SliceAssignOpIr,
306            _b: PhantomData<B>,
307        }
308
309        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
310            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
311                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
312                let value = handles.get_bool_tensor::<B>(&self.desc.value);
313
314                let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
315
316                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
317            }
318        }
319
320        let shape = tensor.shape.clone();
321        let mut streams = OperationStreams::default();
322        streams.tensor(&tensor);
323        streams.tensor(&value);
324
325        let out = tensor
326            .client
327            .tensor_uninitialized(shape, B::BoolElem::dtype());
328
329        let desc = SliceAssignOpIr {
330            tensor: tensor.into_ir(),
331            ranges: ranges.to_vec(),
332            value: value.into_ir(),
333            out: out.to_ir_out(),
334        };
335
336        out.client.register(
337            streams,
338            OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
339            SliceAssignOps::<B>::new(desc),
340        );
341
342        out
343    }
344
345    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
346        #[derive(new, Debug)]
347        struct CatOps<B: FusionBackend> {
348            desc: CatOpIr,
349            _b: PhantomData<B>,
350        }
351
352        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
353            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
354                let tensors = self
355                    .desc
356                    .tensors
357                    .iter()
358                    .map(|tensor| handles.get_bool_tensor::<B>(tensor))
359                    .collect();
360
361                let output = B::bool_cat(tensors, self.desc.dim);
362
363                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
364            }
365        }
366
367        let tensor_first = tensors.first().unwrap();
368        let client = tensor_first.client.clone();
369
370        // Calculate the output shape
371        let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
372        let mut streams = OperationStreams::default();
373        tensors.iter().for_each(|t| streams.tensor(t));
374
375        let out = client.tensor_uninitialized(shape, B::BoolElem::dtype());
376
377        let desc = CatOpIr {
378            tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
379            dim,
380            out: out.to_ir_out(),
381        };
382        client.register(
383            streams,
384            OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
385            CatOps::<B>::new(desc),
386        );
387
388        out
389    }
390
391    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
392        #[derive(new, Debug)]
393        struct EqualOps<B: FusionBackend> {
394            desc: BinaryOpIr,
395            _b: PhantomData<B>,
396        }
397
398        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
399            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
400                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
401                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
402                let output = B::bool_equal(lhs, rhs);
403                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
404            }
405        }
406
407        let mut streams = OperationStreams::default();
408        streams.tensor(&lhs);
409        streams.tensor(&rhs);
410
411        let out = lhs.client.tensor_uninitialized(
412            lhs.shape.broadcast(&rhs.shape).unwrap(),
413            B::BoolElem::dtype(),
414        );
415
416        let desc = BinaryOpIr {
417            lhs: lhs.into_ir(),
418            rhs: rhs.into_ir(),
419            out: out.to_ir_out(),
420        };
421        out.client.register(
422            streams,
423            OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
424            EqualOps::<B>::new(desc),
425        );
426
427        out
428    }
429
430    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
431        #[derive(new, Debug)]
432        struct NotOps<B: FusionBackend> {
433            desc: UnaryOpIr,
434            _b: PhantomData<B>,
435        }
436
437        impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
438            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
439                let input = handles.get_bool_tensor::<B>(&self.desc.input);
440                let output = B::bool_not(input);
441                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
442            }
443        }
444
445        let mut streams = OperationStreams::default();
446        streams.tensor(&tensor);
447
448        let out = tensor
449            .client
450            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
451
452        let desc = UnaryOpIr {
453            input: tensor.into_ir(),
454            out: out.to_ir_out(),
455        };
456
457        out.client.register(
458            streams,
459            OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
460            NotOps::<B>::new(desc),
461        );
462
463        out
464    }
465
466    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
467        #[derive(new, Debug)]
468        struct AndOps<B: FusionBackend> {
469            desc: BinaryOpIr,
470            _b: PhantomData<B>,
471        }
472
473        impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
474            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
475                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
476                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
477                let output = B::bool_and(lhs, rhs);
478                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
479            }
480        }
481
482        let mut streams = OperationStreams::default();
483        streams.tensor(&lhs);
484        streams.tensor(&rhs);
485
486        let out = lhs.client.tensor_uninitialized(
487            lhs.shape.broadcast(&rhs.shape).unwrap(),
488            B::BoolElem::dtype(),
489        );
490
491        let desc = BinaryOpIr {
492            lhs: lhs.into_ir(),
493            rhs: rhs.into_ir(),
494            out: out.to_ir_out(),
495        };
496        out.client.register(
497            streams,
498            OperationIr::Bool(BoolOperationIr::And(desc.clone())),
499            AndOps::<B>::new(desc),
500        );
501
502        out
503    }
504
505    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
506        #[derive(new, Debug)]
507        struct OrOps<B: FusionBackend> {
508            desc: BinaryOpIr,
509            _b: PhantomData<B>,
510        }
511
512        impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
513            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
514                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
515                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
516                let output = B::bool_or(lhs, rhs);
517                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
518            }
519        }
520
521        let mut streams = OperationStreams::default();
522        streams.tensor(&lhs);
523        streams.tensor(&rhs);
524
525        let out = lhs.client.tensor_uninitialized(
526            lhs.shape.broadcast(&rhs.shape).unwrap(),
527            B::BoolElem::dtype(),
528        );
529
530        let desc = BinaryOpIr {
531            lhs: lhs.into_ir(),
532            rhs: rhs.into_ir(),
533            out: out.to_ir_out(),
534        };
535        out.client.register(
536            streams,
537            OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
538            OrOps::<B>::new(desc),
539        );
540
541        out
542    }
543
544    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
545        #[derive(new, Debug)]
546        struct SwapDimsOps<B: FusionBackend> {
547            desc: SwapDimsOpIr,
548            _b: PhantomData<B>,
549        }
550
551        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
552            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
553                let input = handles.get_bool_tensor::<B>(&self.desc.input);
554                let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
555                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
556            }
557        }
558
559        let mut streams = OperationStreams::default();
560        streams.tensor(&tensor);
561
562        let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
563        let out = tensor
564            .client
565            .tensor_uninitialized(shape, B::BoolElem::dtype());
566
567        let desc = SwapDimsOpIr {
568            input: tensor.into_ir(),
569            dim1,
570            dim2,
571            out: out.to_ir_out(),
572        };
573        out.client.register(
574            streams,
575            OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
576            SwapDimsOps::<B>::new(desc),
577        );
578
579        out
580    }
581
582    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
583        #[derive(new, Debug)]
584        struct PermuteDimsOps<B: FusionBackend> {
585            desc: PermuteOpIr,
586            _b: PhantomData<B>,
587        }
588
589        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
590            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
591                let input = handles.get_bool_tensor::<B>(&self.desc.input);
592                let output = B::bool_permute(input, self.desc.axes.as_slice());
593                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
594            }
595        }
596
597        let mut streams = OperationStreams::default();
598        streams.tensor(&tensor);
599
600        // Change the shape of the tensor to match the new axes
601        let shape = tensor.shape.clone().permute(axes).unwrap();
602        let out = tensor
603            .client
604            .tensor_uninitialized(shape, B::BoolElem::dtype());
605
606        let desc = PermuteOpIr {
607            input: tensor.into_ir(),
608            axes: axes.to_vec(),
609            out: out.to_ir_out(),
610        };
611
612        out.client.register(
613            streams,
614            OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
615            PermuteDimsOps::<B>::new(desc),
616        );
617
618        out
619    }
620
621    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
622        #[derive(new, Debug)]
623        struct ExpandOps<B: FusionBackend> {
624            desc: ExpandOpIr,
625            _b: PhantomData<B>,
626        }
627
628        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
629            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
630                let input = handles.get_bool_tensor::<B>(&self.desc.input);
631                let output = B::bool_expand(input, self.desc.shape.clone());
632
633                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
634            }
635        }
636
637        let mut streams = OperationStreams::default();
638        streams.tensor(&tensor);
639
640        let out = tensor
641            .client
642            .tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
643
644        let desc = ExpandOpIr {
645            input: tensor.into_ir(),
646            shape,
647            out: out.to_ir_out(),
648        };
649
650        out.client.register(
651            streams,
652            OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
653            ExpandOps::<B>::new(desc),
654        );
655
656        out
657    }
658
659    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
660        #[derive(new, Debug)]
661        struct FlipOps<B: FusionBackend> {
662            desc: FlipOpIr,
663            _b: PhantomData<B>,
664        }
665
666        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
667            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
668                let input = handles.get_bool_tensor::<B>(&self.desc.input);
669                let output = B::bool_flip(input, self.desc.axes.as_slice());
670                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
671            }
672        }
673
674        let mut streams = OperationStreams::default();
675        streams.tensor(&tensor);
676
677        let out = tensor
678            .client
679            .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
680
681        let desc = FlipOpIr {
682            input: tensor.into_ir(),
683            out: out.to_ir_out(),
684            axes: axes.to_vec(),
685        };
686
687        out.client.register(
688            streams,
689            OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
690            FlipOps::<B>::new(desc),
691        );
692
693        out
694    }
695
696    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
697        #[derive(new, Debug)]
698        struct RepeatDimOps<B: FusionBackend> {
699            desc: RepeatDimOpIr,
700            _b: PhantomData<B>,
701        }
702
703        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
704            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
705                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
706
707                let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
708
709                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
710            }
711        }
712
713        let mut streams = OperationStreams::default();
714        streams.tensor(&tensor);
715
716        let shape = tensor.shape.clone().repeat(dim, times);
717        let out = tensor
718            .client
719            .tensor_uninitialized(shape, B::BoolElem::dtype());
720
721        let desc = RepeatDimOpIr {
722            tensor: tensor.into_ir(),
723            dim,
724            times,
725            out: out.to_ir_out(),
726        };
727        out.client.register(
728            streams,
729            OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
730            RepeatDimOps::<B>::new(desc),
731        );
732
733        out
734    }
735
736    fn bool_unfold(
737        tensor: BoolTensor<Self>,
738        dim: usize,
739        size: usize,
740        step: usize,
741    ) -> BoolTensor<Self> {
742        #[derive(new, Debug)]
743        struct UnfoldOps<B: FusionBackend> {
744            desc: UnfoldOpIr,
745            _b: PhantomData<B>,
746        }
747
748        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
749            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
750                let input = handles.get_bool_tensor::<B>(&self.desc.input);
751                let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
752
753                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
754            }
755        }
756
757        let mut streams = OperationStreams::default();
758        streams.tensor(&tensor);
759
760        let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
761        let out = tensor
762            .client
763            .tensor_uninitialized(Shape::from(shape), tensor.dtype);
764
765        let desc = UnfoldOpIr {
766            input: tensor.into_ir(),
767            out: out.to_ir_out(),
768            dim,
769            size,
770            step,
771        };
772
773        out.client.register(
774            streams,
775            OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
776            UnfoldOps::<B>::new(desc),
777        );
778
779        out
780    }
781}