1use burn_ir::{
2 BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3 InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4 SwapDimsOpIr, TensorIr, UnaryOpIr,
5};
6use burn_tensor::{
7 Device, Element, Shape, TensorData, TensorMetadata,
8 ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
9};
10use std::marker::PhantomData;
11
12use crate::{
13 Fusion, FusionBackend,
14 client::FusionClient,
15 get_client,
16 stream::{OperationStreams, StreamId, execution::Operation},
17};
18
19use super::NoOp;
20
21impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
22 fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
23 #[derive(new, Debug)]
24 struct EmptyOps<B: FusionBackend> {
25 desc: TensorIr,
26 device: Device<B>,
27 }
28
29 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
30 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
31 let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device);
32 handles.register_bool_tensor::<B>(&self.desc.id, output);
33 }
34 }
35
36 let client = get_client::<B>(&device.clone());
37 let out = client.tensor_uninitialized(shape.dims.clone(), B::BoolElem::dtype());
38
39 let desc = out.to_ir_out();
40
41 client.register(
42 OperationStreams::default(),
43 OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),
44 EmptyOps::<B>::new(desc, device.clone()),
45 );
46
47 out
48 }
49
50 async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
51 tensor.bool_into_data::<B>().await
52 }
53
54 fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
55 let stream = StreamId::current();
56 let client = get_client::<B>(&device.clone());
57 let tensor = B::bool_from_data(data, device);
58 let shape = tensor.shape();
59
60 let handle = B::bool_tensor_handle(tensor);
61 let out = client.register_tensor(handle, shape.dims, stream, B::BoolElem::dtype());
62 let desc = out.to_ir_out();
63
64 client.register(
65 OperationStreams::default(),
66 OperationIr::Init(InitOperationIr { out: desc }),
67 NoOp::<B>::new(),
68 );
69
70 out
71 }
72
73 fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
74 #[derive(new, Debug)]
75 struct IntoIntOps<B: FusionBackend> {
76 desc: UnaryOpIr,
77 _b: PhantomData<B>,
78 }
79
80 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {
81 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
82 let input = handles.get_bool_tensor::<B>(&self.desc.input);
83 let output = B::bool_into_int(input);
84 handles.register_int_tensor::<B>(&self.desc.out.id, output);
85 }
86 }
87
88 let mut streams = OperationStreams::default();
89 streams.tensor(&tensor);
90
91 let out = tensor
92 .client
93 .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype());
94
95 let desc = UnaryOpIr {
96 input: tensor.into_ir(),
97 out: out.to_ir_out(),
98 };
99
100 out.client.register(
101 streams,
102 OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),
103 IntoIntOps::<B>::new(desc),
104 );
105
106 out
107 }
108
109 fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {
110 #[derive(new, Debug)]
111 struct IntoFloatOps<B: FusionBackend> {
112 desc: UnaryOpIr,
113 _b: PhantomData<B>,
114 }
115
116 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
117 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
118 let input = handles.get_bool_tensor::<B>(&self.desc.input);
119 let output = B::bool_into_float(input);
120 handles.register_float_tensor::<B>(&self.desc.out.id, output);
121 }
122 }
123
124 let mut streams = OperationStreams::default();
125 streams.tensor(&tensor);
126
127 let out = tensor
128 .client
129 .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
130
131 let desc = UnaryOpIr {
132 input: tensor.into_ir(),
133 out: out.to_ir_out(),
134 };
135 out.client.register(
136 streams,
137 OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),
138 IntoFloatOps::<B>::new(desc),
139 );
140
141 out
142 }
143
144 fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {
145 tensor.client.device().clone()
146 }
147
148 fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {
149 let device_original: &B::Device = tensor.client.device();
150 let device_target: B::Device = device.clone();
151
152 if device_original == &device_target {
153 return tensor;
154 }
155
156 let id = tensor.stream;
157 let client_target = get_client::<B>(&device_target);
158 let client_original = tensor.client.clone();
159
160 client_original
161 .clone()
162 .change_client_bool::<B>(tensor.into_ir(), client_target, id)
163 }
164
165 fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
166 #[derive(new, Debug)]
167 struct ReshapeDimsOps<B: FusionBackend> {
168 desc: UnaryOpIr,
169 _b: PhantomData<B>,
170 }
171
172 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
173 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
174 let input = handles.get_bool_tensor::<B>(&self.desc.input);
175 let output = B::bool_reshape(input, Shape::from(&self.desc.out.shape));
176 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
177 }
178 }
179
180 let mut streams = OperationStreams::default();
181 streams.tensor(&tensor);
182
183 let out = tensor
184 .client
185 .tensor_uninitialized(shape.dims, B::BoolElem::dtype());
186
187 let desc = UnaryOpIr {
188 input: tensor.into_ir(),
189 out: out.to_ir_out(),
190 };
191 out.client.register(
192 streams,
193 OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),
194 ReshapeDimsOps::<B>::new(desc),
195 );
196
197 out
198 }
199
200 fn bool_slice(tensor: BoolTensor<Self>, ranges: &[std::ops::Range<usize>]) -> BoolTensor<Self> {
201 #[derive(new, Debug)]
202 struct SliceOps<B: FusionBackend> {
203 desc: SliceOpIr,
204 _b: PhantomData<B>,
205 }
206
207 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
208 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
209 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
210
211 let output = B::bool_slice(tensor, self.desc.ranges.as_slice());
212
213 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
214 }
215 }
216
217 let ndims = burn_tensor::TensorMetadata::shape(&tensor).num_dims();
218 let mut shape: Vec<usize> = ranges.iter().map(|range| range.end - range.start).collect();
219
220 for i in shape.len()..ndims {
221 shape.push(tensor.shape[i]);
222 }
223
224 let mut streams = OperationStreams::default();
225 streams.tensor(&tensor);
226
227 let out = tensor
228 .client
229 .tensor_uninitialized(shape, B::BoolElem::dtype());
230
231 let desc = SliceOpIr {
232 tensor: tensor.into_ir(),
233 ranges: ranges.into(),
234 out: out.to_ir_out(),
235 };
236 out.client.register(
237 streams,
238 OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),
239 SliceOps::<B>::new(desc),
240 );
241
242 out
243 }
244
245 fn bool_slice_assign(
246 tensor: BoolTensor<Self>,
247 ranges: &[std::ops::Range<usize>],
248 value: BoolTensor<Self>,
249 ) -> BoolTensor<Self> {
250 #[derive(new, Debug)]
251 struct SliceAssignOps<B: FusionBackend> {
252 desc: SliceAssignOpIr,
253 _b: PhantomData<B>,
254 }
255
256 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
257 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
258 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
259 let value = handles.get_bool_tensor::<B>(&self.desc.value);
260
261 let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);
262
263 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
264 }
265 }
266
267 let shape: Vec<usize> = tensor.shape.clone();
268 let mut streams = OperationStreams::default();
269 streams.tensor(&tensor);
270 streams.tensor(&value);
271
272 let out = tensor
273 .client
274 .tensor_uninitialized(shape, B::BoolElem::dtype());
275
276 let desc = SliceAssignOpIr {
277 tensor: tensor.into_ir(),
278 ranges: ranges.into(),
279 value: value.into_ir(),
280 out: out.to_ir_out(),
281 };
282
283 out.client.register(
284 streams,
285 OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),
286 SliceAssignOps::<B>::new(desc),
287 );
288
289 out
290 }
291
292 fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {
293 #[derive(new, Debug)]
294 struct CatOps<B: FusionBackend> {
295 desc: CatOpIr,
296 _b: PhantomData<B>,
297 }
298
299 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
300 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
301 let tensors = self
302 .desc
303 .tensors
304 .iter()
305 .map(|tensor| handles.get_bool_tensor::<B>(tensor))
306 .collect();
307
308 let output = B::bool_cat(tensors, self.desc.dim);
309
310 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
311 }
312 }
313
314 let tensor_first = tensors.first().unwrap();
315 let client = tensor_first.client.clone();
316
317 let mut shape: Vec<usize> = tensor_first.shape.clone();
319 let mut streams = OperationStreams::default();
320 tensors.iter().for_each(|t| streams.tensor(t));
321
322 shape[dim] = 0;
323 for tensor in tensors.iter() {
324 shape[dim] += tensor.shape[dim];
325 }
326
327 let out = client.tensor_uninitialized(shape, B::BoolElem::dtype());
328
329 let desc = CatOpIr {
330 tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
331 dim,
332 out: out.to_ir_out(),
333 };
334 client.register(
335 streams,
336 OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),
337 CatOps::<B>::new(desc),
338 );
339
340 out
341 }
342
343 fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
344 #[derive(new, Debug)]
345 struct EqualOps<B: FusionBackend> {
346 desc: BinaryOpIr,
347 _b: PhantomData<B>,
348 }
349
350 impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {
351 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
352 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
353 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
354 let output = B::bool_equal(lhs, rhs);
355 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
356 }
357 }
358
359 let mut streams = OperationStreams::default();
360 streams.tensor(&lhs);
361 streams.tensor(&rhs);
362
363 let out = lhs.client.tensor_uninitialized(
364 binary_ops_shape(&lhs.shape, &rhs.shape),
365 B::BoolElem::dtype(),
366 );
367
368 let desc = BinaryOpIr {
369 lhs: lhs.into_ir(),
370 rhs: rhs.into_ir(),
371 out: out.to_ir_out(),
372 };
373 out.client.register(
374 streams,
375 OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),
376 EqualOps::<B>::new(desc),
377 );
378
379 out
380 }
381
382 fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {
383 #[derive(new, Debug)]
384 struct NotOps<B: FusionBackend> {
385 desc: UnaryOpIr,
386 _b: PhantomData<B>,
387 }
388
389 impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {
390 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
391 let input = handles.get_bool_tensor::<B>(&self.desc.input);
392 let output = B::bool_not(input);
393 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
394 }
395 }
396
397 let mut streams = OperationStreams::default();
398 streams.tensor(&tensor);
399
400 let out = tensor
401 .client
402 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
403
404 let desc = UnaryOpIr {
405 input: tensor.into_ir(),
406 out: out.to_ir_out(),
407 };
408
409 out.client.register(
410 streams,
411 OperationIr::Bool(BoolOperationIr::Not(desc.clone())),
412 NotOps::<B>::new(desc),
413 );
414
415 out
416 }
417
418 fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
419 #[derive(new, Debug)]
420 struct AndOps<B: FusionBackend> {
421 desc: BinaryOpIr,
422 _b: PhantomData<B>,
423 }
424
425 impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {
426 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
427 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
428 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
429 let output = B::bool_and(lhs, rhs);
430 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
431 }
432 }
433
434 let mut streams = OperationStreams::default();
435 streams.tensor(&lhs);
436 streams.tensor(&rhs);
437
438 let out = lhs.client.tensor_uninitialized(
439 binary_ops_shape(&lhs.shape, &rhs.shape),
440 B::BoolElem::dtype(),
441 );
442
443 let desc = BinaryOpIr {
444 lhs: lhs.into_ir(),
445 rhs: rhs.into_ir(),
446 out: out.to_ir_out(),
447 };
448 out.client.register(
449 streams,
450 OperationIr::Bool(BoolOperationIr::And(desc.clone())),
451 AndOps::<B>::new(desc),
452 );
453
454 out
455 }
456
457 fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {
458 #[derive(new, Debug)]
459 struct OrOps<B: FusionBackend> {
460 desc: BinaryOpIr,
461 _b: PhantomData<B>,
462 }
463
464 impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {
465 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
466 let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);
467 let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);
468 let output = B::bool_or(lhs, rhs);
469 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
470 }
471 }
472
473 let mut streams = OperationStreams::default();
474 streams.tensor(&lhs);
475 streams.tensor(&rhs);
476
477 let out = lhs.client.tensor_uninitialized(
478 binary_ops_shape(&lhs.shape, &rhs.shape),
479 B::BoolElem::dtype(),
480 );
481
482 let desc = BinaryOpIr {
483 lhs: lhs.into_ir(),
484 rhs: rhs.into_ir(),
485 out: out.to_ir_out(),
486 };
487 out.client.register(
488 streams,
489 OperationIr::Bool(BoolOperationIr::Or(desc.clone())),
490 OrOps::<B>::new(desc),
491 );
492
493 out
494 }
495
496 fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {
497 #[derive(new, Debug)]
498 struct SwapDimsOps<B: FusionBackend> {
499 desc: SwapDimsOpIr,
500 _b: PhantomData<B>,
501 }
502
503 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
504 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
505 let input = handles.get_bool_tensor::<B>(&self.desc.input);
506 let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);
507 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
508 }
509 }
510
511 let mut streams = OperationStreams::default();
512 streams.tensor(&tensor);
513
514 let mut shape = tensor.shape.clone();
515 shape[dim1] = tensor.shape[dim2];
516 shape[dim2] = tensor.shape[dim1];
517
518 let out = tensor
519 .client
520 .tensor_uninitialized(shape, B::BoolElem::dtype());
521
522 let desc = SwapDimsOpIr {
523 input: tensor.into_ir(),
524 dim1,
525 dim2,
526 out: out.to_ir_out(),
527 };
528 out.client.register(
529 streams,
530 OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),
531 SwapDimsOps::<B>::new(desc),
532 );
533
534 out
535 }
536
537 fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
538 #[derive(new, Debug)]
539 struct PermuteDimsOps<B: FusionBackend> {
540 desc: PermuteOpIr,
541 _b: PhantomData<B>,
542 }
543
544 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
545 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
546 let input = handles.get_bool_tensor::<B>(&self.desc.input);
547 let output = B::bool_permute(input, self.desc.axes.as_slice());
548 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
549 }
550 }
551
552 let mut streams = OperationStreams::default();
553 streams.tensor(&tensor);
554
555 let shape = axes.iter().map(|x| tensor.shape[*x]).collect();
557
558 let out = tensor
559 .client
560 .tensor_uninitialized(shape, B::BoolElem::dtype());
561
562 let desc = PermuteOpIr {
563 input: tensor.into_ir(),
564 axes: axes.to_vec(),
565 out: out.to_ir_out(),
566 };
567
568 out.client.register(
569 streams,
570 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
571 PermuteDimsOps::<B>::new(desc),
572 );
573
574 out
575 }
576
577 fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
578 #[derive(new, Debug)]
579 struct ExpandOps<B: FusionBackend> {
580 desc: ExpandOpIr,
581 _b: PhantomData<B>,
582 }
583
584 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
585 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
586 let input = handles.get_bool_tensor::<B>(&self.desc.input);
587 let output = B::bool_expand(input, self.desc.shape.clone().into());
588
589 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
590 }
591 }
592
593 let mut streams = OperationStreams::default();
594 streams.tensor(&tensor);
595
596 let out = tensor
597 .client
598 .tensor_uninitialized(shape.dims.clone(), B::BoolElem::dtype());
599
600 let desc = ExpandOpIr {
601 input: tensor.into_ir(),
602 shape: shape.dims,
603 out: out.to_ir_out(),
604 };
605
606 out.client.register(
607 streams,
608 OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),
609 ExpandOps::<B>::new(desc),
610 );
611
612 out
613 }
614
615 fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
616 #[derive(new, Debug)]
617 struct FlipOps<B: FusionBackend> {
618 desc: FlipOpIr,
619 _b: PhantomData<B>,
620 }
621
622 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {
623 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
624 let input = handles.get_bool_tensor::<B>(&self.desc.input);
625 let output = B::bool_flip(input, self.desc.axes.as_slice());
626 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
627 }
628 }
629
630 let mut streams = OperationStreams::default();
631 streams.tensor(&tensor);
632
633 let out = tensor
634 .client
635 .tensor_uninitialized(tensor.shape.clone(), B::BoolElem::dtype());
636
637 let desc = FlipOpIr {
638 input: tensor.into_ir(),
639 out: out.to_ir_out(),
640 axes: axes.to_vec(),
641 };
642
643 out.client.register(
644 streams,
645 OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),
646 FlipOps::<B>::new(desc),
647 );
648
649 out
650 }
651
652 fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {
653 #[derive(new, Debug)]
654 struct RepeatDimOps<B: FusionBackend> {
655 desc: RepeatDimOpIr,
656 _b: PhantomData<B>,
657 }
658
659 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
660 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
661 let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);
662
663 let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);
664
665 handles.register_bool_tensor::<B>(&self.desc.out.id, output);
666 }
667 }
668
669 let mut streams = OperationStreams::default();
670 streams.tensor(&tensor);
671
672 let mut shape = tensor.shape.clone();
673 shape[dim] *= times;
674 let out = tensor
675 .client
676 .tensor_uninitialized(shape, B::BoolElem::dtype());
677
678 let desc = RepeatDimOpIr {
679 tensor: tensor.into_ir(),
680 dim,
681 times,
682 out: out.to_ir_out(),
683 };
684 out.client.register(
685 streams,
686 OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),
687 RepeatDimOps::<B>::new(desc),
688 );
689
690 out
691 }
692}