Skip to main content

burn_router/ops/
int_tensor.rs

1use alloc::vec::Vec;
2use burn_backend::backend::{Backend, ExecutionError};
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::tensor::{
6    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
7};
8use burn_backend::{Distribution, Element, IntDType, Shape, Slice, TensorData, ops::IntTensorOps};
9use burn_ir::{
10    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr,
11    GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, MatmulOpIr,
12    NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr, ReduceDimOpIr,
13    ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr, ScatterOpIr,
14    SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,
15    UnfoldOpIr,
16};
17
18impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
19    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
20        let client = get_client::<R>(device);
21        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
22
23        client
24            .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc)))
25            .output()
26    }
27
28    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
29        Ok(tensor
30            .into_data()
31            .await?
32            // Since underlying backends can have different data types, we convert to the current elem
33            .convert::<<Self as Backend>::IntElem>())
34    }
35
36    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
37        let client = get_client::<R>(device);
38        let out = client.register_tensor_data(data);
39        let desc = InitOperationIr {
40            out: out.to_ir_out(),
41        };
42
43        // Call register op when output is already initialized
44        client.register_op(OperationIr::Init(desc));
45
46        out
47    }
48
49    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
50        tensor.client.device()
51    }
52
53    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
54        if &tensor.client.device() == device {
55            return tensor;
56        }
57        R::change_client_backend(tensor, device)
58    }
59
60    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
61        let client = tensor.client.clone();
62        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
63
64        client
65            .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc)))
66            .output()
67    }
68
69    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
70        let client = tensor.client.clone();
71        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
72            client.create_empty_handle()
73        });
74
75        client
76            .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc)))
77            .output()
78    }
79
80    fn int_slice_assign(
81        tensor: IntTensor<Self>,
82        slices: &[burn_backend::Slice],
83        value: IntTensor<Self>,
84    ) -> IntTensor<Self> {
85        let client = tensor.client.clone();
86        let desc =
87            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
88                client.create_empty_handle()
89            });
90
91        client
92            .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc)))
93            .output()
94    }
95
96    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
97        let client = lhs.client.clone();
98        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
99            client.create_empty_handle()
100        });
101
102        client
103            .register(OperationIr::Int(IntOperationIr::Matmul(desc)))
104            .output()
105    }
106
107    fn int_mask_where(
108        tensor: IntTensor<Self>,
109        mask: BoolTensor<Self>,
110        value: IntTensor<Self>,
111    ) -> IntTensor<Self> {
112        let client = tensor.client.clone();
113        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
114            client.create_empty_handle()
115        });
116
117        client
118            .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc)))
119            .output()
120    }
121
122    fn int_mask_fill(
123        tensor: IntTensor<Self>,
124        mask: BoolTensor<Self>,
125        value: IntElem<Self>,
126    ) -> IntTensor<Self> {
127        let client = tensor.client.clone();
128        let value = ScalarIr::with_dtype(value, &tensor.dtype);
129        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
130            client.create_empty_handle()
131        });
132
133        client
134            .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc)))
135            .output()
136    }
137
138    fn int_gather(
139        dim: usize,
140        tensor: IntTensor<Self>,
141        indices: IntTensor<Self>,
142    ) -> IntTensor<Self> {
143        let client = tensor.client.clone();
144        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
145            client.create_empty_handle()
146        });
147
148        client
149            .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc)))
150            .output()
151    }
152
153    fn int_scatter_add(
154        dim: usize,
155        tensor: IntTensor<Self>,
156        indices: IntTensor<Self>,
157        value: IntTensor<Self>,
158    ) -> IntTensor<Self> {
159        let client = tensor.client.clone();
160        let desc = ScatterOpIr::create(
161            tensor.into_ir(),
162            dim,
163            indices.into_ir(),
164            value.into_ir(),
165            IndexingUpdateOp::Add,
166            || client.create_empty_handle(),
167        );
168
169        client
170            .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc)))
171            .output()
172    }
173
174    fn int_select(
175        tensor: IntTensor<Self>,
176        dim: usize,
177        indices: IntTensor<Self>,
178    ) -> IntTensor<Self> {
179        let client = tensor.client.clone();
180        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
181            client.create_empty_handle()
182        });
183
184        client
185            .register(OperationIr::BaseInt(BaseOperationIr::Select(desc)))
186            .output()
187    }
188
189    fn int_select_add(
190        tensor: IntTensor<Self>,
191        dim: usize,
192        indices: IntTensor<Self>,
193        value: IntTensor<Self>,
194    ) -> IntTensor<Self> {
195        let client = tensor.client.clone();
196        let desc = SelectAssignOpIr::create(
197            tensor.into_ir(),
198            dim,
199            indices.into_ir(),
200            value.into_ir(),
201            IndexingUpdateOp::Add,
202            || client.create_empty_handle(),
203        );
204
205        client
206            .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc)))
207            .output()
208    }
209
210    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
211        let client = tensors.first().unwrap().client.clone();
212        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
213        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
214
215        client
216            .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc)))
217            .output()
218    }
219
220    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
221        let client = lhs.client.clone();
222        let desc = BinaryOpIr::create_comparison(
223            lhs.into_ir(),
224            rhs.into_ir(),
225            R::BoolElem::dtype(),
226            || client.create_empty_handle(),
227        );
228
229        client
230            .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc)))
231            .output()
232    }
233
234    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
235        let client = lhs.client.clone();
236        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
237        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
238            client.create_empty_handle()
239        });
240
241        client
242            .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc)))
243            .output()
244    }
245
246    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
247        let client = lhs.client.clone();
248        let desc = BinaryOpIr::create_comparison(
249            lhs.into_ir(),
250            rhs.into_ir(),
251            R::BoolElem::dtype(),
252            || client.create_empty_handle(),
253        );
254
255        client
256            .register(OperationIr::NumericInt(
257                desc.lhs.dtype,
258                NumericOperationIr::Greater(desc),
259            ))
260            .output()
261    }
262
263    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
264        let client = lhs.client.clone();
265        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
266        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
267            client.create_empty_handle()
268        });
269
270        client
271            .register(OperationIr::NumericInt(
272                desc.lhs.dtype,
273                NumericOperationIr::GreaterElem(desc),
274            ))
275            .output()
276    }
277
278    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
279        let client = lhs.client.clone();
280        let desc = BinaryOpIr::create_comparison(
281            lhs.into_ir(),
282            rhs.into_ir(),
283            R::BoolElem::dtype(),
284            || client.create_empty_handle(),
285        );
286
287        client
288            .register(OperationIr::NumericInt(
289                desc.lhs.dtype,
290                NumericOperationIr::GreaterEqual(desc),
291            ))
292            .output()
293    }
294
295    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
296        let client = lhs.client.clone();
297        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
298        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
299            client.create_empty_handle()
300        });
301
302        client
303            .register(OperationIr::NumericInt(
304                desc.lhs.dtype,
305                NumericOperationIr::GreaterEqualElem(desc),
306            ))
307            .output()
308    }
309
310    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
311        let client = lhs.client.clone();
312        let desc = BinaryOpIr::create_comparison(
313            lhs.into_ir(),
314            rhs.into_ir(),
315            R::BoolElem::dtype(),
316            || client.create_empty_handle(),
317        );
318
319        client
320            .register(OperationIr::NumericInt(
321                desc.lhs.dtype,
322                NumericOperationIr::Lower(desc),
323            ))
324            .output()
325    }
326
327    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
328        let client = lhs.client.clone();
329        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
330        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
331            client.create_empty_handle()
332        });
333
334        client
335            .register(OperationIr::NumericInt(
336                desc.lhs.dtype,
337                NumericOperationIr::LowerElem(desc),
338            ))
339            .output()
340    }
341
342    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
343        let client = lhs.client.clone();
344        let desc = BinaryOpIr::create_comparison(
345            lhs.into_ir(),
346            rhs.into_ir(),
347            R::BoolElem::dtype(),
348            || client.create_empty_handle(),
349        );
350
351        client
352            .register(OperationIr::NumericInt(
353                desc.lhs.dtype,
354                NumericOperationIr::LowerEqual(desc),
355            ))
356            .output()
357    }
358
359    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
360        let client = lhs.client.clone();
361        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
362        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
363            client.create_empty_handle()
364        });
365
366        client
367            .register(OperationIr::NumericInt(
368                desc.lhs.dtype,
369                NumericOperationIr::LowerEqualElem(desc),
370            ))
371            .output()
372    }
373
374    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
375        let client = lhs.client.clone();
376        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
377            client.create_empty_handle()
378        });
379
380        client
381            .register(OperationIr::NumericInt(
382                desc.out.dtype,
383                NumericOperationIr::Add(desc),
384            ))
385            .output()
386    }
387
388    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
389        let client = lhs.client.clone();
390        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
391        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
392
393        client
394            .register(OperationIr::NumericInt(
395                desc.out.dtype,
396                NumericOperationIr::AddScalar(desc),
397            ))
398            .output()
399    }
400
401    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
402        let client = lhs.client.clone();
403        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
404            client.create_empty_handle()
405        });
406
407        client
408            .register(OperationIr::NumericInt(
409                desc.out.dtype,
410                NumericOperationIr::Sub(desc),
411            ))
412            .output()
413    }
414
415    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
416        let client = lhs.client.clone();
417        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
418        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
419
420        client
421            .register(OperationIr::NumericInt(
422                desc.out.dtype,
423                NumericOperationIr::SubScalar(desc),
424            ))
425            .output()
426    }
427
428    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
429        let client = lhs.client.clone();
430        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
431            client.create_empty_handle()
432        });
433
434        client
435            .register(OperationIr::NumericInt(
436                desc.out.dtype,
437                NumericOperationIr::Mul(desc),
438            ))
439            .output()
440    }
441
442    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
443        let client = lhs.client.clone();
444        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
445        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
446
447        client
448            .register(OperationIr::NumericInt(
449                desc.out.dtype,
450                NumericOperationIr::MulScalar(desc),
451            ))
452            .output()
453    }
454
455    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
456        let client = lhs.client.clone();
457        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
458            client.create_empty_handle()
459        });
460
461        client
462            .register(OperationIr::NumericInt(
463                desc.out.dtype,
464                NumericOperationIr::Div(desc),
465            ))
466            .output()
467    }
468
469    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
470        let client = lhs.client.clone();
471        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
472        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
473
474        client
475            .register(OperationIr::NumericInt(
476                desc.out.dtype,
477                NumericOperationIr::DivScalar(desc),
478            ))
479            .output()
480    }
481
482    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
483        let client = lhs.client.clone();
484        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
485            client.create_empty_handle()
486        });
487
488        client
489            .register(OperationIr::NumericInt(
490                desc.out.dtype,
491                NumericOperationIr::Rem(desc),
492            ))
493            .output()
494    }
495
496    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
497        let client = lhs.client.clone();
498        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
499        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
500
501        client
502            .register(OperationIr::NumericInt(
503                desc.out.dtype,
504                NumericOperationIr::RemScalar(desc),
505            ))
506            .output()
507    }
508
509    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
510        let client = get_client::<R>(device);
511        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
512
513        client
514            .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc)))
515            .output()
516    }
517
518    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
519        let client = get_client::<R>(device);
520        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
521
522        client
523            .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc)))
524            .output()
525    }
526
527    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
528        let client = tensor.client.clone();
529        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
530
531        client
532            .register(OperationIr::NumericInt(
533                desc.out.dtype,
534                NumericOperationIr::Sum(desc),
535            ))
536            .output()
537    }
538
539    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
540        let client = tensor.client.clone();
541        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());
542
543        client
544            .register(OperationIr::NumericInt(
545                desc.out.dtype,
546                NumericOperationIr::SumDim(desc),
547            ))
548            .output()
549    }
550
551    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
552        let client = tensor.client.clone();
553        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
554
555        client
556            .register(OperationIr::NumericInt(
557                desc.out.dtype,
558                NumericOperationIr::Prod(desc),
559            ))
560            .output()
561    }
562
563    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
564        let client = tensor.client.clone();
565        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
566
567        client
568            .register(OperationIr::NumericInt(
569                desc.out.dtype,
570                NumericOperationIr::ProdDim(desc),
571            ))
572            .output()
573    }
574
575    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
576        let client = tensor.client.clone();
577        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
578
579        client
580            .register(OperationIr::NumericInt(
581                desc.out.dtype,
582                NumericOperationIr::Mean(desc),
583            ))
584            .output()
585    }
586
587    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
588        let client = tensor.client.clone();
589        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
590
591        client
592            .register(OperationIr::NumericInt(
593                desc.out.dtype,
594                NumericOperationIr::MeanDim(desc),
595            ))
596            .output()
597    }
598
599    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
600        let client = tensor.client.clone();
601        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
602
603        client
604            .register(OperationIr::NumericInt(
605                desc.out.dtype,
606                NumericOperationIr::CumSum(desc),
607            ))
608            .output()
609    }
610
611    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
612        let client = tensor.client.clone();
613        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
614
615        client
616            .register(OperationIr::NumericInt(
617                desc.out.dtype,
618                NumericOperationIr::CumProd(desc),
619            ))
620            .output()
621    }
622
623    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
624        let client = tensor.client.clone();
625        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
626
627        client
628            .register(OperationIr::NumericInt(
629                desc.out.dtype,
630                NumericOperationIr::CumMin(desc),
631            ))
632            .output()
633    }
634
635    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
636        let client = tensor.client.clone();
637        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
638
639        client
640            .register(OperationIr::NumericInt(
641                desc.out.dtype,
642                NumericOperationIr::CumMax(desc),
643            ))
644            .output()
645    }
646
647    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
648        let client = tensor.client.clone();
649        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
650
651        client
652            .register(OperationIr::NumericInt(
653                desc.out.dtype,
654                NumericOperationIr::ArgMax(desc),
655            ))
656            .output()
657    }
658
659    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
660        let client = tensor.client.clone();
661        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
662
663        client
664            .register(OperationIr::NumericInt(
665                desc.out.dtype,
666                NumericOperationIr::ArgMin(desc),
667            ))
668            .output()
669    }
670
671    fn int_clamp(
672        tensor: IntTensor<Self>,
673        min: IntElem<Self>,
674        max: IntElem<Self>,
675    ) -> IntTensor<Self> {
676        let client = tensor.client.clone();
677        let min = ScalarIr::with_dtype(min, &tensor.dtype);
678        let max = ScalarIr::with_dtype(max, &tensor.dtype);
679        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
680
681        client
682            .register(OperationIr::NumericInt(
683                desc.out.dtype,
684                NumericOperationIr::Clamp(desc),
685            ))
686            .output()
687    }
688
689    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
690        let client = tensor.client.clone();
691        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
692
693        client
694            .register(OperationIr::NumericInt(
695                desc.out.dtype,
696                NumericOperationIr::Abs(desc),
697            ))
698            .output()
699    }
700
701    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
702        let client = tensor.client.clone();
703        let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
704            client.create_empty_handle()
705        });
706
707        client
708            .register(OperationIr::Int(IntOperationIr::IntoFloat(desc)))
709            .output()
710    }
711
712    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
713        let client = tensor.client.clone();
714        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
715            client.create_empty_handle()
716        });
717
718        client
719            .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc)))
720            .output()
721    }
722
723    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
724        let client = tensor.client.clone();
725        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
726
727        client
728            .register(OperationIr::NumericInt(
729                desc.out.dtype,
730                NumericOperationIr::Max(desc),
731            ))
732            .output()
733    }
734
735    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
736        let client = tensor.client.clone();
737        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
738
739        client
740            .register(OperationIr::NumericInt(
741                desc.out.dtype,
742                NumericOperationIr::MaxDim(desc),
743            ))
744            .output()
745    }
746
747    fn int_max_dim_with_indices(
748        tensor: IntTensor<Self>,
749        dim: usize,
750    ) -> (IntTensor<Self>, IntTensor<Self>) {
751        let client = tensor.client.clone();
752        let desc = ReduceDimWithIndicesOpIr::create(
753            tensor.into_ir(),
754            dim,
755            IntElem::<Self>::dtype(),
756            || client.create_empty_handle(),
757        );
758
759        client
760            .register(OperationIr::NumericInt(
761                desc.tensor.dtype,
762                NumericOperationIr::MaxDimWithIndices(desc),
763            ))
764            .outputs()
765            .into()
766    }
767
768    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
769        let client = tensor.client.clone();
770        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
771
772        client
773            .register(OperationIr::NumericInt(
774                desc.out.dtype,
775                NumericOperationIr::MaxAbs(desc),
776            ))
777            .output()
778    }
779
780    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
781        let client = tensor.client.clone();
782        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
783
784        client
785            .register(OperationIr::NumericInt(
786                desc.out.dtype,
787                NumericOperationIr::MaxAbsDim(desc),
788            ))
789            .output()
790    }
791
792    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
793        let client = tensor.client.clone();
794        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
795
796        client
797            .register(OperationIr::NumericInt(
798                desc.out.dtype,
799                NumericOperationIr::Min(desc),
800            ))
801            .output()
802    }
803
804    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
805        let client = tensor.client.clone();
806        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
807
808        client
809            .register(OperationIr::NumericInt(
810                desc.out.dtype,
811                NumericOperationIr::MinDim(desc),
812            ))
813            .output()
814    }
815
816    fn int_min_dim_with_indices(
817        tensor: IntTensor<Self>,
818        dim: usize,
819    ) -> (IntTensor<Self>, IntTensor<Self>) {
820        let client = tensor.client.clone();
821        let desc = ReduceDimWithIndicesOpIr::create(
822            tensor.into_ir(),
823            dim,
824            IntElem::<Self>::dtype(),
825            || client.create_empty_handle(),
826        );
827
828        client
829            .register(OperationIr::NumericInt(
830                desc.out.dtype,
831                NumericOperationIr::MinDimWithIndices(desc),
832            ))
833            .outputs()
834            .into()
835    }
836
837    fn int_random(
838        shape: Shape,
839        distribution: Distribution,
840        device: &Device<Self>,
841    ) -> IntTensor<Self> {
842        let client = get_client::<R>(device);
843        let dtype = IntElem::<Self>::dtype();
844        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
845
846        client
847            .register(OperationIr::NumericInt(
848                dtype,
849                NumericOperationIr::IntRandom(desc),
850            ))
851            .output()
852    }
853
854    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
855        let client = tensor.client.clone();
856        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
857            client.create_empty_handle()
858        });
859
860        client
861            .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc)))
862            .output()
863    }
864
865    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
866        let client = tensor.client.clone();
867        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
868
869        client
870            .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc)))
871            .output()
872    }
873
874    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
875        let client = tensor.client.clone();
876        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
877            client.create_empty_handle()
878        });
879
880        client
881            .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc)))
882            .output()
883    }
884
885    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
886        let client = tensor.client.clone();
887        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
888            client.create_empty_handle()
889        });
890
891        client
892            .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc)))
893            .output()
894    }
895
896    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
897        let client = lhs.client.clone();
898        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
899            client.create_empty_handle()
900        });
901
902        client
903            .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc)))
904            .output()
905    }
906
907    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
908        let client = lhs.client.clone();
909        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
910            client.create_empty_handle()
911        });
912
913        client
914            .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc)))
915            .output()
916    }
917
918    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
919        let client = lhs.client.clone();
920        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
921            client.create_empty_handle()
922        });
923
924        client
925            .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc)))
926            .output()
927    }
928
929    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
930        let client = tensor.client.clone();
931        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
932
933        client
934            .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc)))
935            .output()
936    }
937
938    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
939        let client = lhs.client.clone();
940        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
941        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
942
943        client
944            .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc)))
945            .output()
946    }
947
948    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
949        let client = lhs.client.clone();
950        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
951        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
952
953        client
954            .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc)))
955            .output()
956    }
957
958    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
959        let client = lhs.client.clone();
960        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
961        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
962
963        client
964            .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc)))
965            .output()
966    }
967
968    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
969        let client = lhs.client.clone();
970        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
971            client.create_empty_handle()
972        });
973
974        client
975            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc)))
976            .output()
977    }
978
979    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
980        let client = lhs.client.clone();
981        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
982        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
983
984        client
985            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(
986                desc,
987            )))
988            .output()
989    }
990
991    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
992        let client = lhs.client.clone();
993        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
994            client.create_empty_handle()
995        });
996
997        client
998            .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc)))
999            .output()
1000    }
1001
1002    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
1003        let client = lhs.client.clone();
1004        let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
1005        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1006
1007        client
1008            .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(
1009                desc,
1010            )))
1011            .output()
1012    }
1013
1014    fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
1015        let client = tensor.client.clone();
1016        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1017            client.create_empty_handle()
1018        });
1019
1020        client
1021            .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc)))
1022            .output()
1023    }
1024
1025    fn int_unfold(
1026        tensor: IntTensor<Self>,
1027        dim: usize,
1028        size: usize,
1029        step: usize,
1030    ) -> IntTensor<Self> {
1031        let client = tensor.client.clone();
1032        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1033            client.create_empty_handle()
1034        });
1035
1036        client
1037            .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)))
1038            .output()
1039    }
1040}