burn_fusion/ops/
int.rs

1use super::NoOp;
2use crate::{
3    Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops, get_client, reduce_int_ops,
4    scalar_int_cmp_ops, scalar_int_ops,
5    stream::{OperationStreams, StreamId, execution::Operation},
6    unary_int_ops,
7};
8use burn_ir::*;
9use burn_tensor::ops::unfold::calculate_unfold_shape;
10use burn_tensor::{
11    Device, Distribution, Element, IntDType, Shape, Slice, TensorData, TensorMetadata,
12    ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps},
13};
14use std::marker::PhantomData;
15
16impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
17    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
18        #[derive(new, Debug)]
19        struct EmptyOps<B: FusionBackend> {
20            desc: TensorIr,
21            device: Device<B>,
22        }
23
24        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
25            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
26                let output = B::int_empty(
27                    self.desc.shape.clone(),
28                    &self.device,
29                    self.desc.dtype.into(),
30                );
31                handles.register_int_tensor::<B>(&self.desc.id, output);
32            }
33        }
34
35        let client = get_client::<B>(&device.clone());
36        let out = client.tensor_uninitialized(shape, dtype.into());
37
38        let desc = out.to_ir_out();
39
40        client.register(
41            OperationStreams::default(),
42            OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())),
43            EmptyOps::<B>::new(desc, device.clone()),
44        );
45
46        out
47    }
48
49    async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
50        tensor.int_into_data::<B>().await
51    }
52
53    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
54        let stream = StreamId::current();
55        let client = get_client::<B>(&device.clone());
56        let dtype = data.dtype;
57        let tensor = B::int_from_data(data, device);
58        let shape = tensor.shape();
59
60        let handle = B::int_tensor_handle(tensor);
61        let out = client.register_tensor(handle, shape, stream, dtype);
62        let desc = out.to_ir_out();
63
64        client.register(
65            OperationStreams::default(),
66            OperationIr::Init(InitOperationIr { out: desc }),
67            NoOp::<B>::new(),
68        );
69
70        out
71    }
72
73    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
74        tensor.client.device().clone()
75    }
76
77    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
78        let device_original: &B::Device = tensor.client.device();
79        let device_target: B::Device = device.clone();
80
81        if device_original == &device_target {
82            return tensor;
83        }
84
85        let id = tensor.stream;
86        let client_target = get_client::<B>(&device_target);
87        let client_original = tensor.client.clone();
88
89        client_original
90            .clone()
91            .change_client_int::<B>(tensor.into_ir(), client_target, id)
92    }
93
94    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
95        if tensor.shape == shape {
96            return tensor;
97        }
98
99        #[derive(new, Debug)]
100        struct ReshapeDimsOps<B: FusionBackend> {
101            desc: UnaryOpIr,
102            _b: PhantomData<B>,
103        }
104
105        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {
106            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
107                let input = handles.get_int_tensor::<B>(&self.desc.input);
108                let output = B::int_reshape(input, self.desc.out.shape.clone());
109                handles.register_int_tensor::<B>(&self.desc.out.id, output);
110            }
111        }
112
113        let dtype = tensor.dtype;
114        let mut streams = OperationStreams::default();
115        streams.tensor(&tensor);
116        let out = tensor.client.tensor_uninitialized(shape, dtype);
117
118        let desc = UnaryOpIr {
119            input: tensor.into_ir(),
120            out: out.to_ir_out(),
121        };
122        out.client.register(
123            streams,
124            OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())),
125            ReshapeDimsOps::<B>::new(desc),
126        );
127
128        out
129    }
130
131    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
132        #[derive(new, Debug)]
133        struct SliceOps<B: FusionBackend> {
134            desc: SliceOpIr,
135            _b: PhantomData<B>,
136        }
137
138        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {
139            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
140                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
141
142                let output = B::int_slice(tensor, self.desc.ranges.as_slice());
143
144                handles.register_int_tensor::<B>(&self.desc.out.id, output);
145            }
146        }
147
148        let mut streams = OperationStreams::default();
149        streams.tensor(&tensor);
150
151        let dtype = tensor.dtype;
152        let shape = tensor.shape.clone().slice(slices).unwrap();
153        let out = tensor.client.tensor_uninitialized(shape, dtype);
154
155        let desc = SliceOpIr {
156            tensor: tensor.into_ir(),
157            ranges: slices.to_vec(),
158            out: out.to_ir_out(),
159        };
160        out.client.register(
161            streams,
162            OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())),
163            SliceOps::<B>::new(desc),
164        );
165
166        out
167    }
168
169    fn int_slice_assign(
170        tensor: IntTensor<Self>,
171        ranges: &[burn_tensor::Slice],
172        value: IntTensor<Self>,
173    ) -> IntTensor<Self> {
174        #[derive(new, Debug)]
175        struct SliceAssignOps<B: FusionBackend> {
176            desc: SliceAssignOpIr,
177            _b: PhantomData<B>,
178        }
179
180        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {
181            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
182                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
183                let value = handles.get_int_tensor::<B>(&self.desc.value);
184
185                let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value);
186
187                handles.register_int_tensor::<B>(&self.desc.out.id, output);
188            }
189        }
190
191        let mut streams = OperationStreams::default();
192        streams.tensor(&tensor);
193        streams.tensor(&value);
194
195        let dtype = tensor.dtype;
196        let shape = tensor.shape.clone();
197        let out = tensor.client.tensor_uninitialized(shape, dtype);
198        let desc = SliceAssignOpIr {
199            tensor: tensor.into_ir(),
200            ranges: ranges.to_vec(),
201            value: value.into_ir(),
202            out: out.to_ir_out(),
203        };
204        out.client.register(
205            streams,
206            OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())),
207            SliceAssignOps::<B>::new(desc),
208        );
209
210        out
211    }
212
213    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
214        binary_int_ops!(MatmulOps, B::int_matmul);
215
216        let mut streams = OperationStreams::default();
217        streams.tensor(&lhs);
218        streams.tensor(&rhs);
219        let dtype = lhs.dtype;
220        let shape = Shape::matmul(&lhs.shape, &rhs.shape).unwrap();
221
222        let out = lhs.client.tensor_uninitialized(shape, dtype);
223        let desc = BinaryOpIr {
224            lhs: lhs.into_ir(),
225            rhs: rhs.into_ir(),
226            out: out.to_ir_out(),
227        };
228
229        out.client.register(
230            streams,
231            OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),
232            MatmulOps::<B>::new(desc),
233        );
234
235        out
236    }
237
238    fn int_mask_where(
239        tensor: IntTensor<Self>,
240        mask: BoolTensor<Self>,
241        value: IntTensor<Self>,
242    ) -> IntTensor<Self> {
243        #[derive(new, Debug)]
244        struct MaskWhereOps<B: FusionBackend> {
245            desc: MaskWhereOpIr,
246            _b: PhantomData<B>,
247        }
248
249        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {
250            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
251                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
252                let value = handles.get_int_tensor::<B>(&self.desc.value);
253                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
254
255                let output = B::int_mask_where(tensor, mask, value);
256
257                handles.register_int_tensor::<B>(&self.desc.out.id, output);
258            }
259        }
260
261        let mut streams = OperationStreams::default();
262        streams.tensor(&tensor);
263        streams.tensor(&value);
264        streams.tensor(&mask);
265
266        let dtype = tensor.dtype;
267        let out = tensor
268            .client
269            .tensor_uninitialized(tensor.shape.broadcast(&mask.shape).unwrap(), dtype);
270
271        let desc = MaskWhereOpIr {
272            tensor: tensor.into_ir(),
273            value: value.into_ir(),
274            mask: mask.into_ir(),
275            out: out.to_ir_out(),
276        };
277        out.client.register(
278            streams,
279            OperationIr::NumericInt(dtype, NumericOperationIr::MaskWhere(desc.clone())),
280            MaskWhereOps::<B>::new(desc),
281        );
282
283        out
284    }
285
286    fn int_mask_fill(
287        tensor: IntTensor<Self>,
288        mask: BoolTensor<Self>,
289        value: IntElem<Self>,
290    ) -> IntTensor<Self> {
291        #[derive(new, Debug)]
292        struct MaskFillOps<B: FusionBackend> {
293            desc: MaskFillOpIr,
294            _b: PhantomData<B>,
295        }
296
297        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {
298            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
299                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
300                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);
301
302                let output = B::int_mask_fill(tensor, mask, self.desc.value.elem());
303
304                handles.register_int_tensor::<B>(&self.desc.out.id, output);
305            }
306        }
307
308        let mut streams = OperationStreams::default();
309        streams.tensor(&tensor);
310        streams.tensor(&mask);
311
312        let dtype = tensor.dtype;
313        let shape = tensor.shape.clone();
314        let out = tensor.client.tensor_uninitialized(shape, dtype);
315        let desc = MaskFillOpIr {
316            tensor: tensor.into_ir(),
317            value: ScalarIr::with_dtype(value, &dtype),
318            mask: mask.into_ir(),
319            out: out.to_ir_out(),
320        };
321        out.client.register(
322            streams,
323            OperationIr::NumericInt(dtype, NumericOperationIr::MaskFill(desc.clone())),
324            MaskFillOps::<B>::new(desc),
325        );
326
327        out
328    }
329
330    fn int_gather(
331        dim: usize,
332        tensor: IntTensor<Self>,
333        indices: IntTensor<Self>,
334    ) -> IntTensor<Self> {
335        #[derive(new, Debug)]
336        struct GatherOps<B: FusionBackend> {
337            desc: GatherOpIr,
338            _b: PhantomData<B>,
339        }
340
341        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {
342            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
343                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
344                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
345
346                let output = B::int_gather(self.desc.dim, tensor, indices);
347                handles.register_int_tensor::<B>(&self.desc.out.id, output);
348            }
349        }
350
351        let mut streams = OperationStreams::default();
352        streams.tensor(&tensor);
353        streams.tensor(&indices);
354
355        let dtype = tensor.dtype;
356        let shape = indices.shape.clone();
357        let out = tensor.client.tensor_uninitialized(shape, dtype);
358        let desc = GatherOpIr {
359            tensor: tensor.into_ir(),
360            dim,
361            indices: indices.into_ir(),
362            out: out.to_ir_out(),
363        };
364        out.client.register(
365            streams,
366            OperationIr::NumericInt(dtype, NumericOperationIr::Gather(desc.clone())),
367            GatherOps::<B>::new(desc),
368        );
369
370        out
371    }
372
373    fn int_scatter(
374        dim: usize,
375        tensor: IntTensor<Self>,
376        indices: IntTensor<Self>,
377        value: IntTensor<Self>,
378    ) -> IntTensor<Self> {
379        #[derive(new, Debug)]
380        struct ScatterOps<B: FusionBackend> {
381            desc: ScatterOpIr,
382            _b: PhantomData<B>,
383        }
384
385        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {
386            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
387                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
388                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
389                let value = handles.get_int_tensor::<B>(&self.desc.value);
390
391                let output = B::int_scatter(self.desc.dim, tensor, indices, value);
392
393                handles.register_int_tensor::<B>(&self.desc.out.id, output);
394            }
395        }
396
397        let mut streams = OperationStreams::default();
398        streams.tensor(&tensor);
399        streams.tensor(&indices);
400        streams.tensor(&value);
401
402        let dtype = tensor.dtype;
403        let shape = tensor.shape.clone();
404        let out = tensor.client.tensor_uninitialized(shape, dtype);
405        let desc = ScatterOpIr {
406            tensor: tensor.into_ir(),
407            dim,
408            indices: indices.into_ir(),
409            value: value.into_ir(),
410            out: out.to_ir_out(),
411        };
412        out.client.register(
413            streams,
414            OperationIr::NumericInt(dtype, NumericOperationIr::Scatter(desc.clone())),
415            ScatterOps::<B>::new(desc),
416        );
417
418        out
419    }
420
421    fn int_select(
422        tensor: IntTensor<Self>,
423        dim: usize,
424        indices: IntTensor<Self>,
425    ) -> IntTensor<Self> {
426        #[derive(new, Debug)]
427        struct SelectOps<B: FusionBackend> {
428            desc: SelectOpIr,
429            _b: PhantomData<B>,
430        }
431
432        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {
433            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
434                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
435                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
436
437                let output = B::int_select(tensor, self.desc.dim, indices);
438
439                handles.register_int_tensor::<B>(&self.desc.out.id, output);
440            }
441        }
442
443        let mut streams = OperationStreams::default();
444        streams.tensor(&tensor);
445        streams.tensor(&indices);
446
447        let dtype = tensor.dtype;
448        let mut shape = tensor.shape.clone();
449        shape[dim] = indices.shape[0];
450        let out = tensor.client.tensor_uninitialized(shape, dtype);
451        let desc = SelectOpIr {
452            tensor: tensor.into_ir(),
453            dim,
454            indices: indices.into_ir(),
455            out: out.to_ir_out(),
456        };
457        out.client.register(
458            streams,
459            OperationIr::NumericInt(dtype, NumericOperationIr::Select(desc.clone())),
460            SelectOps::<B>::new(desc),
461        );
462
463        out
464    }
465
466    fn int_select_assign(
467        tensor: IntTensor<Self>,
468        dim: usize,
469        indices: IntTensor<Self>,
470        value: IntTensor<Self>,
471    ) -> IntTensor<Self> {
472        #[derive(new, Debug)]
473        struct SelectAssignOps<B: FusionBackend> {
474            desc: SelectAssignOpIr,
475            _b: PhantomData<B>,
476        }
477
478        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {
479            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
480                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
481                let indices = handles.get_int_tensor::<B>(&self.desc.indices);
482                let value = handles.get_int_tensor::<B>(&self.desc.value);
483
484                let output = B::int_select_assign(tensor, self.desc.dim, indices, value);
485
486                handles.register_int_tensor::<B>(&self.desc.out.id, output);
487            }
488        }
489
490        let mut streams = OperationStreams::default();
491        streams.tensor(&tensor);
492        streams.tensor(&indices);
493        streams.tensor(&value);
494
495        let dtype = tensor.dtype;
496        let shape = tensor.shape.clone();
497        let out = tensor.client.tensor_uninitialized(shape, dtype);
498        let desc = SelectAssignOpIr {
499            tensor: tensor.into_ir(),
500            dim,
501            indices: indices.into_ir(),
502            value: value.into_ir(),
503            out: out.to_ir_out(),
504        };
505        out.client.register(
506            streams,
507            OperationIr::NumericInt(dtype, NumericOperationIr::SelectAssign(desc.clone())),
508            SelectAssignOps::<B>::new(desc),
509        );
510
511        out
512    }
513
514    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {
515        #[derive(new, Debug)]
516        struct CatOps<B: FusionBackend> {
517            desc: CatOpIr,
518            _b: PhantomData<B>,
519        }
520
521        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {
522            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
523                let tensors = self
524                    .desc
525                    .tensors
526                    .iter()
527                    .map(|tensor| handles.get_int_tensor::<B>(tensor))
528                    .collect();
529
530                let output = B::int_cat(tensors, self.desc.dim);
531
532                handles.register_int_tensor::<B>(&self.desc.out.id, output);
533            }
534        }
535
536        let tensor_first = tensors.first().unwrap();
537        let client = tensor_first.client.clone();
538
539        let mut streams = OperationStreams::default();
540        tensors.iter().for_each(|tensor| streams.tensor(tensor));
541
542        // Calculate the output shape
543        let shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap();
544
545        let dtype = tensor_first.dtype;
546        let out = client.tensor_uninitialized(shape, dtype);
547
548        let desc = CatOpIr {
549            tensors: tensors.into_iter().map(|t| t.into_ir()).collect(),
550            dim,
551            out: out.to_ir_out(),
552        };
553        client.register(
554            streams,
555            OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())),
556            CatOps::<B>::new(desc),
557        );
558
559        out
560    }
561
562    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
563        binary_int_cmp_ops!(EqualOps, B::int_equal);
564
565        let mut streams = OperationStreams::default();
566        streams.tensor(&lhs);
567        streams.tensor(&rhs);
568        let out = lhs.client.tensor_uninitialized(
569            lhs.shape.broadcast(&rhs.shape).unwrap(),
570            B::BoolElem::dtype(),
571        );
572
573        let desc = BinaryOpIr {
574            lhs: lhs.into_ir(),
575            rhs: rhs.into_ir(),
576            out: out.to_ir_out(),
577        };
578        out.client.register(
579            streams,
580            OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())),
581            EqualOps::<B>::new(desc),
582        );
583
584        out
585    }
586
587    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
588        scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);
589
590        let mut streams = OperationStreams::default();
591        streams.tensor(&lhs);
592        let out = lhs
593            .client
594            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
595
596        let dtype = lhs.dtype;
597        let desc = ScalarOpIr {
598            lhs: lhs.into_ir(),
599            rhs: ScalarIr::with_dtype(rhs, &dtype),
600            out: out.to_ir_out(),
601        };
602        out.client.register(
603            streams,
604            OperationIr::NumericInt(dtype, NumericOperationIr::EqualElem(desc.clone())),
605            EqualElemOps::<B>::new(desc),
606        );
607
608        out
609    }
610
611    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
612        binary_int_cmp_ops!(GreaterOps, B::int_greater);
613
614        let mut streams = OperationStreams::default();
615        streams.tensor(&lhs);
616        streams.tensor(&rhs);
617        let out = lhs.client.tensor_uninitialized(
618            lhs.shape.broadcast(&rhs.shape).unwrap(),
619            B::BoolElem::dtype(),
620        );
621
622        let dtype = lhs.dtype;
623        let desc = BinaryOpIr {
624            lhs: lhs.into_ir(),
625            rhs: rhs.into_ir(),
626            out: out.to_ir_out(),
627        };
628        out.client.register(
629            streams,
630            OperationIr::NumericInt(dtype, NumericOperationIr::Greater(desc.clone())),
631            GreaterOps::<B>::new(desc),
632        );
633
634        out
635    }
636
637    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
638        scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);
639
640        let mut streams = OperationStreams::default();
641        streams.tensor(&lhs);
642        let out = lhs
643            .client
644            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
645
646        let dtype = lhs.dtype;
647        let desc = ScalarOpIr {
648            lhs: lhs.into_ir(),
649            rhs: ScalarIr::with_dtype(rhs, &dtype),
650            out: out.to_ir_out(),
651        };
652        out.client.register(
653            streams,
654            OperationIr::NumericInt(dtype, NumericOperationIr::GreaterElem(desc.clone())),
655            GreaterElemOps::<B>::new(desc),
656        );
657
658        out
659    }
660
661    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
662        binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal);
663
664        let mut streams = OperationStreams::default();
665        streams.tensor(&lhs);
666        streams.tensor(&rhs);
667        let out = lhs.client.tensor_uninitialized(
668            lhs.shape.broadcast(&rhs.shape).unwrap(),
669            B::BoolElem::dtype(),
670        );
671
672        let dtype = lhs.dtype;
673        let desc = BinaryOpIr {
674            lhs: lhs.into_ir(),
675            rhs: rhs.into_ir(),
676            out: out.to_ir_out(),
677        };
678        out.client.register(
679            streams,
680            OperationIr::NumericInt(dtype, NumericOperationIr::GreaterEqual(desc.clone())),
681            GreaterEqualOps::<B>::new(desc),
682        );
683
684        out
685    }
686
687    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
688        scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);
689
690        let mut streams = OperationStreams::default();
691        streams.tensor(&lhs);
692        let out = lhs
693            .client
694            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
695
696        let dtype = lhs.dtype;
697        let desc = ScalarOpIr {
698            lhs: lhs.into_ir(),
699            rhs: ScalarIr::with_dtype(rhs, &dtype),
700            out: out.to_ir_out(),
701        };
702        out.client.register(
703            streams,
704            OperationIr::NumericInt(dtype, NumericOperationIr::GreaterEqualElem(desc.clone())),
705            GreaterEqualElemOps::<B>::new(desc),
706        );
707
708        out
709    }
710
711    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
712        binary_int_cmp_ops!(LowerOps, B::int_lower);
713
714        let mut streams = OperationStreams::default();
715        streams.tensor(&lhs);
716        streams.tensor(&rhs);
717        let out = lhs.client.tensor_uninitialized(
718            lhs.shape.broadcast(&rhs.shape).unwrap(),
719            B::BoolElem::dtype(),
720        );
721
722        let dtype = lhs.dtype;
723        let desc = BinaryOpIr {
724            lhs: lhs.into_ir(),
725            rhs: rhs.into_ir(),
726            out: out.to_ir_out(),
727        };
728        out.client.register(
729            streams,
730            OperationIr::NumericInt(dtype, NumericOperationIr::Lower(desc.clone())),
731            LowerOps::<B>::new(desc),
732        );
733
734        out
735    }
736
737    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
738        scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);
739
740        let mut streams = OperationStreams::default();
741        streams.tensor(&lhs);
742        let out = lhs
743            .client
744            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
745
746        let dtype = lhs.dtype;
747        let desc = ScalarOpIr {
748            lhs: lhs.into_ir(),
749            rhs: ScalarIr::with_dtype(rhs, &dtype),
750            out: out.to_ir_out(),
751        };
752        out.client.register(
753            streams,
754            OperationIr::NumericInt(dtype, NumericOperationIr::LowerElem(desc.clone())),
755            LowerElemOps::<B>::new(desc),
756        );
757
758        out
759    }
760
761    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
762        binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal);
763
764        let mut streams = OperationStreams::default();
765        streams.tensor(&lhs);
766        streams.tensor(&rhs);
767        let out = lhs.client.tensor_uninitialized(
768            lhs.shape.broadcast(&rhs.shape).unwrap(),
769            B::BoolElem::dtype(),
770        );
771
772        let dtype = lhs.dtype;
773        let desc = BinaryOpIr {
774            lhs: lhs.into_ir(),
775            rhs: rhs.into_ir(),
776            out: out.to_ir_out(),
777        };
778        out.client.register(
779            streams,
780            OperationIr::NumericInt(dtype, NumericOperationIr::LowerEqual(desc.clone())),
781            LowerEqualOps::<B>::new(desc),
782        );
783
784        out
785    }
786
787    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
788        scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);
789
790        let mut streams = OperationStreams::default();
791        streams.tensor(&lhs);
792        let out = lhs
793            .client
794            .tensor_uninitialized(lhs.shape.clone(), B::BoolElem::dtype());
795
796        let dtype = lhs.dtype;
797        let desc = ScalarOpIr {
798            lhs: lhs.into_ir(),
799            rhs: ScalarIr::with_dtype(rhs, &dtype),
800            out: out.to_ir_out(),
801        };
802        out.client.register(
803            streams,
804            OperationIr::NumericInt(dtype, NumericOperationIr::LowerEqualElem(desc.clone())),
805            LowerEqualElemOps::<B>::new(desc),
806        );
807
808        out
809    }
810
811    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
812        binary_int_ops!(AddOps, B::int_add);
813
814        let dtype = lhs.dtype;
815        let mut streams = OperationStreams::default();
816        streams.tensor(&lhs);
817        streams.tensor(&rhs);
818        let out = lhs
819            .client
820            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
821
822        let desc = BinaryOpIr {
823            lhs: lhs.into_ir(),
824            rhs: rhs.into_ir(),
825            out: out.to_ir_out(),
826        };
827        out.client.register(
828            streams,
829            OperationIr::NumericInt(dtype, NumericOperationIr::Add(desc.clone())),
830            AddOps::<B>::new(desc),
831        );
832
833        out
834    }
835
836    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
837        scalar_int_ops!(AddOps, B::int_add_scalar);
838
839        let dtype = lhs.dtype;
840        let mut streams = OperationStreams::default();
841        streams.tensor(&lhs);
842        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
843
844        let desc = ScalarOpIr {
845            lhs: lhs.into_ir(),
846            rhs: ScalarIr::with_dtype(rhs, &dtype),
847            out: out.to_ir_out(),
848        };
849        out.client.register(
850            streams,
851            OperationIr::NumericInt(dtype, NumericOperationIr::AddScalar(desc.clone())),
852            AddOps::<B>::new(desc),
853        );
854
855        out
856    }
857
858    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
859        binary_int_ops!(SubOps, B::int_sub);
860
861        let dtype = lhs.dtype;
862        let mut streams = OperationStreams::default();
863        streams.tensor(&lhs);
864        streams.tensor(&rhs);
865        let out = lhs
866            .client
867            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
868
869        let desc = BinaryOpIr {
870            lhs: lhs.into_ir(),
871            rhs: rhs.into_ir(),
872            out: out.to_ir_out(),
873        };
874        out.client.register(
875            streams,
876            OperationIr::NumericInt(dtype, NumericOperationIr::Sub(desc.clone())),
877            SubOps::<B>::new(desc),
878        );
879
880        out
881    }
882
883    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
884        scalar_int_ops!(SubOps, B::int_sub_scalar);
885
886        let dtype = lhs.dtype;
887        let mut streams = OperationStreams::default();
888        streams.tensor(&lhs);
889        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
890
891        let desc = ScalarOpIr {
892            lhs: lhs.into_ir(),
893            rhs: ScalarIr::with_dtype(rhs, &dtype),
894            out: out.to_ir_out(),
895        };
896        out.client.register(
897            streams,
898            OperationIr::NumericInt(dtype, NumericOperationIr::SubScalar(desc.clone())),
899            SubOps::<B>::new(desc),
900        );
901
902        out
903    }
904
905    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
906        binary_int_ops!(MulOps, B::int_mul);
907
908        let dtype = lhs.dtype;
909        let mut streams = OperationStreams::default();
910        streams.tensor(&lhs);
911        streams.tensor(&rhs);
912        let out = lhs
913            .client
914            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
915
916        let desc = BinaryOpIr {
917            lhs: lhs.into_ir(),
918            rhs: rhs.into_ir(),
919            out: out.to_ir_out(),
920        };
921        out.client.register(
922            streams,
923            OperationIr::NumericInt(dtype, NumericOperationIr::Mul(desc.clone())),
924            MulOps::<B>::new(desc),
925        );
926
927        out
928    }
929
930    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
931        scalar_int_ops!(MulOps, B::int_mul_scalar);
932
933        let dtype = lhs.dtype;
934        let mut streams = OperationStreams::default();
935        streams.tensor(&lhs);
936        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
937
938        let desc = ScalarOpIr {
939            lhs: lhs.into_ir(),
940            rhs: ScalarIr::with_dtype(rhs, &dtype),
941            out: out.to_ir_out(),
942        };
943        out.client.register(
944            streams,
945            OperationIr::NumericInt(dtype, NumericOperationIr::MulScalar(desc.clone())),
946            MulOps::<B>::new(desc),
947        );
948
949        out
950    }
951
952    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
953        binary_int_ops!(DivOps, B::int_div);
954
955        let dtype = lhs.dtype;
956        let mut streams = OperationStreams::default();
957        streams.tensor(&lhs);
958        streams.tensor(&rhs);
959        let out = lhs
960            .client
961            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
962
963        let desc = BinaryOpIr {
964            lhs: lhs.into_ir(),
965            rhs: rhs.into_ir(),
966            out: out.to_ir_out(),
967        };
968        out.client.register(
969            streams,
970            OperationIr::NumericInt(dtype, NumericOperationIr::Div(desc.clone())),
971            DivOps::<B>::new(desc),
972        );
973
974        out
975    }
976
977    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
978        scalar_int_ops!(DivOps, B::int_div_scalar);
979
980        let dtype = lhs.dtype;
981        let mut streams = OperationStreams::default();
982        streams.tensor(&lhs);
983        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
984
985        let desc = ScalarOpIr {
986            lhs: lhs.into_ir(),
987            rhs: ScalarIr::with_dtype(rhs, &dtype),
988            out: out.to_ir_out(),
989        };
990        out.client.register(
991            streams,
992            OperationIr::NumericInt(dtype, NumericOperationIr::DivScalar(desc.clone())),
993            DivOps::<B>::new(desc),
994        );
995
996        out
997    }
998
999    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
1000        binary_int_ops!(ModOps, B::int_remainder);
1001
1002        let dtype = lhs.dtype;
1003        let mut streams = OperationStreams::default();
1004        streams.tensor(&lhs);
1005        streams.tensor(&rhs);
1006        let out = lhs
1007            .client
1008            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
1009
1010        let desc = BinaryOpIr {
1011            lhs: lhs.into_ir(),
1012            rhs: rhs.into_ir(),
1013            out: out.to_ir_out(),
1014        };
1015        out.client.register(
1016            streams,
1017            OperationIr::NumericInt(dtype, NumericOperationIr::Rem(desc.clone())),
1018            ModOps::<B>::new(desc),
1019        );
1020
1021        out
1022    }
1023
1024    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
1025        scalar_int_ops!(ModOps, B::int_remainder_scalar);
1026
1027        let dtype = lhs.dtype;
1028        let mut streams = OperationStreams::default();
1029        streams.tensor(&lhs);
1030        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
1031
1032        let desc = ScalarOpIr {
1033            lhs: lhs.into_ir(),
1034            rhs: ScalarIr::with_dtype(rhs, &dtype),
1035            out: out.to_ir_out(),
1036        };
1037        out.client.register(
1038            streams,
1039            OperationIr::NumericInt(dtype, NumericOperationIr::RemScalar(desc.clone())),
1040            ModOps::<B>::new(desc),
1041        );
1042
1043        out
1044    }
1045
1046    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1047        #[derive(new, Debug)]
1048        struct ZerosOps<B: FusionBackend> {
1049            desc: TensorIr,
1050            device: Device<B>,
1051        }
1052
1053        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {
1054            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1055                let shape = self.desc.shape.clone();
1056                let output = B::int_zeros(shape, &self.device, self.desc.dtype.into());
1057                handles.register_int_tensor::<B>(&self.desc.id, output);
1058            }
1059        }
1060
1061        let dtype = dtype.into();
1062        let client = get_client::<B>(&device.clone());
1063        let out = client.tensor_uninitialized(shape, dtype);
1064        let desc = out.to_ir_out();
1065        client.register(
1066            OperationStreams::default(),
1067            OperationIr::NumericInt(dtype, NumericOperationIr::Zeros(desc.clone())),
1068            ZerosOps::<B>::new(desc, device.clone()),
1069        );
1070
1071        out
1072    }
1073
1074    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
1075        #[derive(new, Debug)]
1076        struct OnesOps<B: FusionBackend> {
1077            desc: TensorIr,
1078            device: Device<B>,
1079        }
1080
1081        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {
1082            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1083                let shape = self.desc.shape.clone();
1084                let output = B::int_ones(shape, &self.device, self.desc.dtype.into());
1085                handles.register_int_tensor::<B>(&self.desc.id, output);
1086            }
1087        }
1088
1089        let dtype = dtype.into();
1090        let client = get_client::<B>(&device.clone());
1091        let out = client.tensor_uninitialized(shape, dtype);
1092
1093        let desc = out.to_ir_out();
1094        client.register(
1095            OperationStreams::default(),
1096            OperationIr::NumericInt(dtype, NumericOperationIr::Ones(desc.clone())),
1097            OnesOps::<B>::new(desc, device.clone()),
1098        );
1099
1100        out
1101    }
1102
1103    fn int_full(
1104        shape: Shape,
1105        fill_value: IntElem<Self>,
1106        device: &Device<Self>,
1107        dtype: IntDType,
1108    ) -> IntTensor<Self> {
1109        #[derive(new, Debug)]
1110        struct FullOps<B: FusionBackend> {
1111            out: TensorIr,
1112            elem: ScalarIr,
1113            device: Device<B>,
1114        }
1115
1116        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {
1117            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1118                let shape = self.out.shape.clone();
1119                let output =
1120                    B::int_full(shape, self.elem.elem(), &self.device, self.out.dtype.into());
1121                handles.register_int_tensor::<B>(&self.out.id, output);
1122            }
1123        }
1124
1125        let dtype = dtype.into();
1126        let client = get_client::<B>(&device.clone());
1127        let out = client.tensor_uninitialized(shape, dtype);
1128
1129        let desc = (out.to_ir_out(), ScalarIr::with_dtype(fill_value, &dtype));
1130        client.register(
1131            OperationStreams::default(),
1132            OperationIr::NumericInt(dtype, NumericOperationIr::Full(desc.clone())),
1133            FullOps::<B>::new(desc.0, desc.1, device.clone()),
1134        );
1135
1136        out
1137    }
1138
1139    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
1140        unary_int_ops!(SumOps, B::int_sum, reduce);
1141
1142        let dtype = tensor.dtype;
1143        let mut streams = OperationStreams::default();
1144        streams.tensor(&tensor);
1145        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1146
1147        let desc = UnaryOpIr {
1148            input: tensor.into_ir(),
1149            out: out.to_ir_out(),
1150        };
1151        out.client.register(
1152            streams,
1153            OperationIr::NumericInt(dtype, NumericOperationIr::Sum(desc.clone())),
1154            SumOps::<B>::new(desc),
1155        );
1156
1157        out
1158    }
1159
1160    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {
1161        reduce_int_ops!(SumDimOps, B::int_sum_dim);
1162
1163        let dtype = tensor.dtype;
1164        let mut streams = OperationStreams::default();
1165        streams.tensor(&tensor);
1166        let mut shape = tensor.shape.clone();
1167        shape[axis] = 1;
1168        let out = tensor.client.tensor_uninitialized(shape, dtype);
1169
1170        let desc = ReduceDimOpIr {
1171            out: out.to_ir_out(),
1172            input: tensor.into_ir(),
1173            axis,
1174        };
1175        out.client.register(
1176            streams,
1177            OperationIr::NumericInt(dtype, NumericOperationIr::SumDim(desc.clone())),
1178            SumDimOps::<B>::new(desc),
1179        );
1180
1181        out
1182    }
1183
1184    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
1185        unary_int_ops!(ProdOps, B::int_prod, reduce);
1186
1187        let dtype = tensor.dtype;
1188        let mut streams = OperationStreams::default();
1189        streams.tensor(&tensor);
1190        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1191
1192        let desc = UnaryOpIr {
1193            input: tensor.into_ir(),
1194            out: out.to_ir_out(),
1195        };
1196        out.client.register(
1197            streams,
1198            OperationIr::NumericInt(dtype, NumericOperationIr::Prod(desc.clone())),
1199            ProdOps::<B>::new(desc),
1200        );
1201
1202        out
1203    }
1204
1205    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1206        reduce_int_ops!(ProdDimOps, B::int_prod_dim);
1207
1208        let dtype = tensor.dtype;
1209        let mut streams = OperationStreams::default();
1210        streams.tensor(&tensor);
1211        let mut shape = tensor.shape.clone();
1212        shape[dim] = 1;
1213        let out = tensor.client.tensor_uninitialized(shape, dtype);
1214
1215        let desc = ReduceDimOpIr {
1216            input: tensor.into_ir(),
1217            axis: dim,
1218            out: out.to_ir_out(),
1219        };
1220        out.client.register(
1221            streams,
1222            OperationIr::NumericInt(dtype, NumericOperationIr::ProdDim(desc.clone())),
1223            ProdDimOps::<B>::new(desc),
1224        );
1225
1226        out
1227    }
1228
1229    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {
1230        unary_int_ops!(MeanOps, B::int_mean, reduce);
1231
1232        let dtype = tensor.dtype;
1233        let mut streams = OperationStreams::default();
1234        streams.tensor(&tensor);
1235        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1236
1237        let desc = UnaryOpIr {
1238            input: tensor.into_ir(),
1239            out: out.to_ir_out(),
1240        };
1241        out.client.register(
1242            streams,
1243            OperationIr::NumericInt(dtype, NumericOperationIr::Mean(desc.clone())),
1244            MeanOps::<B>::new(desc),
1245        );
1246
1247        out
1248    }
1249
1250    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1251        reduce_int_ops!(MeanDimOps, B::int_mean_dim);
1252
1253        let dtype = tensor.dtype;
1254        let mut streams = OperationStreams::default();
1255        streams.tensor(&tensor);
1256        let mut shape = tensor.shape.clone();
1257        shape[dim] = 1;
1258        let out = tensor.client.tensor_uninitialized(shape, dtype);
1259
1260        let desc = ReduceDimOpIr {
1261            input: tensor.into_ir(),
1262            axis: dim,
1263            out: out.to_ir_out(),
1264        };
1265        out.client.register(
1266            streams,
1267            OperationIr::NumericInt(dtype, NumericOperationIr::MeanDim(desc.clone())),
1268            MeanDimOps::<B>::new(desc),
1269        );
1270
1271        out
1272    }
1273
1274    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1275        #[derive(new, Debug)]
1276        struct CumsumOps<B: FusionBackend> {
1277            desc: DimOpIr,
1278            _b: PhantomData<B>,
1279        }
1280
1281        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {
1282            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1283                let input = handles.get_int_tensor::<B>(&self.desc.input);
1284                let output = B::int_cumsum(input, self.desc.axis);
1285                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1286            }
1287        }
1288
1289        let dtype = tensor.dtype;
1290        let mut streams = OperationStreams::default();
1291        streams.tensor(&tensor);
1292        let shape = tensor.shape.clone();
1293        let out = tensor.client.tensor_uninitialized(shape, dtype);
1294
1295        let desc = DimOpIr {
1296            out: out.to_ir_out(),
1297            input: tensor.into_ir(),
1298            axis: dim,
1299        };
1300        out.client.register(
1301            streams,
1302            OperationIr::BaseInt(BaseOperationIr::CumSum(desc.clone())),
1303            CumsumOps::<B>::new(desc),
1304        );
1305
1306        out
1307    }
1308
1309    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1310        #[derive(new, Debug)]
1311        struct CumprodOps<B: FusionBackend> {
1312            desc: DimOpIr,
1313            _b: PhantomData<B>,
1314        }
1315
1316        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {
1317            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1318                let input = handles.get_int_tensor::<B>(&self.desc.input);
1319                let output = B::int_cumprod(input, self.desc.axis);
1320                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1321            }
1322        }
1323
1324        let dtype = tensor.dtype;
1325        let mut streams = OperationStreams::default();
1326        streams.tensor(&tensor);
1327        let shape = tensor.shape.clone();
1328        let out = tensor.client.tensor_uninitialized(shape, dtype);
1329
1330        let desc = DimOpIr {
1331            out: out.to_ir_out(),
1332            input: tensor.into_ir(),
1333            axis: dim,
1334        };
1335        out.client.register(
1336            streams,
1337            OperationIr::BaseInt(BaseOperationIr::CumProd(desc.clone())),
1338            CumprodOps::<B>::new(desc),
1339        );
1340
1341        out
1342    }
1343
1344    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1345        #[derive(new, Debug)]
1346        struct CumminOps<B: FusionBackend> {
1347            desc: DimOpIr,
1348            _b: PhantomData<B>,
1349        }
1350
1351        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
1352            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1353                let input = handles.get_int_tensor::<B>(&self.desc.input);
1354                let output = B::int_cummin(input, self.desc.axis);
1355                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1356            }
1357        }
1358
1359        let dtype = tensor.dtype;
1360        let mut streams = OperationStreams::default();
1361        streams.tensor(&tensor);
1362        let shape = tensor.shape.clone();
1363        let out = tensor.client.tensor_uninitialized(shape, dtype);
1364
1365        let desc = DimOpIr {
1366            out: out.to_ir_out(),
1367            input: tensor.into_ir(),
1368            axis: dim,
1369        };
1370        out.client.register(
1371            streams,
1372            OperationIr::BaseInt(BaseOperationIr::CumMin(desc.clone())),
1373            CumminOps::<B>::new(desc),
1374        );
1375
1376        out
1377    }
1378
1379    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1380        #[derive(new, Debug)]
1381        struct CummaxOps<B: FusionBackend> {
1382            desc: DimOpIr,
1383            _b: PhantomData<B>,
1384        }
1385
1386        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {
1387            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1388                let input = handles.get_int_tensor::<B>(&self.desc.input);
1389                let output = B::int_cummax(input, self.desc.axis);
1390                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1391            }
1392        }
1393
1394        let dtype = tensor.dtype;
1395        let mut streams = OperationStreams::default();
1396        streams.tensor(&tensor);
1397        let shape = tensor.shape.clone();
1398        let out = tensor.client.tensor_uninitialized(shape, dtype);
1399
1400        let desc = DimOpIr {
1401            out: out.to_ir_out(),
1402            input: tensor.into_ir(),
1403            axis: dim,
1404        };
1405        out.client.register(
1406            streams,
1407            OperationIr::BaseInt(BaseOperationIr::CumMax(desc.clone())),
1408            CummaxOps::<B>::new(desc),
1409        );
1410
1411        out
1412    }
1413
1414    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1415        reduce_int_ops!(ArgMaxOps, B::int_argmax);
1416
1417        let dtype = tensor.dtype;
1418        let mut streams = OperationStreams::default();
1419        streams.tensor(&tensor);
1420        let mut shape = tensor.shape.clone();
1421        shape[dim] = 1;
1422        let out = tensor.client.tensor_uninitialized(shape, dtype);
1423
1424        let desc = ReduceDimOpIr {
1425            input: tensor.into_ir(),
1426            axis: dim,
1427            out: out.to_ir_out(),
1428        };
1429        out.client.register(
1430            streams,
1431            OperationIr::NumericInt(dtype, NumericOperationIr::ArgMax(desc.clone())),
1432            ArgMaxOps::<B>::new(desc),
1433        );
1434
1435        out
1436    }
1437
1438    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1439        reduce_int_ops!(ArgMinOps, B::int_argmin);
1440
1441        let dtype = tensor.dtype;
1442        let mut streams = OperationStreams::default();
1443        streams.tensor(&tensor);
1444        let mut shape = tensor.shape.clone();
1445        shape[dim] = 1;
1446        let out = tensor.client.tensor_uninitialized(shape, dtype);
1447
1448        let desc = ReduceDimOpIr {
1449            input: tensor.into_ir(),
1450            axis: dim,
1451            out: out.to_ir_out(),
1452        };
1453        out.client.register(
1454            streams,
1455            OperationIr::NumericInt(dtype, NumericOperationIr::ArgMin(desc.clone())),
1456            ArgMinOps::<B>::new(desc),
1457        );
1458
1459        out
1460    }
1461
1462    fn int_clamp(
1463        tensor: IntTensor<Self>,
1464        min: IntElem<Self>,
1465        max: IntElem<Self>,
1466    ) -> IntTensor<Self> {
1467        #[derive(new, Debug)]
1468        struct ClampOps<B: FusionBackend> {
1469            desc: ClampOpIr,
1470            _b: PhantomData<B>,
1471        }
1472
1473        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {
1474            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1475                let input = handles.get_int_tensor::<B>(&self.desc.tensor);
1476                let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem());
1477
1478                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1479            }
1480        }
1481
1482        let dtype = tensor.dtype;
1483        let mut streams = OperationStreams::default();
1484        streams.tensor(&tensor);
1485        let out = tensor
1486            .client
1487            .tensor_uninitialized(tensor.shape.clone(), dtype);
1488        let desc = ClampOpIr {
1489            tensor: tensor.into_ir(),
1490            min: ScalarIr::with_dtype(min, &dtype),
1491            max: ScalarIr::with_dtype(max, &dtype),
1492            out: out.to_ir_out(),
1493        };
1494        out.client.register(
1495            streams,
1496            OperationIr::NumericInt(dtype, NumericOperationIr::Clamp(desc.clone())),
1497            ClampOps::<B>::new(desc),
1498        );
1499
1500        out
1501    }
1502
1503    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1504        unary_int_ops!(AbsOps, B::int_abs);
1505
1506        let dtype = tensor.dtype;
1507        let mut streams = OperationStreams::default();
1508        streams.tensor(&tensor);
1509        let out = tensor
1510            .client
1511            .tensor_uninitialized(tensor.shape.clone(), dtype);
1512
1513        let desc = UnaryOpIr {
1514            input: tensor.into_ir(),
1515            out: out.to_ir_out(),
1516        };
1517        out.client.register(
1518            streams,
1519            OperationIr::NumericInt(dtype, NumericOperationIr::Abs(desc.clone())),
1520            AbsOps::<B>::new(desc),
1521        );
1522
1523        out
1524    }
1525
1526    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
1527        #[derive(new, Debug)]
1528        struct IntoFloatOps<B: FusionBackend> {
1529            desc: UnaryOpIr,
1530            _b: PhantomData<B>,
1531        }
1532
1533        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {
1534            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1535                let input = handles.get_int_tensor::<B>(&self.desc.input);
1536                let output = B::int_into_float(input);
1537                handles.register_float_tensor::<B>(&self.desc.out.id, output);
1538            }
1539        }
1540
1541        let mut streams = OperationStreams::default();
1542        streams.tensor(&tensor);
1543        let out = tensor
1544            .client
1545            .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
1546        let desc = UnaryOpIr {
1547            input: tensor.into_ir(),
1548            out: out.to_ir_out(),
1549        };
1550        out.client.register(
1551            streams,
1552            OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())),
1553            IntoFloatOps::<B>::new(desc),
1554        );
1555
1556        out
1557    }
1558
1559    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
1560        #[derive(new, Debug)]
1561        struct SwapDimsOps<B: FusionBackend> {
1562            desc: SwapDimsOpIr,
1563            _b: PhantomData<B>,
1564        }
1565
1566        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {
1567            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1568                let input = handles.get_int_tensor::<B>(&self.desc.input);
1569                let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);
1570                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1571            }
1572        }
1573
1574        let mut streams = OperationStreams::default();
1575        streams.tensor(&tensor);
1576        let shape = tensor.shape.clone().swap(dim1, dim2).unwrap();
1577        let dtype = tensor.dtype;
1578        let out = tensor.client.tensor_uninitialized(shape, dtype);
1579
1580        let desc = SwapDimsOpIr {
1581            input: tensor.into_ir(),
1582            dim1,
1583            dim2,
1584            out: out.to_ir_out(),
1585        };
1586        out.client.register(
1587            streams,
1588            OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())),
1589            SwapDimsOps::<B>::new(desc),
1590        );
1591
1592        out
1593    }
1594
1595    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
1596        unary_int_ops!(MaxOps, B::int_max, reduce);
1597
1598        let dtype = tensor.dtype;
1599        let mut streams = OperationStreams::default();
1600        streams.tensor(&tensor);
1601        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1602
1603        let desc = UnaryOpIr {
1604            input: tensor.into_ir(),
1605            out: out.to_ir_out(),
1606        };
1607        out.client.register(
1608            streams,
1609            OperationIr::NumericInt(dtype, NumericOperationIr::Max(desc.clone())),
1610            MaxOps::<B>::new(desc),
1611        );
1612
1613        out
1614    }
1615
1616    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1617        reduce_int_ops!(MaxDimOps, B::int_max_dim);
1618
1619        let dtype = tensor.dtype;
1620        let mut streams = OperationStreams::default();
1621        streams.tensor(&tensor);
1622        let mut shape = tensor.shape.clone();
1623        shape[dim] = 1;
1624        let out = tensor.client.tensor_uninitialized(shape, dtype);
1625
1626        let desc = ReduceDimOpIr {
1627            input: tensor.into_ir(),
1628            axis: dim,
1629            out: out.to_ir_out(),
1630        };
1631        out.client.register(
1632            streams,
1633            OperationIr::NumericInt(dtype, NumericOperationIr::MaxDim(desc.clone())),
1634            MaxDimOps::<B>::new(desc),
1635        );
1636
1637        out
1638    }
1639
1640    fn int_max_dim_with_indices(
1641        tensor: IntTensor<Self>,
1642        dim: usize,
1643    ) -> (IntTensor<Self>, IntTensor<Self>) {
1644        #[derive(new, Debug)]
1645        struct MaxDimWithIndicesOps<B: FusionBackend> {
1646            desc: ReduceDimWithIndicesOpIr,
1647            _b: PhantomData<B>,
1648        }
1649
1650        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {
1651            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1652                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1653                let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);
1654
1655                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1656                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1657            }
1658        }
1659
1660        let dtype = tensor.dtype;
1661        let mut streams = OperationStreams::default();
1662        streams.tensor(&tensor);
1663        let mut shape = tensor.shape.clone();
1664        shape[dim] = 1;
1665        let client = tensor.client.clone();
1666        let out = client.tensor_uninitialized(shape.clone(), dtype);
1667        let out_indices = client.tensor_uninitialized(shape, dtype);
1668        let desc = ReduceDimWithIndicesOpIr {
1669            tensor: tensor.into_ir(),
1670            dim,
1671            out: out.to_ir_out(),
1672            out_indices: out_indices.to_ir_out(),
1673        };
1674        client.register(
1675            streams,
1676            OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),
1677            MaxDimWithIndicesOps::<B>::new(desc),
1678        );
1679
1680        (out, out_indices)
1681    }
1682
1683    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
1684        unary_int_ops!(MinOps, B::int_min, reduce);
1685
1686        let dtype = tensor.dtype;
1687        let mut streams = OperationStreams::default();
1688        streams.tensor(&tensor);
1689        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1690
1691        let desc = UnaryOpIr {
1692            input: tensor.into_ir(),
1693            out: out.to_ir_out(),
1694        };
1695        out.client.register(
1696            streams,
1697            OperationIr::NumericInt(dtype, NumericOperationIr::Min(desc.clone())),
1698            MinOps::<B>::new(desc),
1699        );
1700
1701        out
1702    }
1703
1704    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
1705        unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce);
1706
1707        let dtype = tensor.dtype;
1708        let mut streams = OperationStreams::default();
1709        streams.tensor(&tensor);
1710        let out = tensor.client.tensor_uninitialized(Shape::new([1]), dtype);
1711
1712        let desc = UnaryOpIr {
1713            input: tensor.into_ir(),
1714            out: out.to_ir_out(),
1715        };
1716        out.client.register(
1717            streams,
1718            OperationIr::NumericInt(dtype, NumericOperationIr::MaxAbs(desc.clone())),
1719            MaxAbsOps::<B>::new(desc),
1720        );
1721
1722        out
1723    }
1724
1725    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1726        reduce_int_ops!(MaxAbsDimOps, B::int_max_abs_dim);
1727
1728        let dtype = tensor.dtype;
1729        let mut streams = OperationStreams::default();
1730        streams.tensor(&tensor);
1731        let mut shape = tensor.shape.clone();
1732        shape[dim] = 1;
1733        let out = tensor.client.tensor_uninitialized(shape, dtype);
1734
1735        let desc = ReduceDimOpIr {
1736            input: tensor.into_ir(),
1737            axis: dim,
1738            out: out.to_ir_out(),
1739        };
1740        out.client.register(
1741            streams,
1742            OperationIr::NumericInt(dtype, NumericOperationIr::MaxAbsDim(desc.clone())),
1743            MaxAbsDimOps::<B>::new(desc),
1744        );
1745
1746        out
1747    }
1748
1749    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1750        reduce_int_ops!(MinDimOps, B::int_min_dim);
1751
1752        let dtype = tensor.dtype;
1753        let mut streams = OperationStreams::default();
1754        streams.tensor(&tensor);
1755        let mut shape = tensor.shape.clone();
1756        shape[dim] = 1;
1757        let out = tensor.client.tensor_uninitialized(shape, dtype);
1758
1759        let desc = ReduceDimOpIr {
1760            input: tensor.into_ir(),
1761            axis: dim,
1762            out: out.to_ir_out(),
1763        };
1764
1765        out.client.register(
1766            streams,
1767            OperationIr::NumericInt(dtype, NumericOperationIr::MinDim(desc.clone())),
1768            MinDimOps::<B>::new(desc),
1769        );
1770
1771        out
1772    }
1773
1774    fn int_min_dim_with_indices(
1775        tensor: IntTensor<Self>,
1776        dim: usize,
1777    ) -> (IntTensor<Self>, IntTensor<Self>) {
1778        #[derive(new, Debug)]
1779        struct MinDimWithIndicesOps<B: FusionBackend> {
1780            desc: ReduceDimWithIndicesOpIr,
1781            _b: PhantomData<B>,
1782        }
1783
1784        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {
1785            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1786                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1787                let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);
1788
1789                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1790                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);
1791            }
1792        }
1793
1794        let dtype = tensor.dtype;
1795        let mut streams = OperationStreams::default();
1796        streams.tensor(&tensor);
1797        let mut shape = tensor.shape.clone();
1798        shape[dim] = 1;
1799        let client = tensor.client.clone();
1800        let out = client.tensor_uninitialized(shape.clone(), dtype);
1801        let out_indices = client.tensor_uninitialized(shape, dtype);
1802        let desc = ReduceDimWithIndicesOpIr {
1803            tensor: tensor.into_ir(),
1804            dim,
1805            out: out.to_ir_out(),
1806            out_indices: out_indices.to_ir_out(),
1807        };
1808        client.register(
1809            streams,
1810            OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),
1811            MinDimWithIndicesOps::<B>::new(desc),
1812        );
1813
1814        (out, out_indices)
1815    }
1816
1817    fn int_random(
1818        shape: Shape,
1819        distribution: Distribution,
1820        device: &Device<Self>,
1821    ) -> IntTensor<Self> {
1822        #[derive(new, Debug)]
1823        struct IntRandomOps<B: FusionBackend> {
1824            desc: RandomOpIr,
1825            device: Device<B>,
1826        }
1827
1828        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntRandomOps<B> {
1829            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1830                let shape = self.desc.out.shape.clone();
1831                let output = B::int_random(shape, self.desc.distribution, &self.device);
1832                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1833            }
1834        }
1835
1836        let client = get_client::<B>(&device.clone());
1837        let out = client.tensor_uninitialized(shape, B::IntElem::dtype());
1838
1839        let desc = RandomOpIr {
1840            out: out.to_ir_out(),
1841            distribution,
1842        };
1843        client.register(
1844            OperationStreams::default(),
1845            OperationIr::NumericInt(
1846                IntElem::<Self>::dtype(),
1847                NumericOperationIr::IntRandom(desc.clone()),
1848            ),
1849            IntRandomOps::<B>::new(desc, device.clone()),
1850        );
1851
1852        out
1853    }
1854
1855    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1856        #[derive(new, Debug)]
1857        struct PermuteDimsOps<B: FusionBackend> {
1858            desc: PermuteOpIr,
1859            _b: PhantomData<B>,
1860        }
1861
1862        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {
1863            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1864                let input = handles.get_int_tensor::<B>(&self.desc.input);
1865                let output = B::int_permute(input, self.desc.axes.as_slice());
1866                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1867            }
1868        }
1869
1870        let mut streams = OperationStreams::default();
1871        streams.tensor(&tensor);
1872
1873        // Change the shape of the tensor to match the new axes
1874        let shape = tensor.shape.clone().permute(axes).unwrap();
1875        let dtype = tensor.dtype;
1876        let out = tensor.client.tensor_uninitialized(shape, dtype);
1877
1878        let desc = PermuteOpIr {
1879            input: tensor.into_ir(),
1880            axes: axes.to_vec(),
1881            out: out.to_ir_out(),
1882        };
1883
1884        out.client.register(
1885            streams,
1886            OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),
1887            PermuteDimsOps::<B>::new(desc),
1888        );
1889
1890        out
1891    }
1892
1893    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
1894        #[derive(new, Debug)]
1895        struct ExpandOps<B: FusionBackend> {
1896            desc: ExpandOpIr,
1897            _b: PhantomData<B>,
1898        }
1899
1900        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {
1901            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1902                let input = handles.get_int_tensor::<B>(&self.desc.input);
1903                let output = B::int_expand(input, self.desc.shape.clone());
1904                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1905            }
1906        }
1907
1908        let mut streams = OperationStreams::default();
1909        streams.tensor(&tensor);
1910
1911        let dtype = tensor.dtype;
1912        let out = tensor.client.tensor_uninitialized(shape.clone(), dtype);
1913
1914        let desc = ExpandOpIr {
1915            input: tensor.into_ir(),
1916            shape,
1917            out: out.to_ir_out(),
1918        };
1919
1920        out.client.register(
1921            streams,
1922            OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())),
1923            ExpandOps::<B>::new(desc),
1924        );
1925
1926        out
1927    }
1928
1929    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
1930        #[derive(new, Debug)]
1931        struct FlipDimsOps<B: FusionBackend> {
1932            desc: FlipOpIr,
1933            _b: PhantomData<B>,
1934        }
1935
1936        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipDimsOps<B> {
1937            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1938                let input = handles.get_int_tensor::<B>(&self.desc.input);
1939                let axes = &self.desc.axes;
1940                let output = B::int_flip(input, axes);
1941                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1942            }
1943        }
1944
1945        let mut streams = OperationStreams::default();
1946        streams.tensor(&tensor);
1947
1948        let dtype = tensor.dtype;
1949        let out = tensor
1950            .client
1951            .tensor_uninitialized(tensor.shape.clone(), dtype);
1952
1953        let desc = FlipOpIr {
1954            input: tensor.into_ir(),
1955            axes: axes.to_vec(),
1956            out: out.to_ir_out(),
1957        };
1958
1959        out.client.register(
1960            streams,
1961            OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),
1962            FlipDimsOps::<B>::new(desc),
1963        );
1964
1965        out
1966    }
1967
1968    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
1969        #[derive(new, Debug)]
1970        struct RepeatDimOps<B: FusionBackend> {
1971            desc: RepeatDimOpIr,
1972            _b: PhantomData<B>,
1973        }
1974
1975        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {
1976            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
1977                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);
1978
1979                let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times);
1980
1981                handles.register_int_tensor::<B>(&self.desc.out.id, output);
1982            }
1983        }
1984
1985        let dtype = tensor.dtype;
1986        let mut streams = OperationStreams::default();
1987        streams.tensor(&tensor);
1988        let shape = tensor.shape.clone().repeat(dim, times);
1989        let out = tensor.client.tensor_uninitialized(shape, dtype);
1990
1991        let desc = RepeatDimOpIr {
1992            tensor: tensor.into_ir(),
1993            dim,
1994            times,
1995            out: out.to_ir_out(),
1996        };
1997        out.client.register(
1998            streams,
1999            OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())),
2000            RepeatDimOps::<B>::new(desc),
2001        );
2002
2003        out
2004    }
2005
2006    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2007        binary_int_ops!(BitwiseAndOps, B::bitwise_and);
2008
2009        let dtype = lhs.dtype;
2010        let mut streams = OperationStreams::default();
2011        streams.tensor(&lhs);
2012        streams.tensor(&rhs);
2013        let out = lhs
2014            .client
2015            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2016
2017        let desc = BinaryOpIr {
2018            lhs: lhs.into_ir(),
2019            rhs: rhs.into_ir(),
2020            out: out.to_ir_out(),
2021        };
2022        out.client.register(
2023            streams,
2024            OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())),
2025            BitwiseAndOps::<B>::new(desc),
2026        );
2027
2028        out
2029    }
2030
2031    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2032        scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);
2033
2034        let dtype = lhs.dtype;
2035        let mut streams = OperationStreams::default();
2036        streams.tensor(&lhs);
2037        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2038
2039        let desc = ScalarOpIr {
2040            lhs: lhs.into_ir(),
2041            rhs: ScalarIr::with_dtype(rhs, &dtype),
2042            out: out.to_ir_out(),
2043        };
2044        out.client.register(
2045            streams,
2046            OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())),
2047            BitwiseAndOps::<B>::new(desc),
2048        );
2049
2050        out
2051    }
2052
2053    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2054        binary_int_ops!(BitwiseOrOps, B::bitwise_or);
2055
2056        let dtype = lhs.dtype;
2057        let mut streams = OperationStreams::default();
2058        streams.tensor(&lhs);
2059        streams.tensor(&rhs);
2060        let out = lhs
2061            .client
2062            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2063
2064        let desc = BinaryOpIr {
2065            lhs: lhs.into_ir(),
2066            rhs: rhs.into_ir(),
2067            out: out.to_ir_out(),
2068        };
2069        out.client.register(
2070            streams,
2071            OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())),
2072            BitwiseOrOps::<B>::new(desc),
2073        );
2074
2075        out
2076    }
2077
2078    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2079        scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);
2080
2081        let dtype = lhs.dtype;
2082        let mut streams = OperationStreams::default();
2083        streams.tensor(&lhs);
2084        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2085
2086        let desc = ScalarOpIr {
2087            lhs: lhs.into_ir(),
2088            rhs: ScalarIr::with_dtype(rhs, &dtype),
2089            out: out.to_ir_out(),
2090        };
2091        out.client.register(
2092            streams,
2093            OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())),
2094            BitwiseOrOps::<B>::new(desc),
2095        );
2096
2097        out
2098    }
2099
2100    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2101        binary_int_ops!(BitwiseXorOps, B::bitwise_xor);
2102
2103        let dtype = lhs.dtype;
2104        let mut streams = OperationStreams::default();
2105        streams.tensor(&lhs);
2106        streams.tensor(&rhs);
2107        let out = lhs
2108            .client
2109            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2110
2111        let desc = BinaryOpIr {
2112            lhs: lhs.into_ir(),
2113            rhs: rhs.into_ir(),
2114            out: out.to_ir_out(),
2115        };
2116        out.client.register(
2117            streams,
2118            OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())),
2119            BitwiseXorOps::<B>::new(desc),
2120        );
2121
2122        out
2123    }
2124
2125    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2126        scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);
2127
2128        let dtype = lhs.dtype;
2129        let mut streams = OperationStreams::default();
2130        streams.tensor(&lhs);
2131        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2132
2133        let desc = ScalarOpIr {
2134            lhs: lhs.into_ir(),
2135            rhs: ScalarIr::with_dtype(rhs, &dtype),
2136            out: out.to_ir_out(),
2137        };
2138        out.client.register(
2139            streams,
2140            OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())),
2141            BitwiseXorOps::<B>::new(desc),
2142        );
2143
2144        out
2145    }
2146
2147    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
2148        unary_int_ops!(BitwiseNotOps, B::bitwise_not);
2149
2150        let dtype = tensor.dtype;
2151        let mut streams = OperationStreams::default();
2152        streams.tensor(&tensor);
2153        let out = tensor
2154            .client
2155            .tensor_uninitialized(tensor.shape.clone(), dtype);
2156
2157        let desc = UnaryOpIr {
2158            input: tensor.into_ir(),
2159            out: out.to_ir_out(),
2160        };
2161        out.client.register(
2162            streams,
2163            OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())),
2164            BitwiseNotOps::<B>::new(desc),
2165        );
2166
2167        out
2168    }
2169
2170    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2171        binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);
2172
2173        let dtype = lhs.dtype;
2174        let mut streams = OperationStreams::default();
2175        streams.tensor(&lhs);
2176        streams.tensor(&rhs);
2177        let out = lhs
2178            .client
2179            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2180
2181        let desc = BinaryOpIr {
2182            lhs: lhs.into_ir(),
2183            rhs: rhs.into_ir(),
2184            out: out.to_ir_out(),
2185        };
2186        out.client.register(
2187            streams,
2188            OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())),
2189            BitwiseLeftShiftOps::<B>::new(desc),
2190        );
2191
2192        out
2193    }
2194
2195    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2196        scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);
2197
2198        let dtype = lhs.dtype;
2199        let mut streams = OperationStreams::default();
2200        streams.tensor(&lhs);
2201        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2202
2203        let desc = ScalarOpIr {
2204            lhs: lhs.into_ir(),
2205            rhs: ScalarIr::with_dtype(rhs, &dtype),
2206            out: out.to_ir_out(),
2207        };
2208        out.client.register(
2209            streams,
2210            OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())),
2211            BitwiseLeftShiftOps::<B>::new(desc),
2212        );
2213
2214        out
2215    }
2216
2217    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
2218        binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);
2219
2220        let dtype = lhs.dtype;
2221        let mut streams = OperationStreams::default();
2222        streams.tensor(&lhs);
2223        streams.tensor(&rhs);
2224        let out = lhs
2225            .client
2226            .tensor_uninitialized(lhs.shape.broadcast(&rhs.shape).unwrap(), dtype);
2227
2228        let desc = BinaryOpIr {
2229            lhs: lhs.into_ir(),
2230            rhs: rhs.into_ir(),
2231            out: out.to_ir_out(),
2232        };
2233        out.client.register(
2234            streams,
2235            OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())),
2236            BitwiseRightShiftOps::<B>::new(desc),
2237        );
2238
2239        out
2240    }
2241
2242    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
2243        scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);
2244
2245        let dtype = lhs.dtype;
2246        let mut streams = OperationStreams::default();
2247        streams.tensor(&lhs);
2248        let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype);
2249
2250        let desc = ScalarOpIr {
2251            lhs: lhs.into_ir(),
2252            rhs: ScalarIr::with_dtype(rhs, &dtype),
2253            out: out.to_ir_out(),
2254        };
2255        out.client.register(
2256            streams,
2257            OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())),
2258            BitwiseRightShiftOps::<B>::new(desc),
2259        );
2260
2261        out
2262    }
2263
2264    fn int_cast(tensor: IntTensor<Self>, dtype: burn_tensor::IntDType) -> IntTensor<Self> {
2265        #[derive(new, Debug)]
2266        struct CastOps<B: FusionBackend> {
2267            desc: UnaryOpIr,
2268            dtype: burn_tensor::IntDType,
2269            _b: PhantomData<B>,
2270        }
2271
2272        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {
2273            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2274                let input = handles.get_int_tensor::<B>(&self.desc.input);
2275                let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype);
2276                handles.register_int_tensor::<B>(&self.desc.out.id, output);
2277            }
2278        }
2279
2280        let mut streams = OperationStreams::default();
2281        streams.tensor(&tensor);
2282        let out = tensor
2283            .client
2284            .tensor_uninitialized(tensor.shape.clone(), dtype.into());
2285
2286        let desc = UnaryOpIr {
2287            input: tensor.into_ir(),
2288            out: out.to_ir_out(),
2289        };
2290        out.client.register(
2291            streams,
2292            OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),
2293            CastOps::<B>::new(desc, dtype),
2294        );
2295
2296        out
2297    }
2298
2299    fn int_unfold(
2300        tensor: IntTensor<Self>,
2301        dim: usize,
2302        size: usize,
2303        step: usize,
2304    ) -> IntTensor<Self> {
2305        #[derive(new, Debug)]
2306        struct UnfoldOps<B: FusionBackend> {
2307            desc: UnfoldOpIr,
2308            _b: PhantomData<B>,
2309        }
2310
2311        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2312            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2313                let input = handles.get_int_tensor::<B>(&self.desc.input);
2314                let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);
2315
2316                handles.register_int_tensor::<B>(&self.desc.out.id, output);
2317            }
2318        }
2319
2320        let mut streams = OperationStreams::default();
2321        streams.tensor(&tensor);
2322
2323        let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
2324        let out = tensor
2325            .client
2326            .tensor_uninitialized(Shape::from(shape), tensor.dtype);
2327
2328        let desc = UnfoldOpIr {
2329            input: tensor.into_ir(),
2330            out: out.to_ir_out(),
2331            dim,
2332            size,
2333            step,
2334        };
2335
2336        out.client.register(
2337            streams,
2338            OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2339            UnfoldOps::<B>::new(desc),
2340        );
2341
2342        out
2343    }
2344}