burn_cubecl/ops/
int_ops.rs

1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::{
5    CubeBackend, CubeRuntime, FloatElement, IntElement,
6    kernel::{
7        self,
8        matmul::{MatmulStrategy, matmul},
9    },
10};
11use crate::{
12    element::BoolElement,
13    kernel::prng::{random_bernoulli, random_normal, random_uniform},
14};
15use crate::{
16    execute_with_dtype,
17    kernel::{
18        BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
19        launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
20    },
21};
22use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
23use burn_tensor::{DType, IntDType};
24use burn_tensor::{Distribution, ElementConversion, Shape, TensorData, ops::IntTensorOps};
25use cubecl::frontend::Numeric;
26use cubecl::prelude::*;
27use cubecl::reduce::ReducePrecision;
28use cubecl::reduce::instructions::ReduceFnConfig;
29use std::ops::Range;
30
31impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
32where
33    R: CubeRuntime,
34    F: FloatElement,
35    I: IntElement,
36    BT: BoolElement,
37{
38    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
39        let dtype = dtype.into();
40        execute_with_dtype!(int(dtype), I, super::empty::<R, I>(shape, device))
41    }
42
43    async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
44        execute_with_dtype!(int(tensor.dtype), I, super::into_data::<R, I>(tensor).await)
45    }
46
47    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
48        match data.dtype {
49            DType::I64
50            | DType::I32
51            | DType::I16
52            | DType::I8
53            | DType::U64
54            | DType::U32
55            | DType::U16
56            | DType::U8 => super::from_data::<R>(data, device),
57            _ => unimplemented!("Unsupported dtype for `int_from_data`"),
58        }
59    }
60
61    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
62        tensor.device.clone()
63    }
64
65    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
66        super::to_device(tensor, device)
67    }
68
69    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
70        super::reshape(tensor, shape)
71    }
72
73    fn int_slice(tensor: IntTensor<Self>, slices: &[burn_tensor::Slice]) -> IntTensor<Self> {
74        // Check if all steps are 1
75        let all_steps_one = slices.iter().all(|info| info.step == 1);
76
77        if all_steps_one {
78            // Use optimized slice for step=1
79            let simple_ranges: Vec<Range<usize>> = slices
80                .iter()
81                .enumerate()
82                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
83                .collect();
84
85            execute_with_dtype!(
86                int(tensor.dtype),
87                I,
88                kernel::slice::<R, I>(tensor, &simple_ranges)
89            )
90        } else {
91            // Use slice with steps kernel
92            execute_with_dtype!(
93                int(tensor.dtype),
94                I,
95                kernel::slice_with_steps::<R, I>(tensor, slices)
96            )
97        }
98    }
99
100    fn int_slice_assign(
101        tensor: IntTensor<Self>,
102        ranges: &[burn_tensor::Slice],
103        value: IntTensor<Self>,
104    ) -> IntTensor<Self> {
105        execute_with_dtype!(
106            int(tensor.dtype),
107            I,
108            kernel::slice_assign::<R, I>(tensor, ranges, value)
109        )
110    }
111
112    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
113        let dtype = lhs.dtype;
114        execute_with_dtype!(
115            int(dtype),
116            E,
117            matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
118        )
119    }
120
121    fn int_mask_where(
122        tensor: IntTensor<Self>,
123        mask: BoolTensor<Self>,
124        value: IntTensor<Self>,
125    ) -> IntTensor<Self> {
126        execute_with_dtype!(
127            int(tensor.dtype),
128            I,
129            kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
130        )
131    }
132
133    fn int_mask_fill(
134        tensor: IntTensor<Self>,
135        mask: BoolTensor<Self>,
136        value: IntElem<Self>,
137    ) -> IntTensor<Self> {
138        execute_with_dtype!(
139            int(tensor.dtype),
140            I,
141            kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value.elem())
142        )
143    }
144
145    fn int_gather(
146        dim: usize,
147        tensor: IntTensor<Self>,
148        indices: IntTensor<Self>,
149    ) -> IntTensor<Self> {
150        execute_with_dtype!(
151            int(tensor.dtype),
152            E,
153            execute_with_dtype!(
154                int(tensor.dtype),
155                I,
156                kernel::gather::<R, E, I>(dim, tensor, indices)
157            )
158        )
159    }
160
161    fn int_scatter(
162        dim: usize,
163        tensor: IntTensor<Self>,
164        indices: IntTensor<Self>,
165        value: IntTensor<Self>,
166    ) -> IntTensor<Self> {
167        execute_with_dtype!(
168            int(tensor.dtype),
169            E,
170            execute_with_dtype!(
171                int(indices.dtype),
172                I,
173                kernel::scatter::<R, E, I>(dim, tensor, indices, value)
174            )
175        )
176    }
177
178    fn int_select(
179        tensor: IntTensor<Self>,
180        dim: usize,
181        indices: IntTensor<Self>,
182    ) -> IntTensor<Self> {
183        execute_with_dtype!(
184            int(tensor.dtype),
185            E,
186            execute_with_dtype!(
187                int(indices.dtype),
188                I,
189                kernel::select::<R, E, I>(tensor, dim, indices)
190            )
191        )
192    }
193
194    fn int_select_assign(
195        tensor: IntTensor<Self>,
196        dim: usize,
197        indices: IntTensor<Self>,
198        value: IntTensor<Self>,
199    ) -> IntTensor<Self> {
200        execute_with_dtype!(
201            int(tensor.dtype),
202            E,
203            execute_with_dtype!(
204                int(indices.dtype),
205                I,
206                kernel::select_assign::<R, E, I>(tensor, dim, indices, value, false)
207            )
208        )
209    }
210
211    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
212        execute_with_dtype!(int(lhs.dtype), I, kernel::equal::<R, I, BT>(lhs, rhs))
213    }
214
215    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
216        execute_with_dtype!(
217            int(lhs.dtype),
218            I,
219            kernel::equal_elem::<R, I, BT>(lhs, rhs.elem())
220        )
221    }
222
223    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
224        execute_with_dtype!(int(lhs.dtype), I, kernel::greater::<R, I, BT>(lhs, rhs))
225    }
226
227    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
228        execute_with_dtype!(
229            int(lhs.dtype),
230            I,
231            kernel::greater_elem::<R, I, BT>(lhs, rhs.elem())
232        )
233    }
234
235    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
236        execute_with_dtype!(
237            int(lhs.dtype),
238            I,
239            kernel::greater_equal::<R, I, BT>(lhs, rhs)
240        )
241    }
242
243    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
244        execute_with_dtype!(
245            int(lhs.dtype),
246            I,
247            kernel::greater_equal_elem::<R, I, BT>(lhs, rhs.elem())
248        )
249    }
250
251    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
252        execute_with_dtype!(int(lhs.dtype), I, kernel::lower::<R, I, BT>(lhs, rhs))
253    }
254
255    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
256        execute_with_dtype!(
257            int(lhs.dtype),
258            I,
259            kernel::lower_elem::<R, I, BT>(lhs, rhs.elem())
260        )
261    }
262
263    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
264        execute_with_dtype!(int(lhs.dtype), I, kernel::lower_equal::<R, I, BT>(lhs, rhs))
265    }
266
267    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
268        execute_with_dtype!(
269            int(lhs.dtype),
270            I,
271            kernel::lower_equal_elem::<R, I, BT>(lhs, rhs.elem())
272        )
273    }
274
275    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
276        execute_with_dtype!(int(lhs.dtype), I, numeric::add::<R, I>(lhs, rhs))
277    }
278
279    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
280        execute_with_dtype!(
281            int(lhs.dtype),
282            I,
283            numeric::add_scalar::<R, I>(lhs, rhs.elem())
284        )
285    }
286
287    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
288        execute_with_dtype!(int(lhs.dtype), I, numeric::sub::<R, I>(lhs, rhs))
289    }
290
291    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
292        execute_with_dtype!(
293            int(lhs.dtype),
294            I,
295            numeric::sub_scalar::<R, I>(lhs, rhs.elem())
296        )
297    }
298
299    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
300        execute_with_dtype!(int(lhs.dtype), I, numeric::mul::<R, I>(lhs, rhs))
301    }
302
303    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
304        execute_with_dtype!(
305            int(lhs.dtype),
306            I,
307            numeric::mul_scalar::<R, I>(lhs, rhs.elem())
308        )
309    }
310
311    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
312        execute_with_dtype!(int(lhs.dtype), I, numeric::div::<R, I>(lhs, rhs))
313    }
314
315    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
316        execute_with_dtype!(
317            int(lhs.dtype),
318            I,
319            numeric::div_scalar::<R, I>(lhs, rhs.elem())
320        )
321    }
322
323    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
324        execute_with_dtype!(int(lhs.dtype), I, numeric::remainder::<R, I>(lhs, rhs))
325    }
326
327    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
328        execute_with_dtype!(
329            int(lhs.dtype),
330            I,
331            numeric::remainder_scalar::<R, I>(lhs, rhs.elem())
332        )
333    }
334
335    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
336        let dtype = dtype.into();
337        execute_with_dtype!(int(dtype), I, numeric::zeros::<R, I>(shape, device))
338    }
339
340    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
341        let dtype = dtype.into();
342        execute_with_dtype!(int(dtype), I, numeric::ones::<R, I>(shape, device))
343    }
344
345    fn int_full(
346        shape: Shape,
347        fill_value: IntElem<Self>,
348        device: &Device<Self>,
349        dtype: IntDType,
350    ) -> IntTensor<Self> {
351        let dtype = dtype.into();
352        execute_with_dtype!(
353            int(dtype),
354            I,
355            numeric::full::<R, I>(shape, device, fill_value.elem())
356        )
357    }
358
359    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
360        execute_with_dtype!(
361            int(tensor.dtype),
362            I,
363            reduce::sum_fallback::<R, I>(tensor, Default::default()).unwrap()
364        )
365    }
366
367    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
368        execute_with_dtype!(
369            int(tensor.dtype),
370            I,
371            reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
372                tensor,
373                dim,
374                Default::default(),
375                ReduceFnConfig::Sum,
376            )
377            .unwrap()
378        )
379    }
380
381    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
382        execute_with_dtype!(
383            int(tensor.dtype),
384            I,
385            reduce::reduce::<R, I, I, <I as ReducePrecision>::EA>(
386                tensor,
387                Default::default(),
388                ReduceFnConfig::Prod,
389            )
390            .unwrap()
391        )
392    }
393
394    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
395        execute_with_dtype!(
396            int(tensor.dtype),
397            I,
398            reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
399                tensor,
400                dim,
401                Default::default(),
402                ReduceFnConfig::Prod,
403            )
404            .unwrap()
405        )
406    }
407
408    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
409        execute_with_dtype!(
410            int(tensor.dtype),
411            I,
412            reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
413        )
414    }
415
416    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
417        execute_with_dtype!(
418            int(tensor.dtype),
419            I,
420            reduce::reduce_dim::<R, I, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Max)
421                .unwrap()
422        )
423    }
424
425    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
426        execute_with_dtype!(
427            int(tensor.dtype),
428            I,
429            reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::MaxAbs)
430                .unwrap()
431        )
432    }
433
434    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
435        execute_with_dtype!(
436            int(tensor.dtype),
437            I,
438            reduce::reduce_dim::<R, I, I, I>(
439                tensor,
440                dim,
441                Default::default(),
442                ReduceFnConfig::MaxAbs
443            )
444            .unwrap()
445        )
446    }
447
448    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
449        execute_with_dtype!(
450            int(tensor.dtype),
451            I,
452            reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
453        )
454    }
455
456    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
457        execute_with_dtype!(
458            int(tensor.dtype),
459            I,
460            reduce::reduce_dim::<R, I, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Min)
461                .unwrap()
462        )
463    }
464
465    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
466        execute_with_dtype!(
467            int(tensor.dtype),
468            I,
469            reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
470                tensor,
471                dim,
472                Default::default(),
473                ReduceFnConfig::Mean,
474            )
475            .unwrap()
476        )
477    }
478
479    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
480        execute_with_dtype!(int(tensor.dtype), I, numeric::cumsum::<R, I>(tensor, dim))
481    }
482
483    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
484        execute_with_dtype!(int(tensor.dtype), I, numeric::cumprod::<R, I>(tensor, dim))
485    }
486
487    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
488        execute_with_dtype!(int(tensor.dtype), I, numeric::cummin::<R, I>(tensor, dim))
489    }
490
491    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
492        execute_with_dtype!(int(tensor.dtype), I, numeric::cummax::<R, I>(tensor, dim))
493    }
494
495    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
496        execute_with_dtype!(
497            int(tensor.dtype),
498            I,
499            reduce::reduce_dim::<R, I, I, I>(
500                tensor,
501                dim,
502                Default::default(),
503                ReduceFnConfig::ArgMax
504            )
505            .unwrap()
506        )
507    }
508
509    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
510        execute_with_dtype!(
511            int(tensor.dtype),
512            I,
513            reduce::reduce_dim::<R, I, I, I>(
514                tensor,
515                dim,
516                Default::default(),
517                ReduceFnConfig::ArgMin
518            )
519            .unwrap()
520        )
521    }
522
523    fn int_clamp(
524        tensor: IntTensor<Self>,
525        min: IntElem<Self>,
526        max: IntElem<Self>,
527    ) -> IntTensor<Self> {
528        execute_with_dtype!(
529            int(tensor.dtype),
530            I,
531            kernel::clamp::<R, I>(tensor, min.elem(), max.elem())
532        )
533    }
534
535    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
536        struct Abs;
537
538        #[cube]
539        impl<N: Numeric> NumericUnaryOp<N> for Abs {
540            type Options = ();
541
542            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
543                Line::abs(input)
544            }
545        }
546
547        impl NumericUnaryOpFamily for Abs {
548            type Options<N: Numeric> = ();
549            type Unary<N: Numeric> = Self;
550        }
551
552        execute_with_dtype!(
553            int(tensor.dtype),
554            I,
555            launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
556        )
557    }
558
559    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
560        execute_with_dtype!(int(tensor.dtype), I, kernel::cast::<R, I, F>(tensor))
561    }
562
563    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
564        tensor.strides.swap(dim1, dim2);
565        tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
566
567        tensor
568    }
569
570    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
571        execute_with_dtype!(
572            int(tensor.dtype),
573            I,
574            kernel::repeat_dim::<R, I>(tensor, dim, times)
575        )
576    }
577
578    fn int_random(
579        shape: Shape,
580        distribution: Distribution,
581        device: &Device<Self>,
582    ) -> IntTensor<Self> {
583        match distribution {
584            Distribution::Default => random_uniform(shape, device, 0.elem::<I>(), 255.elem()),
585            Distribution::Uniform(low, high) => {
586                random_uniform(shape, device, low.elem::<I>(), high.elem())
587            }
588            Distribution::Bernoulli(prob) => random_bernoulli::<R, I>(shape, device, prob as f32),
589            Distribution::Normal(mean, std) => {
590                random_normal(shape, device, mean.elem::<I>(), std.elem())
591            }
592        }
593    }
594
595    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
596        permute(tensor, axes)
597    }
598
599    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
600        expand(tensor, shape)
601    }
602
603    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
604        execute_with_dtype!(int(tensor.dtype), I, kernel::flip::<R, I, BT>(tensor, axes))
605    }
606
607    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
608        execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_and::<R, I>(lhs, rhs))
609    }
610
611    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
612        execute_with_dtype!(
613            int(lhs.dtype),
614            I,
615            numeric::bitwise_and_scalar::<R, I>(lhs, rhs.elem())
616        )
617    }
618
619    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
620        execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_or::<R, I>(lhs, rhs))
621    }
622
623    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
624        execute_with_dtype!(
625            int(lhs.dtype),
626            I,
627            numeric::bitwise_or_scalar::<R, I>(lhs, rhs.elem())
628        )
629    }
630
631    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
632        execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_xor::<R, I>(lhs, rhs))
633    }
634
635    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
636        execute_with_dtype!(
637            int(lhs.dtype),
638            I,
639            numeric::bitwise_xor_scalar::<R, I>(lhs, rhs.elem())
640        )
641    }
642
643    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
644        execute_with_dtype!(
645            int(tensor.dtype),
646            I,
647            unary_basic_int::launch::<R, _, I>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
648        )
649    }
650
651    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
652        execute_with_dtype!(
653            int(lhs.dtype),
654            I,
655            launch_binop_int::<R, I, kernel::BitwiseShlOp>(lhs, rhs)
656        )
657    }
658
659    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
660        execute_with_dtype!(
661            int(lhs.dtype),
662            I,
663            launch_scalar_binop_int::<R, I, BitwiseShlOp>(lhs, rhs.elem())
664        )
665    }
666
667    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
668        execute_with_dtype!(
669            int(lhs.dtype),
670            I,
671            launch_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
672        )
673    }
674
675    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
676        execute_with_dtype!(
677            int(lhs.dtype),
678            I,
679            launch_scalar_binop_int::<R, I, BitwiseShrOp>(lhs, rhs.elem())
680        )
681    }
682
683    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
684        if tensor.dtype == dtype.into() {
685            return tensor;
686        }
687
688        execute_with_dtype!(
689            int(tensor.dtype),
690            I,
691            match dtype {
692                IntDType::I64 => kernel::cast::<R, I, i64>(tensor),
693                IntDType::I32 => kernel::cast::<R, I, i32>(tensor),
694                IntDType::I16 => kernel::cast::<R, I, i16>(tensor),
695                IntDType::I8 => kernel::cast::<R, I, i8>(tensor),
696                IntDType::U64 => kernel::cast::<R, I, u64>(tensor),
697                IntDType::U32 => kernel::cast::<R, I, u32>(tensor),
698                IntDType::U16 => kernel::cast::<R, I, u16>(tensor),
699                IntDType::U8 => kernel::cast::<R, I, u8>(tensor),
700            }
701        )
702    }
703
704    fn int_unfold(
705        tensor: FloatTensor<Self>,
706        dim: usize,
707        size: usize,
708        step: usize,
709    ) -> FloatTensor<Self> {
710        unfold(tensor, dim, size, step)
711    }
712}