1use crate::{
2 Fusion, FusionBackend,
3 client::GlobalFusionClient,
4 get_client,
5 stream::{OperationStreams, execution::Operation},
6};
7use burn_backend::{
8 BoolDType, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice, TensorData,
9 ops::BoolTensorOps,
10 tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
11};
12use burn_ir::{
13 BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,
14 GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr,
15 OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr,
16 SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr,
17 UnfoldOpIr,
18};
19use std::marker::PhantomData;
20
21use super::NoOp;
22
23impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
24 fn bool_empty(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
25 #[derive(new, Debug)]
26 struct EmptyOps<B: FusionBackend> {
27 desc: TensorIr,
28 device: Device<B>,
29 }
30
31 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
32 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
33 let output = B::bool_empty(
34 self.desc.shape.clone(),
35 &self.device,
36 self.desc.dtype.into(),
37 );
38 handles.register_bool_tensor::<B>(&self.desc.id, output);
39 }
40 }
41
42 let client = get_client::<B>(device);
43 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
44
45 client
46 .register(
47 OperationStreams::default(),
48 OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
49 EmptyOps::<B>::new(desc.out, device.clone()),
50 )
51 .output()
52 }
53
54 fn bool_zeros(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
55 #[derive(new, Debug)]
56 struct ZerosOps<B: FusionBackend> {
57 desc: TensorIr,
58 device: Device<B>,
59 }
60
61 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
62 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
63 let output = B::bool_zeros(
64 self.desc.shape.clone(),
65 &self.device,
66 self.desc.dtype.into(),
67 );
68 handles.register_bool_tensor::<B>(&self.desc.id, output);
69 }
70 }
71
72 let client = get_client::<B>(device);
73 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
74
75 client
76 .register(
77 OperationStreams::default(),
78 OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())),
79 ZerosOps::<B>::new(desc.out, device.clone()),
80 )
81 .output()
82 }
83
84 fn bool_ones(shape: Shape, device: &Device<Self>, dtype: BoolDType) -> BoolTensor<Self> {
85 #[derive(new, Debug)]
86 struct OnesOps<B: FusionBackend> {
87 desc: TensorIr,
88 device: Device<B>,
89 }
90
91 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
92 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
93 let output = B::bool_ones(
94 self.desc.shape.clone(),
95 &self.device,
96 self.desc.dtype.into(),
97 );
98 handles.register_bool_tensor::<B>(&self.desc.id, output);
99 }
100 }
101
102 let client = get_client::<B>(device);
103 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
104
105 client
106 .register(
107 OperationStreams::default(),
108 OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())),
109 OnesOps::<B>::new(desc.out, device.clone()),
110 )
111 .output()
112 }
113
114 async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {
115 tensor.bool_into_data::<B>().await
116 }
117
118 fn bool_from_data(data: burn_backend::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
119 let client = get_client::<B>(device);
120 let dtype = data.dtype;
121 let tensor = B::bool_from_data(data, device);
122 let shape = burn_backend::TensorMetadata::shape(&tensor);
123
124 let handle = B::bool_tensor_handle(tensor);
125 let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
126
127 client
128 .register(
129 OperationStreams::default(),
130 OperationIr::Init(desc),
131 NoOp::<B>::new(),
132 )
133 .output()
134 }
135
136 fn bool_into_int(tensor: BoolTensor<Self>, out_dtype: IntDType) -> IntTensor<Self> {
137 #[derive(new, Debug)]
138 struct IntoIntOps<B: FusionBackend> {
139 desc: CastOpIr,
140 _b: PhantomData<B>,
141 }
142
143 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
144 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
145 let input = handles.get_bool_tensor::<B>(&self.desc.input);
146 let output = B::bool_into_int(input, self.desc.out.dtype.into());
147 handles.register_int_tensor::<B>(&self.desc.out.id, output);
148 }
149 }
150
151 let streams = OperationStreams::with_inputs([&tensor]);
152
153 let client = tensor.client.clone();
154 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
155 client.create_empty_handle()
156 });
157
158 client
159 .register(
160 streams,
161 OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
162 IntoIntOps::<B>::new(desc),
163 )
164 .output()
165 }
166
167 fn bool_into_float(tensor: BoolTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
168 #[derive(new, Debug)]
169 struct IntoFloatOps<B: FusionBackend> {
170 desc: CastOpIr,
171 _b: PhantomData<B>,
172 }
173
174 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
175 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
176 let input = handles.get_bool_tensor::<B>(&self.desc.input);
177 let output = B::bool_into_float(input, self.desc.out.dtype.into());
178 handles.register_float_tensor::<B>(&self.desc.out.id, output);
179 }
180 }
181
182 let streams = OperationStreams::with_inputs([&tensor]);
183
184 let client = tensor.client.clone();
185 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
186 client.create_empty_handle()
187 });
188
189 client
190 .register(
191 streams,
192 OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
193 IntoFloatOps::<B>::new(desc),
194 )
195 .output()
196 }
197
198 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
199 tensor.client.device().clone()
200 }
201
202 fn bool_to_device(tensor: BoolTensor<Self>, device_dst: &Device<Self>) -> BoolTensor<Self> {
203 let device_src: &B::Device = tensor.client.device();
204
205 if device_src == device_dst {
206 return tensor;
207 }
208
209 let id = tensor.stream;
210 let client_dst = get_client::<B>(device_dst);
211 let client_src = tensor.client.clone();
212
213 GlobalFusionClient::change_client_bool::<B>(tensor.into_ir(), client_src, client_dst, id)
214 }
215
216 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
217 if tensor.shape == shape {
218 return tensor;
219 }
220
221 #[derive(new, Debug)]
222 struct ReshapeDimsOps<B: FusionBackend> {
223 desc: ShapeOpIr,
224 _b: PhantomData<B>,
225 }
226
227 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
228 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
229 let input = handles.get_bool_tensor::<B>(&self.desc.input);
230 let output = B::bool_reshape(input, self.desc.out.shape.clone());
231 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
232 }
233 }
234
235 let streams = OperationStreams::with_inputs([&tensor]);
236
237 let client = tensor.client.clone();
238 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
239
240 client
241 .register(
242 streams,
243 OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
244 ReshapeDimsOps::<B>::new(desc),
245 )
246 .output()
247 }
248
249 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
250 #[derive(new, Debug)]
251 struct SliceOps<B: FusionBackend> {
252 desc: SliceOpIr,
253 _b: PhantomData<B>,
254 }
255
256 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
257 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
258 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
259
260 let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
261
262 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
263 }
264 }
265
266 let streams = OperationStreams::with_inputs([&tensor]);
267
268 let client = tensor.client.clone();
269 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
270 client.create_empty_handle()
271 });
272
273 client
274 .register(
275 streams,
276 OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
277 SliceOps::<B>::new(desc),
278 )
279 .output()
280 }
281
282 fn bool_slice_assign(
283 tensor: BoolTensor<Self>,
284 slices: &[Slice],
285 value: BoolTensor<Self>,
286 ) -> BoolTensor<Self> {
287 #[derive(new, Debug)]
288 struct SliceAssignOps<B: FusionBackend> {
289 desc: SliceAssignOpIr,
290 _b: PhantomData<B>,
291 }
292
293 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
294 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
295 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
296 let value = handles.get_bool_tensor::<B>(&self.desc.value);
297
298 let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
299
300 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
301 }
302 }
303
304 let streams = OperationStreams::with_inputs([&tensor, &value]);
305
306 let client = tensor.client.clone();
307 let desc =
308 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
309 client.create_empty_handle()
310 });
311
312 client
313 .register(
314 streams,
315 OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
316 SliceAssignOps::<B>::new(desc),
317 )
318 .output()
319 }
320
321 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
322 #[derive(new, Debug)]
323 struct CatOps<B: FusionBackend> {
324 desc: CatOpIr,
325 _b: PhantomData<B>,
326 }
327
328 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
329 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
330 let tensors = self
331 .desc
332 .tensors
333 .iter()
334 .map(|tensor| handles.get_bool_tensor::<B>(tensor))
335 .collect();
336
337 let output = B::bool_cat(tensors, self.desc.dim);
338
339 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
340 }
341 }
342
343 let streams = OperationStreams::with_inputs(&tensors);
344
345 let client = tensors.first().unwrap().client.clone();
346 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
347 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
348
349 client
350 .register(
351 streams,
352 OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
353 CatOps::<B>::new(desc),
354 )
355 .output()
356 }
357
358 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
359 #[derive(new, Debug)]
360 struct EqualOps<B: FusionBackend> {
361 desc: BinaryOpIr,
362 _b: PhantomData<B>,
363 }
364
365 impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
366 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
367 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
368 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
369 let output = B::bool_equal(lhs, rhs);
370 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
371 }
372 }
373
374 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
375
376 let client = lhs.client.clone();
377 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
378 client.create_empty_handle()
379 });
380
381 client
382 .register(
383 streams,
384 OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
385 EqualOps::<B>::new(desc),
386 )
387 .output()
388 }
389
390 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
391 #[derive(new, Debug)]
392 struct NotOps<B: FusionBackend> {
393 desc: UnaryOpIr,
394 _b: PhantomData<B>,
395 }
396
397 impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
398 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
399 let input = handles.get_bool_tensor::<B>(&self.desc.input);
400 let output = B::bool_not(input);
401 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
402 }
403 }
404
405 let streams = OperationStreams::with_inputs([&tensor]);
406
407 let client = tensor.client.clone();
408 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
409
410 client
411 .register(
412 streams,
413 OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
414 NotOps::<B>::new(desc),
415 )
416 .output()
417 }
418
419 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
420 #[derive(new, Debug)]
421 struct AndOps<B: FusionBackend> {
422 desc: BinaryOpIr,
423 _b: PhantomData<B>,
424 }
425
426 impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
427 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
428 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
429 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
430 let output = B::bool_and(lhs, rhs);
431 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
432 }
433 }
434
435 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
436
437 let client = lhs.client.clone();
438 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
439 client.create_empty_handle()
440 });
441
442 client
443 .register(
444 streams,
445 OperationIr::Bool(BoolOperationIr::And(desc.clone())),
446 AndOps::<B>::new(desc),
447 )
448 .output()
449 }
450
451 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
452 #[derive(new, Debug)]
453 struct OrOps<B: FusionBackend> {
454 desc: BinaryOpIr,
455 _b: PhantomData<B>,
456 }
457
458 impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
459 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
460 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
461 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
462 let output = B::bool_or(lhs, rhs);
463 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
464 }
465 }
466
467 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
468
469 let client = lhs.client.clone();
470 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
471 client.create_empty_handle()
472 });
473 client
474 .register(
475 streams,
476 OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
477 OrOps::<B>::new(desc),
478 )
479 .output()
480 }
481
482 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
483 #[derive(new, Debug)]
484 struct SwapDimsOps<B: FusionBackend> {
485 desc: SwapDimsOpIr,
486 _b: PhantomData<B>,
487 }
488
489 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
490 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
491 let input = handles.get_bool_tensor::<B>(&self.desc.input);
492 let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
493 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
494 }
495 }
496
497 let streams = OperationStreams::with_inputs([&tensor]);
498
499 let client = tensor.client.clone();
500 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
501 client.create_empty_handle()
502 });
503
504 client
505 .register(
506 streams,
507 OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
508 SwapDimsOps::<B>::new(desc),
509 )
510 .output()
511 }
512
513 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
514 #[derive(new, Debug)]
515 struct PermuteDimsOps<B: FusionBackend> {
516 desc: PermuteOpIr,
517 _b: PhantomData<B>,
518 }
519
520 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
521 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
522 let input = handles.get_bool_tensor::<B>(&self.desc.input);
523 let output = B::bool_permute(input, self.desc.axes.as_slice());
524 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
525 }
526 }
527
528 let streams = OperationStreams::with_inputs([&tensor]);
529
530 let client = tensor.client.clone();
531 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
532 client.create_empty_handle()
533 });
534
535 client
536 .register(
537 streams,
538 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
539 PermuteDimsOps::<B>::new(desc),
540 )
541 .output()
542 }
543
544 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
545 #[derive(new, Debug)]
546 struct ExpandOps<B: FusionBackend> {
547 desc: ShapeOpIr,
548 _b: PhantomData<B>,
549 }
550
551 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
552 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
553 let input = handles.get_bool_tensor::<B>(&self.desc.input);
554 let output = B::bool_expand(input, self.desc.out.shape.clone());
555
556 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
557 }
558 }
559
560 let streams = OperationStreams::with_inputs([&tensor]);
561
562 let client = tensor.client.clone();
563 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
564
565 client
566 .register(
567 streams,
568 OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
569 ExpandOps::<B>::new(desc),
570 )
571 .output()
572 }
573
574 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
575 #[derive(new, Debug)]
576 struct FlipOps<B: FusionBackend> {
577 desc: FlipOpIr,
578 _b: PhantomData<B>,
579 }
580
581 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
582 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
583 let input = handles.get_bool_tensor::<B>(&self.desc.input);
584 let output = B::bool_flip(input, self.desc.axes.as_slice());
585 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
586 }
587 }
588
589 let streams = OperationStreams::with_inputs([&tensor]);
590
591 let client = tensor.client.clone();
592 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
593 client.create_empty_handle()
594 });
595
596 client
597 .register(
598 streams,
599 OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
600 FlipOps::<B>::new(desc),
601 )
602 .output()
603 }
604
605 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
606 #[derive(new, Debug)]
607 struct RepeatDimOps<B: FusionBackend> {
608 desc: RepeatDimOpIr,
609 _b: PhantomData<B>,
610 }
611
612 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
613 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
614 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
615
616 let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
617
618 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
619 }
620 }
621
622 let streams = OperationStreams::with_inputs([&tensor]);
623
624 let client = tensor.client.clone();
625 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
626 client.create_empty_handle()
627 });
628
629 client
630 .register(
631 streams,
632 OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
633 RepeatDimOps::<B>::new(desc),
634 )
635 .output()
636 }
637
638 fn bool_unfold(
639 tensor: BoolTensor<Self>,
640 dim: usize,
641 size: usize,
642 step: usize,
643 ) -> BoolTensor<Self> {
644 #[derive(new, Debug)]
645 struct UnfoldOps<B: FusionBackend> {
646 desc: UnfoldOpIr,
647 _b: PhantomData<B>,
648 }
649
650 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
651 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
652 let input = handles.get_bool_tensor::<B>(&self.desc.input);
653 let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
654
655 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
656 }
657 }
658
659 let streams = OperationStreams::with_inputs([&tensor]);
660
661 let client = tensor.client.clone();
662 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
663 client.create_empty_handle()
664 });
665
666 client
667 .register(
668 streams,
669 OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
670 UnfoldOps::<B>::new(desc),
671 )
672 .output()
673 }
674
675 fn bool_mask_where(
676 tensor: BoolTensor<Self>,
677 mask: BoolTensor<Self>,
678 value: BoolTensor<Self>,
679 ) -> BoolTensor<Self> {
680 #[derive(new, Debug)]
681 struct MaskWhereOps<B: FusionBackend> {
682 desc: MaskWhereOpIr,
683 _b: PhantomData<B>,
684 }
685
686 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
687 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
688 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
689 let value = handles.get_bool_tensor::<B>(&self.desc.value);
690 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
691
692 let output = B::bool_mask_where(tensor, mask, value);
693
694 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
695 }
696 }
697
698 let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
699
700 let client = tensor.client.clone();
701 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
702 client.create_empty_handle()
703 });
704
705 client
706 .register(
707 streams,
708 OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())),
709 MaskWhereOps::<B>::new(desc),
710 )
711 .output()
712 }
713
714 fn bool_mask_fill(
715 tensor: BoolTensor<Self>,
716 mask: BoolTensor<Self>,
717 value: Scalar,
718 ) -> BoolTensor<Self> {
719 #[derive(new, Debug)]
720 struct MaskFillOps<B: FusionBackend> {
721 desc: MaskFillOpIr,
722 _b: PhantomData<B>,
723 }
724
725 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
726 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
727 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
728 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
729
730 let output = B::bool_mask_fill(tensor, mask, self.desc.value.into());
731
732 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
733 }
734 }
735
736 let streams = OperationStreams::with_inputs([&tensor, &mask]);
737
738 let client = tensor.client.clone();
739 let value = value.into();
740 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
741 client.create_empty_handle()
742 });
743
744 client
745 .register(
746 streams,
747 OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())),
748 MaskFillOps::<B>::new(desc),
749 )
750 .output()
751 }
752
753 fn bool_gather(
754 dim: usize,
755 tensor: BoolTensor<Self>,
756 indices: IntTensor<Self>,
757 ) -> BoolTensor<Self> {
758 #[derive(new, Debug)]
759 struct GatherOps<B: FusionBackend> {
760 desc: GatherOpIr,
761 _b: PhantomData<B>,
762 }
763
764 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
765 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
766 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
767 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
768
769 let output = B::bool_gather(self.desc.dim, tensor, indices);
770 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
771 }
772 }
773
774 let streams = OperationStreams::with_inputs([&tensor, &indices]);
775
776 let client = tensor.client.clone();
777 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
778 client.create_empty_handle()
779 });
780
781 client
782 .register(
783 streams,
784 OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())),
785 GatherOps::<B>::new(desc),
786 )
787 .output()
788 }
789
790 fn bool_scatter_or(
791 dim: usize,
792 tensor: BoolTensor<Self>,
793 indices: IntTensor<Self>,
794 value: BoolTensor<Self>,
795 ) -> BoolTensor<Self> {
796 #[derive(new, Debug)]
797 struct ScatterOps<B: FusionBackend> {
798 desc: ScatterOpIr,
799 _b: PhantomData<B>,
800 }
801
802 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
803 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
804 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
805 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
806 let value = handles.get_bool_tensor::<B>(&self.desc.value);
807
808 let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value);
809
810 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
811 }
812 }
813
814 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
815
816 let client = tensor.client.clone();
817 let desc = ScatterOpIr::create(
818 tensor.into_ir(),
819 dim,
820 indices.into_ir(),
821 value.into_ir(),
822 IndexingUpdateOp::Add,
823 || client.create_empty_handle(),
824 );
825
826 client
827 .register(
828 streams,
829 OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())),
830 ScatterOps::<B>::new(desc),
831 )
832 .output()
833 }
834
835 fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {
836 #[derive(new, Debug)]
837 struct EqualElemOps<B: FusionBackend> {
838 desc: ScalarOpIr,
839 _b: PhantomData<B>,
840 }
841 impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualElemOps<B> {
842 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
843 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
844 let output = B::bool_equal_elem(lhs, self.desc.rhs.into());
845 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
846 }
847 }
848
849 let streams = OperationStreams::with_inputs([&lhs]);
850
851 let dtype = lhs.dtype;
852 let client = lhs.client.clone();
853 let rhs = rhs.into();
854 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, dtype, || {
855 client.create_empty_handle()
856 });
857
858 client
859 .register(
860 streams,
861 OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())),
862 EqualElemOps::<B>::new(desc),
863 )
864 .output()
865 }
866
867 fn bool_select(
868 tensor: BoolTensor<Self>,
869 dim: usize,
870 indices: IntTensor<Self>,
871 ) -> BoolTensor<Self> {
872 #[derive(new, Debug)]
873 struct SelectOps<B: FusionBackend> {
874 desc: SelectOpIr,
875 _b: PhantomData<B>,
876 }
877
878 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
879 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
880 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
881 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
882
883 let output = B::bool_select(tensor, self.desc.dim, indices);
884
885 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
886 }
887 }
888
889 let streams = OperationStreams::with_inputs([&tensor, &indices]);
890
891 let client = tensor.client.clone();
892 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
893 client.create_empty_handle()
894 });
895
896 client
897 .register(
898 streams,
899 OperationIr::BaseBool(BaseOperationIr::Select(desc.clone())),
900 SelectOps::<B>::new(desc),
901 )
902 .output()
903 }
904
905 fn bool_select_or(
906 tensor: BoolTensor<Self>,
907 dim: usize,
908 indices: IntTensor<Self>,
909 value: BoolTensor<Self>,
910 ) -> BoolTensor<Self> {
911 #[derive(new, Debug)]
912 struct SelectAssignOps<B: FusionBackend> {
913 desc: SelectAssignOpIr,
914 _b: PhantomData<B>,
915 }
916
917 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
918 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
919 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
920 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
921 let value = handles.get_bool_tensor::<B>(&self.desc.value);
922
923 let output = B::bool_select_or(tensor, self.desc.dim, indices, value);
924
925 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
926 }
927 }
928
929 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
930
931 let client = tensor.client.clone();
932 let desc = SelectAssignOpIr::create(
933 tensor.into_ir(),
934 dim,
935 indices.into_ir(),
936 value.into_ir(),
937 IndexingUpdateOp::Add,
938 || client.create_empty_handle(),
939 );
940
941 client
942 .register(
943 streams,
944 OperationIr::BaseBool(BaseOperationIr::SelectAssign(desc.clone())),
945 SelectAssignOps::<B>::new(desc),
946 )
947 .output()
948 }
949}