1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::{
5 CubeBackend, CubeRuntime, FloatElement, IntElement,
6 kernel::{
7 self,
8 matmul::{MatmulStrategy, matmul},
9 },
10};
11use crate::{
12 element::BoolElement,
13 kernel::prng::{random_bernoulli, random_normal, random_uniform},
14};
15use crate::{
16 execute_with_dtype,
17 kernel::{
18 BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
19 launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
20 },
21};
22use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
23use burn_tensor::{DType, IntDType};
24use burn_tensor::{Distribution, ElementConversion, Shape, TensorData, ops::IntTensorOps};
25use cubecl::frontend::Numeric;
26use cubecl::prelude::*;
27use cubecl::reduce::ReducePrecision;
28use cubecl::reduce::instructions::ReduceFnConfig;
29use std::ops::Range;
30
31impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
32where
33 R: CubeRuntime,
34 F: FloatElement,
35 I: IntElement,
36 BT: BoolElement,
37{
38 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
39 let dtype = dtype.into();
40 execute_with_dtype!(int(dtype), I, super::empty::<R, I>(shape, device))
41 }
42
43 async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
44 execute_with_dtype!(int(tensor.dtype), I, super::into_data::<R, I>(tensor).await)
45 }
46
47 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
48 match data.dtype {
49 DType::I64
50 | DType::I32
51 | DType::I16
52 | DType::I8
53 | DType::U64
54 | DType::U32
55 | DType::U16
56 | DType::U8 => super::from_data::<R>(data, device),
57 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
58 }
59 }
60
61 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
62 tensor.device.clone()
63 }
64
65 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
66 super::to_device(tensor, device)
67 }
68
69 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
70 super::reshape(tensor, shape)
71 }
72
73 fn int_slice(tensor: IntTensor<Self>, slices: &[burn_tensor::Slice]) -> IntTensor<Self> {
74 let all_steps_one = slices.iter().all(|info| info.step == 1);
76
77 if all_steps_one {
78 let simple_ranges: Vec<Range<usize>> = slices
80 .iter()
81 .enumerate()
82 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
83 .collect();
84
85 execute_with_dtype!(
86 int(tensor.dtype),
87 I,
88 kernel::slice::<R, I>(tensor, &simple_ranges)
89 )
90 } else {
91 execute_with_dtype!(
93 int(tensor.dtype),
94 I,
95 kernel::slice_with_steps::<R, I>(tensor, slices)
96 )
97 }
98 }
99
100 fn int_slice_assign(
101 tensor: IntTensor<Self>,
102 ranges: &[burn_tensor::Slice],
103 value: IntTensor<Self>,
104 ) -> IntTensor<Self> {
105 execute_with_dtype!(
106 int(tensor.dtype),
107 I,
108 kernel::slice_assign::<R, I>(tensor, ranges, value)
109 )
110 }
111
112 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
113 let dtype = lhs.dtype;
114 execute_with_dtype!(
115 int(dtype),
116 E,
117 matmul::<R, E>(lhs, rhs, None, MatmulStrategy::default()).unwrap()
118 )
119 }
120
121 fn int_mask_where(
122 tensor: IntTensor<Self>,
123 mask: BoolTensor<Self>,
124 value: IntTensor<Self>,
125 ) -> IntTensor<Self> {
126 execute_with_dtype!(
127 int(tensor.dtype),
128 I,
129 kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
130 )
131 }
132
133 fn int_mask_fill(
134 tensor: IntTensor<Self>,
135 mask: BoolTensor<Self>,
136 value: IntElem<Self>,
137 ) -> IntTensor<Self> {
138 execute_with_dtype!(
139 int(tensor.dtype),
140 I,
141 kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value.elem())
142 )
143 }
144
145 fn int_gather(
146 dim: usize,
147 tensor: IntTensor<Self>,
148 indices: IntTensor<Self>,
149 ) -> IntTensor<Self> {
150 execute_with_dtype!(
151 int(tensor.dtype),
152 E,
153 execute_with_dtype!(
154 int(tensor.dtype),
155 I,
156 kernel::gather::<R, E, I>(dim, tensor, indices)
157 )
158 )
159 }
160
161 fn int_scatter(
162 dim: usize,
163 tensor: IntTensor<Self>,
164 indices: IntTensor<Self>,
165 value: IntTensor<Self>,
166 ) -> IntTensor<Self> {
167 execute_with_dtype!(
168 int(tensor.dtype),
169 E,
170 execute_with_dtype!(
171 int(indices.dtype),
172 I,
173 kernel::scatter::<R, E, I>(dim, tensor, indices, value)
174 )
175 )
176 }
177
178 fn int_select(
179 tensor: IntTensor<Self>,
180 dim: usize,
181 indices: IntTensor<Self>,
182 ) -> IntTensor<Self> {
183 execute_with_dtype!(
184 int(tensor.dtype),
185 E,
186 execute_with_dtype!(
187 int(indices.dtype),
188 I,
189 kernel::select::<R, E, I>(tensor, dim, indices)
190 )
191 )
192 }
193
194 fn int_select_assign(
195 tensor: IntTensor<Self>,
196 dim: usize,
197 indices: IntTensor<Self>,
198 value: IntTensor<Self>,
199 ) -> IntTensor<Self> {
200 execute_with_dtype!(
201 int(tensor.dtype),
202 E,
203 execute_with_dtype!(
204 int(indices.dtype),
205 I,
206 kernel::select_assign::<R, E, I>(tensor, dim, indices, value, false)
207 )
208 )
209 }
210
211 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
212 execute_with_dtype!(int(lhs.dtype), I, kernel::equal::<R, I, BT>(lhs, rhs))
213 }
214
215 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
216 execute_with_dtype!(
217 int(lhs.dtype),
218 I,
219 kernel::equal_elem::<R, I, BT>(lhs, rhs.elem())
220 )
221 }
222
223 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
224 execute_with_dtype!(int(lhs.dtype), I, kernel::greater::<R, I, BT>(lhs, rhs))
225 }
226
227 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
228 execute_with_dtype!(
229 int(lhs.dtype),
230 I,
231 kernel::greater_elem::<R, I, BT>(lhs, rhs.elem())
232 )
233 }
234
235 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
236 execute_with_dtype!(
237 int(lhs.dtype),
238 I,
239 kernel::greater_equal::<R, I, BT>(lhs, rhs)
240 )
241 }
242
243 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
244 execute_with_dtype!(
245 int(lhs.dtype),
246 I,
247 kernel::greater_equal_elem::<R, I, BT>(lhs, rhs.elem())
248 )
249 }
250
251 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
252 execute_with_dtype!(int(lhs.dtype), I, kernel::lower::<R, I, BT>(lhs, rhs))
253 }
254
255 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
256 execute_with_dtype!(
257 int(lhs.dtype),
258 I,
259 kernel::lower_elem::<R, I, BT>(lhs, rhs.elem())
260 )
261 }
262
263 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
264 execute_with_dtype!(int(lhs.dtype), I, kernel::lower_equal::<R, I, BT>(lhs, rhs))
265 }
266
267 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
268 execute_with_dtype!(
269 int(lhs.dtype),
270 I,
271 kernel::lower_equal_elem::<R, I, BT>(lhs, rhs.elem())
272 )
273 }
274
275 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
276 execute_with_dtype!(int(lhs.dtype), I, numeric::add::<R, I>(lhs, rhs))
277 }
278
279 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
280 execute_with_dtype!(
281 int(lhs.dtype),
282 I,
283 numeric::add_scalar::<R, I>(lhs, rhs.elem())
284 )
285 }
286
287 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
288 execute_with_dtype!(int(lhs.dtype), I, numeric::sub::<R, I>(lhs, rhs))
289 }
290
291 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
292 execute_with_dtype!(
293 int(lhs.dtype),
294 I,
295 numeric::sub_scalar::<R, I>(lhs, rhs.elem())
296 )
297 }
298
299 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
300 execute_with_dtype!(int(lhs.dtype), I, numeric::mul::<R, I>(lhs, rhs))
301 }
302
303 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
304 execute_with_dtype!(
305 int(lhs.dtype),
306 I,
307 numeric::mul_scalar::<R, I>(lhs, rhs.elem())
308 )
309 }
310
311 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
312 execute_with_dtype!(int(lhs.dtype), I, numeric::div::<R, I>(lhs, rhs))
313 }
314
315 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
316 execute_with_dtype!(
317 int(lhs.dtype),
318 I,
319 numeric::div_scalar::<R, I>(lhs, rhs.elem())
320 )
321 }
322
323 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
324 execute_with_dtype!(int(lhs.dtype), I, numeric::remainder::<R, I>(lhs, rhs))
325 }
326
327 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
328 execute_with_dtype!(
329 int(lhs.dtype),
330 I,
331 numeric::remainder_scalar::<R, I>(lhs, rhs.elem())
332 )
333 }
334
335 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
336 let dtype = dtype.into();
337 execute_with_dtype!(int(dtype), I, numeric::zeros::<R, I>(shape, device))
338 }
339
340 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
341 let dtype = dtype.into();
342 execute_with_dtype!(int(dtype), I, numeric::ones::<R, I>(shape, device))
343 }
344
345 fn int_full(
346 shape: Shape,
347 fill_value: IntElem<Self>,
348 device: &Device<Self>,
349 dtype: IntDType,
350 ) -> IntTensor<Self> {
351 let dtype = dtype.into();
352 execute_with_dtype!(
353 int(dtype),
354 I,
355 numeric::full::<R, I>(shape, device, fill_value.elem())
356 )
357 }
358
359 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
360 execute_with_dtype!(
361 int(tensor.dtype),
362 I,
363 reduce::sum_fallback::<R, I>(tensor, Default::default()).unwrap()
364 )
365 }
366
367 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
368 execute_with_dtype!(
369 int(tensor.dtype),
370 I,
371 reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
372 tensor,
373 dim,
374 Default::default(),
375 ReduceFnConfig::Sum,
376 )
377 .unwrap()
378 )
379 }
380
381 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
382 execute_with_dtype!(
383 int(tensor.dtype),
384 I,
385 reduce::reduce::<R, I, I, <I as ReducePrecision>::EA>(
386 tensor,
387 Default::default(),
388 ReduceFnConfig::Prod,
389 )
390 .unwrap()
391 )
392 }
393
394 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
395 execute_with_dtype!(
396 int(tensor.dtype),
397 I,
398 reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
399 tensor,
400 dim,
401 Default::default(),
402 ReduceFnConfig::Prod,
403 )
404 .unwrap()
405 )
406 }
407
408 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
409 execute_with_dtype!(
410 int(tensor.dtype),
411 I,
412 reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
413 )
414 }
415
416 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
417 execute_with_dtype!(
418 int(tensor.dtype),
419 I,
420 reduce::reduce_dim::<R, I, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Max)
421 .unwrap()
422 )
423 }
424
425 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
426 execute_with_dtype!(
427 int(tensor.dtype),
428 I,
429 reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::MaxAbs)
430 .unwrap()
431 )
432 }
433
434 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
435 execute_with_dtype!(
436 int(tensor.dtype),
437 I,
438 reduce::reduce_dim::<R, I, I, I>(
439 tensor,
440 dim,
441 Default::default(),
442 ReduceFnConfig::MaxAbs
443 )
444 .unwrap()
445 )
446 }
447
448 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
449 execute_with_dtype!(
450 int(tensor.dtype),
451 I,
452 reduce::reduce::<R, I, I, I>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
453 )
454 }
455
456 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
457 execute_with_dtype!(
458 int(tensor.dtype),
459 I,
460 reduce::reduce_dim::<R, I, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Min)
461 .unwrap()
462 )
463 }
464
465 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
466 execute_with_dtype!(
467 int(tensor.dtype),
468 I,
469 reduce::reduce_dim::<R, I, I, <I as ReducePrecision>::EA>(
470 tensor,
471 dim,
472 Default::default(),
473 ReduceFnConfig::Mean,
474 )
475 .unwrap()
476 )
477 }
478
479 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
480 execute_with_dtype!(int(tensor.dtype), I, numeric::cumsum::<R, I>(tensor, dim))
481 }
482
483 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
484 execute_with_dtype!(int(tensor.dtype), I, numeric::cumprod::<R, I>(tensor, dim))
485 }
486
487 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
488 execute_with_dtype!(int(tensor.dtype), I, numeric::cummin::<R, I>(tensor, dim))
489 }
490
491 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
492 execute_with_dtype!(int(tensor.dtype), I, numeric::cummax::<R, I>(tensor, dim))
493 }
494
495 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
496 execute_with_dtype!(
497 int(tensor.dtype),
498 I,
499 reduce::reduce_dim::<R, I, I, I>(
500 tensor,
501 dim,
502 Default::default(),
503 ReduceFnConfig::ArgMax
504 )
505 .unwrap()
506 )
507 }
508
509 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
510 execute_with_dtype!(
511 int(tensor.dtype),
512 I,
513 reduce::reduce_dim::<R, I, I, I>(
514 tensor,
515 dim,
516 Default::default(),
517 ReduceFnConfig::ArgMin
518 )
519 .unwrap()
520 )
521 }
522
523 fn int_clamp(
524 tensor: IntTensor<Self>,
525 min: IntElem<Self>,
526 max: IntElem<Self>,
527 ) -> IntTensor<Self> {
528 execute_with_dtype!(
529 int(tensor.dtype),
530 I,
531 kernel::clamp::<R, I>(tensor, min.elem(), max.elem())
532 )
533 }
534
535 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
536 struct Abs;
537
538 #[cube]
539 impl<N: Numeric> NumericUnaryOp<N> for Abs {
540 type Options = ();
541
542 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
543 Line::abs(input)
544 }
545 }
546
547 impl NumericUnaryOpFamily for Abs {
548 type Options<N: Numeric> = ();
549 type Unary<N: Numeric> = Self;
550 }
551
552 execute_with_dtype!(
553 int(tensor.dtype),
554 I,
555 launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
556 )
557 }
558
559 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
560 execute_with_dtype!(int(tensor.dtype), I, kernel::cast::<R, I, F>(tensor))
561 }
562
563 fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
564 tensor.strides.swap(dim1, dim2);
565 tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
566
567 tensor
568 }
569
570 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
571 execute_with_dtype!(
572 int(tensor.dtype),
573 I,
574 kernel::repeat_dim::<R, I>(tensor, dim, times)
575 )
576 }
577
578 fn int_random(
579 shape: Shape,
580 distribution: Distribution,
581 device: &Device<Self>,
582 ) -> IntTensor<Self> {
583 match distribution {
584 Distribution::Default => random_uniform(shape, device, 0.elem::<I>(), 255.elem()),
585 Distribution::Uniform(low, high) => {
586 random_uniform(shape, device, low.elem::<I>(), high.elem())
587 }
588 Distribution::Bernoulli(prob) => random_bernoulli::<R, I>(shape, device, prob as f32),
589 Distribution::Normal(mean, std) => {
590 random_normal(shape, device, mean.elem::<I>(), std.elem())
591 }
592 }
593 }
594
595 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
596 permute(tensor, axes)
597 }
598
599 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
600 expand(tensor, shape)
601 }
602
603 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
604 execute_with_dtype!(int(tensor.dtype), I, kernel::flip::<R, I, BT>(tensor, axes))
605 }
606
607 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
608 execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_and::<R, I>(lhs, rhs))
609 }
610
611 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
612 execute_with_dtype!(
613 int(lhs.dtype),
614 I,
615 numeric::bitwise_and_scalar::<R, I>(lhs, rhs.elem())
616 )
617 }
618
619 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
620 execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_or::<R, I>(lhs, rhs))
621 }
622
623 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
624 execute_with_dtype!(
625 int(lhs.dtype),
626 I,
627 numeric::bitwise_or_scalar::<R, I>(lhs, rhs.elem())
628 )
629 }
630
631 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
632 execute_with_dtype!(int(lhs.dtype), I, numeric::bitwise_xor::<R, I>(lhs, rhs))
633 }
634
635 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
636 execute_with_dtype!(
637 int(lhs.dtype),
638 I,
639 numeric::bitwise_xor_scalar::<R, I>(lhs, rhs.elem())
640 )
641 }
642
643 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
644 execute_with_dtype!(
645 int(tensor.dtype),
646 I,
647 unary_basic_int::launch::<R, _, I>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
648 )
649 }
650
651 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
652 execute_with_dtype!(
653 int(lhs.dtype),
654 I,
655 launch_binop_int::<R, I, kernel::BitwiseShlOp>(lhs, rhs)
656 )
657 }
658
659 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
660 execute_with_dtype!(
661 int(lhs.dtype),
662 I,
663 launch_scalar_binop_int::<R, I, BitwiseShlOp>(lhs, rhs.elem())
664 )
665 }
666
667 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
668 execute_with_dtype!(
669 int(lhs.dtype),
670 I,
671 launch_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
672 )
673 }
674
675 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
676 execute_with_dtype!(
677 int(lhs.dtype),
678 I,
679 launch_scalar_binop_int::<R, I, BitwiseShrOp>(lhs, rhs.elem())
680 )
681 }
682
683 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
684 if tensor.dtype == dtype.into() {
685 return tensor;
686 }
687
688 execute_with_dtype!(
689 int(tensor.dtype),
690 I,
691 match dtype {
692 IntDType::I64 => kernel::cast::<R, I, i64>(tensor),
693 IntDType::I32 => kernel::cast::<R, I, i32>(tensor),
694 IntDType::I16 => kernel::cast::<R, I, i16>(tensor),
695 IntDType::I8 => kernel::cast::<R, I, i8>(tensor),
696 IntDType::U64 => kernel::cast::<R, I, u64>(tensor),
697 IntDType::U32 => kernel::cast::<R, I, u32>(tensor),
698 IntDType::U16 => kernel::cast::<R, I, u16>(tensor),
699 IntDType::U8 => kernel::cast::<R, I, u8>(tensor),
700 }
701 )
702 }
703
704 fn int_unfold(
705 tensor: FloatTensor<Self>,
706 dim: usize,
707 size: usize,
708 step: usize,
709 ) -> FloatTensor<Self> {
710 unfold(tensor, dim, size, step)
711 }
712}