1use alloc::vec::Vec;
2use burn_backend::backend::ExecutionError;
3use burn_std::{BoolDType, FloatDType};
4
5use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
6use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntElem, IntTensor};
7use burn_backend::{
8 Distribution, Element, IntDType, Scalar, Shape, Slice, TensorData, ops::IntTensorOps,
9};
10use burn_ir::{
11 BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr,
12 GatherNdOpIr, GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr,
13 MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr,
14 ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterNdOpIr,
15 ScatterOpIr, SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr,
16 UnaryOpIr, UnfoldOpIr,
17};
18
19impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
20 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
21 let client = get_client::<R>(device);
22 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
23
24 client
25 .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc)))
26 .output()
27 }
28
29 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
30 Ok(tensor
31 .into_data()
32 .await?
33 .convert::<IntElem<Self>>())
35 }
36
37 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
38 let client = get_client::<R>(device);
39 let out = client.register_tensor_data(data);
40 let desc = InitOperationIr {
41 out: out.to_ir_out(),
42 };
43
44 client.register_op(OperationIr::Init(desc));
46
47 out
48 }
49
50 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
51 tensor.client.device()
52 }
53
54 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
55 if &tensor.client.device() == device {
56 return tensor;
57 }
58 R::change_client_backend(tensor, device)
59 }
60
61 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
62 let client = tensor.client.clone();
63 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
64
65 client
66 .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc)))
67 .output()
68 }
69
70 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
71 let client = tensor.client.clone();
72 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
73 client.create_empty_handle()
74 });
75
76 client
77 .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc)))
78 .output()
79 }
80
81 fn int_slice_assign(
82 tensor: IntTensor<Self>,
83 slices: &[burn_backend::Slice],
84 value: IntTensor<Self>,
85 ) -> IntTensor<Self> {
86 let client = tensor.client.clone();
87 let desc =
88 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
89 client.create_empty_handle()
90 });
91
92 client
93 .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc)))
94 .output()
95 }
96
97 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
98 let client = lhs.client.clone();
99 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
100 client.create_empty_handle()
101 });
102
103 client
104 .register(OperationIr::Int(IntOperationIr::Matmul(desc)))
105 .output()
106 }
107
108 fn int_mask_where(
109 tensor: IntTensor<Self>,
110 mask: BoolTensor<Self>,
111 value: IntTensor<Self>,
112 ) -> IntTensor<Self> {
113 let client = tensor.client.clone();
114 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
115 client.create_empty_handle()
116 });
117
118 client
119 .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc)))
120 .output()
121 }
122
123 fn int_mask_fill(
124 tensor: IntTensor<Self>,
125 mask: BoolTensor<Self>,
126 value: Scalar,
127 ) -> IntTensor<Self> {
128 let client = tensor.client.clone();
129 let value = value.into();
130 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
131 client.create_empty_handle()
132 });
133
134 client
135 .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc)))
136 .output()
137 }
138
139 fn int_gather(
140 dim: usize,
141 tensor: IntTensor<Self>,
142 indices: IntTensor<Self>,
143 ) -> IntTensor<Self> {
144 let client = tensor.client.clone();
145 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
146 client.create_empty_handle()
147 });
148
149 client
150 .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc)))
151 .output()
152 }
153
154 fn int_scatter_add(
155 dim: usize,
156 tensor: IntTensor<Self>,
157 indices: IntTensor<Self>,
158 value: IntTensor<Self>,
159 ) -> IntTensor<Self> {
160 let client = tensor.client.clone();
161 let desc = ScatterOpIr::create(
162 tensor.into_ir(),
163 dim,
164 indices.into_ir(),
165 value.into_ir(),
166 IndexingUpdateOp::Add,
167 || client.create_empty_handle(),
168 );
169
170 client
171 .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc)))
172 .output()
173 }
174
175 fn int_scatter_nd(
176 data: IntTensor<Self>,
177 indices: IntTensor<Self>,
178 values: IntTensor<Self>,
179 reduction: IndexingUpdateOp,
180 ) -> IntTensor<Self> {
181 let client = data.client.clone();
182 let desc = ScatterNdOpIr::create(
183 data.into_ir(),
184 indices.into_ir(),
185 values.into_ir(),
186 reduction,
187 || client.create_empty_handle(),
188 );
189
190 client
191 .register(OperationIr::BaseInt(BaseOperationIr::ScatterNd(desc)))
192 .output()
193 }
194
195 fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
196 let client = data.client.clone();
197 let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
198 client.create_empty_handle()
199 });
200
201 client
202 .register(OperationIr::BaseInt(BaseOperationIr::GatherNd(desc)))
203 .output()
204 }
205
206 fn int_select(
207 tensor: IntTensor<Self>,
208 dim: usize,
209 indices: IntTensor<Self>,
210 ) -> IntTensor<Self> {
211 let client = tensor.client.clone();
212 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
213 client.create_empty_handle()
214 });
215
216 client
217 .register(OperationIr::BaseInt(BaseOperationIr::Select(desc)))
218 .output()
219 }
220
221 fn int_select_add(
222 tensor: IntTensor<Self>,
223 dim: usize,
224 indices: IntTensor<Self>,
225 value: IntTensor<Self>,
226 ) -> IntTensor<Self> {
227 let client = tensor.client.clone();
228 let desc = SelectAssignOpIr::create(
229 tensor.into_ir(),
230 dim,
231 indices.into_ir(),
232 value.into_ir(),
233 IndexingUpdateOp::Add,
234 || client.create_empty_handle(),
235 );
236
237 client
238 .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc)))
239 .output()
240 }
241
242 fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
243 let client = tensors.first().unwrap().client.clone();
244 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
245 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
246
247 client
248 .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc)))
249 .output()
250 }
251
252 fn int_equal(
253 lhs: IntTensor<Self>,
254 rhs: IntTensor<Self>,
255 out_dtype: BoolDType,
256 ) -> BoolTensor<Self> {
257 let client = lhs.client.clone();
258 let desc =
259 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
260 client.create_empty_handle()
261 });
262
263 client
264 .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc)))
265 .output()
266 }
267
268 fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
269 let client = lhs.client.clone();
270 let rhs = rhs.into();
271 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
272 client.create_empty_handle()
273 });
274
275 client
276 .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc)))
277 .output()
278 }
279
280 fn int_greater(
281 lhs: IntTensor<Self>,
282 rhs: IntTensor<Self>,
283 out_dtype: BoolDType,
284 ) -> BoolTensor<Self> {
285 let client = lhs.client.clone();
286 let desc =
287 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
288 client.create_empty_handle()
289 });
290
291 client
292 .register(OperationIr::NumericInt(
293 desc.lhs.dtype,
294 NumericOperationIr::Greater(desc),
295 ))
296 .output()
297 }
298
299 fn int_greater_elem(
300 lhs: IntTensor<Self>,
301 rhs: Scalar,
302 out_dtype: BoolDType,
303 ) -> BoolTensor<Self> {
304 let client = lhs.client.clone();
305 let rhs = rhs.into();
306 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
307 client.create_empty_handle()
308 });
309
310 client
311 .register(OperationIr::NumericInt(
312 desc.lhs.dtype,
313 NumericOperationIr::GreaterElem(desc),
314 ))
315 .output()
316 }
317
318 fn int_greater_equal(
319 lhs: IntTensor<Self>,
320 rhs: IntTensor<Self>,
321 out_dtype: BoolDType,
322 ) -> BoolTensor<Self> {
323 let client = lhs.client.clone();
324 let desc =
325 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
326 client.create_empty_handle()
327 });
328
329 client
330 .register(OperationIr::NumericInt(
331 desc.lhs.dtype,
332 NumericOperationIr::GreaterEqual(desc),
333 ))
334 .output()
335 }
336
337 fn int_greater_equal_elem(
338 lhs: IntTensor<Self>,
339 rhs: Scalar,
340 out_dtype: BoolDType,
341 ) -> BoolTensor<Self> {
342 let client = lhs.client.clone();
343 let rhs = rhs.into();
344 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
345 client.create_empty_handle()
346 });
347
348 client
349 .register(OperationIr::NumericInt(
350 desc.lhs.dtype,
351 NumericOperationIr::GreaterEqualElem(desc),
352 ))
353 .output()
354 }
355
356 fn int_lower(
357 lhs: IntTensor<Self>,
358 rhs: IntTensor<Self>,
359 out_dtype: BoolDType,
360 ) -> BoolTensor<Self> {
361 let client = lhs.client.clone();
362 let desc =
363 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
364 client.create_empty_handle()
365 });
366
367 client
368 .register(OperationIr::NumericInt(
369 desc.lhs.dtype,
370 NumericOperationIr::Lower(desc),
371 ))
372 .output()
373 }
374
375 fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
376 let client = lhs.client.clone();
377 let rhs = rhs.into();
378 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
379 client.create_empty_handle()
380 });
381
382 client
383 .register(OperationIr::NumericInt(
384 desc.lhs.dtype,
385 NumericOperationIr::LowerElem(desc),
386 ))
387 .output()
388 }
389
390 fn int_lower_equal(
391 lhs: IntTensor<Self>,
392 rhs: IntTensor<Self>,
393 out_dtype: BoolDType,
394 ) -> BoolTensor<Self> {
395 let client = lhs.client.clone();
396 let desc =
397 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
398 client.create_empty_handle()
399 });
400
401 client
402 .register(OperationIr::NumericInt(
403 desc.lhs.dtype,
404 NumericOperationIr::LowerEqual(desc),
405 ))
406 .output()
407 }
408
409 fn int_lower_equal_elem(
410 lhs: IntTensor<Self>,
411 rhs: Scalar,
412 out_dtype: BoolDType,
413 ) -> BoolTensor<Self> {
414 let client = lhs.client.clone();
415 let rhs = rhs.into();
416 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
417 client.create_empty_handle()
418 });
419
420 client
421 .register(OperationIr::NumericInt(
422 desc.lhs.dtype,
423 NumericOperationIr::LowerEqualElem(desc),
424 ))
425 .output()
426 }
427
428 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
429 let client = lhs.client.clone();
430 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
431 client.create_empty_handle()
432 });
433
434 client
435 .register(OperationIr::NumericInt(
436 desc.out.dtype,
437 NumericOperationIr::Add(desc),
438 ))
439 .output()
440 }
441
442 fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
443 let client = lhs.client.clone();
444 let rhs = rhs.into();
445 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
446
447 client
448 .register(OperationIr::NumericInt(
449 desc.out.dtype,
450 NumericOperationIr::AddScalar(desc),
451 ))
452 .output()
453 }
454
455 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
456 let client = lhs.client.clone();
457 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
458 client.create_empty_handle()
459 });
460
461 client
462 .register(OperationIr::NumericInt(
463 desc.out.dtype,
464 NumericOperationIr::Sub(desc),
465 ))
466 .output()
467 }
468
469 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
470 let client = lhs.client.clone();
471 let rhs = rhs.into();
472 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
473
474 client
475 .register(OperationIr::NumericInt(
476 desc.out.dtype,
477 NumericOperationIr::SubScalar(desc),
478 ))
479 .output()
480 }
481
482 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
483 let client = lhs.client.clone();
484 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
485 client.create_empty_handle()
486 });
487
488 client
489 .register(OperationIr::NumericInt(
490 desc.out.dtype,
491 NumericOperationIr::Mul(desc),
492 ))
493 .output()
494 }
495
496 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
497 let client = lhs.client.clone();
498 let rhs = rhs.into();
499 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
500
501 client
502 .register(OperationIr::NumericInt(
503 desc.out.dtype,
504 NumericOperationIr::MulScalar(desc),
505 ))
506 .output()
507 }
508
509 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
510 let client = lhs.client.clone();
511 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
512 client.create_empty_handle()
513 });
514
515 client
516 .register(OperationIr::NumericInt(
517 desc.out.dtype,
518 NumericOperationIr::Div(desc),
519 ))
520 .output()
521 }
522
523 fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
524 let client = lhs.client.clone();
525 let rhs = rhs.into();
526 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
527
528 client
529 .register(OperationIr::NumericInt(
530 desc.out.dtype,
531 NumericOperationIr::DivScalar(desc),
532 ))
533 .output()
534 }
535
536 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
537 let client = lhs.client.clone();
538 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
539 client.create_empty_handle()
540 });
541
542 client
543 .register(OperationIr::NumericInt(
544 desc.out.dtype,
545 NumericOperationIr::Rem(desc),
546 ))
547 .output()
548 }
549
550 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
551 let client = lhs.client.clone();
552 let rhs = rhs.into();
553 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
554
555 client
556 .register(OperationIr::NumericInt(
557 desc.out.dtype,
558 NumericOperationIr::RemScalar(desc),
559 ))
560 .output()
561 }
562
563 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
564 let client = get_client::<R>(device);
565 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
566
567 client
568 .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc)))
569 .output()
570 }
571
572 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
573 let client = get_client::<R>(device);
574 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
575
576 client
577 .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc)))
578 .output()
579 }
580
581 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
582 let client = tensor.client.clone();
583 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
584
585 client
586 .register(OperationIr::NumericInt(
587 desc.out.dtype,
588 NumericOperationIr::Sum(desc),
589 ))
590 .output()
591 }
592
593 fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
594 let client = tensor.client.clone();
595 let desc =
596 ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
597
598 client
599 .register(OperationIr::NumericInt(
600 desc.out.dtype,
601 NumericOperationIr::SumDim(desc),
602 ))
603 .output()
604 }
605
606 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
607 let client = tensor.client.clone();
608 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
609
610 client
611 .register(OperationIr::NumericInt(
612 desc.out.dtype,
613 NumericOperationIr::Prod(desc),
614 ))
615 .output()
616 }
617
618 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
619 let client = tensor.client.clone();
620 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
621
622 client
623 .register(OperationIr::NumericInt(
624 desc.out.dtype,
625 NumericOperationIr::ProdDim(desc),
626 ))
627 .output()
628 }
629
630 fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
631 let client = tensor.client.clone();
632 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
633
634 client
635 .register(OperationIr::NumericInt(
636 desc.out.dtype,
637 NumericOperationIr::Mean(desc),
638 ))
639 .output()
640 }
641
642 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
643 let client = tensor.client.clone();
644 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
645
646 client
647 .register(OperationIr::NumericInt(
648 desc.out.dtype,
649 NumericOperationIr::MeanDim(desc),
650 ))
651 .output()
652 }
653
654 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
655 let client = tensor.client.clone();
656 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
657
658 client
659 .register(OperationIr::NumericInt(
660 desc.out.dtype,
661 NumericOperationIr::CumSum(desc),
662 ))
663 .output()
664 }
665
666 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
667 let client = tensor.client.clone();
668 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
669
670 client
671 .register(OperationIr::NumericInt(
672 desc.out.dtype,
673 NumericOperationIr::CumProd(desc),
674 ))
675 .output()
676 }
677
678 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
679 let client = tensor.client.clone();
680 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
681
682 client
683 .register(OperationIr::NumericInt(
684 desc.out.dtype,
685 NumericOperationIr::CumMin(desc),
686 ))
687 .output()
688 }
689
690 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
691 let client = tensor.client.clone();
692 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
693
694 client
695 .register(OperationIr::NumericInt(
696 desc.out.dtype,
697 NumericOperationIr::CumMax(desc),
698 ))
699 .output()
700 }
701
702 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
703 let client = tensor.client.clone();
704 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
705
706 client
707 .register(OperationIr::NumericInt(
708 desc.out.dtype,
709 NumericOperationIr::ArgMax(desc),
710 ))
711 .output()
712 }
713
714 fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
715 let client = tensor.client.clone();
716 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
717
718 client
719 .register(OperationIr::NumericInt(
720 desc.out.dtype,
721 NumericOperationIr::ArgTopK(desc),
722 ))
723 .output()
724 }
725
726 fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
727 let client = tensor.client.clone();
728 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
729
730 client
731 .register(OperationIr::NumericInt(
732 desc.out.dtype,
733 NumericOperationIr::TopK(desc),
734 ))
735 .output()
736 }
737
738 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
739 let client = tensor.client.clone();
740 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
741
742 client
743 .register(OperationIr::NumericInt(
744 desc.out.dtype,
745 NumericOperationIr::ArgMin(desc),
746 ))
747 .output()
748 }
749
750 fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
751 let client = tensor.client.clone();
752 let min = min.into();
753 let max = max.into();
754 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
755
756 client
757 .register(OperationIr::NumericInt(
758 desc.out.dtype,
759 NumericOperationIr::Clamp(desc),
760 ))
761 .output()
762 }
763
764 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
765 let client = tensor.client.clone();
766 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
767
768 client
769 .register(OperationIr::NumericInt(
770 desc.out.dtype,
771 NumericOperationIr::Abs(desc),
772 ))
773 .output()
774 }
775
776 fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
777 let client = tensor.client.clone();
778 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
779 client.create_empty_handle()
780 });
781
782 client
783 .register(OperationIr::Int(IntOperationIr::IntoFloat(desc)))
784 .output()
785 }
786
787 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
788 let client = tensor.client.clone();
789 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
790 client.create_empty_handle()
791 });
792
793 client
794 .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc)))
795 .output()
796 }
797
798 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
799 let client = tensor.client.clone();
800 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
801
802 client
803 .register(OperationIr::NumericInt(
804 desc.out.dtype,
805 NumericOperationIr::Max(desc),
806 ))
807 .output()
808 }
809
810 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
811 let client = tensor.client.clone();
812 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
813
814 client
815 .register(OperationIr::NumericInt(
816 desc.out.dtype,
817 NumericOperationIr::MaxDim(desc),
818 ))
819 .output()
820 }
821
822 fn int_max_dim_with_indices(
823 tensor: IntTensor<Self>,
824 dim: usize,
825 ) -> (IntTensor<Self>, IntTensor<Self>) {
826 let client = tensor.client.clone();
827 let desc = ReduceDimWithIndicesOpIr::create(
828 tensor.into_ir(),
829 dim,
830 IntElem::<Self>::dtype(),
831 || client.create_empty_handle(),
832 );
833
834 client
835 .register(OperationIr::NumericInt(
836 desc.tensor.dtype,
837 NumericOperationIr::MaxDimWithIndices(desc),
838 ))
839 .outputs()
840 .into()
841 }
842
843 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
844 let client = tensor.client.clone();
845 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
846
847 client
848 .register(OperationIr::NumericInt(
849 desc.out.dtype,
850 NumericOperationIr::MaxAbs(desc),
851 ))
852 .output()
853 }
854
855 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
856 let client = tensor.client.clone();
857 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
858
859 client
860 .register(OperationIr::NumericInt(
861 desc.out.dtype,
862 NumericOperationIr::MaxAbsDim(desc),
863 ))
864 .output()
865 }
866
867 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
868 let client = tensor.client.clone();
869 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
870
871 client
872 .register(OperationIr::NumericInt(
873 desc.out.dtype,
874 NumericOperationIr::Min(desc),
875 ))
876 .output()
877 }
878
879 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
880 let client = tensor.client.clone();
881 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
882
883 client
884 .register(OperationIr::NumericInt(
885 desc.out.dtype,
886 NumericOperationIr::MinDim(desc),
887 ))
888 .output()
889 }
890
891 fn int_min_dim_with_indices(
892 tensor: IntTensor<Self>,
893 dim: usize,
894 ) -> (IntTensor<Self>, IntTensor<Self>) {
895 let client = tensor.client.clone();
896 let desc = ReduceDimWithIndicesOpIr::create(
897 tensor.into_ir(),
898 dim,
899 IntElem::<Self>::dtype(),
900 || client.create_empty_handle(),
901 );
902
903 client
904 .register(OperationIr::NumericInt(
905 desc.out.dtype,
906 NumericOperationIr::MinDimWithIndices(desc),
907 ))
908 .outputs()
909 .into()
910 }
911
912 fn int_random(
913 shape: Shape,
914 distribution: Distribution,
915 device: &Device<Self>,
916 dtype: IntDType,
917 ) -> IntTensor<Self> {
918 let client = get_client::<R>(device);
919 let dtype = dtype.into();
920 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
921
922 client
923 .register(OperationIr::NumericInt(
924 dtype,
925 NumericOperationIr::IntRandom(desc),
926 ))
927 .output()
928 }
929
930 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
931 let client = tensor.client.clone();
932 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
933 client.create_empty_handle()
934 });
935
936 client
937 .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc)))
938 .output()
939 }
940
941 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
942 let client = tensor.client.clone();
943 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
944
945 client
946 .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc)))
947 .output()
948 }
949
950 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
951 let client = tensor.client.clone();
952 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
953 client.create_empty_handle()
954 });
955
956 client
957 .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc)))
958 .output()
959 }
960
961 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
962 let client = tensor.client.clone();
963 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
964 client.create_empty_handle()
965 });
966
967 client
968 .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc)))
969 .output()
970 }
971
972 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
973 let client = lhs.client.clone();
974 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
975 client.create_empty_handle()
976 });
977
978 client
979 .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc)))
980 .output()
981 }
982
983 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
984 let client = lhs.client.clone();
985 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
986 client.create_empty_handle()
987 });
988
989 client
990 .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc)))
991 .output()
992 }
993
994 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
995 let client = lhs.client.clone();
996 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
997 client.create_empty_handle()
998 });
999
1000 client
1001 .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc)))
1002 .output()
1003 }
1004
1005 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
1006 let client = tensor.client.clone();
1007 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1008
1009 client
1010 .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc)))
1011 .output()
1012 }
1013
1014 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1015 let client = lhs.client.clone();
1016 let rhs = rhs.into();
1017 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1018
1019 client
1020 .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc)))
1021 .output()
1022 }
1023
1024 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1025 let client = lhs.client.clone();
1026 let rhs = rhs.into();
1027 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1028
1029 client
1030 .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc)))
1031 .output()
1032 }
1033
1034 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1035 let client = lhs.client.clone();
1036 let rhs = rhs.into();
1037 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1038
1039 client
1040 .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc)))
1041 .output()
1042 }
1043
1044 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1045 let client = lhs.client.clone();
1046 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1047 client.create_empty_handle()
1048 });
1049
1050 client
1051 .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc)))
1052 .output()
1053 }
1054
1055 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1056 let client = lhs.client.clone();
1057 let rhs = rhs.into();
1058 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1059
1060 client
1061 .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(
1062 desc,
1063 )))
1064 .output()
1065 }
1066
1067 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1068 let client = lhs.client.clone();
1069 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1070 client.create_empty_handle()
1071 });
1072
1073 client
1074 .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc)))
1075 .output()
1076 }
1077
1078 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1079 let client = lhs.client.clone();
1080 let rhs = rhs.into();
1081 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1082
1083 client
1084 .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(
1085 desc,
1086 )))
1087 .output()
1088 }
1089
1090 fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
1091 let client = tensor.client.clone();
1092 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1093 client.create_empty_handle()
1094 });
1095
1096 client
1097 .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc)))
1098 .output()
1099 }
1100
1101 fn int_unfold(
1102 tensor: IntTensor<Self>,
1103 dim: usize,
1104 size: usize,
1105 step: usize,
1106 ) -> IntTensor<Self> {
1107 let client = tensor.client.clone();
1108 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1109 client.create_empty_handle()
1110 });
1111
1112 client
1113 .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)))
1114 .output()
1115 }
1116
1117 fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1118 let client = lhs.client.clone();
1119 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1120 client.create_empty_handle()
1121 });
1122
1123 client
1124 .register(OperationIr::NumericInt(
1125 desc.out.dtype,
1126 NumericOperationIr::Powi(desc),
1127 ))
1128 .output()
1129 }
1130
1131 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1132 let client = lhs.client.clone();
1133 let rhs = rhs.into();
1134 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1135
1136 client
1137 .register(OperationIr::NumericInt(
1138 desc.out.dtype,
1139 NumericOperationIr::PowiScalar(desc),
1140 ))
1141 .output()
1142 }
1143}