Skip to main content

burn_fusion/ops/
bool_tensor.rs

1use crate::{
2    Fusion, FusionBackend,
3    client::GlobalFusionClient,
4    get_client,
5    stream::{OperationStreams, execution::Operation},
6};
7use burn_backend::{
8    BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
9    ops::BoolTensorOps,
10    tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
11};
12use burn_ir::{
13    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
14    GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr,
15    OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr,
16    SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr,
17    UnfoldOpIr,
18};
19use std::marker::PhantomData;
20
21use super::NoOp;
22
23impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
24    fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
25        #[derive(new, Debug)]
26        struct EmptyOps<B: FusionBackend> {
27            desc: TensorIr,
28            device: Device<B>,
29        }
30
31        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
32            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
33                let output = B::bool_empty(
34                    self.desc.shape.clone(),
35                    &self.device,
36                    self.desc.dtype.into(),
37                );
38                handles.register_bool_tensor::<B>(&self.desc.id, output);
39            }
40        }
41
42        let client = get_client::<B>(device);
43        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
44
45        client
46            .register(
47                OperationStreams::default(),
48                OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
49                EmptyOps::<B>::new(desc.out, device.clone()),
50            )
51            .output()
52    }
53
54    fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
55        #[derive(new, Debug)]
56        struct ZerosOps<B: FusionBackend> {
57            desc: TensorIr,
58            device: Device<B>,
59        }
60
61        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
62            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
63                let output = B::bool_zeros(
64                    self.desc.shape.clone(),
65                    &self.device,
66                    self.desc.dtype.into(),
67                );
68                handles.register_bool_tensor::<B>(&self.desc.id, output);
69            }
70        }
71
72        let client = get_client::<B>(device);
73        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
74
75        client
76            .register(
77                OperationStreams::default(),
78                OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())),
79                ZerosOps::<B>::new(desc.out, device.clone()),
80            )
81            .output()
82    }
83
84    fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
85        #[derive(new, Debug)]
86        struct OnesOps<B: FusionBackend> {
87            desc: TensorIr,
88            device: Device<B>,
89        }
90
91        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
92            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
93                let output = B::bool_ones(
94                    self.desc.shape.clone(),
95                    &self.device,
96                    self.desc.dtype.into(),
97                );
98                handles.register_bool_tensor::<B>(&self.desc.id, output);
99            }
100        }
101
102        let client = get_client::<B>(device);
103        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
104
105        client
106            .register(
107                OperationStreams::default(),
108                OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())),
109                OnesOps::<B>::new(desc.out, device.clone()),
110            )
111            .output()
112    }
113
114    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
115        tensor.bool_into_data::<B>().await
116    }
117
118    fn bool_from_data(data: burn_backend::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
119        let client = get_client::<B>(device);
120        let dtype = data.dtype;
121        let tensor = B::bool_from_data(data, device);
122        let shape = burn_backend::TensorMetadata::shape(&tensor);
123
124        let handle = B::bool_tensor_handle(tensor);
125        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
126
127        client
128            .register(
129                OperationStreams::default(),
130                OperationIr::Init(desc),
131                NoOp::<B>::new(),
132            )
133            .output()
134    }
135
136    fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
137        #[derive(new, Debug)]
138        struct IntoIntOps<B: FusionBackend> {
139            desc: CastOpIr,
140            _b: PhantomData<B>,
141        }
142
143        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
144            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
145                let input = handles.get_bool_tensor::<B>(&self.desc.input);
146                let output = B::bool_into_int(input, self.desc.out.dtype.into());
147                handles.register_int_tensor::<B>(&self.desc.out.id, output);
148            }
149        }
150
151        let streams = OperationStreams::with_inputs([&tensor]);
152
153        let client = tensor.client.clone();
154        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
155            client.create_empty_handle()
156        });
157
158        client
159            .register(
160                streams,
161                OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
162                IntoIntOps::<B>::new(desc),
163            )
164            .output()
165    }
166
167    fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
168        #[derive(new, Debug)]
169        struct IntoFloatOps<B: FusionBackend> {
170            desc: CastOpIr,
171            _b: PhantomData<B>,
172        }
173
174        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
175            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
176                let input = handles.get_bool_tensor::<B>(&self.desc.input);
177                let output = B::bool_into_float(input, self.desc.out.dtype.into());
178                handles.register_float_tensor::<B>(&self.desc.out.id, output);
179            }
180        }
181
182        let streams = OperationStreams::with_inputs([&tensor]);
183
184        let client = tensor.client.clone();
185        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
186            client.create_empty_handle()
187        });
188
189        client
190            .register(
191                streams,
192                OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
193                IntoFloatOps::<B>::new(desc),
194            )
195            .output()
196    }
197
198    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
199        tensor.client.device().clone()
200    }
201
202    fn bool_to_device(tensor: BoolTensor<Self>, device_dst: &Device<Self>) -> BoolTensor<Self> {
203        let device_src: &B::Device = tensor.client.device();
204
205        if device_src == device_dst {
206            return tensor;
207        }
208
209        let id = tensor.stream;
210        let client_dst = get_client::<B>(device_dst);
211        let client_src = tensor.client.clone();
212
213        GlobalFusionClient::change_client_bool::<B>(tensor.into_ir(), client_src, client_dst, id)
214    }
215
216    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
217        if tensor.shape == shape {
218            return tensor;
219        }
220
221        #[derive(new, Debug)]
222        struct ReshapeDimsOps<B: FusionBackend> {
223            desc: ShapeOpIr,
224            _b: PhantomData<B>,
225        }
226
227        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
228            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
229                let input = handles.get_bool_tensor::<B>(&self.desc.input);
230                let output = B::bool_reshape(input, self.desc.out.shape.clone());
231                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
232            }
233        }
234
235        let streams = OperationStreams::with_inputs([&tensor]);
236
237        let client = tensor.client.clone();
238        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
239
240        client
241            .register(
242                streams,
243                OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
244                ReshapeDimsOps::<B>::new(desc),
245            )
246            .output()
247    }
248
249    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
250        #[derive(new, Debug)]
251        struct SliceOps<B: FusionBackend> {
252            desc: SliceOpIr,
253            _b: PhantomData<B>,
254        }
255
256        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
257            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
258                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
259
260                let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
261
262                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
263            }
264        }
265
266        let streams = OperationStreams::with_inputs([&tensor]);
267
268        let client = tensor.client.clone();
269        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
270            client.create_empty_handle()
271        });
272
273        client
274            .register(
275                streams,
276                OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
277                SliceOps::<B>::new(desc),
278            )
279            .output()
280    }
281
282    fn bool_slice_assign(
283        tensor: BoolTensor<Self>,
284        slices: &[Slice],
285        value: BoolTensor<Self>,
286    ) -> BoolTensor<Self> {
287        #[derive(new, Debug)]
288        struct SliceAssignOps<B: FusionBackend> {
289            desc: SliceAssignOpIr,
290            _b: PhantomData<B>,
291        }
292
293        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
294            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
295                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
296                let value = handles.get_bool_tensor::<B>(&self.desc.value);
297
298                let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
299
300                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
301            }
302        }
303
304        let streams = OperationStreams::with_inputs([&tensor, &value]);
305
306        let client = tensor.client.clone();
307        let desc =
308            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
309                client.create_empty_handle()
310            });
311
312        client
313            .register(
314                streams,
315                OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
316                SliceAssignOps::<B>::new(desc),
317            )
318            .output()
319    }
320
321    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
322        #[derive(new, Debug)]
323        struct CatOps<B: FusionBackend> {
324            desc: CatOpIr,
325            _b: PhantomData<B>,
326        }
327
328        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
329            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
330                let tensors = self
331                    .desc
332                    .tensors
333                    .iter()
334                    .map(|tensor| handles.get_bool_tensor::<B>(tensor))
335                    .collect();
336
337                let output = B::bool_cat(tensors, self.desc.dim);
338
339                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
340            }
341        }
342
343        let streams = OperationStreams::with_inputs(&tensors);
344
345        let client = tensors.first().unwrap().client.clone();
346        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
347        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
348
349        client
350            .register(
351                streams,
352                OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
353                CatOps::<B>::new(desc),
354            )
355            .output()
356    }
357
358    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
359        #[derive(new, Debug)]
360        struct EqualOps<B: FusionBackend> {
361            desc: BinaryOpIr,
362            _b: PhantomData<B>,
363        }
364
365        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
366            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
367                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
368                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
369                let output = B::bool_equal(lhs, rhs);
370                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
371            }
372        }
373
374        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
375
376        let client = lhs.client.clone();
377        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
378            client.create_empty_handle()
379        });
380
381        client
382            .register(
383                streams,
384                OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
385                EqualOps::<B>::new(desc),
386            )
387            .output()
388    }
389
390    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
391        #[derive(new, Debug)]
392        struct NotOps<B: FusionBackend> {
393            desc: UnaryOpIr,
394            _b: PhantomData<B>,
395        }
396
397        impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
398            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
399                let input = handles.get_bool_tensor::<B>(&self.desc.input);
400                let output = B::bool_not(input);
401                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
402            }
403        }
404
405        let streams = OperationStreams::with_inputs([&tensor]);
406
407        let client = tensor.client.clone();
408        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
409
410        client
411            .register(
412                streams,
413                OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
414                NotOps::<B>::new(desc),
415            )
416            .output()
417    }
418
419    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
420        #[derive(new, Debug)]
421        struct AndOps<B: FusionBackend> {
422            desc: BinaryOpIr,
423            _b: PhantomData<B>,
424        }
425
426        impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
427            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
428                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
429                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
430                let output = B::bool_and(lhs, rhs);
431                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
432            }
433        }
434
435        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
436
437        let client = lhs.client.clone();
438        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
439            client.create_empty_handle()
440        });
441
442        client
443            .register(
444                streams,
445                OperationIr::Bool(BoolOperationIr::And(desc.clone())),
446                AndOps::<B>::new(desc),
447            )
448            .output()
449    }
450
451    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
452        #[derive(new, Debug)]
453        struct OrOps<B: FusionBackend> {
454            desc: BinaryOpIr,
455            _b: PhantomData<B>,
456        }
457
458        impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
459            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
460                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
461                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
462                let output = B::bool_or(lhs, rhs);
463                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
464            }
465        }
466
467        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
468
469        let client = lhs.client.clone();
470        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
471            client.create_empty_handle()
472        });
473        client
474            .register(
475                streams,
476                OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
477                OrOps::<B>::new(desc),
478            )
479            .output()
480    }
481
482    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
483        #[derive(new, Debug)]
484        struct SwapDimsOps<B: FusionBackend> {
485            desc: SwapDimsOpIr,
486            _b: PhantomData<B>,
487        }
488
489        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
490            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
491                let input = handles.get_bool_tensor::<B>(&self.desc.input);
492                let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
493                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
494            }
495        }
496
497        let streams = OperationStreams::with_inputs([&tensor]);
498
499        let client = tensor.client.clone();
500        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
501            client.create_empty_handle()
502        });
503
504        client
505            .register(
506                streams,
507                OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
508                SwapDimsOps::<B>::new(desc),
509            )
510            .output()
511    }
512
513    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
514        #[derive(new, Debug)]
515        struct PermuteDimsOps<B: FusionBackend> {
516            desc: PermuteOpIr,
517            _b: PhantomData<B>,
518        }
519
520        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
521            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
522                let input = handles.get_bool_tensor::<B>(&self.desc.input);
523                let output = B::bool_permute(input, self.desc.axes.as_slice());
524                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
525            }
526        }
527
528        let streams = OperationStreams::with_inputs([&tensor]);
529
530        let client = tensor.client.clone();
531        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
532            client.create_empty_handle()
533        });
534
535        client
536            .register(
537                streams,
538                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
539                PermuteDimsOps::<B>::new(desc),
540            )
541            .output()
542    }
543
544    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
545        #[derive(new, Debug)]
546        struct ExpandOps<B: FusionBackend> {
547            desc: ShapeOpIr,
548            _b: PhantomData<B>,
549        }
550
551        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
552            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
553                let input = handles.get_bool_tensor::<B>(&self.desc.input);
554                let output = B::bool_expand(input, self.desc.out.shape.clone());
555
556                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
557            }
558        }
559
560        let streams = OperationStreams::with_inputs([&tensor]);
561
562        let client = tensor.client.clone();
563        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
564
565        client
566            .register(
567                streams,
568                OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
569                ExpandOps::<B>::new(desc),
570            )
571            .output()
572    }
573
574    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
575        #[derive(new, Debug)]
576        struct FlipOps<B: FusionBackend> {
577            desc: FlipOpIr,
578            _b: PhantomData<B>,
579        }
580
581        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
582            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
583                let input = handles.get_bool_tensor::<B>(&self.desc.input);
584                let output = B::bool_flip(input, self.desc.axes.as_slice());
585                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
586            }
587        }
588
589        let streams = OperationStreams::with_inputs([&tensor]);
590
591        let client = tensor.client.clone();
592        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
593            client.create_empty_handle()
594        });
595
596        client
597            .register(
598                streams,
599                OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
600                FlipOps::<B>::new(desc),
601            )
602            .output()
603    }
604
605    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
606        #[derive(new, Debug)]
607        struct RepeatDimOps<B: FusionBackend> {
608            desc: RepeatDimOpIr,
609            _b: PhantomData<B>,
610        }
611
612        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
613            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
614                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
615
616                let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
617
618                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
619            }
620        }
621
622        let streams = OperationStreams::with_inputs([&tensor]);
623
624        let client = tensor.client.clone();
625        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
626            client.create_empty_handle()
627        });
628
629        client
630            .register(
631                streams,
632                OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
633                RepeatDimOps::<B>::new(desc),
634            )
635            .output()
636    }
637
638    fn bool_unfold(
639        tensor: BoolTensor<Self>,
640        dim: usize,
641        size: usize,
642        step: usize,
643    ) -> BoolTensor<Self> {
644        #[derive(new, Debug)]
645        struct UnfoldOps<B: FusionBackend> {
646            desc: UnfoldOpIr,
647            _b: PhantomData<B>,
648        }
649
650        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
651            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
652                let input = handles.get_bool_tensor::<B>(&self.desc.input);
653                let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
654
655                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
656            }
657        }
658
659        let streams = OperationStreams::with_inputs([&tensor]);
660
661        let client = tensor.client.clone();
662        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
663            client.create_empty_handle()
664        });
665
666        client
667            .register(
668                streams,
669                OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
670                UnfoldOps::<B>::new(desc),
671            )
672            .output()
673    }
674
675    fn bool_mask_where(
676        tensor: BoolTensor<Self>,
677        mask: BoolTensor<Self>,
678        value: BoolTensor<Self>,
679    ) -> BoolTensor<Self> {
680        #[derive(new, Debug)]
681        struct MaskWhereOps<B: FusionBackend> {
682            desc: MaskWhereOpIr,
683            _b: PhantomData<B>,
684        }
685
686        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
687            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
688                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
689                let value = handles.get_bool_tensor::<B>(&self.desc.value);
690                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
691
692                let output = B::bool_mask_where(tensor, mask, value);
693
694                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
695            }
696        }
697
698        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
699
700        let client = tensor.client.clone();
701        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
702            client.create_empty_handle()
703        });
704
705        client
706            .register(
707                streams,
708                OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())),
709                MaskWhereOps::<B>::new(desc),
710            )
711            .output()
712    }
713
714    fn bool_mask_fill(
715        tensor: BoolTensor<Self>,
716        mask: BoolTensor<Self>,
717        value: Scalar,
718    ) -> BoolTensor<Self> {
719        #[derive(new, Debug)]
720        struct MaskFillOps<B: FusionBackend> {
721            desc: MaskFillOpIr,
722            _b: PhantomData<B>,
723        }
724
725        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
726            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
727                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
728                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
729
730                let output = B::bool_mask_fill(tensor, mask, self.desc.value.into());
731
732                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
733            }
734        }
735
736        let streams = OperationStreams::with_inputs([&tensor, &mask]);
737
738        let client = tensor.client.clone();
739        let value = value.into();
740        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
741            client.create_empty_handle()
742        });
743
744        client
745            .register(
746                streams,
747                OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())),
748                MaskFillOps::<B>::new(desc),
749            )
750            .output()
751    }
752
753    fn bool_gather(
754        dim: usize,
755        tensor: BoolTensor<Self>,
756        indices: IntTensor<Self>,
757    ) -> BoolTensor<Self> {
758        #[derive(new, Debug)]
759        struct GatherOps<B: FusionBackend> {
760            desc: GatherOpIr,
761            _b: PhantomData<B>,
762        }
763
764        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
765            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
766                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
767                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
768
769                let output = B::bool_gather(self.desc.dim, tensor, indices);
770                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
771            }
772        }
773
774        let streams = OperationStreams::with_inputs([&tensor, &indices]);
775
776        let client = tensor.client.clone();
777        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
778            client.create_empty_handle()
779        });
780
781        client
782            .register(
783                streams,
784                OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())),
785                GatherOps::<B>::new(desc),
786            )
787            .output()
788    }
789
790    fn bool_scatter_or(
791        dim: usize,
792        tensor: BoolTensor<Self>,
793        indices: IntTensor<Self>,
794        value: BoolTensor<Self>,
795    ) -> BoolTensor<Self> {
796        #[derive(new, Debug)]
797        struct ScatterOps<B: FusionBackend> {
798            desc: ScatterOpIr,
799            _b: PhantomData<B>,
800        }
801
802        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
803            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
804                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
805                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
806                let value = handles.get_bool_tensor::<B>(&self.desc.value);
807
808                let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value);
809
810                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
811            }
812        }
813
814        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
815
816        let client = tensor.client.clone();
817        let desc = ScatterOpIr::create(
818            tensor.into_ir(),
819            dim,
820            indices.into_ir(),
821            value.into_ir(),
822            IndexingUpdateOp::Add,
823            || client.create_empty_handle(),
824        );
825
826        client
827            .register(
828                streams,
829                OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())),
830                ScatterOps::<B>::new(desc),
831            )
832            .output()
833    }
834
835    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
836        #[derive(new, Debug)]
837        struct EqualElemOps<B: FusionBackend> {
838            desc: ScalarOpIr,
839            _b: PhantomData<B>,
840        }
841        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualElemOps<B> {
842            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
843                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
844                let output = B::bool_equal_elem(lhs, self.desc.rhs.into());
845                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
846            }
847        }
848
849        let streams = OperationStreams::with_inputs([&lhs]);
850
851        let dtype = lhs.dtype;
852        let client = lhs.client.clone();
853        let rhs = rhs.into();
854        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, dtype, || {
855            client.create_empty_handle()
856        });
857
858        client
859            .register(
860                streams,
861                OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())),
862                EqualElemOps::<B>::new(desc),
863            )
864            .output()
865    }
866
867    fn bool_select(
868        tensor: BoolTensor<Self>,
869        dim: usize,
870        indices: IntTensor<Self>,
871    ) -> BoolTensor<Self> {
872        #[derive(new, Debug)]
873        struct SelectOps<B: FusionBackend> {
874            desc: SelectOpIr,
875            _b: PhantomData<B>,
876        }
877
878        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
879            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
880                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
881                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
882
883                let output = B::bool_select(tensor, self.desc.dim, indices);
884
885                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
886            }
887        }
888
889        let streams = OperationStreams::with_inputs([&tensor, &indices]);
890
891        let client = tensor.client.clone();
892        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
893            client.create_empty_handle()
894        });
895
896        client
897            .register(
898                streams,
899                OperationIr::BaseBool(BaseOperationIr::Select(desc.clone())),
900                SelectOps::<B>::new(desc),
901            )
902            .output()
903    }
904
905    fn bool_select_or(
906        tensor: BoolTensor<Self>,
907        dim: usize,
908        indices: IntTensor<Self>,
909        value: BoolTensor<Self>,
910    ) -> BoolTensor<Self> {
911        #[derive(new, Debug)]
912        struct SelectAssignOps<B: FusionBackend> {
913            desc: SelectAssignOpIr,
914            _b: PhantomData<B>,
915        }
916
917        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
918            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
919                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
920                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
921                let value = handles.get_bool_tensor::<B>(&self.desc.value);
922
923                let output = B::bool_select_or(tensor, self.desc.dim, indices, value);
924
925                handles.register_bool_tensor::<B>(&self.desc.out.id, output);
926            }
927        }
928
929        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
930
931        let client = tensor.client.clone();
932        let desc = SelectAssignOpIr::create(
933            tensor.into_ir(),
934            dim,
935            indices.into_ir(),
936            value.into_ir(),
937            IndexingUpdateOp::Add,
938            || client.create_empty_handle(),
939        );
940
941        client
942            .register(
943                streams,
944                OperationIr::BaseBool(BaseOperationIr::SelectAssign(desc.clone())),
945                SelectAssignOps::<B>::new(desc),
946            )
947            .output()
948    }
949}