1use crate::{
2 Fusion, FusionBackend, get_client,
3 stream::{OperationStreams, StreamId, execution::Operation},
4};
5use burn_ir::{
6 BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
7 InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
8 SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
9};
10use burn_tensor::ops::unfold::calculate_unfold_shape;
11use burn_tensor::{
12 Device, Element, Shape, Slice, TensorData, TensorMetadata,
13 ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor},
14};
15use std::marker::PhantomData;
16
17use super::NoOp;
18
19impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
20 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
21 #[derive(new, Debug)]
22 struct EmptyOps<B: FusionBackend> {
23 desc: TensorIr,
24 device: Device<B>,
25 }
26
27 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
28 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
29 let output = B::bool_empty(self.desc.shape.clone(), &self.device);
30 handles.register_bool_tensor::<B>(&self.desc.id, output);
31 }
32 }
33
34 let client = get_client::<B>(&device.clone());
35 let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
36
37 let desc = out.to_ir_out();
38
39 client.register(
40 OperationStreams::default(),
41 OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
42 EmptyOps::<B>::new(desc, device.clone()),
43 );
44
45 out
46 }
47
48 fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
49 #[derive(new, Debug)]
50 struct ZerosOps<B: FusionBackend> {
51 desc: TensorIr,
52 device: Device<B>,
53 }
54
55 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
56 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
57 let output = B::bool_zeros(self.desc.shape.clone(), &self.device);
58 handles.register_bool_tensor::<B>(&self.desc.id, output);
59 }
60 }
61
62 let client = get_client::<B>(&device.clone());
63 let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
64
65 let desc = out.to_ir_out();
66
67 client.register(
68 OperationStreams::default(),
69 OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
70 ZerosOps::<B>::new(desc, device.clone()),
71 );
72
73 out
74 }
75
76 fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
77 #[derive(new, Debug)]
78 struct OnesOps<B: FusionBackend> {
79 desc: TensorIr,
80 device: Device<B>,
81 }
82
83 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
84 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
85 let output = B::bool_ones(self.desc.shape.clone(), &self.device);
86 handles.register_bool_tensor::<B>(&self.desc.id, output);
87 }
88 }
89
90 let client = get_client::<B>(&device.clone());
91 let out = client.tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
92
93 let desc = out.to_ir_out();
94
95 client.register(
96 OperationStreams::default(),
97 OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
98 OnesOps::<B>::new(desc, device.clone()),
99 );
100
101 out
102 }
103
104 async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
105 tensor.bool_into_data::<B>().await
106 }
107
108 fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
109 let stream = StreamId::current();
110 let client = get_client::<B>(&device.clone());
111 let tensor = B::bool_from_data(data, device);
112 let shape = tensor.shape();
113
114 let handle = B::bool_tensor_handle(tensor);
115 let out = client.register_tensor(handle, shape, stream, B::BoolElem::dtype());
116 let desc = out.to_ir_out();
117
118 client.register(
119 OperationStreams::default(),
120 OperationIr::Init(InitOperationIr { out: desc }),
121 NoOp::<B>::new(),
122 );
123
124 out
125 }
126
127 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
128 #[derive(new, Debug)]
129 struct IntoIntOps<B: FusionBackend> {
130 desc: UnaryOpIr,
131 _b: PhantomData<B>,
132 }
133
134 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
135 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
136 let input = handles.get_bool_tensor::<B>(&self.desc.input);
137 let output = B::bool_into_int(input);
138 handles.register_int_tensor::<B>(&self.desc.out.id, output);
139 }
140 }
141
142 let mut streams = OperationStreams::default();
143 streams.tensor(&tensor);
144
145 let out = tensor
146 .client
147 .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
148
149 let desc = UnaryOpIr {
150 input: tensor.into_ir(),
151 out: out.to_ir_out(),
152 };
153
154 out.client.register(
155 streams,
156 OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
157 IntoIntOps::<B>::new(desc),
158 );
159
160 out
161 }
162
163 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
164 #[derive(new, Debug)]
165 struct IntoFloatOps<B: FusionBackend> {
166 desc: UnaryOpIr,
167 _b: PhantomData<B>,
168 }
169
170 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
171 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
172 let input = handles.get_bool_tensor::<B>(&self.desc.input);
173 let output = B::bool_into_float(input);
174 handles.register_float_tensor::<B>(&self.desc.out.id, output);
175 }
176 }
177
178 let mut streams = OperationStreams::default();
179 streams.tensor(&tensor);
180
181 let out = tensor
182 .client
183 .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
184
185 let desc = UnaryOpIr {
186 input: tensor.into_ir(),
187 out: out.to_ir_out(),
188 };
189 out.client.register(
190 streams,
191 OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
192 IntoFloatOps::<B>::new(desc),
193 );
194
195 out
196 }
197
198 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
199 tensor.client.device().clone()
200 }
201
202 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
203 let device_original: &B::Device = tensor.client.device();
204 let device_target: B::Device = device.clone();
205
206 if device_original == &device_target {
207 return tensor;
208 }
209
210 let id = tensor.stream;
211 let client_target = get_client::<B>(&device_target);
212 let client_original = tensor.client.clone();
213
214 client_original
215 .clone()
216 .change_client_bool::<B>(tensor.into_ir(), client_target, id)
217 }
218
219 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
220 if tensor.shape == shape {
221 return tensor;
222 }
223
224 #[derive(new, Debug)]
225 struct ReshapeDimsOps<B: FusionBackend> {
226 desc: UnaryOpIr,
227 _b: PhantomData<B>,
228 }
229
230 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
231 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
232 let input = handles.get_bool_tensor::<B>(&self.desc.input);
233 let output = B::bool_reshape(input, self.desc.out.shape.clone());
234 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
235 }
236 }
237
238 let mut streams = OperationStreams::default();
239 streams.tensor(&tensor);
240
241 let out = tensor
242 .client
243 .tensor_uninitialized(shape, B::BoolElem::dtype());
244
245 let desc = UnaryOpIr {
246 input: tensor.into_ir(),
247 out: out.to_ir_out(),
248 };
249 out.client.register(
250 streams,
251 OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
252 ReshapeDimsOps::<B>::new(desc),
253 );
254
255 out
256 }
257
258 fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {
259 #[derive(new, Debug)]
260 struct SliceOps<B: FusionBackend> {
261 desc: SliceOpIr,
262 _b: PhantomData<B>,
263 }
264
265 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
266 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
267 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
268
269 let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
270
271 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
272 }
273 }
274
275 let shape = tensor.shape.clone().slice(slices).unwrap();
276
277 let mut streams = OperationStreams::default();
278 streams.tensor(&tensor);
279
280 let out = tensor
281 .client
282 .tensor_uninitialized(shape, B::BoolElem::dtype());
283
284 let desc = SliceOpIr {
285 tensor: tensor.into_ir(),
286 ranges: slices.to_vec(),
287 out: out.to_ir_out(),
288 };
289 out.client.register(
290 streams,
291 OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
292 SliceOps::<B>::new(desc),
293 );
294
295 out
296 }
297
298 fn bool_slice_assign(
299 tensor: BoolTensor<Self>,
300 ranges: &[burn_tensor::Slice],
301 value: BoolTensor<Self>,
302 ) -> BoolTensor<Self> {
303 #[derive(new, Debug)]
304 struct SliceAssignOps<B: FusionBackend> {
305 desc: SliceAssignOpIr,
306 _b: PhantomData<B>,
307 }
308
309 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
310 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
311 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
312 let value = handles.get_bool_tensor::<B>(&self.desc.value);
313
314 let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
315
316 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
317 }
318 }
319
320 let shape = tensor.shape.clone();
321 let mut streams = OperationStreams::default();
322 streams.tensor(&tensor);
323 streams.tensor(&value);
324
325 let out = tensor
326 .client
327 .tensor_uninitialized(shape, B::BoolElem::dtype());
328
329 let desc = SliceAssignOpIr {
330 tensor: tensor.into_ir(),
331 ranges: ranges.to_vec(),
332 value: value.into_ir(),
333 out: out.to_ir_out(),
334 };
335
336 out.client.register(
337 streams,
338 OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
339 SliceAssignOps::<B>::new(desc),
340 );
341
342 out
343 }
344
345 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
346 #[derive(new, Debug)]
347 struct CatOps<B: FusionBackend> {
348 desc: CatOpIr,
349 _b: PhantomData<B>,
350 }
351
352 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
353 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
354 let tensors = self
355 .desc
356 .tensors
357 .iter()
358 .map(|tensor| handles.get_bool_tensor::<B>(tensor))
359 .collect();
360
361 let output = B::bool_cat(tensors, self.desc.dim);
362
363 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
364 }
365 }
366
367 let tensor_first = tensors.first().unwrap();
368 let client = tensor_first.client.clone();
369
370 let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
372 let mut streams = OperationStreams::default();
373 tensors.iter().for_each(|t| streams.tensor(t));
374
375 let out = client.tensor_uninitialized(shape, B::BoolElem::dtype());
376
377 let desc = CatOpIr {
378 tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
379 dim,
380 out: out.to_ir_out(),
381 };
382 client.register(
383 streams,
384 OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
385 CatOps::<B>::new(desc),
386 );
387
388 out
389 }
390
391 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
392 #[derive(new, Debug)]
393 struct EqualOps<B: FusionBackend> {
394 desc: BinaryOpIr,
395 _b: PhantomData<B>,
396 }
397
398 impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
399 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
400 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
401 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
402 let output = B::bool_equal(lhs, rhs);
403 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
404 }
405 }
406
407 let mut streams = OperationStreams::default();
408 streams.tensor(&lhs);
409 streams.tensor(&rhs);
410
411 let out = lhs.client.tensor_uninitialized(
412 lhs.shape.broadcast(&rhs.shape).unwrap(),
413 B::BoolElem::dtype(),
414 );
415
416 let desc = BinaryOpIr {
417 lhs: lhs.into_ir(),
418 rhs: rhs.into_ir(),
419 out: out.to_ir_out(),
420 };
421 out.client.register(
422 streams,
423 OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
424 EqualOps::<B>::new(desc),
425 );
426
427 out
428 }
429
430 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
431 #[derive(new, Debug)]
432 struct NotOps<B: FusionBackend> {
433 desc: UnaryOpIr,
434 _b: PhantomData<B>,
435 }
436
437 impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
438 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
439 let input = handles.get_bool_tensor::<B>(&self.desc.input);
440 let output = B::bool_not(input);
441 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
442 }
443 }
444
445 let mut streams = OperationStreams::default();
446 streams.tensor(&tensor);
447
448 let out = tensor
449 .client
450 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
451
452 let desc = UnaryOpIr {
453 input: tensor.into_ir(),
454 out: out.to_ir_out(),
455 };
456
457 out.client.register(
458 streams,
459 OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
460 NotOps::<B>::new(desc),
461 );
462
463 out
464 }
465
466 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
467 #[derive(new, Debug)]
468 struct AndOps<B: FusionBackend> {
469 desc: BinaryOpIr,
470 _b: PhantomData<B>,
471 }
472
473 impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
474 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
475 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
476 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
477 let output = B::bool_and(lhs, rhs);
478 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
479 }
480 }
481
482 let mut streams = OperationStreams::default();
483 streams.tensor(&lhs);
484 streams.tensor(&rhs);
485
486 let out = lhs.client.tensor_uninitialized(
487 lhs.shape.broadcast(&rhs.shape).unwrap(),
488 B::BoolElem::dtype(),
489 );
490
491 let desc = BinaryOpIr {
492 lhs: lhs.into_ir(),
493 rhs: rhs.into_ir(),
494 out: out.to_ir_out(),
495 };
496 out.client.register(
497 streams,
498 OperationIr::Bool(BoolOperationIr::And(desc.clone())),
499 AndOps::<B>::new(desc),
500 );
501
502 out
503 }
504
505 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
506 #[derive(new, Debug)]
507 struct OrOps<B: FusionBackend> {
508 desc: BinaryOpIr,
509 _b: PhantomData<B>,
510 }
511
512 impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
513 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
514 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
515 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
516 let output = B::bool_or(lhs, rhs);
517 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
518 }
519 }
520
521 let mut streams = OperationStreams::default();
522 streams.tensor(&lhs);
523 streams.tensor(&rhs);
524
525 let out = lhs.client.tensor_uninitialized(
526 lhs.shape.broadcast(&rhs.shape).unwrap(),
527 B::BoolElem::dtype(),
528 );
529
530 let desc = BinaryOpIr {
531 lhs: lhs.into_ir(),
532 rhs: rhs.into_ir(),
533 out: out.to_ir_out(),
534 };
535 out.client.register(
536 streams,
537 OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
538 OrOps::<B>::new(desc),
539 );
540
541 out
542 }
543
544 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
545 #[derive(new, Debug)]
546 struct SwapDimsOps<B: FusionBackend> {
547 desc: SwapDimsOpIr,
548 _b: PhantomData<B>,
549 }
550
551 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
552 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
553 let input = handles.get_bool_tensor::<B>(&self.desc.input);
554 let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
555 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
556 }
557 }
558
559 let mut streams = OperationStreams::default();
560 streams.tensor(&tensor);
561
562 let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
563 let out = tensor
564 .client
565 .tensor_uninitialized(shape, B::BoolElem::dtype());
566
567 let desc = SwapDimsOpIr {
568 input: tensor.into_ir(),
569 dim1,
570 dim2,
571 out: out.to_ir_out(),
572 };
573 out.client.register(
574 streams,
575 OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
576 SwapDimsOps::<B>::new(desc),
577 );
578
579 out
580 }
581
582 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
583 #[derive(new, Debug)]
584 struct PermuteDimsOps<B: FusionBackend> {
585 desc: PermuteOpIr,
586 _b: PhantomData<B>,
587 }
588
589 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
590 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
591 let input = handles.get_bool_tensor::<B>(&self.desc.input);
592 let output = B::bool_permute(input, self.desc.axes.as_slice());
593 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
594 }
595 }
596
597 let mut streams = OperationStreams::default();
598 streams.tensor(&tensor);
599
600 let shape = tensor.shape.clone().permute(axes).unwrap();
602 let out = tensor
603 .client
604 .tensor_uninitialized(shape, B::BoolElem::dtype());
605
606 let desc = PermuteOpIr {
607 input: tensor.into_ir(),
608 axes: axes.to_vec(),
609 out: out.to_ir_out(),
610 };
611
612 out.client.register(
613 streams,
614 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
615 PermuteDimsOps::<B>::new(desc),
616 );
617
618 out
619 }
620
621 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
622 #[derive(new, Debug)]
623 struct ExpandOps<B: FusionBackend> {
624 desc: ExpandOpIr,
625 _b: PhantomData<B>,
626 }
627
628 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
629 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
630 let input = handles.get_bool_tensor::<B>(&self.desc.input);
631 let output = B::bool_expand(input, self.desc.shape.clone());
632
633 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
634 }
635 }
636
637 let mut streams = OperationStreams::default();
638 streams.tensor(&tensor);
639
640 let out = tensor
641 .client
642 .tensor_uninitialized(shape.clone(), B::BoolElem::dtype());
643
644 let desc = ExpandOpIr {
645 input: tensor.into_ir(),
646 shape,
647 out: out.to_ir_out(),
648 };
649
650 out.client.register(
651 streams,
652 OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
653 ExpandOps::<B>::new(desc),
654 );
655
656 out
657 }
658
659 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
660 #[derive(new, Debug)]
661 struct FlipOps<B: FusionBackend> {
662 desc: FlipOpIr,
663 _b: PhantomData<B>,
664 }
665
666 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
667 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
668 let input = handles.get_bool_tensor::<B>(&self.desc.input);
669 let output = B::bool_flip(input, self.desc.axes.as_slice());
670 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
671 }
672 }
673
674 let mut streams = OperationStreams::default();
675 streams.tensor(&tensor);
676
677 let out = tensor
678 .client
679 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
680
681 let desc = FlipOpIr {
682 input: tensor.into_ir(),
683 out: out.to_ir_out(),
684 axes: axes.to_vec(),
685 };
686
687 out.client.register(
688 streams,
689 OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
690 FlipOps::<B>::new(desc),
691 );
692
693 out
694 }
695
696 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
697 #[derive(new, Debug)]
698 struct RepeatDimOps<B: FusionBackend> {
699 desc: RepeatDimOpIr,
700 _b: PhantomData<B>,
701 }
702
703 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
704 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
705 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
706
707 let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
708
709 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
710 }
711 }
712
713 let mut streams = OperationStreams::default();
714 streams.tensor(&tensor);
715
716 let shape = tensor.shape.clone().repeat(dim, times);
717 let out = tensor
718 .client
719 .tensor_uninitialized(shape, B::BoolElem::dtype());
720
721 let desc = RepeatDimOpIr {
722 tensor: tensor.into_ir(),
723 dim,
724 times,
725 out: out.to_ir_out(),
726 };
727 out.client.register(
728 streams,
729 OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
730 RepeatDimOps::<B>::new(desc),
731 );
732
733 out
734 }
735
736 fn bool_unfold(
737 tensor: BoolTensor<Self>,
738 dim: usize,
739 size: usize,
740 step: usize,
741 ) -> BoolTensor<Self> {
742 #[derive(new, Debug)]
743 struct UnfoldOps<B: FusionBackend> {
744 desc: UnfoldOpIr,
745 _b: PhantomData<B>,
746 }
747
748 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
749 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
750 let input = handles.get_bool_tensor::<B>(&self.desc.input);
751 let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
752
753 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
754 }
755 }
756
757 let mut streams = OperationStreams::default();
758 streams.tensor(&tensor);
759
760 let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
761 let out = tensor
762 .client
763 .tensor_uninitialized(Shape::from(shape), tensor.dtype);
764
765 let desc = UnfoldOpIr {
766 input: tensor.into_ir(),
767 out: out.to_ir_out(),
768 dim,
769 size,
770 step,
771 };
772
773 out.client.register(
774 streams,
775 OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
776 UnfoldOps::<B>::new(desc),
777 );
778
779 out
780 }
781}