Skip to main content

burn_router/ops/
int_tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_std::{BoolDType, FloatDType};
4
5use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
6use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntElem, IntTensor};
7use burn_backend::{
8    Distribution, Element, IntDType, Scalar, Shape, Slice, TensorData, ops::IntTensorOps,
9};
10use burn_ir::{
11    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr,
12    GatherNdOpIr, GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr,
13    MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr,
14    ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterNdOpIr,
15    ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr,
16    UnaryOpIr, UnfoldOpIr,
17};
18
19impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
20    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
21        let client = get_client::<R>(device);
22        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
23
24        client
25            .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc)))
26            .output()
27    }
28
29    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
30        Ok(tensor
31            .into_data()
32            .await?
33            // Since underlying backends can have different data types, we convert to the current elem
34            .convert::<IntElem<Self>>())
35    }
36
37    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
38        let client = get_client::<R>(device);
39        let out = client.register_tensor_data(data);
40        let desc = InitOperationIr {
41            out: out.to_ir_out(),
42        };
43
44        // Call register op when output is already initialized
45        client.register_op(OperationIr::Init(desc));
46
47        out
48    }
49
50    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
51        tensor.client.device()
52    }
53
54    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
55        if &tensor.client.device() == device {
56            return tensor;
57        }
58        R::change_client_backend(tensor, device)
59    }
60
61    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
62        let client = tensor.client.clone();
63        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
64
65        client
66            .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc)))
67            .output()
68    }
69
70    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
71        let client = tensor.client.clone();
72        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
73            client.create_empty_handle()
74        });
75
76        client
77            .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc)))
78            .output()
79    }
80
81    fn int_slice_assign(
82        tensor: IntTensor<Self>,
83        slices: &[burn_backend::Slice],
84        value: IntTensor<Self>,
85    ) -> IntTensor<Self> {
86        let client = tensor.client.clone();
87        let desc =
88            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
89                client.create_empty_handle()
90            });
91
92        client
93            .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc)))
94            .output()
95    }
96
97    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
98        let client = lhs.client.clone();
99        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
100            client.create_empty_handle()
101        });
102
103        client
104            .register(OperationIr::Int(IntOperationIr::Matmul(desc)))
105            .output()
106    }
107
108    fn int_mask_where(
109        tensor: IntTensor<Self>,
110        mask: BoolTensor<Self>,
111        value: IntTensor<Self>,
112    ) -> IntTensor<Self> {
113        let client = tensor.client.clone();
114        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
115            client.create_empty_handle()
116        });
117
118        client
119            .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc)))
120            .output()
121    }
122
123    fn int_mask_fill(
124        tensor: IntTensor<Self>,
125        mask: BoolTensor<Self>,
126        value: Scalar,
127    ) -> IntTensor<Self> {
128        let client = tensor.client.clone();
129        let value = value.into();
130        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
131            client.create_empty_handle()
132        });
133
134        client
135            .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc)))
136            .output()
137    }
138
139    fn int_gather(
140        dim: usize,
141        tensor: IntTensor<Self>,
142        indices: IntTensor<Self>,
143    ) -> IntTensor<Self> {
144        let client = tensor.client.clone();
145        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
146            client.create_empty_handle()
147        });
148
149        client
150            .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc)))
151            .output()
152    }
153
154    fn int_scatter_add(
155        dim: usize,
156        tensor: IntTensor<Self>,
157        indices: IntTensor<Self>,
158        value: IntTensor<Self>,
159    ) -> IntTensor<Self> {
160        let client = tensor.client.clone();
161        let desc = ScatterOpIr::create(
162            tensor.into_ir(),
163            dim,
164            indices.into_ir(),
165            value.into_ir(),
166            IndexingUpdateOp::Add,
167            || client.create_empty_handle(),
168        );
169
170        client
171            .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc)))
172            .output()
173    }
174
175    fn int_scatter_nd(
176        data: IntTensor<Self>,
177        indices: IntTensor<Self>,
178        values: IntTensor<Self>,
179        reduction: IndexingUpdateOp,
180    ) -> IntTensor<Self> {
181        let client = data.client.clone();
182        let desc = ScatterNdOpIr::create(
183            data.into_ir(),
184            indices.into_ir(),
185            values.into_ir(),
186            reduction,
187            || client.create_empty_handle(),
188        );
189
190        client
191            .register(OperationIr::BaseInt(BaseOperationIr::ScatterNd(desc)))
192            .output()
193    }
194
195    fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
196        let client = data.client.clone();
197        let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
198            client.create_empty_handle()
199        });
200
201        client
202            .register(OperationIr::BaseInt(BaseOperationIr::GatherNd(desc)))
203            .output()
204    }
205
206    fn int_select(
207        tensor: IntTensor<Self>,
208        dim: usize,
209        indices: IntTensor<Self>,
210    ) -> IntTensor<Self> {
211        let client = tensor.client.clone();
212        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
213            client.create_empty_handle()
214        });
215
216        client
217            .register(OperationIr::BaseInt(BaseOperationIr::Select(desc)))
218            .output()
219    }
220
221    fn int_select_add(
222        tensor: IntTensor<Self>,
223        dim: usize,
224        indices: IntTensor<Self>,
225        value: IntTensor<Self>,
226    ) -> IntTensor<Self> {
227        let client = tensor.client.clone();
228        let desc = SelectAssignOpIr::create(
229            tensor.into_ir(),
230            dim,
231            indices.into_ir(),
232            value.into_ir(),
233            IndexingUpdateOp::Add,
234            || client.create_empty_handle(),
235        );
236
237        client
238            .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc)))
239            .output()
240    }
241
242    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
243        let client = tensors.first().unwrap().client.clone();
244        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
245        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
246
247        client
248            .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc)))
249            .output()
250    }
251
252    fn int_equal(
253        lhs: IntTensor<Self>,
254        rhs: IntTensor<Self>,
255        out_dtype: BoolDType,
256    ) -> BoolTensor<Self> {
257        let client = lhs.client.clone();
258        let desc =
259            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
260                client.create_empty_handle()
261            });
262
263        client
264            .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc)))
265            .output()
266    }
267
268    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
269        let client = lhs.client.clone();
270        let rhs = rhs.into();
271        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
272            client.create_empty_handle()
273        });
274
275        client
276            .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc)))
277            .output()
278    }
279
280    fn int_greater(
281        lhs: IntTensor<Self>,
282        rhs: IntTensor<Self>,
283        out_dtype: BoolDType,
284    ) -> BoolTensor<Self> {
285        let client = lhs.client.clone();
286        let desc =
287            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
288                client.create_empty_handle()
289            });
290
291        client
292            .register(OperationIr::NumericInt(
293                desc.lhs.dtype,
294                NumericOperationIr::Greater(desc),
295            ))
296            .output()
297    }
298
299    fn int_greater_elem(
300        lhs: IntTensor<Self>,
301        rhs: Scalar,
302        out_dtype: BoolDType,
303    ) -> BoolTensor<Self> {
304        let client = lhs.client.clone();
305        let rhs = rhs.into();
306        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
307            client.create_empty_handle()
308        });
309
310        client
311            .register(OperationIr::NumericInt(
312                desc.lhs.dtype,
313                NumericOperationIr::GreaterElem(desc),
314            ))
315            .output()
316    }
317
318    fn int_greater_equal(
319        lhs: IntTensor<Self>,
320        rhs: IntTensor<Self>,
321        out_dtype: BoolDType,
322    ) -> BoolTensor<Self> {
323        let client = lhs.client.clone();
324        let desc =
325            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
326                client.create_empty_handle()
327            });
328
329        client
330            .register(OperationIr::NumericInt(
331                desc.lhs.dtype,
332                NumericOperationIr::GreaterEqual(desc),
333            ))
334            .output()
335    }
336
337    fn int_greater_equal_elem(
338        lhs: IntTensor<Self>,
339        rhs: Scalar,
340        out_dtype: BoolDType,
341    ) -> BoolTensor<Self> {
342        let client = lhs.client.clone();
343        let rhs = rhs.into();
344        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
345            client.create_empty_handle()
346        });
347
348        client
349            .register(OperationIr::NumericInt(
350                desc.lhs.dtype,
351                NumericOperationIr::GreaterEqualElem(desc),
352            ))
353            .output()
354    }
355
356    fn int_lower(
357        lhs: IntTensor<Self>,
358        rhs: IntTensor<Self>,
359        out_dtype: BoolDType,
360    ) -> BoolTensor<Self> {
361        let client = lhs.client.clone();
362        let desc =
363            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
364                client.create_empty_handle()
365            });
366
367        client
368            .register(OperationIr::NumericInt(
369                desc.lhs.dtype,
370                NumericOperationIr::Lower(desc),
371            ))
372            .output()
373    }
374
375    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
376        let client = lhs.client.clone();
377        let rhs = rhs.into();
378        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
379            client.create_empty_handle()
380        });
381
382        client
383            .register(OperationIr::NumericInt(
384                desc.lhs.dtype,
385                NumericOperationIr::LowerElem(desc),
386            ))
387            .output()
388    }
389
390    fn int_lower_equal(
391        lhs: IntTensor<Self>,
392        rhs: IntTensor<Self>,
393        out_dtype: BoolDType,
394    ) -> BoolTensor<Self> {
395        let client = lhs.client.clone();
396        let desc =
397            BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
398                client.create_empty_handle()
399            });
400
401        client
402            .register(OperationIr::NumericInt(
403                desc.lhs.dtype,
404                NumericOperationIr::LowerEqual(desc),
405            ))
406            .output()
407    }
408
409    fn int_lower_equal_elem(
410        lhs: IntTensor<Self>,
411        rhs: Scalar,
412        out_dtype: BoolDType,
413    ) -> BoolTensor<Self> {
414        let client = lhs.client.clone();
415        let rhs = rhs.into();
416        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
417            client.create_empty_handle()
418        });
419
420        client
421            .register(OperationIr::NumericInt(
422                desc.lhs.dtype,
423                NumericOperationIr::LowerEqualElem(desc),
424            ))
425            .output()
426    }
427
428    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
429        let client = lhs.client.clone();
430        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
431            client.create_empty_handle()
432        });
433
434        client
435            .register(OperationIr::NumericInt(
436                desc.out.dtype,
437                NumericOperationIr::Add(desc),
438            ))
439            .output()
440    }
441
442    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
443        let client = lhs.client.clone();
444        let rhs = rhs.into();
445        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
446
447        client
448            .register(OperationIr::NumericInt(
449                desc.out.dtype,
450                NumericOperationIr::AddScalar(desc),
451            ))
452            .output()
453    }
454
455    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
456        let client = lhs.client.clone();
457        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
458            client.create_empty_handle()
459        });
460
461        client
462            .register(OperationIr::NumericInt(
463                desc.out.dtype,
464                NumericOperationIr::Sub(desc),
465            ))
466            .output()
467    }
468
469    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
470        let client = lhs.client.clone();
471        let rhs = rhs.into();
472        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
473
474        client
475            .register(OperationIr::NumericInt(
476                desc.out.dtype,
477                NumericOperationIr::SubScalar(desc),
478            ))
479            .output()
480    }
481
482    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
483        let client = lhs.client.clone();
484        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
485            client.create_empty_handle()
486        });
487
488        client
489            .register(OperationIr::NumericInt(
490                desc.out.dtype,
491                NumericOperationIr::Mul(desc),
492            ))
493            .output()
494    }
495
496    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
497        let client = lhs.client.clone();
498        let rhs = rhs.into();
499        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
500
501        client
502            .register(OperationIr::NumericInt(
503                desc.out.dtype,
504                NumericOperationIr::MulScalar(desc),
505            ))
506            .output()
507    }
508
509    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
510        let client = lhs.client.clone();
511        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
512            client.create_empty_handle()
513        });
514
515        client
516            .register(OperationIr::NumericInt(
517                desc.out.dtype,
518                NumericOperationIr::Div(desc),
519            ))
520            .output()
521    }
522
523    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
524        let client = lhs.client.clone();
525        let rhs = rhs.into();
526        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
527
528        client
529            .register(OperationIr::NumericInt(
530                desc.out.dtype,
531                NumericOperationIr::DivScalar(desc),
532            ))
533            .output()
534    }
535
536    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
537        let client = lhs.client.clone();
538        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
539            client.create_empty_handle()
540        });
541
542        client
543            .register(OperationIr::NumericInt(
544                desc.out.dtype,
545                NumericOperationIr::Rem(desc),
546            ))
547            .output()
548    }
549
550    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
551        let client = lhs.client.clone();
552        let rhs = rhs.into();
553        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
554
555        client
556            .register(OperationIr::NumericInt(
557                desc.out.dtype,
558                NumericOperationIr::RemScalar(desc),
559            ))
560            .output()
561    }
562
563    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
564        let client = get_client::<R>(device);
565        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
566
567        client
568            .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc)))
569            .output()
570    }
571
572    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
573        let client = get_client::<R>(device);
574        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
575
576        client
577            .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc)))
578            .output()
579    }
580
581    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
582        let client = tensor.client.clone();
583        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
584
585        client
586            .register(OperationIr::NumericInt(
587                desc.out.dtype,
588                NumericOperationIr::Sum(desc),
589            ))
590            .output()
591    }
592
593    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
594        let client = tensor.client.clone();
595        let desc =
596            ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
597
598        client
599            .register(OperationIr::NumericInt(
600                desc.out.dtype,
601                NumericOperationIr::SumDim(desc),
602            ))
603            .output()
604    }
605
606    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
607        let client = tensor.client.clone();
608        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
609
610        client
611            .register(OperationIr::NumericInt(
612                desc.out.dtype,
613                NumericOperationIr::Prod(desc),
614            ))
615            .output()
616    }
617
618    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
619        let client = tensor.client.clone();
620        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
621
622        client
623            .register(OperationIr::NumericInt(
624                desc.out.dtype,
625                NumericOperationIr::ProdDim(desc),
626            ))
627            .output()
628    }
629
630    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
631        let client = tensor.client.clone();
632        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
633
634        client
635            .register(OperationIr::NumericInt(
636                desc.out.dtype,
637                NumericOperationIr::Mean(desc),
638            ))
639            .output()
640    }
641
642    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
643        let client = tensor.client.clone();
644        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
645
646        client
647            .register(OperationIr::NumericInt(
648                desc.out.dtype,
649                NumericOperationIr::MeanDim(desc),
650            ))
651            .output()
652    }
653
654    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
655        let client = tensor.client.clone();
656        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
657
658        client
659            .register(OperationIr::NumericInt(
660                desc.out.dtype,
661                NumericOperationIr::CumSum(desc),
662            ))
663            .output()
664    }
665
666    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
667        let client = tensor.client.clone();
668        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
669
670        client
671            .register(OperationIr::NumericInt(
672                desc.out.dtype,
673                NumericOperationIr::CumProd(desc),
674            ))
675            .output()
676    }
677
678    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
679        let client = tensor.client.clone();
680        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
681
682        client
683            .register(OperationIr::NumericInt(
684                desc.out.dtype,
685                NumericOperationIr::CumMin(desc),
686            ))
687            .output()
688    }
689
690    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
691        let client = tensor.client.clone();
692        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
693
694        client
695            .register(OperationIr::NumericInt(
696                desc.out.dtype,
697                NumericOperationIr::CumMax(desc),
698            ))
699            .output()
700    }
701
702    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
703        let client = tensor.client.clone();
704        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
705
706        client
707            .register(OperationIr::NumericInt(
708                desc.out.dtype,
709                NumericOperationIr::ArgMax(desc),
710            ))
711            .output()
712    }
713
714    fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
715        let client = tensor.client.clone();
716        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
717
718        client
719            .register(OperationIr::NumericInt(
720                desc.out.dtype,
721                NumericOperationIr::ArgTopK(desc),
722            ))
723            .output()
724    }
725
726    fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
727        let client = tensor.client.clone();
728        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
729
730        client
731            .register(OperationIr::NumericInt(
732                desc.out.dtype,
733                NumericOperationIr::TopK(desc),
734            ))
735            .output()
736    }
737
738    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
739        let client = tensor.client.clone();
740        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
741
742        client
743            .register(OperationIr::NumericInt(
744                desc.out.dtype,
745                NumericOperationIr::ArgMin(desc),
746            ))
747            .output()
748    }
749
750    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
751        let client = tensor.client.clone();
752        let min = min.into();
753        let max = max.into();
754        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
755
756        client
757            .register(OperationIr::NumericInt(
758                desc.out.dtype,
759                NumericOperationIr::Clamp(desc),
760            ))
761            .output()
762    }
763
764    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
765        let client = tensor.client.clone();
766        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
767
768        client
769            .register(OperationIr::NumericInt(
770                desc.out.dtype,
771                NumericOperationIr::Abs(desc),
772            ))
773            .output()
774    }
775
776    fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
777        let client = tensor.client.clone();
778        let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
779            client.create_empty_handle()
780        });
781
782        client
783            .register(OperationIr::Int(IntOperationIr::IntoFloat(desc)))
784            .output()
785    }
786
787    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
788        let client = tensor.client.clone();
789        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
790            client.create_empty_handle()
791        });
792
793        client
794            .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc)))
795            .output()
796    }
797
798    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
799        let client = tensor.client.clone();
800        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
801
802        client
803            .register(OperationIr::NumericInt(
804                desc.out.dtype,
805                NumericOperationIr::Max(desc),
806            ))
807            .output()
808    }
809
810    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
811        let client = tensor.client.clone();
812        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
813
814        client
815            .register(OperationIr::NumericInt(
816                desc.out.dtype,
817                NumericOperationIr::MaxDim(desc),
818            ))
819            .output()
820    }
821
822    fn int_max_dim_with_indices(
823        tensor: IntTensor<Self>,
824        dim: usize,
825    ) -> (IntTensor<Self>, IntTensor<Self>) {
826        let client = tensor.client.clone();
827        let desc = ReduceDimWithIndicesOpIr::create(
828            tensor.into_ir(),
829            dim,
830            IntElem::<Self>::dtype(),
831            || client.create_empty_handle(),
832        );
833
834        client
835            .register(OperationIr::NumericInt(
836                desc.tensor.dtype,
837                NumericOperationIr::MaxDimWithIndices(desc),
838            ))
839            .outputs()
840            .into()
841    }
842
843    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
844        let client = tensor.client.clone();
845        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
846
847        client
848            .register(OperationIr::NumericInt(
849                desc.out.dtype,
850                NumericOperationIr::MaxAbs(desc),
851            ))
852            .output()
853    }
854
855    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
856        let client = tensor.client.clone();
857        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
858
859        client
860            .register(OperationIr::NumericInt(
861                desc.out.dtype,
862                NumericOperationIr::MaxAbsDim(desc),
863            ))
864            .output()
865    }
866
867    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
868        let client = tensor.client.clone();
869        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
870
871        client
872            .register(OperationIr::NumericInt(
873                desc.out.dtype,
874                NumericOperationIr::Min(desc),
875            ))
876            .output()
877    }
878
879    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
880        let client = tensor.client.clone();
881        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
882
883        client
884            .register(OperationIr::NumericInt(
885                desc.out.dtype,
886                NumericOperationIr::MinDim(desc),
887            ))
888            .output()
889    }
890
891    fn int_min_dim_with_indices(
892        tensor: IntTensor<Self>,
893        dim: usize,
894    ) -> (IntTensor<Self>, IntTensor<Self>) {
895        let client = tensor.client.clone();
896        let desc = ReduceDimWithIndicesOpIr::create(
897            tensor.into_ir(),
898            dim,
899            IntElem::<Self>::dtype(),
900            || client.create_empty_handle(),
901        );
902
903        client
904            .register(OperationIr::NumericInt(
905                desc.out.dtype,
906                NumericOperationIr::MinDimWithIndices(desc),
907            ))
908            .outputs()
909            .into()
910    }
911
912    fn int_random(
913        shape: Shape,
914        distribution: Distribution,
915        device: &Device<Self>,
916        dtype: IntDType,
917    ) -> IntTensor<Self> {
918        let client = get_client::<R>(device);
919        let dtype = dtype.into();
920        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
921
922        client
923            .register(OperationIr::NumericInt(
924                dtype,
925                NumericOperationIr::IntRandom(desc),
926            ))
927            .output()
928    }
929
930    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
931        let client = tensor.client.clone();
932        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
933            client.create_empty_handle()
934        });
935
936        client
937            .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc)))
938            .output()
939    }
940
941    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
942        let client = tensor.client.clone();
943        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
944
945        client
946            .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc)))
947            .output()
948    }
949
950    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
951        let client = tensor.client.clone();
952        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
953            client.create_empty_handle()
954        });
955
956        client
957            .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc)))
958            .output()
959    }
960
961    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
962        let client = tensor.client.clone();
963        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
964            client.create_empty_handle()
965        });
966
967        client
968            .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc)))
969            .output()
970    }
971
972    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
973        let client = lhs.client.clone();
974        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
975            client.create_empty_handle()
976        });
977
978        client
979            .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc)))
980            .output()
981    }
982
983    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
984        let client = lhs.client.clone();
985        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
986            client.create_empty_handle()
987        });
988
989        client
990            .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc)))
991            .output()
992    }
993
994    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
995        let client = lhs.client.clone();
996        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
997            client.create_empty_handle()
998        });
999
1000        client
1001            .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc)))
1002            .output()
1003    }
1004
1005    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
1006        let client = tensor.client.clone();
1007        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1008
1009        client
1010            .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc)))
1011            .output()
1012    }
1013
1014    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1015        let client = lhs.client.clone();
1016        let rhs = rhs.into();
1017        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1018
1019        client
1020            .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc)))
1021            .output()
1022    }
1023
1024    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1025        let client = lhs.client.clone();
1026        let rhs = rhs.into();
1027        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1028
1029        client
1030            .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc)))
1031            .output()
1032    }
1033
1034    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1035        let client = lhs.client.clone();
1036        let rhs = rhs.into();
1037        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1038
1039        client
1040            .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc)))
1041            .output()
1042    }
1043
1044    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1045        let client = lhs.client.clone();
1046        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1047            client.create_empty_handle()
1048        });
1049
1050        client
1051            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc)))
1052            .output()
1053    }
1054
1055    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1056        let client = lhs.client.clone();
1057        let rhs = rhs.into();
1058        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1059
1060        client
1061            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(
1062                desc,
1063            )))
1064            .output()
1065    }
1066
1067    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1068        let client = lhs.client.clone();
1069        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1070            client.create_empty_handle()
1071        });
1072
1073        client
1074            .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc)))
1075            .output()
1076    }
1077
1078    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1079        let client = lhs.client.clone();
1080        let rhs = rhs.into();
1081        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1082
1083        client
1084            .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(
1085                desc,
1086            )))
1087            .output()
1088    }
1089
1090    fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
1091        let client = tensor.client.clone();
1092        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1093            client.create_empty_handle()
1094        });
1095
1096        client
1097            .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc)))
1098            .output()
1099    }
1100
1101    fn int_unfold(
1102        tensor: IntTensor<Self>,
1103        dim: usize,
1104        size: usize,
1105        step: usize,
1106    ) -> IntTensor<Self> {
1107        let client = tensor.client.clone();
1108        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1109            client.create_empty_handle()
1110        });
1111
1112        client
1113            .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)))
1114            .output()
1115    }
1116
1117    fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1118        let client = lhs.client.clone();
1119        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1120            client.create_empty_handle()
1121        });
1122
1123        client
1124            .register(OperationIr::NumericInt(
1125                desc.out.dtype,
1126                NumericOperationIr::Powi(desc),
1127            ))
1128            .output()
1129    }
1130
1131    fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1132        let client = lhs.client.clone();
1133        let rhs = rhs.into();
1134        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1135
1136        client
1137            .register(OperationIr::NumericInt(
1138                desc.out.dtype,
1139                NumericOperationIr::PowiScalar(desc),
1140            ))
1141            .output()
1142    }
1143}