Skip to main content

burn_fusion/ops/
int_tensor.rs

1use super::NoOp;
2use crate::{
3    Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops,
4    client::GlobalFusionClient,
5    get_client, reduce_int_ops, scalar_int_cmp_ops, scalar_int_ops,
6    stream::{OperationStreams, execution::Operation},
7    unary_int_ops,
8};
9use burn_backend::{
10    BoolDType, Distribution, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice,
11    TensorData,
12    ops::IntTensorOps,
13    tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
14};
15use burn_ir::*;
16use std::marker::PhantomData;
17
18impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
19    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
20        #[derive(new, Debug)]
21        struct EmptyOps<B: FusionBackend> {
22            desc: TensorIr,
23            device: Device<B>,
24        }
25
26        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
27            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
28                let output = B::int_empty(
29                    self.desc.shape.clone(),
30                    &self.device,
31                    self.desc.dtype.into(),
32                );
33                handles.register_int_tensor::<B>(&self.desc.id, output);
34            }
35        }
36
37        let client = get_client::<B>(device);
38        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
39
40        client
41            .register(
42                OperationStreams::default(),
43                OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())),
44                EmptyOps::<B>::new(desc.out, device.clone()),
45            )
46            .output()
47    }
48
49    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
50        tensor.int_into_data::<B>().await
51    }
52
53    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
54        let client = get_client::<B>(device);
55        let dtype = data.dtype;
56        let tensor = B::int_from_data(data, device);
57        let shape = burn_backend::TensorMetadata::shape(&tensor);
58
59        let handle = B::int_tensor_handle(tensor);
60        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
61
62        client
63            .register(
64                OperationStreams::default(),
65                OperationIr::Init(desc),
66                NoOp::<B>::new(),
67            )
68            .output()
69    }
70
71    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
72        tensor.client.device().clone()
73    }
74
75    fn int_to_device(tensor: IntTensor<Self>, device_dst: &Device<Self>) -> IntTensor<Self> {
76        let device_src: &B::Device = tensor.client.device();
77
78        if device_src == device_dst {
79            return tensor;
80        }
81
82        let id = tensor.stream;
83        let client_dst = get_client::<B>(device_dst);
84        let client_src = tensor.client.clone();
85
86        GlobalFusionClient::change_client_int::<B>(tensor.into_ir(), client_src, client_dst, id)
87    }
88
89    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
90        if tensor.shape == shape {
91            return tensor;
92        }
93
94        #[derive(new, Debug)]
95        struct ReshapeDimsOps<B: FusionBackend> {
96            desc: ShapeOpIr,
97            _b: PhantomData<B>,
98        }
99
100        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
101            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
102                let input = handles.get_int_tensor::<B>(&self.desc.input);
103                let output = B::int_reshape(input, self.desc.out.shape.clone());
104                handles.register_int_tensor::<B>(&self.desc.out.id, output);
105            }
106        }
107
108        let streams = OperationStreams::with_inputs([&tensor]);
109
110        let client = tensor.client.clone();
111        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
112
113        client
114            .register(
115                streams,
116                OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())),
117                ReshapeDimsOps::<B>::new(desc),
118            )
119            .output()
120    }
121
122    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
123        #[derive(new, Debug)]
124        struct SliceOps<B: FusionBackend> {
125            desc: SliceOpIr,
126            _b: PhantomData<B>,
127        }
128
129        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
130            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
131                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
132
133                let output = B::int_slice(tensor, self.desc.ranges.as_slice());
134
135                handles.register_int_tensor::<B>(&self.desc.out.id, output);
136            }
137        }
138
139        let streams = OperationStreams::with_inputs([&tensor]);
140
141        let client = tensor.client.clone();
142        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
143            client.create_empty_handle()
144        });
145
146        client
147            .register(
148                streams,
149                OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())),
150                SliceOps::<B>::new(desc),
151            )
152            .output()
153    }
154
155    fn int_slice_assign(
156        tensor: IntTensor<Self>,
157        slices: &[burn_backend::Slice],
158        value: IntTensor<Self>,
159    ) -> IntTensor<Self> {
160        #[derive(new, Debug)]
161        struct SliceAssignOps<B: FusionBackend> {
162            desc: SliceAssignOpIr,
163            _b: PhantomData<B>,
164        }
165
166        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
167            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
168                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
169                let value = handles.get_int_tensor::<B>(&self.desc.value);
170
171                let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value);
172
173                handles.register_int_tensor::<B>(&self.desc.out.id, output);
174            }
175        }
176
177        let streams = OperationStreams::with_inputs([&tensor, &value]);
178
179        let client = tensor.client.clone();
180        let desc =
181            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
182                client.create_empty_handle()
183            });
184
185        client
186            .register(
187                streams,
188                OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())),
189                SliceAssignOps::<B>::new(desc),
190            )
191            .output()
192    }
193
194    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
195        binary_int_ops!(MatmulOps, B::int_matmul);
196
197        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
198
199        let client = lhs.client.clone();
200        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
201            client.create_empty_handle()
202        });
203
204        client
205            .register(
206                streams,
207                OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),
208                MatmulOps::<B>::new(desc.into()),
209            )
210            .output()
211    }
212
213    fn int_mask_where(
214        tensor: IntTensor<Self>,
215        mask: BoolTensor<Self>,
216        value: IntTensor<Self>,
217    ) -> IntTensor<Self> {
218        #[derive(new, Debug)]
219        struct MaskWhereOps<B: FusionBackend> {
220            desc: MaskWhereOpIr,
221            _b: PhantomData<B>,
222        }
223
224        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
225            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
226                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
227                let value = handles.get_int_tensor::<B>(&self.desc.value);
228                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
229
230                let output = B::int_mask_where(tensor, mask, value);
231
232                handles.register_int_tensor::<B>(&self.desc.out.id, output);
233            }
234        }
235
236        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
237
238        let client = tensor.client.clone();
239        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
240            client.create_empty_handle()
241        });
242
243        client
244            .register(
245                streams,
246                OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc.clone())),
247                MaskWhereOps::<B>::new(desc),
248            )
249            .output()
250    }
251
252    fn int_mask_fill(
253        tensor: IntTensor<Self>,
254        mask: BoolTensor<Self>,
255        value: Scalar,
256    ) -> IntTensor<Self> {
257        #[derive(new, Debug)]
258        struct MaskFillOps<B: FusionBackend> {
259            desc: MaskFillOpIr,
260            _b: PhantomData<B>,
261        }
262
263        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
264            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
265                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
266                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
267
268                let output = B::int_mask_fill(tensor, mask, self.desc.value.into());
269
270                handles.register_int_tensor::<B>(&self.desc.out.id, output);
271            }
272        }
273
274        let streams = OperationStreams::with_inputs([&tensor, &mask]);
275
276        let client = tensor.client.clone();
277        let value = value.into();
278        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
279            client.create_empty_handle()
280        });
281
282        client
283            .register(
284                streams,
285                OperationIr::BaseInt(BaseOperationIr::MaskFill(desc.clone())),
286                MaskFillOps::<B>::new(desc),
287            )
288            .output()
289    }
290
291    fn int_gather(
292        dim: usize,
293        tensor: IntTensor<Self>,
294        indices: IntTensor<Self>,
295    ) -> IntTensor<Self> {
296        #[derive(new, Debug)]
297        struct GatherOps<B: FusionBackend> {
298            desc: GatherOpIr,
299            _b: PhantomData<B>,
300        }
301
302        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
303            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
304                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
305                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
306
307                let output = B::int_gather(self.desc.dim, tensor, indices);
308                handles.register_int_tensor::<B>(&self.desc.out.id, output);
309            }
310        }
311
312        let streams = OperationStreams::with_inputs([&tensor, &indices]);
313
314        let client = tensor.client.clone();
315        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
316            client.create_empty_handle()
317        });
318
319        client
320            .register(
321                streams,
322                OperationIr::BaseInt(BaseOperationIr::Gather(desc.clone())),
323                GatherOps::<B>::new(desc),
324            )
325            .output()
326    }
327
328    fn int_scatter_add(
329        dim: usize,
330        tensor: IntTensor<Self>,
331        indices: IntTensor<Self>,
332        value: IntTensor<Self>,
333    ) -> IntTensor<Self> {
334        #[derive(new, Debug)]
335        struct ScatterOps<B: FusionBackend> {
336            desc: ScatterOpIr,
337            _b: PhantomData<B>,
338        }
339
340        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
341            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
342                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
343                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
344                let value = handles.get_int_tensor::<B>(&self.desc.value);
345
346                let output = B::int_scatter_add(self.desc.dim, tensor, indices, value);
347
348                handles.register_int_tensor::<B>(&self.desc.out.id, output);
349            }
350        }
351
352        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
353
354        let client = tensor.client.clone();
355        let desc = ScatterOpIr::create(
356            tensor.into_ir(),
357            dim,
358            indices.into_ir(),
359            value.into_ir(),
360            IndexingUpdateOp::Add,
361            || client.create_empty_handle(),
362        );
363
364        client
365            .register(
366                streams,
367                OperationIr::BaseInt(BaseOperationIr::Scatter(desc.clone())),
368                ScatterOps::<B>::new(desc),
369            )
370            .output()
371    }
372
373    fn int_scatter_nd(
374        data: IntTensor<Self>,
375        indices: IntTensor<Self>,
376        values: IntTensor<Self>,
377        reduction: IndexingUpdateOp,
378    ) -> IntTensor<Self> {
379        #[derive(new, Debug)]
380        struct ScatterNdOps<B: FusionBackend> {
381            desc: ScatterNdOpIr,
382            _b: PhantomData<B>,
383        }
384
385        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterNdOps<B> {
386            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
387                let data = handles.get_int_tensor::<B>(&self.desc.data);
388                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
389                let values = handles.get_int_tensor::<B>(&self.desc.values);
390
391                let output = B::int_scatter_nd(data, indices, values, self.desc.reduction);
392
393                handles.register_int_tensor::<B>(&self.desc.out.id, output);
394            }
395        }
396
397        let streams = OperationStreams::with_inputs([&data, &indices, &values]);
398
399        let client = data.client.clone();
400        let desc = ScatterNdOpIr::create(
401            data.into_ir(),
402            indices.into_ir(),
403            values.into_ir(),
404            reduction,
405            || client.create_empty_handle(),
406        );
407
408        client
409            .register(
410                streams,
411                OperationIr::BaseInt(BaseOperationIr::ScatterNd(desc.clone())),
412                ScatterNdOps::<B>::new(desc),
413            )
414            .output()
415    }
416
417    fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
418        #[derive(new, Debug)]
419        struct GatherNdOps<B: FusionBackend> {
420            desc: GatherNdOpIr,
421            _b: PhantomData<B>,
422        }
423
424        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherNdOps<B> {
425            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
426                let data = handles.get_int_tensor::<B>(&self.desc.data);
427                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
428
429                let output = B::int_gather_nd(data, indices);
430                handles.register_int_tensor::<B>(&self.desc.out.id, output);
431            }
432        }
433
434        let streams = OperationStreams::with_inputs([&data, &indices]);
435
436        let client = data.client.clone();
437        let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
438            client.create_empty_handle()
439        });
440
441        client
442            .register(
443                streams,
444                OperationIr::BaseInt(BaseOperationIr::GatherNd(desc.clone())),
445                GatherNdOps::<B>::new(desc),
446            )
447            .output()
448    }
449
450    fn int_select(
451        tensor: IntTensor<Self>,
452        dim: usize,
453        indices: IntTensor<Self>,
454    ) -> IntTensor<Self> {
455        #[derive(new, Debug)]
456        struct SelectOps<B: FusionBackend> {
457            desc: SelectOpIr,
458            _b: PhantomData<B>,
459        }
460
461        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
462            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
463                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
464                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
465
466                let output = B::int_select(tensor, self.desc.dim, indices);
467
468                handles.register_int_tensor::<B>(&self.desc.out.id, output);
469            }
470        }
471
472        let streams = OperationStreams::with_inputs([&tensor, &indices]);
473
474        let client = tensor.client.clone();
475        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
476            client.create_empty_handle()
477        });
478
479        client
480            .register(
481                streams,
482                OperationIr::BaseInt(BaseOperationIr::Select(desc.clone())),
483                SelectOps::<B>::new(desc),
484            )
485            .output()
486    }
487
488    fn int_select_add(
489        tensor: IntTensor<Self>,
490        dim: usize,
491        indices: IntTensor<Self>,
492        value: IntTensor<Self>,
493    ) -> IntTensor<Self> {
494        #[derive(new, Debug)]
495        struct SelectAssignOps<B: FusionBackend> {
496            desc: SelectAssignOpIr,
497            _b: PhantomData<B>,
498        }
499
500        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
501            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
502                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
503                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
504                let value = handles.get_int_tensor::<B>(&self.desc.value);
505
506                let output = B::int_select_add(tensor, self.desc.dim, indices, value);
507
508                handles.register_int_tensor::<B>(&self.desc.out.id, output);
509            }
510        }
511
512        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
513
514        let client = tensor.client.clone();
515        let desc = SelectAssignOpIr::create(
516            tensor.into_ir(),
517            dim,
518            indices.into_ir(),
519            value.into_ir(),
520            IndexingUpdateOp::Add,
521            || client.create_empty_handle(),
522        );
523
524        client
525            .register(
526                streams,
527                OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc.clone())),
528                SelectAssignOps::<B>::new(desc),
529            )
530            .output()
531    }
532
533    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
534        #[derive(new, Debug)]
535        struct CatOps<B: FusionBackend> {
536            desc: CatOpIr,
537            _b: PhantomData<B>,
538        }
539
540        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
541            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
542                let tensors = self
543                    .desc
544                    .tensors
545                    .iter()
546                    .map(|tensor| handles.get_int_tensor::<B>(tensor))
547                    .collect();
548
549                let output = B::int_cat(tensors, self.desc.dim);
550
551                handles.register_int_tensor::<B>(&self.desc.out.id, output);
552            }
553        }
554
555        let streams = OperationStreams::with_inputs(&tensors);
556
557        let client = tensors.first().unwrap().client.clone();
558        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
559        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
560
561        client
562            .register(
563                streams,
564                OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())),
565                CatOps::<B>::new(desc),
566            )
567            .output()
568    }
569
570    fn int_equal(
571        lhs: IntTensor<Self>,
572        rhs: IntTensor<Self>,
573        out_dtype: BoolDType,
574    ) -> BoolTensor<Self> {
575        binary_int_cmp_ops!(EqualOps, B::int_equal);
576
577        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
578
579        let client = lhs.client.clone();
580        let desc =
581            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
582                client.create_empty_handle()
583            });
584
585        client
586            .register(
587                streams,
588                OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())),
589                EqualOps::<B>::new(desc),
590            )
591            .output()
592    }
593
594    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
595        scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
596
597        let streams = OperationStreams::with_inputs([&lhs]);
598
599        let client = lhs.client.clone();
600        let rhs = rhs.into();
601        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
602            client.create_empty_handle()
603        });
604
605        client
606            .register(
607                streams,
608                OperationIr::BaseInt(BaseOperationIr::EqualElem(desc.clone())),
609                EqualElemOps::<B>::new(desc),
610            )
611            .output()
612    }
613
614    fn int_greater(
615        lhs: IntTensor<Self>,
616        rhs: IntTensor<Self>,
617        out_dtype: BoolDType,
618    ) -> BoolTensor<Self> {
619        binary_int_cmp_ops!(GreaterOps, B::int_greater);
620
621        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
622
623        let client = lhs.client.clone();
624        let desc =
625            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
626                client.create_empty_handle()
627            });
628
629        client
630            .register(
631                streams,
632                OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Greater(desc.clone())),
633                GreaterOps::<B>::new(desc),
634            )
635            .output()
636    }
637
638    fn int_greater_elem(
639        lhs: IntTensor<Self>,
640        rhs: Scalar,
641        out_dtype: BoolDType,
642    ) -> BoolTensor<Self> {
643        scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
644
645        let streams = OperationStreams::with_inputs([&lhs]);
646
647        let client = lhs.client.clone();
648        let rhs = rhs.into();
649        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
650            client.create_empty_handle()
651        });
652
653        client
654            .register(
655                streams,
656                OperationIr::NumericInt(
657                    desc.lhs.dtype,
658                    NumericOperationIr::GreaterElem(desc.clone()),
659                ),
660                GreaterElemOps::<B>::new(desc),
661            )
662            .output()
663    }
664
665    fn int_greater_equal(
666        lhs: IntTensor<Self>,
667        rhs: IntTensor<Self>,
668        out_dtype: BoolDType,
669    ) -> BoolTensor<Self> {
670        binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal);
671
672        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
673
674        let client = lhs.client.clone();
675        let desc =
676            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
677                client.create_empty_handle()
678            });
679
680        client
681            .register(
682                streams,
683                OperationIr::NumericInt(
684                    desc.lhs.dtype,
685                    NumericOperationIr::GreaterEqual(desc.clone()),
686                ),
687                GreaterEqualOps::<B>::new(desc),
688            )
689            .output()
690    }
691
692    fn int_greater_equal_elem(
693        lhs: IntTensor<Self>,
694        rhs: Scalar,
695        out_dtype: BoolDType,
696    ) -> BoolTensor<Self> {
697        scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
698
699        let streams = OperationStreams::with_inputs([&lhs]);
700
701        let client = lhs.client.clone();
702        let rhs = rhs.into();
703        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
704            client.create_empty_handle()
705        });
706
707        client
708            .register(
709                streams,
710                OperationIr::NumericInt(
711                    desc.lhs.dtype,
712                    NumericOperationIr::GreaterEqualElem(desc.clone()),
713                ),
714                GreaterEqualElemOps::<B>::new(desc),
715            )
716            .output()
717    }
718
719    fn int_lower(
720        lhs: IntTensor<Self>,
721        rhs: IntTensor<Self>,
722        out_dtype: BoolDType,
723    ) -> BoolTensor<Self> {
724        binary_int_cmp_ops!(LowerOps, B::int_lower);
725
726        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
727
728        let client = lhs.client.clone();
729        let desc =
730            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
731                client.create_empty_handle()
732            });
733
734        client
735            .register(
736                streams,
737                OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),
738                LowerOps::<B>::new(desc),
739            )
740            .output()
741    }
742
743    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
744        scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
745
746        let streams = OperationStreams::with_inputs([&lhs]);
747
748        let client = lhs.client.clone();
749        let rhs = rhs.into();
750        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
751            client.create_empty_handle()
752        });
753
754        client
755            .register(
756                streams,
757                OperationIr::NumericInt(
758                    desc.lhs.dtype,
759                    NumericOperationIr::LowerElem(desc.clone()),
760                ),
761                LowerElemOps::<B>::new(desc),
762            )
763            .output()
764    }
765
766    fn int_lower_equal(
767        lhs: IntTensor<Self>,
768        rhs: IntTensor<Self>,
769        out_dtype: BoolDType,
770    ) -> BoolTensor<Self> {
771        binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal);
772
773        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
774
775        let client = lhs.client.clone();
776        let desc =
777            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
778                client.create_empty_handle()
779            });
780
781        client
782            .register(
783                streams,
784                OperationIr::NumericInt(
785                    desc.lhs.dtype,
786                    NumericOperationIr::LowerEqual(desc.clone()),
787                ),
788                LowerEqualOps::<B>::new(desc),
789            )
790            .output()
791    }
792
793    fn int_lower_equal_elem(
794        lhs: IntTensor<Self>,
795        rhs: Scalar,
796        out_dtype: BoolDType,
797    ) -> BoolTensor<Self> {
798        scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
799
800        let streams = OperationStreams::with_inputs([&lhs]);
801
802        let client = lhs.client.clone();
803        let rhs = rhs.into();
804        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
805            client.create_empty_handle()
806        });
807
808        client
809            .register(
810                streams,
811                OperationIr::NumericInt(
812                    desc.lhs.dtype,
813                    NumericOperationIr::LowerEqualElem(desc.clone()),
814                ),
815                LowerEqualElemOps::<B>::new(desc),
816            )
817            .output()
818    }
819
820    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
821        binary_int_ops!(AddOps, B::int_add);
822
823        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
824
825        let client = lhs.client.clone();
826        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
827            client.create_empty_handle()
828        });
829
830        client
831            .register(
832                streams,
833                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Add(desc.clone())),
834                AddOps::<B>::new(desc),
835            )
836            .output()
837    }
838
839    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
840        scalar_int_ops!(AddOps, B::int_add_scalar);
841
842        let streams = OperationStreams::with_inputs([&lhs]);
843
844        let client = lhs.client.clone();
845        let rhs = rhs.into();
846        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
847
848        client
849            .register(
850                streams,
851                OperationIr::NumericInt(
852                    desc.out.dtype,
853                    NumericOperationIr::AddScalar(desc.clone()),
854                ),
855                AddOps::<B>::new(desc),
856            )
857            .output()
858    }
859
860    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
861        binary_int_ops!(SubOps, B::int_sub);
862
863        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
864
865        let client = lhs.client.clone();
866        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
867            client.create_empty_handle()
868        });
869
870        client
871            .register(
872                streams,
873                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),
874                SubOps::<B>::new(desc),
875            )
876            .output()
877    }
878
879    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
880        scalar_int_ops!(SubOps, B::int_sub_scalar);
881
882        let streams = OperationStreams::with_inputs([&lhs]);
883
884        let client = lhs.client.clone();
885        let rhs = rhs.into();
886        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
887
888        client
889            .register(
890                streams,
891                OperationIr::NumericInt(
892                    desc.out.dtype,
893                    NumericOperationIr::SubScalar(desc.clone()),
894                ),
895                SubOps::<B>::new(desc),
896            )
897            .output()
898    }
899
900    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
901        binary_int_ops!(MulOps, B::int_mul);
902
903        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
904
905        let client = lhs.client.clone();
906        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
907            client.create_empty_handle()
908        });
909
910        client
911            .register(
912                streams,
913                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),
914                MulOps::<B>::new(desc),
915            )
916            .output()
917    }
918
919    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
920        scalar_int_ops!(MulOps, B::int_mul_scalar);
921
922        let streams = OperationStreams::with_inputs([&lhs]);
923
924        let client = lhs.client.clone();
925        let rhs = rhs.into();
926        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
927
928        client
929            .register(
930                streams,
931                OperationIr::NumericInt(
932                    desc.out.dtype,
933                    NumericOperationIr::MulScalar(desc.clone()),
934                ),
935                MulOps::<B>::new(desc),
936            )
937            .output()
938    }
939
940    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
941        binary_int_ops!(DivOps, B::int_div);
942
943        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
944
945        let client = lhs.client.clone();
946        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
947            client.create_empty_handle()
948        });
949
950        client
951            .register(
952                streams,
953                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Div(desc.clone())),
954                DivOps::<B>::new(desc),
955            )
956            .output()
957    }
958
959    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
960        scalar_int_ops!(DivOps, B::int_div_scalar);
961
962        let streams = OperationStreams::with_inputs([&lhs]);
963
964        let client = lhs.client.clone();
965        let rhs = rhs.into();
966        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
967
968        client
969            .register(
970                streams,
971                OperationIr::NumericInt(
972                    desc.out.dtype,
973                    NumericOperationIr::DivScalar(desc.clone()),
974                ),
975                DivOps::<B>::new(desc),
976            )
977            .output()
978    }
979
980    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
981        binary_int_ops!(ModOps, B::int_remainder);
982
983        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
984
985        let client = lhs.client.clone();
986        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
987            client.create_empty_handle()
988        });
989
990        client
991            .register(
992                streams,
993                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),
994                ModOps::<B>::new(desc),
995            )
996            .output()
997    }
998
999    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1000        scalar_int_ops!(ModOps, B::int_remainder_scalar);
1001
1002        let streams = OperationStreams::with_inputs([&lhs]);
1003
1004        let client = lhs.client.clone();
1005        let rhs = rhs.into();
1006        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1007
1008        client
1009            .register(
1010                streams,
1011                OperationIr::NumericInt(
1012                    desc.out.dtype,
1013                    NumericOperationIr::RemScalar(desc.clone()),
1014                ),
1015                ModOps::<B>::new(desc),
1016            )
1017            .output()
1018    }
1019
1020    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1021        #[derive(new, Debug)]
1022        struct ZerosOps<B: FusionBackend> {
1023            desc: TensorIr,
1024            device: Device<B>,
1025        }
1026
1027        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
1028            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1029                let shape = self.desc.shape.clone();
1030                let output = B::int_zeros(shape, &self.device, self.desc.dtype.into());
1031                handles.register_int_tensor::<B>(&self.desc.id, output);
1032            }
1033        }
1034
1035        let client = get_client::<B>(device);
1036        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
1037
1038        client
1039            .register(
1040                OperationStreams::default(),
1041                OperationIr::BaseInt(BaseOperationIr::Zeros(desc.clone())),
1042                ZerosOps::<B>::new(desc.out, device.clone()),
1043            )
1044            .output()
1045    }
1046
1047    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1048        #[derive(new, Debug)]
1049        struct OnesOps<B: FusionBackend> {
1050            desc: TensorIr,
1051            device: Device<B>,
1052        }
1053
1054        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
1055            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1056                let shape = self.desc.shape.clone();
1057                let output = B::int_ones(shape, &self.device, self.desc.dtype.into());
1058                handles.register_int_tensor::<B>(&self.desc.id, output);
1059            }
1060        }
1061        let client = get_client::<B>(device);
1062        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
1063
1064        client
1065            .register(
1066                OperationStreams::default(),
1067                OperationIr::BaseInt(BaseOperationIr::Ones(desc.clone())),
1068                OnesOps::<B>::new(desc.out, device.clone()),
1069            )
1070            .output()
1071    }
1072
1073    fn int_full(
1074        shape: Shape,
1075        fill_value: Scalar,
1076        device: &Device<Self>,
1077        dtype: IntDType,
1078    ) -> IntTensor<Self> {
1079        #[derive(new, Debug)]
1080        struct FullOps<B: FusionBackend> {
1081            out: TensorIr,
1082            elem: ScalarIr,
1083            device: Device<B>,
1084        }
1085
1086        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
1087            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1088                let shape = self.out.shape.clone();
1089                let output =
1090                    B::int_full(shape, self.elem.into(), &self.device, self.out.dtype.into());
1091                handles.register_int_tensor::<B>(&self.out.id, output);
1092            }
1093        }
1094
1095        let client = get_client::<B>(device);
1096        let dtype = dtype.into();
1097        let value = fill_value.into();
1098        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
1099
1100        client
1101            .register(
1102                OperationStreams::default(),
1103                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Full(desc.clone())),
1104                FullOps::<B>::new(desc.out, desc.value, device.clone()),
1105            )
1106            .output()
1107    }
1108
1109    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
1110        unary_int_ops!(SumOps, B::int_sum, reduce);
1111
1112        let streams = OperationStreams::with_inputs([&tensor]);
1113
1114        let client = tensor.client.clone();
1115        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1116
1117        client
1118            .register(
1119                streams,
1120                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),
1121                SumOps::<B>::new(desc.into()),
1122            )
1123            .output()
1124    }
1125
1126    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
1127        reduce_int_ops!(SumDimOps, |tensor, axis, _| B::int_sum_dim(tensor, axis));
1128
1129        let streams = OperationStreams::with_inputs([&tensor]);
1130
1131        let client = tensor.client.clone();
1132        let desc =
1133            ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
1134
1135        client
1136            .register(
1137                streams,
1138                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),
1139                SumDimOps::<B>::new(desc),
1140            )
1141            .output()
1142    }
1143
1144    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
1145        unary_int_ops!(ProdOps, B::int_prod, reduce);
1146
1147        let streams = OperationStreams::with_inputs([&tensor]);
1148
1149        let client = tensor.client.clone();
1150        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1151
1152        client
1153            .register(
1154                streams,
1155                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),
1156                ProdOps::<B>::new(desc.into()),
1157            )
1158            .output()
1159    }
1160
1161    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1162        reduce_int_ops!(ProdDimOps, |tensor, axis, _| B::int_prod_dim(tensor, axis));
1163
1164        let streams = OperationStreams::with_inputs([&tensor]);
1165
1166        let client = tensor.client.clone();
1167        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1168
1169        client
1170            .register(
1171                streams,
1172                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ProdDim(desc.clone())),
1173                ProdDimOps::<B>::new(desc),
1174            )
1175            .output()
1176    }
1177
1178    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
1179        unary_int_ops!(MeanOps, B::int_mean, reduce);
1180
1181        let streams = OperationStreams::with_inputs([&tensor]);
1182
1183        let client = tensor.client.clone();
1184        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1185
1186        client
1187            .register(
1188                streams,
1189                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),
1190                MeanOps::<B>::new(desc.into()),
1191            )
1192            .output()
1193    }
1194
1195    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1196        reduce_int_ops!(MeanDimOps, |tensor, axis, _| B::int_mean_dim(tensor, axis));
1197
1198        let streams = OperationStreams::with_inputs([&tensor]);
1199
1200        let client = tensor.client.clone();
1201        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1202
1203        client
1204            .register(
1205                streams,
1206                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MeanDim(desc.clone())),
1207                MeanDimOps::<B>::new(desc),
1208            )
1209            .output()
1210    }
1211
1212    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1213        #[derive(new, Debug)]
1214        struct CumsumOps<B: FusionBackend> {
1215            desc: DimOpIr,
1216            _b: PhantomData<B>,
1217        }
1218
1219        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1220            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1221                let input = handles.get_int_tensor::<B>(&self.desc.input);
1222                let output = B::int_cumsum(input, self.desc.axis);
1223                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1224            }
1225        }
1226
1227        let streams = OperationStreams::with_inputs([&tensor]);
1228
1229        let client = tensor.client.clone();
1230        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1231
1232        client
1233            .register(
1234                streams,
1235                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),
1236                CumsumOps::<B>::new(desc),
1237            )
1238            .output()
1239    }
1240
1241    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1242        #[derive(new, Debug)]
1243        struct CumprodOps<B: FusionBackend> {
1244            desc: DimOpIr,
1245            _b: PhantomData<B>,
1246        }
1247
1248        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1249            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1250                let input = handles.get_int_tensor::<B>(&self.desc.input);
1251                let output = B::int_cumprod(input, self.desc.axis);
1252                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1253            }
1254        }
1255
1256        let streams = OperationStreams::with_inputs([&tensor]);
1257
1258        let client = tensor.client.clone();
1259        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1260
1261        client
1262            .register(
1263                streams,
1264                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumProd(desc.clone())),
1265                CumprodOps::<B>::new(desc),
1266            )
1267            .output()
1268    }
1269
1270    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1271        #[derive(new, Debug)]
1272        struct CumminOps<B: FusionBackend> {
1273            desc: DimOpIr,
1274            _b: PhantomData<B>,
1275        }
1276
1277        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1278            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1279                let input = handles.get_int_tensor::<B>(&self.desc.input);
1280                let output = B::int_cummin(input, self.desc.axis);
1281                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1282            }
1283        }
1284
1285        let streams = OperationStreams::with_inputs([&tensor]);
1286
1287        let client = tensor.client.clone();
1288        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1289
1290        client
1291            .register(
1292                streams,
1293                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),
1294                CumminOps::<B>::new(desc),
1295            )
1296            .output()
1297    }
1298
1299    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1300        #[derive(new, Debug)]
1301        struct CummaxOps<B: FusionBackend> {
1302            desc: DimOpIr,
1303            _b: PhantomData<B>,
1304        }
1305
1306        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1307            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1308                let input = handles.get_int_tensor::<B>(&self.desc.input);
1309                let output = B::int_cummax(input, self.desc.axis);
1310                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1311            }
1312        }
1313
1314        let streams = OperationStreams::with_inputs([&tensor]);
1315
1316        let client = tensor.client.clone();
1317        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1318
1319        client
1320            .register(
1321                streams,
1322                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),
1323                CummaxOps::<B>::new(desc),
1324            )
1325            .output()
1326    }
1327
1328    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1329        reduce_int_ops!(ArgMaxOps, |tensor, axis, _| B::int_argmax(tensor, axis));
1330
1331        let streams = OperationStreams::with_inputs([&tensor]);
1332
1333        let client = tensor.client.clone();
1334        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1335
1336        client
1337            .register(
1338                streams,
1339                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMax(desc.clone())),
1340                ArgMaxOps::<B>::new(desc),
1341            )
1342            .output()
1343    }
1344
1345    fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
1346        reduce_int_ops!(ArgTopKOps, B::int_argtopk);
1347
1348        let streams = OperationStreams::with_inputs([&tensor]);
1349
1350        let client = tensor.client.clone();
1351        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1352
1353        client
1354            .register(
1355                streams,
1356                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgTopK(desc.clone())),
1357                ArgTopKOps::<B>::new(desc),
1358            )
1359            .output()
1360    }
1361
1362    fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
1363        reduce_int_ops!(TopKOps, B::int_topk);
1364
1365        let streams = OperationStreams::with_inputs([&tensor]);
1366
1367        let client = tensor.client.clone();
1368        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1369
1370        client
1371            .register(
1372                streams,
1373                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::TopK(desc.clone())),
1374                TopKOps::<B>::new(desc),
1375            )
1376            .output()
1377    }
1378
1379    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1380        reduce_int_ops!(ArgMinOps, |tensor, axis, _| B::int_argmin(tensor, axis));
1381
1382        let streams = OperationStreams::with_inputs([&tensor]);
1383
1384        let client = tensor.client.clone();
1385        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1386
1387        client
1388            .register(
1389                streams,
1390                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMin(desc.clone())),
1391                ArgMinOps::<B>::new(desc),
1392            )
1393            .output()
1394    }
1395
1396    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
1397        #[derive(new, Debug)]
1398        struct ClampOps<B: FusionBackend> {
1399            desc: ClampOpIr,
1400            _b: PhantomData<B>,
1401        }
1402
1403        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
1404            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1405                let input = handles.get_int_tensor::<B>(&self.desc.tensor);
1406                let output = B::int_clamp(input, self.desc.min.into(), self.desc.max.into());
1407
1408                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1409            }
1410        }
1411
1412        let streams = OperationStreams::with_inputs([&tensor]);
1413
1414        let client = tensor.client.clone();
1415        let min = min.into();
1416        let max = max.into();
1417        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
1418
1419        client
1420            .register(
1421                streams,
1422                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Clamp(desc.clone())),
1423                ClampOps::<B>::new(desc),
1424            )
1425            .output()
1426    }
1427
1428    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1429        unary_int_ops!(AbsOps, B::int_abs);
1430
1431        let streams = OperationStreams::with_inputs([&tensor]);
1432
1433        let client = tensor.client.clone();
1434        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1435
1436        client
1437            .register(
1438                streams,
1439                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),
1440                AbsOps::<B>::new(desc),
1441            )
1442            .output()
1443    }
1444
1445    fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
1446        #[derive(new, Debug)]
1447        struct IntoFloatOps<B: FusionBackend> {
1448            desc: CastOpIr,
1449            _b: PhantomData<B>,
1450        }
1451
1452        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
1453            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1454                let input = handles.get_int_tensor::<B>(&self.desc.input);
1455                let output = B::int_into_float(input, self.desc.out.dtype.into());
1456                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1457            }
1458        }
1459
1460        let streams = OperationStreams::with_inputs([&tensor]);
1461
1462        let client = tensor.client.clone();
1463        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
1464            client.create_empty_handle()
1465        });
1466
1467        client
1468            .register(
1469                streams,
1470                OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())),
1471                IntoFloatOps::<B>::new(desc),
1472            )
1473            .output()
1474    }
1475
1476    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
1477        #[derive(new, Debug)]
1478        struct SwapDimsOps<B: FusionBackend> {
1479            desc: SwapDimsOpIr,
1480            _b: PhantomData<B>,
1481        }
1482
1483        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
1484            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1485                let input = handles.get_int_tensor::<B>(&self.desc.input);
1486                let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
1487                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1488            }
1489        }
1490        let streams = OperationStreams::with_inputs([&tensor]);
1491
1492        let client = tensor.client.clone();
1493        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
1494            client.create_empty_handle()
1495        });
1496
1497        client
1498            .register(
1499                streams,
1500                OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())),
1501                SwapDimsOps::<B>::new(desc),
1502            )
1503            .output()
1504    }
1505
1506    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
1507        unary_int_ops!(MaxOps, B::int_max, reduce);
1508
1509        let streams = OperationStreams::with_inputs([&tensor]);
1510
1511        let client = tensor.client.clone();
1512        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1513
1514        client
1515            .register(
1516                streams,
1517                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Max(desc.clone())),
1518                MaxOps::<B>::new(desc.into()),
1519            )
1520            .output()
1521    }
1522
1523    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1524        reduce_int_ops!(MaxDimOps, |tensor, axis, _| B::int_max_dim(tensor, axis));
1525
1526        let streams = OperationStreams::with_inputs([&tensor]);
1527
1528        let client = tensor.client.clone();
1529        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1530
1531        client
1532            .register(
1533                streams,
1534                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),
1535                MaxDimOps::<B>::new(desc),
1536            )
1537            .output()
1538    }
1539
1540    fn int_max_dim_with_indices(
1541        tensor: IntTensor<Self>,
1542        dim: usize,
1543    ) -> (IntTensor<Self>, IntTensor<Self>) {
1544        #[derive(new, Debug)]
1545        struct MaxDimWithIndicesOps<B: FusionBackend> {
1546            desc: ReduceDimWithIndicesOpIr,
1547            _b: PhantomData<B>,
1548        }
1549
1550        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
1551            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1552                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1553                let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
1554
1555                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1556                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1557            }
1558        }
1559
1560        let streams = OperationStreams::with_inputs([&tensor]);
1561
1562        let client = tensor.client.clone();
1563        let dtype = tensor.dtype;
1564        let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {
1565            client.create_empty_handle()
1566        });
1567
1568        client
1569            .register(
1570                streams,
1571                OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
1572                MaxDimWithIndicesOps::<B>::new(desc),
1573            )
1574            .outputs()
1575            .into()
1576    }
1577
1578    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
1579        unary_int_ops!(MinOps, B::int_min, reduce);
1580
1581        let streams = OperationStreams::with_inputs([&tensor]);
1582
1583        let client = tensor.client.clone();
1584        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1585
1586        client
1587            .register(
1588                streams,
1589                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Min(desc.clone())),
1590                MinOps::<B>::new(desc.into()),
1591            )
1592            .output()
1593    }
1594
1595    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1596        unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce);
1597
1598        let streams = OperationStreams::with_inputs([&tensor]);
1599
1600        let client = tensor.client.clone();
1601        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1602
1603        client
1604            .register(
1605                streams,
1606                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),
1607                MaxAbsOps::<B>::new(desc.into()),
1608            )
1609            .output()
1610    }
1611
1612    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1613        reduce_int_ops!(MaxAbsDimOps, |tensor, axis, _| B::int_max_abs_dim(
1614            tensor, axis
1615        ));
1616
1617        let streams = OperationStreams::with_inputs([&tensor]);
1618
1619        let client = tensor.client.clone();
1620        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1621
1622        client
1623            .register(
1624                streams,
1625                OperationIr::NumericInt(
1626                    desc.out.dtype,
1627                    NumericOperationIr::MaxAbsDim(desc.clone()),
1628                ),
1629                MaxAbsDimOps::<B>::new(desc),
1630            )
1631            .output()
1632    }
1633
1634    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1635        reduce_int_ops!(MinDimOps, |tensor, axis, _| B::int_min_dim(tensor, axis));
1636
1637        let streams = OperationStreams::with_inputs([&tensor]);
1638
1639        let client = tensor.client.clone();
1640        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1641
1642        client
1643            .register(
1644                streams,
1645                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),
1646                MinDimOps::<B>::new(desc),
1647            )
1648            .output()
1649    }
1650
1651    fn int_min_dim_with_indices(
1652        tensor: IntTensor<Self>,
1653        dim: usize,
1654    ) -> (IntTensor<Self>, IntTensor<Self>) {
1655        #[derive(new, Debug)]
1656        struct MinDimWithIndicesOps<B: FusionBackend> {
1657            desc: ReduceDimWithIndicesOpIr,
1658            _b: PhantomData<B>,
1659        }
1660
1661        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
1662            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1663                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1664                let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
1665
1666                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1667                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1668            }
1669        }
1670
1671        let streams = OperationStreams::with_inputs([&tensor]);
1672
1673        let client = tensor.client.clone();
1674        let dtype = tensor.dtype;
1675        let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {
1676            client.create_empty_handle()
1677        });
1678
1679        client
1680            .register(
1681                streams,
1682                OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
1683                MinDimWithIndicesOps::<B>::new(desc),
1684            )
1685            .outputs()
1686            .into()
1687    }
1688
1689    fn int_random(
1690        shape: Shape,
1691        distribution: Distribution,
1692        device: &Device<Self>,
1693        dtype: IntDType,
1694    ) -> IntTensor<Self> {
1695        #[derive(new, Debug)]
1696        struct IntRandomOps<B: FusionBackend> {
1697            desc: RandomOpIr,
1698            device: Device<B>,
1699        }
1700
1701        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntRandomOps<B> {
1702            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1703                let shape = self.desc.out.shape.clone();
1704                let output = B::int_random(
1705                    shape,
1706                    self.desc.distribution,
1707                    &self.device,
1708                    self.desc.out.dtype.into(),
1709                );
1710                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1711            }
1712        }
1713
1714        let dtype = dtype.into();
1715        let client = get_client::<B>(device);
1716        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
1717
1718        client
1719            .register(
1720                OperationStreams::default(),
1721                OperationIr::NumericInt(dtype, NumericOperationIr::IntRandom(desc.clone())),
1722                IntRandomOps::<B>::new(desc, device.clone()),
1723            )
1724            .output()
1725    }
1726
1727    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1728        #[derive(new, Debug)]
1729        struct PermuteDimsOps<B: FusionBackend> {
1730            desc: PermuteOpIr,
1731            _b: PhantomData<B>,
1732        }
1733
1734        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
1735            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1736                let input = handles.get_int_tensor::<B>(&self.desc.input);
1737                let output = B::int_permute(input, self.desc.axes.as_slice());
1738                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1739            }
1740        }
1741
1742        let streams = OperationStreams::with_inputs([&tensor]);
1743
1744        let client = tensor.client.clone();
1745        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1746            client.create_empty_handle()
1747        });
1748
1749        client
1750            .register(
1751                streams,
1752                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
1753                PermuteDimsOps::<B>::new(desc),
1754            )
1755            .output()
1756    }
1757
1758    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
1759        #[derive(new, Debug)]
1760        struct ExpandOps<B: FusionBackend> {
1761            desc: ShapeOpIr,
1762            _b: PhantomData<B>,
1763        }
1764
1765        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
1766            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1767                let input = handles.get_int_tensor::<B>(&self.desc.input);
1768                let output = B::int_expand(input, self.desc.out.shape.clone());
1769                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1770            }
1771        }
1772
1773        let streams = OperationStreams::with_inputs([&tensor]);
1774
1775        let client = tensor.client.clone();
1776        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1777
1778        client
1779            .register(
1780                streams,
1781                OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())),
1782                ExpandOps::<B>::new(desc),
1783            )
1784            .output()
1785    }
1786
1787    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1788        #[derive(new, Debug)]
1789        struct FlipDimsOps<B: FusionBackend> {
1790            desc: FlipOpIr,
1791            _b: PhantomData<B>,
1792        }
1793
1794        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipDimsOps<B> {
1795            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1796                let input = handles.get_int_tensor::<B>(&self.desc.input);
1797                let axes = &self.desc.axes;
1798                let output = B::int_flip(input, axes);
1799                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1800            }
1801        }
1802
1803        let streams = OperationStreams::with_inputs([&tensor]);
1804
1805        let client = tensor.client.clone();
1806        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1807            client.create_empty_handle()
1808        });
1809
1810        client
1811            .register(
1812                streams,
1813                OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
1814                FlipDimsOps::<B>::new(desc),
1815            )
1816            .output()
1817    }
1818
1819    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
1820        #[derive(new, Debug)]
1821        struct RepeatDimOps<B: FusionBackend> {
1822            desc: RepeatDimOpIr,
1823            _b: PhantomData<B>,
1824        }
1825
1826        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1827            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1828                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1829
1830                let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times);
1831
1832                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1833            }
1834        }
1835
1836        let streams = OperationStreams::with_inputs([&tensor]);
1837
1838        let client = tensor.client.clone();
1839        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1840            client.create_empty_handle()
1841        });
1842
1843        client
1844            .register(
1845                streams,
1846                OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())),
1847                RepeatDimOps::<B>::new(desc),
1848            )
1849            .output()
1850    }
1851
1852    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1853        binary_int_ops!(BitwiseAndOps, B::bitwise_and);
1854
1855        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1856
1857        let client = lhs.client.clone();
1858        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1859            client.create_empty_handle()
1860        });
1861
1862        client
1863            .register(
1864                streams,
1865                OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())),
1866                BitwiseAndOps::<B>::new(desc),
1867            )
1868            .output()
1869    }
1870
1871    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1872        scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);
1873
1874        let streams = OperationStreams::with_inputs([&lhs]);
1875
1876        let client = lhs.client.clone();
1877        let rhs = rhs.into();
1878        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1879
1880        client
1881            .register(
1882                streams,
1883                OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())),
1884                BitwiseAndOps::<B>::new(desc),
1885            )
1886            .output()
1887    }
1888
1889    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1890        binary_int_ops!(BitwiseOrOps, B::bitwise_or);
1891
1892        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1893
1894        let client = lhs.client.clone();
1895        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1896            client.create_empty_handle()
1897        });
1898
1899        client
1900            .register(
1901                streams,
1902                OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())),
1903                BitwiseOrOps::<B>::new(desc),
1904            )
1905            .output()
1906    }
1907
1908    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1909        scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);
1910
1911        let streams = OperationStreams::with_inputs([&lhs]);
1912
1913        let client = lhs.client.clone();
1914        let rhs = rhs.into();
1915        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1916
1917        client
1918            .register(
1919                streams,
1920                OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())),
1921                BitwiseOrOps::<B>::new(desc),
1922            )
1923            .output()
1924    }
1925
1926    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1927        binary_int_ops!(BitwiseXorOps, B::bitwise_xor);
1928
1929        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1930
1931        let client = lhs.client.clone();
1932        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1933            client.create_empty_handle()
1934        });
1935
1936        client
1937            .register(
1938                streams,
1939                OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())),
1940                BitwiseXorOps::<B>::new(desc),
1941            )
1942            .output()
1943    }
1944
1945    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1946        scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);
1947
1948        let streams = OperationStreams::with_inputs([&lhs]);
1949
1950        let client = lhs.client.clone();
1951        let rhs = rhs.into();
1952        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1953
1954        client
1955            .register(
1956                streams,
1957                OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())),
1958                BitwiseXorOps::<B>::new(desc),
1959            )
1960            .output()
1961    }
1962
1963    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
1964        unary_int_ops!(BitwiseNotOps, B::bitwise_not);
1965
1966        let streams = OperationStreams::with_inputs([&tensor]);
1967
1968        let client = tensor.client.clone();
1969        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1970
1971        client
1972            .register(
1973                streams,
1974                OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())),
1975                BitwiseNotOps::<B>::new(desc),
1976            )
1977            .output()
1978    }
1979
1980    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1981        binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);
1982
1983        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1984
1985        let client = lhs.client.clone();
1986        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1987            client.create_empty_handle()
1988        });
1989
1990        client
1991            .register(
1992                streams,
1993                OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())),
1994                BitwiseLeftShiftOps::<B>::new(desc),
1995            )
1996            .output()
1997    }
1998
1999    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2000        scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);
2001
2002        let streams = OperationStreams::with_inputs([&lhs]);
2003
2004        let client = lhs.client.clone();
2005        let rhs = rhs.into();
2006        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2007
2008        client
2009            .register(
2010                streams,
2011                OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())),
2012                BitwiseLeftShiftOps::<B>::new(desc),
2013            )
2014            .output()
2015    }
2016
2017    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2018        binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);
2019
2020        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2021
2022        let client = lhs.client.clone();
2023        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2024            client.create_empty_handle()
2025        });
2026
2027        client
2028            .register(
2029                streams,
2030                OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())),
2031                BitwiseRightShiftOps::<B>::new(desc),
2032            )
2033            .output()
2034    }
2035
2036    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2037        scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);
2038
2039        let streams = OperationStreams::with_inputs([&lhs]);
2040
2041        let client = lhs.client.clone();
2042        let rhs = rhs.into();
2043        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2044
2045        client
2046            .register(
2047                streams,
2048                OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())),
2049                BitwiseRightShiftOps::<B>::new(desc),
2050            )
2051            .output()
2052    }
2053
2054    fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
2055        #[derive(new, Debug)]
2056        struct CastOps<B: FusionBackend> {
2057            desc: CastOpIr,
2058            dtype: burn_backend::IntDType,
2059            _b: PhantomData<B>,
2060        }
2061
2062        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2063            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2064                let input = handles.get_int_tensor::<B>(&self.desc.input);
2065                let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype);
2066                handles.register_int_tensor::<B>(&self.desc.out.id, output);
2067            }
2068        }
2069
2070        let streams = OperationStreams::with_inputs([&tensor]);
2071
2072        let client = tensor.client.clone();
2073        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
2074            client.create_empty_handle()
2075        });
2076
2077        client
2078            .register(
2079                streams,
2080                OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),
2081                CastOps::<B>::new(desc, dtype),
2082            )
2083            .output()
2084    }
2085
2086    fn int_unfold(
2087        tensor: IntTensor<Self>,
2088        dim: usize,
2089        size: usize,
2090        step: usize,
2091    ) -> IntTensor<Self> {
2092        #[derive(new, Debug)]
2093        struct UnfoldOps<B: FusionBackend> {
2094            desc: UnfoldOpIr,
2095            _b: PhantomData<B>,
2096        }
2097
2098        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2099            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2100                let input = handles.get_int_tensor::<B>(&self.desc.input);
2101                let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2102
2103                handles.register_int_tensor::<B>(&self.desc.out.id, output);
2104            }
2105        }
2106
2107        let streams = OperationStreams::with_inputs([&tensor]);
2108
2109        let client = tensor.client.clone();
2110        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
2111            client.create_empty_handle()
2112        });
2113
2114        client
2115            .register(
2116                streams,
2117                OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2118                UnfoldOps::<B>::new(desc),
2119            )
2120            .output()
2121    }
2122
2123    fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2124        binary_int_ops!(PowOps, B::int_powi);
2125
2126        let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2127
2128        let client = lhs.client.clone();
2129        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2130            client.create_empty_handle()
2131        });
2132
2133        client
2134            .register(
2135                streams,
2136                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Powi(desc.clone())),
2137                PowOps::<B>::new(desc),
2138            )
2139            .output()
2140    }
2141
2142    fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2143        scalar_int_ops!(PowiOps, B::int_powi_scalar);
2144
2145        let streams = OperationStreams::with_inputs([&lhs]);
2146
2147        let client = lhs.client.clone();
2148        let rhs = rhs.into();
2149        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2150
2151        client
2152            .register(
2153                streams,
2154                OperationIr::NumericInt(
2155                    desc.out.dtype,
2156                    NumericOperationIr::PowiScalar(desc.clone()),
2157                ),
2158                PowiOps::<B>::new(desc),
2159            )
2160            .output()
2161    }
2162}