1use super::NoOp;
2use crate::{
3 Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops,
4 client::GlobalFusionClient,
5 get_client, reduce_int_ops, scalar_int_cmp_ops, scalar_int_ops,
6 stream::{OperationStreams, execution::Operation},
7 unary_int_ops,
8};
9use burn_backend::{
10 BoolDType, Distribution, ExecutionError, FloatDType, IntDType, Scalar, Shape, Slice,
11 TensorData,
12 ops::IntTensorOps,
13 tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},
14};
15use burn_ir::*;
16use std::marker::PhantomData;
17
18impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
19 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
20 #[derive(new, Debug)]
21 struct EmptyOps<B: FusionBackend> {
22 desc: TensorIr,
23 device: Device<B>,
24 }
25
26 impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
27 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
28 let output = B::int_empty(
29 self.desc.shape.clone(),
30 &self.device,
31 self.desc.dtype.into(),
32 );
33 handles.register_int_tensor::<B>(&self.desc.id, output);
34 }
35 }
36
37 let client = get_client::<B>(device);
38 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
39
40 client
41 .register(
42 OperationStreams::default(),
43 OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())),
44 EmptyOps::<B>::new(desc.out, device.clone()),
45 )
46 .output()
47 }
48
49 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
50 tensor.int_into_data::<B>().await
51 }
52
53 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
54 let client = get_client::<B>(device);
55 let dtype = data.dtype;
56 let tensor = B::int_from_data(data, device);
57 let shape = burn_backend::TensorMetadata::shape(&tensor);
58
59 let handle = B::int_tensor_handle(tensor);
60 let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));
61
62 client
63 .register(
64 OperationStreams::default(),
65 OperationIr::Init(desc),
66 NoOp::<B>::new(),
67 )
68 .output()
69 }
70
71 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
72 tensor.client.device().clone()
73 }
74
75 fn int_to_device(tensor: IntTensor<Self>, device_dst: &Device<Self>) -> IntTensor<Self> {
76 let device_src: &B::Device = tensor.client.device();
77
78 if device_src == device_dst {
79 return tensor;
80 }
81
82 let id = tensor.stream;
83 let client_dst = get_client::<B>(device_dst);
84 let client_src = tensor.client.clone();
85
86 GlobalFusionClient::change_client_int::<B>(tensor.into_ir(), client_src, client_dst, id)
87 }
88
89 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
90 if tensor.shape == shape {
91 return tensor;
92 }
93
94 #[derive(new, Debug)]
95 struct ReshapeDimsOps<B: FusionBackend> {
96 desc: ShapeOpIr,
97 _b: PhantomData<B>,
98 }
99
100 impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
101 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
102 let input = handles.get_int_tensor::<B>(&self.desc.input);
103 let output = B::int_reshape(input, self.desc.out.shape.clone());
104 handles.register_int_tensor::<B>(&self.desc.out.id, output);
105 }
106 }
107
108 let streams = OperationStreams::with_inputs([&tensor]);
109
110 let client = tensor.client.clone();
111 let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());
112
113 client
114 .register(
115 streams,
116 OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())),
117 ReshapeDimsOps::<B>::new(desc),
118 )
119 .output()
120 }
121
122 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
123 #[derive(new, Debug)]
124 struct SliceOps<B: FusionBackend> {
125 desc: SliceOpIr,
126 _b: PhantomData<B>,
127 }
128
129 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
130 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
131 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
132
133 let output = B::int_slice(tensor, self.desc.ranges.as_slice());
134
135 handles.register_int_tensor::<B>(&self.desc.out.id, output);
136 }
137 }
138
139 let streams = OperationStreams::with_inputs([&tensor]);
140
141 let client = tensor.client.clone();
142 let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {
143 client.create_empty_handle()
144 });
145
146 client
147 .register(
148 streams,
149 OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())),
150 SliceOps::<B>::new(desc),
151 )
152 .output()
153 }
154
155 fn int_slice_assign(
156 tensor: IntTensor<Self>,
157 slices: &[burn_backend::Slice],
158 value: IntTensor<Self>,
159 ) -> IntTensor<Self> {
160 #[derive(new, Debug)]
161 struct SliceAssignOps<B: FusionBackend> {
162 desc: SliceAssignOpIr,
163 _b: PhantomData<B>,
164 }
165
166 impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
167 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
168 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
169 let value = handles.get_int_tensor::<B>(&self.desc.value);
170
171 let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value);
172
173 handles.register_int_tensor::<B>(&self.desc.out.id, output);
174 }
175 }
176
177 let streams = OperationStreams::with_inputs([&tensor, &value]);
178
179 let client = tensor.client.clone();
180 let desc =
181 SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {
182 client.create_empty_handle()
183 });
184
185 client
186 .register(
187 streams,
188 OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())),
189 SliceAssignOps::<B>::new(desc),
190 )
191 .output()
192 }
193
194 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
195 binary_int_ops!(MatmulOps, B::int_matmul);
196
197 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
198
199 let client = lhs.client.clone();
200 let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
201 client.create_empty_handle()
202 });
203
204 client
205 .register(
206 streams,
207 OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),
208 MatmulOps::<B>::new(desc.into()),
209 )
210 .output()
211 }
212
213 fn int_mask_where(
214 tensor: IntTensor<Self>,
215 mask: BoolTensor<Self>,
216 value: IntTensor<Self>,
217 ) -> IntTensor<Self> {
218 #[derive(new, Debug)]
219 struct MaskWhereOps<B: FusionBackend> {
220 desc: MaskWhereOpIr,
221 _b: PhantomData<B>,
222 }
223
224 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
225 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
226 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
227 let value = handles.get_int_tensor::<B>(&self.desc.value);
228 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
229
230 let output = B::int_mask_where(tensor, mask, value);
231
232 handles.register_int_tensor::<B>(&self.desc.out.id, output);
233 }
234 }
235
236 let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);
237
238 let client = tensor.client.clone();
239 let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {
240 client.create_empty_handle()
241 });
242
243 client
244 .register(
245 streams,
246 OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc.clone())),
247 MaskWhereOps::<B>::new(desc),
248 )
249 .output()
250 }
251
252 fn int_mask_fill(
253 tensor: IntTensor<Self>,
254 mask: BoolTensor<Self>,
255 value: Scalar,
256 ) -> IntTensor<Self> {
257 #[derive(new, Debug)]
258 struct MaskFillOps<B: FusionBackend> {
259 desc: MaskFillOpIr,
260 _b: PhantomData<B>,
261 }
262
263 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
264 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
265 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
266 let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
267
268 let output = B::int_mask_fill(tensor, mask, self.desc.value.into());
269
270 handles.register_int_tensor::<B>(&self.desc.out.id, output);
271 }
272 }
273
274 let streams = OperationStreams::with_inputs([&tensor, &mask]);
275
276 let client = tensor.client.clone();
277 let value = value.into();
278 let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {
279 client.create_empty_handle()
280 });
281
282 client
283 .register(
284 streams,
285 OperationIr::BaseInt(BaseOperationIr::MaskFill(desc.clone())),
286 MaskFillOps::<B>::new(desc),
287 )
288 .output()
289 }
290
291 fn int_gather(
292 dim: usize,
293 tensor: IntTensor<Self>,
294 indices: IntTensor<Self>,
295 ) -> IntTensor<Self> {
296 #[derive(new, Debug)]
297 struct GatherOps<B: FusionBackend> {
298 desc: GatherOpIr,
299 _b: PhantomData<B>,
300 }
301
302 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
303 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
304 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
305 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
306
307 let output = B::int_gather(self.desc.dim, tensor, indices);
308 handles.register_int_tensor::<B>(&self.desc.out.id, output);
309 }
310 }
311
312 let streams = OperationStreams::with_inputs([&tensor, &indices]);
313
314 let client = tensor.client.clone();
315 let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
316 client.create_empty_handle()
317 });
318
319 client
320 .register(
321 streams,
322 OperationIr::BaseInt(BaseOperationIr::Gather(desc.clone())),
323 GatherOps::<B>::new(desc),
324 )
325 .output()
326 }
327
328 fn int_scatter_add(
329 dim: usize,
330 tensor: IntTensor<Self>,
331 indices: IntTensor<Self>,
332 value: IntTensor<Self>,
333 ) -> IntTensor<Self> {
334 #[derive(new, Debug)]
335 struct ScatterOps<B: FusionBackend> {
336 desc: ScatterOpIr,
337 _b: PhantomData<B>,
338 }
339
340 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
341 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
342 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
343 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
344 let value = handles.get_int_tensor::<B>(&self.desc.value);
345
346 let output = B::int_scatter_add(self.desc.dim, tensor, indices, value);
347
348 handles.register_int_tensor::<B>(&self.desc.out.id, output);
349 }
350 }
351
352 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
353
354 let client = tensor.client.clone();
355 let desc = ScatterOpIr::create(
356 tensor.into_ir(),
357 dim,
358 indices.into_ir(),
359 value.into_ir(),
360 IndexingUpdateOp::Add,
361 || client.create_empty_handle(),
362 );
363
364 client
365 .register(
366 streams,
367 OperationIr::BaseInt(BaseOperationIr::Scatter(desc.clone())),
368 ScatterOps::<B>::new(desc),
369 )
370 .output()
371 }
372
373 fn int_scatter_nd(
374 data: IntTensor<Self>,
375 indices: IntTensor<Self>,
376 values: IntTensor<Self>,
377 reduction: IndexingUpdateOp,
378 ) -> IntTensor<Self> {
379 #[derive(new, Debug)]
380 struct ScatterNdOps<B: FusionBackend> {
381 desc: ScatterNdOpIr,
382 _b: PhantomData<B>,
383 }
384
385 impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterNdOps<B> {
386 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
387 let data = handles.get_int_tensor::<B>(&self.desc.data);
388 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
389 let values = handles.get_int_tensor::<B>(&self.desc.values);
390
391 let output = B::int_scatter_nd(data, indices, values, self.desc.reduction);
392
393 handles.register_int_tensor::<B>(&self.desc.out.id, output);
394 }
395 }
396
397 let streams = OperationStreams::with_inputs([&data, &indices, &values]);
398
399 let client = data.client.clone();
400 let desc = ScatterNdOpIr::create(
401 data.into_ir(),
402 indices.into_ir(),
403 values.into_ir(),
404 reduction,
405 || client.create_empty_handle(),
406 );
407
408 client
409 .register(
410 streams,
411 OperationIr::BaseInt(BaseOperationIr::ScatterNd(desc.clone())),
412 ScatterNdOps::<B>::new(desc),
413 )
414 .output()
415 }
416
417 fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
418 #[derive(new, Debug)]
419 struct GatherNdOps<B: FusionBackend> {
420 desc: GatherNdOpIr,
421 _b: PhantomData<B>,
422 }
423
424 impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherNdOps<B> {
425 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
426 let data = handles.get_int_tensor::<B>(&self.desc.data);
427 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
428
429 let output = B::int_gather_nd(data, indices);
430 handles.register_int_tensor::<B>(&self.desc.out.id, output);
431 }
432 }
433
434 let streams = OperationStreams::with_inputs([&data, &indices]);
435
436 let client = data.client.clone();
437 let desc = GatherNdOpIr::create(data.into_ir(), indices.into_ir(), || {
438 client.create_empty_handle()
439 });
440
441 client
442 .register(
443 streams,
444 OperationIr::BaseInt(BaseOperationIr::GatherNd(desc.clone())),
445 GatherNdOps::<B>::new(desc),
446 )
447 .output()
448 }
449
450 fn int_select(
451 tensor: IntTensor<Self>,
452 dim: usize,
453 indices: IntTensor<Self>,
454 ) -> IntTensor<Self> {
455 #[derive(new, Debug)]
456 struct SelectOps<B: FusionBackend> {
457 desc: SelectOpIr,
458 _b: PhantomData<B>,
459 }
460
461 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
462 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
463 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
464 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
465
466 let output = B::int_select(tensor, self.desc.dim, indices);
467
468 handles.register_int_tensor::<B>(&self.desc.out.id, output);
469 }
470 }
471
472 let streams = OperationStreams::with_inputs([&tensor, &indices]);
473
474 let client = tensor.client.clone();
475 let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {
476 client.create_empty_handle()
477 });
478
479 client
480 .register(
481 streams,
482 OperationIr::BaseInt(BaseOperationIr::Select(desc.clone())),
483 SelectOps::<B>::new(desc),
484 )
485 .output()
486 }
487
488 fn int_select_add(
489 tensor: IntTensor<Self>,
490 dim: usize,
491 indices: IntTensor<Self>,
492 value: IntTensor<Self>,
493 ) -> IntTensor<Self> {
494 #[derive(new, Debug)]
495 struct SelectAssignOps<B: FusionBackend> {
496 desc: SelectAssignOpIr,
497 _b: PhantomData<B>,
498 }
499
500 impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
501 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
502 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
503 let indices = handles.get_int_tensor::<B>(&self.desc.indices);
504 let value = handles.get_int_tensor::<B>(&self.desc.value);
505
506 let output = B::int_select_add(tensor, self.desc.dim, indices, value);
507
508 handles.register_int_tensor::<B>(&self.desc.out.id, output);
509 }
510 }
511
512 let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);
513
514 let client = tensor.client.clone();
515 let desc = SelectAssignOpIr::create(
516 tensor.into_ir(),
517 dim,
518 indices.into_ir(),
519 value.into_ir(),
520 IndexingUpdateOp::Add,
521 || client.create_empty_handle(),
522 );
523
524 client
525 .register(
526 streams,
527 OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc.clone())),
528 SelectAssignOps::<B>::new(desc),
529 )
530 .output()
531 }
532
533 fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
534 #[derive(new, Debug)]
535 struct CatOps<B: FusionBackend> {
536 desc: CatOpIr,
537 _b: PhantomData<B>,
538 }
539
540 impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
541 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
542 let tensors = self
543 .desc
544 .tensors
545 .iter()
546 .map(|tensor| handles.get_int_tensor::<B>(tensor))
547 .collect();
548
549 let output = B::int_cat(tensors, self.desc.dim);
550
551 handles.register_int_tensor::<B>(&self.desc.out.id, output);
552 }
553 }
554
555 let streams = OperationStreams::with_inputs(&tensors);
556
557 let client = tensors.first().unwrap().client.clone();
558 let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();
559 let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());
560
561 client
562 .register(
563 streams,
564 OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())),
565 CatOps::<B>::new(desc),
566 )
567 .output()
568 }
569
570 fn int_equal(
571 lhs: IntTensor<Self>,
572 rhs: IntTensor<Self>,
573 out_dtype: BoolDType,
574 ) -> BoolTensor<Self> {
575 binary_int_cmp_ops!(EqualOps, B::int_equal);
576
577 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
578
579 let client = lhs.client.clone();
580 let desc =
581 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
582 client.create_empty_handle()
583 });
584
585 client
586 .register(
587 streams,
588 OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())),
589 EqualOps::<B>::new(desc),
590 )
591 .output()
592 }
593
594 fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
595 scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
596
597 let streams = OperationStreams::with_inputs([&lhs]);
598
599 let client = lhs.client.clone();
600 let rhs = rhs.into();
601 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
602 client.create_empty_handle()
603 });
604
605 client
606 .register(
607 streams,
608 OperationIr::BaseInt(BaseOperationIr::EqualElem(desc.clone())),
609 EqualElemOps::<B>::new(desc),
610 )
611 .output()
612 }
613
614 fn int_greater(
615 lhs: IntTensor<Self>,
616 rhs: IntTensor<Self>,
617 out_dtype: BoolDType,
618 ) -> BoolTensor<Self> {
619 binary_int_cmp_ops!(GreaterOps, B::int_greater);
620
621 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
622
623 let client = lhs.client.clone();
624 let desc =
625 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
626 client.create_empty_handle()
627 });
628
629 client
630 .register(
631 streams,
632 OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Greater(desc.clone())),
633 GreaterOps::<B>::new(desc),
634 )
635 .output()
636 }
637
638 fn int_greater_elem(
639 lhs: IntTensor<Self>,
640 rhs: Scalar,
641 out_dtype: BoolDType,
642 ) -> BoolTensor<Self> {
643 scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
644
645 let streams = OperationStreams::with_inputs([&lhs]);
646
647 let client = lhs.client.clone();
648 let rhs = rhs.into();
649 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
650 client.create_empty_handle()
651 });
652
653 client
654 .register(
655 streams,
656 OperationIr::NumericInt(
657 desc.lhs.dtype,
658 NumericOperationIr::GreaterElem(desc.clone()),
659 ),
660 GreaterElemOps::<B>::new(desc),
661 )
662 .output()
663 }
664
665 fn int_greater_equal(
666 lhs: IntTensor<Self>,
667 rhs: IntTensor<Self>,
668 out_dtype: BoolDType,
669 ) -> BoolTensor<Self> {
670 binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal);
671
672 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
673
674 let client = lhs.client.clone();
675 let desc =
676 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
677 client.create_empty_handle()
678 });
679
680 client
681 .register(
682 streams,
683 OperationIr::NumericInt(
684 desc.lhs.dtype,
685 NumericOperationIr::GreaterEqual(desc.clone()),
686 ),
687 GreaterEqualOps::<B>::new(desc),
688 )
689 .output()
690 }
691
692 fn int_greater_equal_elem(
693 lhs: IntTensor<Self>,
694 rhs: Scalar,
695 out_dtype: BoolDType,
696 ) -> BoolTensor<Self> {
697 scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
698
699 let streams = OperationStreams::with_inputs([&lhs]);
700
701 let client = lhs.client.clone();
702 let rhs = rhs.into();
703 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
704 client.create_empty_handle()
705 });
706
707 client
708 .register(
709 streams,
710 OperationIr::NumericInt(
711 desc.lhs.dtype,
712 NumericOperationIr::GreaterEqualElem(desc.clone()),
713 ),
714 GreaterEqualElemOps::<B>::new(desc),
715 )
716 .output()
717 }
718
719 fn int_lower(
720 lhs: IntTensor<Self>,
721 rhs: IntTensor<Self>,
722 out_dtype: BoolDType,
723 ) -> BoolTensor<Self> {
724 binary_int_cmp_ops!(LowerOps, B::int_lower);
725
726 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
727
728 let client = lhs.client.clone();
729 let desc =
730 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
731 client.create_empty_handle()
732 });
733
734 client
735 .register(
736 streams,
737 OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),
738 LowerOps::<B>::new(desc),
739 )
740 .output()
741 }
742
743 fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
744 scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
745
746 let streams = OperationStreams::with_inputs([&lhs]);
747
748 let client = lhs.client.clone();
749 let rhs = rhs.into();
750 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
751 client.create_empty_handle()
752 });
753
754 client
755 .register(
756 streams,
757 OperationIr::NumericInt(
758 desc.lhs.dtype,
759 NumericOperationIr::LowerElem(desc.clone()),
760 ),
761 LowerElemOps::<B>::new(desc),
762 )
763 .output()
764 }
765
766 fn int_lower_equal(
767 lhs: IntTensor<Self>,
768 rhs: IntTensor<Self>,
769 out_dtype: BoolDType,
770 ) -> BoolTensor<Self> {
771 binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal);
772
773 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
774
775 let client = lhs.client.clone();
776 let desc =
777 BinaryOpIr::create_comparison(lhs.into_ir(), rhs.into_ir(), out_dtype.into(), || {
778 client.create_empty_handle()
779 });
780
781 client
782 .register(
783 streams,
784 OperationIr::NumericInt(
785 desc.lhs.dtype,
786 NumericOperationIr::LowerEqual(desc.clone()),
787 ),
788 LowerEqualOps::<B>::new(desc),
789 )
790 .output()
791 }
792
793 fn int_lower_equal_elem(
794 lhs: IntTensor<Self>,
795 rhs: Scalar,
796 out_dtype: BoolDType,
797 ) -> BoolTensor<Self> {
798 scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
799
800 let streams = OperationStreams::with_inputs([&lhs]);
801
802 let client = lhs.client.clone();
803 let rhs = rhs.into();
804 let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, out_dtype.into(), || {
805 client.create_empty_handle()
806 });
807
808 client
809 .register(
810 streams,
811 OperationIr::NumericInt(
812 desc.lhs.dtype,
813 NumericOperationIr::LowerEqualElem(desc.clone()),
814 ),
815 LowerEqualElemOps::<B>::new(desc),
816 )
817 .output()
818 }
819
820 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
821 binary_int_ops!(AddOps, B::int_add);
822
823 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
824
825 let client = lhs.client.clone();
826 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
827 client.create_empty_handle()
828 });
829
830 client
831 .register(
832 streams,
833 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Add(desc.clone())),
834 AddOps::<B>::new(desc),
835 )
836 .output()
837 }
838
839 fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
840 scalar_int_ops!(AddOps, B::int_add_scalar);
841
842 let streams = OperationStreams::with_inputs([&lhs]);
843
844 let client = lhs.client.clone();
845 let rhs = rhs.into();
846 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
847
848 client
849 .register(
850 streams,
851 OperationIr::NumericInt(
852 desc.out.dtype,
853 NumericOperationIr::AddScalar(desc.clone()),
854 ),
855 AddOps::<B>::new(desc),
856 )
857 .output()
858 }
859
860 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
861 binary_int_ops!(SubOps, B::int_sub);
862
863 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
864
865 let client = lhs.client.clone();
866 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
867 client.create_empty_handle()
868 });
869
870 client
871 .register(
872 streams,
873 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),
874 SubOps::<B>::new(desc),
875 )
876 .output()
877 }
878
879 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
880 scalar_int_ops!(SubOps, B::int_sub_scalar);
881
882 let streams = OperationStreams::with_inputs([&lhs]);
883
884 let client = lhs.client.clone();
885 let rhs = rhs.into();
886 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
887
888 client
889 .register(
890 streams,
891 OperationIr::NumericInt(
892 desc.out.dtype,
893 NumericOperationIr::SubScalar(desc.clone()),
894 ),
895 SubOps::<B>::new(desc),
896 )
897 .output()
898 }
899
900 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
901 binary_int_ops!(MulOps, B::int_mul);
902
903 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
904
905 let client = lhs.client.clone();
906 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
907 client.create_empty_handle()
908 });
909
910 client
911 .register(
912 streams,
913 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),
914 MulOps::<B>::new(desc),
915 )
916 .output()
917 }
918
919 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
920 scalar_int_ops!(MulOps, B::int_mul_scalar);
921
922 let streams = OperationStreams::with_inputs([&lhs]);
923
924 let client = lhs.client.clone();
925 let rhs = rhs.into();
926 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
927
928 client
929 .register(
930 streams,
931 OperationIr::NumericInt(
932 desc.out.dtype,
933 NumericOperationIr::MulScalar(desc.clone()),
934 ),
935 MulOps::<B>::new(desc),
936 )
937 .output()
938 }
939
940 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
941 binary_int_ops!(DivOps, B::int_div);
942
943 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
944
945 let client = lhs.client.clone();
946 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
947 client.create_empty_handle()
948 });
949
950 client
951 .register(
952 streams,
953 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Div(desc.clone())),
954 DivOps::<B>::new(desc),
955 )
956 .output()
957 }
958
959 fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
960 scalar_int_ops!(DivOps, B::int_div_scalar);
961
962 let streams = OperationStreams::with_inputs([&lhs]);
963
964 let client = lhs.client.clone();
965 let rhs = rhs.into();
966 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
967
968 client
969 .register(
970 streams,
971 OperationIr::NumericInt(
972 desc.out.dtype,
973 NumericOperationIr::DivScalar(desc.clone()),
974 ),
975 DivOps::<B>::new(desc),
976 )
977 .output()
978 }
979
980 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
981 binary_int_ops!(ModOps, B::int_remainder);
982
983 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
984
985 let client = lhs.client.clone();
986 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
987 client.create_empty_handle()
988 });
989
990 client
991 .register(
992 streams,
993 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),
994 ModOps::<B>::new(desc),
995 )
996 .output()
997 }
998
999 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1000 scalar_int_ops!(ModOps, B::int_remainder_scalar);
1001
1002 let streams = OperationStreams::with_inputs([&lhs]);
1003
1004 let client = lhs.client.clone();
1005 let rhs = rhs.into();
1006 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1007
1008 client
1009 .register(
1010 streams,
1011 OperationIr::NumericInt(
1012 desc.out.dtype,
1013 NumericOperationIr::RemScalar(desc.clone()),
1014 ),
1015 ModOps::<B>::new(desc),
1016 )
1017 .output()
1018 }
1019
1020 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1021 #[derive(new, Debug)]
1022 struct ZerosOps<B: FusionBackend> {
1023 desc: TensorIr,
1024 device: Device<B>,
1025 }
1026
1027 impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
1028 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1029 let shape = self.desc.shape.clone();
1030 let output = B::int_zeros(shape, &self.device, self.desc.dtype.into());
1031 handles.register_int_tensor::<B>(&self.desc.id, output);
1032 }
1033 }
1034
1035 let client = get_client::<B>(device);
1036 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
1037
1038 client
1039 .register(
1040 OperationStreams::default(),
1041 OperationIr::BaseInt(BaseOperationIr::Zeros(desc.clone())),
1042 ZerosOps::<B>::new(desc.out, device.clone()),
1043 )
1044 .output()
1045 }
1046
1047 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1048 #[derive(new, Debug)]
1049 struct OnesOps<B: FusionBackend> {
1050 desc: TensorIr,
1051 device: Device<B>,
1052 }
1053
1054 impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
1055 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1056 let shape = self.desc.shape.clone();
1057 let output = B::int_ones(shape, &self.device, self.desc.dtype.into());
1058 handles.register_int_tensor::<B>(&self.desc.id, output);
1059 }
1060 }
1061 let client = get_client::<B>(device);
1062 let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());
1063
1064 client
1065 .register(
1066 OperationStreams::default(),
1067 OperationIr::BaseInt(BaseOperationIr::Ones(desc.clone())),
1068 OnesOps::<B>::new(desc.out, device.clone()),
1069 )
1070 .output()
1071 }
1072
1073 fn int_full(
1074 shape: Shape,
1075 fill_value: Scalar,
1076 device: &Device<Self>,
1077 dtype: IntDType,
1078 ) -> IntTensor<Self> {
1079 #[derive(new, Debug)]
1080 struct FullOps<B: FusionBackend> {
1081 out: TensorIr,
1082 elem: ScalarIr,
1083 device: Device<B>,
1084 }
1085
1086 impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
1087 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1088 let shape = self.out.shape.clone();
1089 let output =
1090 B::int_full(shape, self.elem.into(), &self.device, self.out.dtype.into());
1091 handles.register_int_tensor::<B>(&self.out.id, output);
1092 }
1093 }
1094
1095 let client = get_client::<B>(device);
1096 let dtype = dtype.into();
1097 let value = fill_value.into();
1098 let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());
1099
1100 client
1101 .register(
1102 OperationStreams::default(),
1103 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Full(desc.clone())),
1104 FullOps::<B>::new(desc.out, desc.value, device.clone()),
1105 )
1106 .output()
1107 }
1108
1109 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
1110 unary_int_ops!(SumOps, B::int_sum, reduce);
1111
1112 let streams = OperationStreams::with_inputs([&tensor]);
1113
1114 let client = tensor.client.clone();
1115 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1116
1117 client
1118 .register(
1119 streams,
1120 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),
1121 SumOps::<B>::new(desc.into()),
1122 )
1123 .output()
1124 }
1125
1126 fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
1127 reduce_int_ops!(SumDimOps, |tensor, axis, _| B::int_sum_dim(tensor, axis));
1128
1129 let streams = OperationStreams::with_inputs([&tensor]);
1130
1131 let client = tensor.client.clone();
1132 let desc =
1133 ReduceDimOpIr::create(tensor.into_ir(), axis, 1, || client.create_empty_handle());
1134
1135 client
1136 .register(
1137 streams,
1138 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),
1139 SumDimOps::<B>::new(desc),
1140 )
1141 .output()
1142 }
1143
1144 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
1145 unary_int_ops!(ProdOps, B::int_prod, reduce);
1146
1147 let streams = OperationStreams::with_inputs([&tensor]);
1148
1149 let client = tensor.client.clone();
1150 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1151
1152 client
1153 .register(
1154 streams,
1155 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),
1156 ProdOps::<B>::new(desc.into()),
1157 )
1158 .output()
1159 }
1160
1161 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1162 reduce_int_ops!(ProdDimOps, |tensor, axis, _| B::int_prod_dim(tensor, axis));
1163
1164 let streams = OperationStreams::with_inputs([&tensor]);
1165
1166 let client = tensor.client.clone();
1167 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1168
1169 client
1170 .register(
1171 streams,
1172 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ProdDim(desc.clone())),
1173 ProdDimOps::<B>::new(desc),
1174 )
1175 .output()
1176 }
1177
1178 fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
1179 unary_int_ops!(MeanOps, B::int_mean, reduce);
1180
1181 let streams = OperationStreams::with_inputs([&tensor]);
1182
1183 let client = tensor.client.clone();
1184 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1185
1186 client
1187 .register(
1188 streams,
1189 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),
1190 MeanOps::<B>::new(desc.into()),
1191 )
1192 .output()
1193 }
1194
1195 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1196 reduce_int_ops!(MeanDimOps, |tensor, axis, _| B::int_mean_dim(tensor, axis));
1197
1198 let streams = OperationStreams::with_inputs([&tensor]);
1199
1200 let client = tensor.client.clone();
1201 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1202
1203 client
1204 .register(
1205 streams,
1206 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MeanDim(desc.clone())),
1207 MeanDimOps::<B>::new(desc),
1208 )
1209 .output()
1210 }
1211
1212 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1213 #[derive(new, Debug)]
1214 struct CumsumOps<B: FusionBackend> {
1215 desc: DimOpIr,
1216 _b: PhantomData<B>,
1217 }
1218
1219 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1220 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1221 let input = handles.get_int_tensor::<B>(&self.desc.input);
1222 let output = B::int_cumsum(input, self.desc.axis);
1223 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1224 }
1225 }
1226
1227 let streams = OperationStreams::with_inputs([&tensor]);
1228
1229 let client = tensor.client.clone();
1230 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1231
1232 client
1233 .register(
1234 streams,
1235 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),
1236 CumsumOps::<B>::new(desc),
1237 )
1238 .output()
1239 }
1240
1241 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1242 #[derive(new, Debug)]
1243 struct CumprodOps<B: FusionBackend> {
1244 desc: DimOpIr,
1245 _b: PhantomData<B>,
1246 }
1247
1248 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1249 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1250 let input = handles.get_int_tensor::<B>(&self.desc.input);
1251 let output = B::int_cumprod(input, self.desc.axis);
1252 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1253 }
1254 }
1255
1256 let streams = OperationStreams::with_inputs([&tensor]);
1257
1258 let client = tensor.client.clone();
1259 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1260
1261 client
1262 .register(
1263 streams,
1264 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumProd(desc.clone())),
1265 CumprodOps::<B>::new(desc),
1266 )
1267 .output()
1268 }
1269
1270 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1271 #[derive(new, Debug)]
1272 struct CumminOps<B: FusionBackend> {
1273 desc: DimOpIr,
1274 _b: PhantomData<B>,
1275 }
1276
1277 impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1278 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1279 let input = handles.get_int_tensor::<B>(&self.desc.input);
1280 let output = B::int_cummin(input, self.desc.axis);
1281 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1282 }
1283 }
1284
1285 let streams = OperationStreams::with_inputs([&tensor]);
1286
1287 let client = tensor.client.clone();
1288 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1289
1290 client
1291 .register(
1292 streams,
1293 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),
1294 CumminOps::<B>::new(desc),
1295 )
1296 .output()
1297 }
1298
1299 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1300 #[derive(new, Debug)]
1301 struct CummaxOps<B: FusionBackend> {
1302 desc: DimOpIr,
1303 _b: PhantomData<B>,
1304 }
1305
1306 impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1307 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1308 let input = handles.get_int_tensor::<B>(&self.desc.input);
1309 let output = B::int_cummax(input, self.desc.axis);
1310 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1311 }
1312 }
1313
1314 let streams = OperationStreams::with_inputs([&tensor]);
1315
1316 let client = tensor.client.clone();
1317 let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());
1318
1319 client
1320 .register(
1321 streams,
1322 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),
1323 CummaxOps::<B>::new(desc),
1324 )
1325 .output()
1326 }
1327
1328 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1329 reduce_int_ops!(ArgMaxOps, |tensor, axis, _| B::int_argmax(tensor, axis));
1330
1331 let streams = OperationStreams::with_inputs([&tensor]);
1332
1333 let client = tensor.client.clone();
1334 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1335
1336 client
1337 .register(
1338 streams,
1339 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMax(desc.clone())),
1340 ArgMaxOps::<B>::new(desc),
1341 )
1342 .output()
1343 }
1344
1345 fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
1346 reduce_int_ops!(ArgTopKOps, B::int_argtopk);
1347
1348 let streams = OperationStreams::with_inputs([&tensor]);
1349
1350 let client = tensor.client.clone();
1351 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1352
1353 client
1354 .register(
1355 streams,
1356 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgTopK(desc.clone())),
1357 ArgTopKOps::<B>::new(desc),
1358 )
1359 .output()
1360 }
1361
1362 fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
1363 reduce_int_ops!(TopKOps, B::int_topk);
1364
1365 let streams = OperationStreams::with_inputs([&tensor]);
1366
1367 let client = tensor.client.clone();
1368 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, k, || client.create_empty_handle());
1369
1370 client
1371 .register(
1372 streams,
1373 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::TopK(desc.clone())),
1374 TopKOps::<B>::new(desc),
1375 )
1376 .output()
1377 }
1378
1379 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1380 reduce_int_ops!(ArgMinOps, |tensor, axis, _| B::int_argmin(tensor, axis));
1381
1382 let streams = OperationStreams::with_inputs([&tensor]);
1383
1384 let client = tensor.client.clone();
1385 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1386
1387 client
1388 .register(
1389 streams,
1390 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMin(desc.clone())),
1391 ArgMinOps::<B>::new(desc),
1392 )
1393 .output()
1394 }
1395
1396 fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
1397 #[derive(new, Debug)]
1398 struct ClampOps<B: FusionBackend> {
1399 desc: ClampOpIr,
1400 _b: PhantomData<B>,
1401 }
1402
1403 impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
1404 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1405 let input = handles.get_int_tensor::<B>(&self.desc.tensor);
1406 let output = B::int_clamp(input, self.desc.min.into(), self.desc.max.into());
1407
1408 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1409 }
1410 }
1411
1412 let streams = OperationStreams::with_inputs([&tensor]);
1413
1414 let client = tensor.client.clone();
1415 let min = min.into();
1416 let max = max.into();
1417 let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());
1418
1419 client
1420 .register(
1421 streams,
1422 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Clamp(desc.clone())),
1423 ClampOps::<B>::new(desc),
1424 )
1425 .output()
1426 }
1427
1428 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1429 unary_int_ops!(AbsOps, B::int_abs);
1430
1431 let streams = OperationStreams::with_inputs([&tensor]);
1432
1433 let client = tensor.client.clone();
1434 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1435
1436 client
1437 .register(
1438 streams,
1439 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),
1440 AbsOps::<B>::new(desc),
1441 )
1442 .output()
1443 }
1444
1445 fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
1446 #[derive(new, Debug)]
1447 struct IntoFloatOps<B: FusionBackend> {
1448 desc: CastOpIr,
1449 _b: PhantomData<B>,
1450 }
1451
1452 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
1453 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1454 let input = handles.get_int_tensor::<B>(&self.desc.input);
1455 let output = B::int_into_float(input, self.desc.out.dtype.into());
1456 handles.register_float_tensor::<B>(&self.desc.out.id, output);
1457 }
1458 }
1459
1460 let streams = OperationStreams::with_inputs([&tensor]);
1461
1462 let client = tensor.client.clone();
1463 let desc = CastOpIr::create(tensor.into_ir(), out_dtype.into(), || {
1464 client.create_empty_handle()
1465 });
1466
1467 client
1468 .register(
1469 streams,
1470 OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())),
1471 IntoFloatOps::<B>::new(desc),
1472 )
1473 .output()
1474 }
1475
1476 fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
1477 #[derive(new, Debug)]
1478 struct SwapDimsOps<B: FusionBackend> {
1479 desc: SwapDimsOpIr,
1480 _b: PhantomData<B>,
1481 }
1482
1483 impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
1484 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1485 let input = handles.get_int_tensor::<B>(&self.desc.input);
1486 let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
1487 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1488 }
1489 }
1490 let streams = OperationStreams::with_inputs([&tensor]);
1491
1492 let client = tensor.client.clone();
1493 let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {
1494 client.create_empty_handle()
1495 });
1496
1497 client
1498 .register(
1499 streams,
1500 OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())),
1501 SwapDimsOps::<B>::new(desc),
1502 )
1503 .output()
1504 }
1505
1506 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
1507 unary_int_ops!(MaxOps, B::int_max, reduce);
1508
1509 let streams = OperationStreams::with_inputs([&tensor]);
1510
1511 let client = tensor.client.clone();
1512 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1513
1514 client
1515 .register(
1516 streams,
1517 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Max(desc.clone())),
1518 MaxOps::<B>::new(desc.into()),
1519 )
1520 .output()
1521 }
1522
1523 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1524 reduce_int_ops!(MaxDimOps, |tensor, axis, _| B::int_max_dim(tensor, axis));
1525
1526 let streams = OperationStreams::with_inputs([&tensor]);
1527
1528 let client = tensor.client.clone();
1529 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1530
1531 client
1532 .register(
1533 streams,
1534 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),
1535 MaxDimOps::<B>::new(desc),
1536 )
1537 .output()
1538 }
1539
1540 fn int_max_dim_with_indices(
1541 tensor: IntTensor<Self>,
1542 dim: usize,
1543 ) -> (IntTensor<Self>, IntTensor<Self>) {
1544 #[derive(new, Debug)]
1545 struct MaxDimWithIndicesOps<B: FusionBackend> {
1546 desc: ReduceDimWithIndicesOpIr,
1547 _b: PhantomData<B>,
1548 }
1549
1550 impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
1551 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1552 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1553 let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
1554
1555 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1556 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1557 }
1558 }
1559
1560 let streams = OperationStreams::with_inputs([&tensor]);
1561
1562 let client = tensor.client.clone();
1563 let dtype = tensor.dtype;
1564 let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {
1565 client.create_empty_handle()
1566 });
1567
1568 client
1569 .register(
1570 streams,
1571 OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
1572 MaxDimWithIndicesOps::<B>::new(desc),
1573 )
1574 .outputs()
1575 .into()
1576 }
1577
1578 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
1579 unary_int_ops!(MinOps, B::int_min, reduce);
1580
1581 let streams = OperationStreams::with_inputs([&tensor]);
1582
1583 let client = tensor.client.clone();
1584 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1585
1586 client
1587 .register(
1588 streams,
1589 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Min(desc.clone())),
1590 MinOps::<B>::new(desc.into()),
1591 )
1592 .output()
1593 }
1594
1595 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1596 unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce);
1597
1598 let streams = OperationStreams::with_inputs([&tensor]);
1599
1600 let client = tensor.client.clone();
1601 let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1602
1603 client
1604 .register(
1605 streams,
1606 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),
1607 MaxAbsOps::<B>::new(desc.into()),
1608 )
1609 .output()
1610 }
1611
1612 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1613 reduce_int_ops!(MaxAbsDimOps, |tensor, axis, _| B::int_max_abs_dim(
1614 tensor, axis
1615 ));
1616
1617 let streams = OperationStreams::with_inputs([&tensor]);
1618
1619 let client = tensor.client.clone();
1620 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1621
1622 client
1623 .register(
1624 streams,
1625 OperationIr::NumericInt(
1626 desc.out.dtype,
1627 NumericOperationIr::MaxAbsDim(desc.clone()),
1628 ),
1629 MaxAbsDimOps::<B>::new(desc),
1630 )
1631 .output()
1632 }
1633
1634 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1635 reduce_int_ops!(MinDimOps, |tensor, axis, _| B::int_min_dim(tensor, axis));
1636
1637 let streams = OperationStreams::with_inputs([&tensor]);
1638
1639 let client = tensor.client.clone();
1640 let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, 1, || client.create_empty_handle());
1641
1642 client
1643 .register(
1644 streams,
1645 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),
1646 MinDimOps::<B>::new(desc),
1647 )
1648 .output()
1649 }
1650
1651 fn int_min_dim_with_indices(
1652 tensor: IntTensor<Self>,
1653 dim: usize,
1654 ) -> (IntTensor<Self>, IntTensor<Self>) {
1655 #[derive(new, Debug)]
1656 struct MinDimWithIndicesOps<B: FusionBackend> {
1657 desc: ReduceDimWithIndicesOpIr,
1658 _b: PhantomData<B>,
1659 }
1660
1661 impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
1662 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1663 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1664 let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
1665
1666 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1667 handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1668 }
1669 }
1670
1671 let streams = OperationStreams::with_inputs([&tensor]);
1672
1673 let client = tensor.client.clone();
1674 let dtype = tensor.dtype;
1675 let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {
1676 client.create_empty_handle()
1677 });
1678
1679 client
1680 .register(
1681 streams,
1682 OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
1683 MinDimWithIndicesOps::<B>::new(desc),
1684 )
1685 .outputs()
1686 .into()
1687 }
1688
1689 fn int_random(
1690 shape: Shape,
1691 distribution: Distribution,
1692 device: &Device<Self>,
1693 dtype: IntDType,
1694 ) -> IntTensor<Self> {
1695 #[derive(new, Debug)]
1696 struct IntRandomOps<B: FusionBackend> {
1697 desc: RandomOpIr,
1698 device: Device<B>,
1699 }
1700
1701 impl<B: FusionBackend> Operation<B::FusionRuntime> for IntRandomOps<B> {
1702 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1703 let shape = self.desc.out.shape.clone();
1704 let output = B::int_random(
1705 shape,
1706 self.desc.distribution,
1707 &self.device,
1708 self.desc.out.dtype.into(),
1709 );
1710 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1711 }
1712 }
1713
1714 let dtype = dtype.into();
1715 let client = get_client::<B>(device);
1716 let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());
1717
1718 client
1719 .register(
1720 OperationStreams::default(),
1721 OperationIr::NumericInt(dtype, NumericOperationIr::IntRandom(desc.clone())),
1722 IntRandomOps::<B>::new(desc, device.clone()),
1723 )
1724 .output()
1725 }
1726
1727 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1728 #[derive(new, Debug)]
1729 struct PermuteDimsOps<B: FusionBackend> {
1730 desc: PermuteOpIr,
1731 _b: PhantomData<B>,
1732 }
1733
1734 impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
1735 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1736 let input = handles.get_int_tensor::<B>(&self.desc.input);
1737 let output = B::int_permute(input, self.desc.axes.as_slice());
1738 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1739 }
1740 }
1741
1742 let streams = OperationStreams::with_inputs([&tensor]);
1743
1744 let client = tensor.client.clone();
1745 let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {
1746 client.create_empty_handle()
1747 });
1748
1749 client
1750 .register(
1751 streams,
1752 OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
1753 PermuteDimsOps::<B>::new(desc),
1754 )
1755 .output()
1756 }
1757
1758 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
1759 #[derive(new, Debug)]
1760 struct ExpandOps<B: FusionBackend> {
1761 desc: ShapeOpIr,
1762 _b: PhantomData<B>,
1763 }
1764
1765 impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
1766 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1767 let input = handles.get_int_tensor::<B>(&self.desc.input);
1768 let output = B::int_expand(input, self.desc.out.shape.clone());
1769 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1770 }
1771 }
1772
1773 let streams = OperationStreams::with_inputs([&tensor]);
1774
1775 let client = tensor.client.clone();
1776 let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());
1777
1778 client
1779 .register(
1780 streams,
1781 OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())),
1782 ExpandOps::<B>::new(desc),
1783 )
1784 .output()
1785 }
1786
1787 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1788 #[derive(new, Debug)]
1789 struct FlipDimsOps<B: FusionBackend> {
1790 desc: FlipOpIr,
1791 _b: PhantomData<B>,
1792 }
1793
1794 impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipDimsOps<B> {
1795 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1796 let input = handles.get_int_tensor::<B>(&self.desc.input);
1797 let axes = &self.desc.axes;
1798 let output = B::int_flip(input, axes);
1799 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1800 }
1801 }
1802
1803 let streams = OperationStreams::with_inputs([&tensor]);
1804
1805 let client = tensor.client.clone();
1806 let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {
1807 client.create_empty_handle()
1808 });
1809
1810 client
1811 .register(
1812 streams,
1813 OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
1814 FlipDimsOps::<B>::new(desc),
1815 )
1816 .output()
1817 }
1818
1819 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
1820 #[derive(new, Debug)]
1821 struct RepeatDimOps<B: FusionBackend> {
1822 desc: RepeatDimOpIr,
1823 _b: PhantomData<B>,
1824 }
1825
1826 impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1827 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1828 let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1829
1830 let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times);
1831
1832 handles.register_int_tensor::<B>(&self.desc.out.id, output);
1833 }
1834 }
1835
1836 let streams = OperationStreams::with_inputs([&tensor]);
1837
1838 let client = tensor.client.clone();
1839 let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {
1840 client.create_empty_handle()
1841 });
1842
1843 client
1844 .register(
1845 streams,
1846 OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())),
1847 RepeatDimOps::<B>::new(desc),
1848 )
1849 .output()
1850 }
1851
1852 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1853 binary_int_ops!(BitwiseAndOps, B::bitwise_and);
1854
1855 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1856
1857 let client = lhs.client.clone();
1858 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1859 client.create_empty_handle()
1860 });
1861
1862 client
1863 .register(
1864 streams,
1865 OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())),
1866 BitwiseAndOps::<B>::new(desc),
1867 )
1868 .output()
1869 }
1870
1871 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1872 scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);
1873
1874 let streams = OperationStreams::with_inputs([&lhs]);
1875
1876 let client = lhs.client.clone();
1877 let rhs = rhs.into();
1878 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1879
1880 client
1881 .register(
1882 streams,
1883 OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())),
1884 BitwiseAndOps::<B>::new(desc),
1885 )
1886 .output()
1887 }
1888
1889 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1890 binary_int_ops!(BitwiseOrOps, B::bitwise_or);
1891
1892 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1893
1894 let client = lhs.client.clone();
1895 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1896 client.create_empty_handle()
1897 });
1898
1899 client
1900 .register(
1901 streams,
1902 OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())),
1903 BitwiseOrOps::<B>::new(desc),
1904 )
1905 .output()
1906 }
1907
1908 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1909 scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);
1910
1911 let streams = OperationStreams::with_inputs([&lhs]);
1912
1913 let client = lhs.client.clone();
1914 let rhs = rhs.into();
1915 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1916
1917 client
1918 .register(
1919 streams,
1920 OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())),
1921 BitwiseOrOps::<B>::new(desc),
1922 )
1923 .output()
1924 }
1925
1926 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1927 binary_int_ops!(BitwiseXorOps, B::bitwise_xor);
1928
1929 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1930
1931 let client = lhs.client.clone();
1932 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1933 client.create_empty_handle()
1934 });
1935
1936 client
1937 .register(
1938 streams,
1939 OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())),
1940 BitwiseXorOps::<B>::new(desc),
1941 )
1942 .output()
1943 }
1944
1945 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
1946 scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);
1947
1948 let streams = OperationStreams::with_inputs([&lhs]);
1949
1950 let client = lhs.client.clone();
1951 let rhs = rhs.into();
1952 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
1953
1954 client
1955 .register(
1956 streams,
1957 OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())),
1958 BitwiseXorOps::<B>::new(desc),
1959 )
1960 .output()
1961 }
1962
1963 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
1964 unary_int_ops!(BitwiseNotOps, B::bitwise_not);
1965
1966 let streams = OperationStreams::with_inputs([&tensor]);
1967
1968 let client = tensor.client.clone();
1969 let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());
1970
1971 client
1972 .register(
1973 streams,
1974 OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())),
1975 BitwiseNotOps::<B>::new(desc),
1976 )
1977 .output()
1978 }
1979
1980 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1981 binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);
1982
1983 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
1984
1985 let client = lhs.client.clone();
1986 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
1987 client.create_empty_handle()
1988 });
1989
1990 client
1991 .register(
1992 streams,
1993 OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())),
1994 BitwiseLeftShiftOps::<B>::new(desc),
1995 )
1996 .output()
1997 }
1998
1999 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2000 scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);
2001
2002 let streams = OperationStreams::with_inputs([&lhs]);
2003
2004 let client = lhs.client.clone();
2005 let rhs = rhs.into();
2006 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2007
2008 client
2009 .register(
2010 streams,
2011 OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())),
2012 BitwiseLeftShiftOps::<B>::new(desc),
2013 )
2014 .output()
2015 }
2016
2017 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2018 binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);
2019
2020 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2021
2022 let client = lhs.client.clone();
2023 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2024 client.create_empty_handle()
2025 });
2026
2027 client
2028 .register(
2029 streams,
2030 OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())),
2031 BitwiseRightShiftOps::<B>::new(desc),
2032 )
2033 .output()
2034 }
2035
2036 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2037 scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);
2038
2039 let streams = OperationStreams::with_inputs([&lhs]);
2040
2041 let client = lhs.client.clone();
2042 let rhs = rhs.into();
2043 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2044
2045 client
2046 .register(
2047 streams,
2048 OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())),
2049 BitwiseRightShiftOps::<B>::new(desc),
2050 )
2051 .output()
2052 }
2053
2054 fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {
2055 #[derive(new, Debug)]
2056 struct CastOps<B: FusionBackend> {
2057 desc: CastOpIr,
2058 dtype: burn_backend::IntDType,
2059 _b: PhantomData<B>,
2060 }
2061
2062 impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2063 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2064 let input = handles.get_int_tensor::<B>(&self.desc.input);
2065 let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype);
2066 handles.register_int_tensor::<B>(&self.desc.out.id, output);
2067 }
2068 }
2069
2070 let streams = OperationStreams::with_inputs([&tensor]);
2071
2072 let client = tensor.client.clone();
2073 let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {
2074 client.create_empty_handle()
2075 });
2076
2077 client
2078 .register(
2079 streams,
2080 OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),
2081 CastOps::<B>::new(desc, dtype),
2082 )
2083 .output()
2084 }
2085
2086 fn int_unfold(
2087 tensor: IntTensor<Self>,
2088 dim: usize,
2089 size: usize,
2090 step: usize,
2091 ) -> IntTensor<Self> {
2092 #[derive(new, Debug)]
2093 struct UnfoldOps<B: FusionBackend> {
2094 desc: UnfoldOpIr,
2095 _b: PhantomData<B>,
2096 }
2097
2098 impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2099 fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2100 let input = handles.get_int_tensor::<B>(&self.desc.input);
2101 let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2102
2103 handles.register_int_tensor::<B>(&self.desc.out.id, output);
2104 }
2105 }
2106
2107 let streams = OperationStreams::with_inputs([&tensor]);
2108
2109 let client = tensor.client.clone();
2110 let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {
2111 client.create_empty_handle()
2112 });
2113
2114 client
2115 .register(
2116 streams,
2117 OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2118 UnfoldOps::<B>::new(desc),
2119 )
2120 .output()
2121 }
2122
2123 fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2124 binary_int_ops!(PowOps, B::int_powi);
2125
2126 let streams = OperationStreams::with_inputs([&lhs, &rhs]);
2127
2128 let client = lhs.client.clone();
2129 let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {
2130 client.create_empty_handle()
2131 });
2132
2133 client
2134 .register(
2135 streams,
2136 OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Powi(desc.clone())),
2137 PowOps::<B>::new(desc),
2138 )
2139 .output()
2140 }
2141
2142 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
2143 scalar_int_ops!(PowiOps, B::int_powi_scalar);
2144
2145 let streams = OperationStreams::with_inputs([&lhs]);
2146
2147 let client = lhs.client.clone();
2148 let rhs = rhs.into();
2149 let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());
2150
2151 client
2152 .register(
2153 streams,
2154 OperationIr::NumericInt(
2155 desc.out.dtype,
2156 NumericOperationIr::PowiScalar(desc.clone()),
2157 ),
2158 PowiOps::<B>::new(desc),
2159 )
2160 .output()
2161 }
2162}