1use alloc::vec::Vec;
2use burn_backend::backend::{Backend, ExecutionError};
3
4use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
5use burn_backend::tensor::{
6 BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,
7};
8use burn_backend::{Distribution, Element, IntDType, Shape, Slice, TensorData, ops::IntTensorOps};
9use burn_ir::{
10 BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr,
11 GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, MatmulOpIr,
12 NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr, ReduceDimOpIr,
13 ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarIr, ScalarOpIr, ScatterOpIr,
14 SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,
15 UnfoldOpIr,
16};
17
18impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
19 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
20 let client = get_client::<R>(device);
21 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
22
23 client
24 .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc)))
25 .output()
26 }
27
28 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
29 Ok(tensor
30 .into_data()
31 .await?
32 .convert::<<Self as Backend>::IntElem>())
34 }
35
36 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
37 let client = get_client::<R>(device);
38 let out = client.register_tensor_data(data);
39 let desc = InitOperationIr {
40 out: out.to_ir_out(),
41 };
42
43 client.register_op(OperationIr::Init(desc));
45
46 out
47 }
48
49 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
50 tensor.client.device()
51 }
52
53 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
54 if &tensor.client.device() == device {
55 return tensor;
56 }
57 R::change_client_backend(tensor, device)
58 }
59
60 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
61 let client = tensor.client.clone();
62 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
63
64 client
65 .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc)))
66 .output()
67 }
68
69 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
70 let client = tensor.client.clone();
71 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
72 client.create_empty_handle()
73 });
74
75 client
76 .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc)))
77 .output()
78 }
79
80 fn int_slice_assign(
81 tensor: IntTensor<Self>,
82 slices: &[burn_backend::Slice],
83 value: IntTensor<Self>,
84 ) -> IntTensor<Self> {
85 let client = tensor.client.clone();
86 let desc =
87 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
88 client.create_empty_handle()
89 });
90
91 client
92 .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc)))
93 .output()
94 }
95
96 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
97 let client = lhs.client.clone();
98 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
99 client.create_empty_handle()
100 });
101
102 client
103 .register(OperationIr::Int(IntOperationIr::Matmul(desc)))
104 .output()
105 }
106
107 fn int_mask_where(
108 tensor: IntTensor<Self>,
109 mask: BoolTensor<Self>,
110 value: IntTensor<Self>,
111 ) -> IntTensor<Self> {
112 let client = tensor.client.clone();
113 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
114 client.create_empty_handle()
115 });
116
117 client
118 .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc)))
119 .output()
120 }
121
122 fn int_mask_fill(
123 tensor: IntTensor<Self>,
124 mask: BoolTensor<Self>,
125 value: IntElem<Self>,
126 ) -> IntTensor<Self> {
127 let client = tensor.client.clone();
128 let value = ScalarIr::with_dtype(value, &tensor.dtype);
129 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
130 client.create_empty_handle()
131 });
132
133 client
134 .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc)))
135 .output()
136 }
137
138 fn int_gather(
139 dim: usize,
140 tensor: IntTensor<Self>,
141 indices: IntTensor<Self>,
142 ) -> IntTensor<Self> {
143 let client = tensor.client.clone();
144 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
145 client.create_empty_handle()
146 });
147
148 client
149 .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc)))
150 .output()
151 }
152
153 fn int_scatter_add(
154 dim: usize,
155 tensor: IntTensor<Self>,
156 indices: IntTensor<Self>,
157 value: IntTensor<Self>,
158 ) -> IntTensor<Self> {
159 let client = tensor.client.clone();
160 let desc = ScatterOpIr::create(
161 tensor.into_ir(),
162 dim,
163 indices.into_ir(),
164 value.into_ir(),
165 IndexingUpdateOp::Add,
166 || client.create_empty_handle(),
167 );
168
169 client
170 .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc)))
171 .output()
172 }
173
174 fn int_select(
175 tensor: IntTensor<Self>,
176 dim: usize,
177 indices: IntTensor<Self>,
178 ) -> IntTensor<Self> {
179 let client = tensor.client.clone();
180 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
181 client.create_empty_handle()
182 });
183
184 client
185 .register(OperationIr::BaseInt(BaseOperationIr::Select(desc)))
186 .output()
187 }
188
189 fn int_select_add(
190 tensor: IntTensor<Self>,
191 dim: usize,
192 indices: IntTensor<Self>,
193 value: IntTensor<Self>,
194 ) -> IntTensor<Self> {
195 let client = tensor.client.clone();
196 let desc = SelectAssignOpIr::create(
197 tensor.into_ir(),
198 dim,
199 indices.into_ir(),
200 value.into_ir(),
201 IndexingUpdateOp::Add,
202 || client.create_empty_handle(),
203 );
204
205 client
206 .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc)))
207 .output()
208 }
209
210 fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
211 let client = tensors.first().unwrap().client.clone();
212 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
213 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
214
215 client
216 .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc)))
217 .output()
218 }
219
220 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
221 let client = lhs.client.clone();
222 let desc = BinaryOpIr::create_comparison(
223 lhs.into_ir(),
224 rhs.into_ir(),
225 R::BoolElem::dtype(),
226 || client.create_empty_handle(),
227 );
228
229 client
230 .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc)))
231 .output()
232 }
233
234 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
235 let client = lhs.client.clone();
236 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
237 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
238 client.create_empty_handle()
239 });
240
241 client
242 .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc)))
243 .output()
244 }
245
246 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
247 let client = lhs.client.clone();
248 let desc = BinaryOpIr::create_comparison(
249 lhs.into_ir(),
250 rhs.into_ir(),
251 R::BoolElem::dtype(),
252 || client.create_empty_handle(),
253 );
254
255 client
256 .register(OperationIr::NumericInt(
257 desc.lhs.dtype,
258 NumericOperationIr::Greater(desc),
259 ))
260 .output()
261 }
262
263 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
264 let client = lhs.client.clone();
265 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
266 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
267 client.create_empty_handle()
268 });
269
270 client
271 .register(OperationIr::NumericInt(
272 desc.lhs.dtype,
273 NumericOperationIr::GreaterElem(desc),
274 ))
275 .output()
276 }
277
278 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
279 let client = lhs.client.clone();
280 let desc = BinaryOpIr::create_comparison(
281 lhs.into_ir(),
282 rhs.into_ir(),
283 R::BoolElem::dtype(),
284 || client.create_empty_handle(),
285 );
286
287 client
288 .register(OperationIr::NumericInt(
289 desc.lhs.dtype,
290 NumericOperationIr::GreaterEqual(desc),
291 ))
292 .output()
293 }
294
295 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
296 let client = lhs.client.clone();
297 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
298 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
299 client.create_empty_handle()
300 });
301
302 client
303 .register(OperationIr::NumericInt(
304 desc.lhs.dtype,
305 NumericOperationIr::GreaterEqualElem(desc),
306 ))
307 .output()
308 }
309
310 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
311 let client = lhs.client.clone();
312 let desc = BinaryOpIr::create_comparison(
313 lhs.into_ir(),
314 rhs.into_ir(),
315 R::BoolElem::dtype(),
316 || client.create_empty_handle(),
317 );
318
319 client
320 .register(OperationIr::NumericInt(
321 desc.lhs.dtype,
322 NumericOperationIr::Lower(desc),
323 ))
324 .output()
325 }
326
327 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
328 let client = lhs.client.clone();
329 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
330 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
331 client.create_empty_handle()
332 });
333
334 client
335 .register(OperationIr::NumericInt(
336 desc.lhs.dtype,
337 NumericOperationIr::LowerElem(desc),
338 ))
339 .output()
340 }
341
342 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
343 let client = lhs.client.clone();
344 let desc = BinaryOpIr::create_comparison(
345 lhs.into_ir(),
346 rhs.into_ir(),
347 R::BoolElem::dtype(),
348 || client.create_empty_handle(),
349 );
350
351 client
352 .register(OperationIr::NumericInt(
353 desc.lhs.dtype,
354 NumericOperationIr::LowerEqual(desc),
355 ))
356 .output()
357 }
358
359 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
360 let client = lhs.client.clone();
361 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
362 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {
363 client.create_empty_handle()
364 });
365
366 client
367 .register(OperationIr::NumericInt(
368 desc.lhs.dtype,
369 NumericOperationIr::LowerEqualElem(desc),
370 ))
371 .output()
372 }
373
374 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
375 let client = lhs.client.clone();
376 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
377 client.create_empty_handle()
378 });
379
380 client
381 .register(OperationIr::NumericInt(
382 desc.out.dtype,
383 NumericOperationIr::Add(desc),
384 ))
385 .output()
386 }
387
388 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
389 let client = lhs.client.clone();
390 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
391 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
392
393 client
394 .register(OperationIr::NumericInt(
395 desc.out.dtype,
396 NumericOperationIr::AddScalar(desc),
397 ))
398 .output()
399 }
400
401 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
402 let client = lhs.client.clone();
403 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
404 client.create_empty_handle()
405 });
406
407 client
408 .register(OperationIr::NumericInt(
409 desc.out.dtype,
410 NumericOperationIr::Sub(desc),
411 ))
412 .output()
413 }
414
415 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
416 let client = lhs.client.clone();
417 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
418 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
419
420 client
421 .register(OperationIr::NumericInt(
422 desc.out.dtype,
423 NumericOperationIr::SubScalar(desc),
424 ))
425 .output()
426 }
427
428 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
429 let client = lhs.client.clone();
430 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
431 client.create_empty_handle()
432 });
433
434 client
435 .register(OperationIr::NumericInt(
436 desc.out.dtype,
437 NumericOperationIr::Mul(desc),
438 ))
439 .output()
440 }
441
442 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
443 let client = lhs.client.clone();
444 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
445 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
446
447 client
448 .register(OperationIr::NumericInt(
449 desc.out.dtype,
450 NumericOperationIr::MulScalar(desc),
451 ))
452 .output()
453 }
454
455 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
456 let client = lhs.client.clone();
457 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
458 client.create_empty_handle()
459 });
460
461 client
462 .register(OperationIr::NumericInt(
463 desc.out.dtype,
464 NumericOperationIr::Div(desc),
465 ))
466 .output()
467 }
468
469 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
470 let client = lhs.client.clone();
471 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
472 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
473
474 client
475 .register(OperationIr::NumericInt(
476 desc.out.dtype,
477 NumericOperationIr::DivScalar(desc),
478 ))
479 .output()
480 }
481
482 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
483 let client = lhs.client.clone();
484 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
485 client.create_empty_handle()
486 });
487
488 client
489 .register(OperationIr::NumericInt(
490 desc.out.dtype,
491 NumericOperationIr::Rem(desc),
492 ))
493 .output()
494 }
495
496 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
497 let client = lhs.client.clone();
498 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
499 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
500
501 client
502 .register(OperationIr::NumericInt(
503 desc.out.dtype,
504 NumericOperationIr::RemScalar(desc),
505 ))
506 .output()
507 }
508
509 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
510 let client = get_client::<R>(device);
511 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
512
513 client
514 .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc)))
515 .output()
516 }
517
518 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
519 let client = get_client::<R>(device);
520 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
521
522 client
523 .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc)))
524 .output()
525 }
526
527 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
528 let client = tensor.client.clone();
529 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
530
531 client
532 .register(OperationIr::NumericInt(
533 desc.out.dtype,
534 NumericOperationIr::Sum(desc),
535 ))
536 .output()
537 }
538
539 fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
540 let client = tensor.client.clone();
541 let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());
542
543 client
544 .register(OperationIr::NumericInt(
545 desc.out.dtype,
546 NumericOperationIr::SumDim(desc),
547 ))
548 .output()
549 }
550
551 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
552 let client = tensor.client.clone();
553 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
554
555 client
556 .register(OperationIr::NumericInt(
557 desc.out.dtype,
558 NumericOperationIr::Prod(desc),
559 ))
560 .output()
561 }
562
563 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
564 let client = tensor.client.clone();
565 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
566
567 client
568 .register(OperationIr::NumericInt(
569 desc.out.dtype,
570 NumericOperationIr::ProdDim(desc),
571 ))
572 .output()
573 }
574
575 fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
576 let client = tensor.client.clone();
577 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
578
579 client
580 .register(OperationIr::NumericInt(
581 desc.out.dtype,
582 NumericOperationIr::Mean(desc),
583 ))
584 .output()
585 }
586
587 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
588 let client = tensor.client.clone();
589 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
590
591 client
592 .register(OperationIr::NumericInt(
593 desc.out.dtype,
594 NumericOperationIr::MeanDim(desc),
595 ))
596 .output()
597 }
598
599 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
600 let client = tensor.client.clone();
601 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
602
603 client
604 .register(OperationIr::NumericInt(
605 desc.out.dtype,
606 NumericOperationIr::CumSum(desc),
607 ))
608 .output()
609 }
610
611 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
612 let client = tensor.client.clone();
613 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
614
615 client
616 .register(OperationIr::NumericInt(
617 desc.out.dtype,
618 NumericOperationIr::CumProd(desc),
619 ))
620 .output()
621 }
622
623 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
624 let client = tensor.client.clone();
625 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
626
627 client
628 .register(OperationIr::NumericInt(
629 desc.out.dtype,
630 NumericOperationIr::CumMin(desc),
631 ))
632 .output()
633 }
634
635 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
636 let client = tensor.client.clone();
637 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
638
639 client
640 .register(OperationIr::NumericInt(
641 desc.out.dtype,
642 NumericOperationIr::CumMax(desc),
643 ))
644 .output()
645 }
646
647 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
648 let client = tensor.client.clone();
649 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
650
651 client
652 .register(OperationIr::NumericInt(
653 desc.out.dtype,
654 NumericOperationIr::ArgMax(desc),
655 ))
656 .output()
657 }
658
659 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
660 let client = tensor.client.clone();
661 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
662
663 client
664 .register(OperationIr::NumericInt(
665 desc.out.dtype,
666 NumericOperationIr::ArgMin(desc),
667 ))
668 .output()
669 }
670
671 fn int_clamp(
672 tensor: IntTensor<Self>,
673 min: IntElem<Self>,
674 max: IntElem<Self>,
675 ) -> IntTensor<Self> {
676 let client = tensor.client.clone();
677 let min = ScalarIr::with_dtype(min, &tensor.dtype);
678 let max = ScalarIr::with_dtype(max, &tensor.dtype);
679 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
680
681 client
682 .register(OperationIr::NumericInt(
683 desc.out.dtype,
684 NumericOperationIr::Clamp(desc),
685 ))
686 .output()
687 }
688
689 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
690 let client = tensor.client.clone();
691 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
692
693 client
694 .register(OperationIr::NumericInt(
695 desc.out.dtype,
696 NumericOperationIr::Abs(desc),
697 ))
698 .output()
699 }
700
701 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
702 let client = tensor.client.clone();
703 let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {
704 client.create_empty_handle()
705 });
706
707 client
708 .register(OperationIr::Int(IntOperationIr::IntoFloat(desc)))
709 .output()
710 }
711
712 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
713 let client = tensor.client.clone();
714 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
715 client.create_empty_handle()
716 });
717
718 client
719 .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc)))
720 .output()
721 }
722
723 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
724 let client = tensor.client.clone();
725 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
726
727 client
728 .register(OperationIr::NumericInt(
729 desc.out.dtype,
730 NumericOperationIr::Max(desc),
731 ))
732 .output()
733 }
734
735 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
736 let client = tensor.client.clone();
737 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
738
739 client
740 .register(OperationIr::NumericInt(
741 desc.out.dtype,
742 NumericOperationIr::MaxDim(desc),
743 ))
744 .output()
745 }
746
747 fn int_max_dim_with_indices(
748 tensor: IntTensor<Self>,
749 dim: usize,
750 ) -> (IntTensor<Self>, IntTensor<Self>) {
751 let client = tensor.client.clone();
752 let desc = ReduceDimWithIndicesOpIr::create(
753 tensor.into_ir(),
754 dim,
755 IntElem::<Self>::dtype(),
756 || client.create_empty_handle(),
757 );
758
759 client
760 .register(OperationIr::NumericInt(
761 desc.tensor.dtype,
762 NumericOperationIr::MaxDimWithIndices(desc),
763 ))
764 .outputs()
765 .into()
766 }
767
768 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
769 let client = tensor.client.clone();
770 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
771
772 client
773 .register(OperationIr::NumericInt(
774 desc.out.dtype,
775 NumericOperationIr::MaxAbs(desc),
776 ))
777 .output()
778 }
779
780 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
781 let client = tensor.client.clone();
782 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
783
784 client
785 .register(OperationIr::NumericInt(
786 desc.out.dtype,
787 NumericOperationIr::MaxAbsDim(desc),
788 ))
789 .output()
790 }
791
792 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
793 let client = tensor.client.clone();
794 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
795
796 client
797 .register(OperationIr::NumericInt(
798 desc.out.dtype,
799 NumericOperationIr::Min(desc),
800 ))
801 .output()
802 }
803
804 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
805 let client = tensor.client.clone();
806 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
807
808 client
809 .register(OperationIr::NumericInt(
810 desc.out.dtype,
811 NumericOperationIr::MinDim(desc),
812 ))
813 .output()
814 }
815
816 fn int_min_dim_with_indices(
817 tensor: IntTensor<Self>,
818 dim: usize,
819 ) -> (IntTensor<Self>, IntTensor<Self>) {
820 let client = tensor.client.clone();
821 let desc = ReduceDimWithIndicesOpIr::create(
822 tensor.into_ir(),
823 dim,
824 IntElem::<Self>::dtype(),
825 || client.create_empty_handle(),
826 );
827
828 client
829 .register(OperationIr::NumericInt(
830 desc.out.dtype,
831 NumericOperationIr::MinDimWithIndices(desc),
832 ))
833 .outputs()
834 .into()
835 }
836
837 fn int_random(
838 shape: Shape,
839 distribution: Distribution,
840 device: &Device<Self>,
841 ) -> IntTensor<Self> {
842 let client = get_client::<R>(device);
843 let dtype = IntElem::<Self>::dtype();
844 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
845
846 client
847 .register(OperationIr::NumericInt(
848 dtype,
849 NumericOperationIr::IntRandom(desc),
850 ))
851 .output()
852 }
853
854 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
855 let client = tensor.client.clone();
856 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
857 client.create_empty_handle()
858 });
859
860 client
861 .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc)))
862 .output()
863 }
864
865 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
866 let client = tensor.client.clone();
867 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
868
869 client
870 .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc)))
871 .output()
872 }
873
874 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
875 let client = tensor.client.clone();
876 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
877 client.create_empty_handle()
878 });
879
880 client
881 .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc)))
882 .output()
883 }
884
885 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
886 let client = tensor.client.clone();
887 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
888 client.create_empty_handle()
889 });
890
891 client
892 .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc)))
893 .output()
894 }
895
896 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
897 let client = lhs.client.clone();
898 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
899 client.create_empty_handle()
900 });
901
902 client
903 .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc)))
904 .output()
905 }
906
907 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
908 let client = lhs.client.clone();
909 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
910 client.create_empty_handle()
911 });
912
913 client
914 .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc)))
915 .output()
916 }
917
918 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
919 let client = lhs.client.clone();
920 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
921 client.create_empty_handle()
922 });
923
924 client
925 .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc)))
926 .output()
927 }
928
929 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
930 let client = tensor.client.clone();
931 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
932
933 client
934 .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc)))
935 .output()
936 }
937
938 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
939 let client = lhs.client.clone();
940 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
941 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
942
943 client
944 .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc)))
945 .output()
946 }
947
948 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
949 let client = lhs.client.clone();
950 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
951 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
952
953 client
954 .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc)))
955 .output()
956 }
957
958 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
959 let client = lhs.client.clone();
960 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
961 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
962
963 client
964 .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc)))
965 .output()
966 }
967
968 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
969 let client = lhs.client.clone();
970 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
971 client.create_empty_handle()
972 });
973
974 client
975 .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc)))
976 .output()
977 }
978
979 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
980 let client = lhs.client.clone();
981 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
982 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
983
984 client
985 .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(
986 desc,
987 )))
988 .output()
989 }
990
991 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
992 let client = lhs.client.clone();
993 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
994 client.create_empty_handle()
995 });
996
997 client
998 .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc)))
999 .output()
1000 }
1001
1002 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
1003 let client = lhs.client.clone();
1004 let rhs = ScalarIr::with_dtype(rhs, &lhs.dtype);
1005 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1006
1007 client
1008 .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(
1009 desc,
1010 )))
1011 .output()
1012 }
1013
1014 fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
1015 let client = tensor.client.clone();
1016 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
1017 client.create_empty_handle()
1018 });
1019
1020 client
1021 .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc)))
1022 .output()
1023 }
1024
1025 fn int_unfold(
1026 tensor: IntTensor<Self>,
1027 dim: usize,
1028 size: usize,
1029 step: usize,
1030 ) -> IntTensor<Self> {
1031 let client = tensor.client.clone();
1032 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
1033 client.create_empty_handle()
1034 });
1035
1036 client
1037 .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)))
1038 .output()
1039 }
1040}