burn_cubecl/ops/
int_tensor.rs

1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::kernel::{
5    BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6    launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{
9    CubeBackend, CubeRuntime, FloatElement, IntElement,
10    kernel::{
11        self,
12        matmul::{MatmulStrategy, matmul},
13    },
14};
15use crate::{
16    element::BoolElement,
17    kernel::prng::{random_bernoulli, random_normal, random_uniform},
18};
19use burn_backend::ExecutionError;
20use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
21use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps};
22use burn_backend::{Distribution, ElementConversion, Shape, TensorData};
23use cubecl::frontend::Numeric;
24use cubecl::prelude::*;
25use cubek::reduce::components::instructions::ReduceOperationConfig;
26use std::ops::Range;
27
28impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
29where
30    R: CubeRuntime,
31    F: FloatElement,
32    I: IntElement,
33    BT: BoolElement,
34{
35    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
36        let dtype = dtype.into();
37        super::empty(shape, device, dtype)
38    }
39
40    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
41        super::into_data(tensor).await
42    }
43
44    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
45        match data.dtype {
46            DType::I64
47            | DType::I32
48            | DType::I16
49            | DType::I8
50            | DType::U64
51            | DType::U32
52            | DType::U16
53            | DType::U8 => super::from_data(data, device),
54            _ => unimplemented!("Unsupported dtype for `int_from_data`"),
55        }
56    }
57
58    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
59        tensor.device.clone()
60    }
61
62    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
63        super::to_device(tensor, device)
64    }
65
66    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
67        super::reshape(tensor, shape)
68    }
69
70    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
71        // Check if all steps are 1
72        let all_steps_one = slices.iter().all(|info| info.step == 1);
73
74        if all_steps_one {
75            // Use optimized slice for step=1
76            let simple_ranges: Vec<Range<usize>> = slices
77                .iter()
78                .enumerate()
79                .map(|(i, slice)| slice.to_range(tensor.shape[i]))
80                .collect();
81
82            kernel::slice(tensor, &simple_ranges)
83        } else {
84            // Use slice with steps kernel
85            kernel::slice_with_steps(tensor, slices)
86        }
87    }
88
89    fn int_slice_assign(
90        tensor: IntTensor<Self>,
91        ranges: &[Slice],
92        value: IntTensor<Self>,
93    ) -> IntTensor<Self> {
94        kernel::slice_assign(tensor, ranges, value)
95    }
96
97    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
98        let dtype = lhs.dtype;
99        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
100    }
101
102    fn int_mask_where(
103        tensor: IntTensor<Self>,
104        mask: BoolTensor<Self>,
105        value: IntTensor<Self>,
106    ) -> IntTensor<Self> {
107        kernel::mask_where_auto(tensor, mask, value, BT::dtype())
108    }
109
110    fn int_mask_fill(
111        tensor: IntTensor<Self>,
112        mask: BoolTensor<Self>,
113        value: IntElem<Self>,
114    ) -> IntTensor<Self> {
115        let dtype = tensor.dtype;
116        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), BT::dtype())
117    }
118
119    fn int_gather(
120        dim: usize,
121        tensor: IntTensor<Self>,
122        indices: IntTensor<Self>,
123    ) -> IntTensor<Self> {
124        kernel::gather(dim, tensor, indices)
125    }
126
127    fn int_scatter_add(
128        dim: usize,
129        tensor: IntTensor<Self>,
130        indices: IntTensor<Self>,
131        value: IntTensor<Self>,
132    ) -> IntTensor<Self> {
133        kernel::scatter(dim, tensor, indices, value, false)
134    }
135
136    fn int_select(
137        tensor: IntTensor<Self>,
138        dim: usize,
139        indices: IntTensor<Self>,
140    ) -> IntTensor<Self> {
141        kernel::select(tensor, dim, indices)
142    }
143
144    fn int_select_add(
145        tensor: IntTensor<Self>,
146        dim: usize,
147        indices: IntTensor<Self>,
148        value: IntTensor<Self>,
149    ) -> IntTensor<Self> {
150        kernel::select_assign(tensor, dim, indices, value, false)
151    }
152
153    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
154        kernel::equal(lhs, rhs, BT::dtype())
155    }
156
157    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
158        let dtype = lhs.dtype;
159        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
160    }
161
162    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
163        kernel::greater(lhs, rhs, BT::dtype())
164    }
165
166    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
167        let dtype = lhs.dtype;
168        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
169    }
170
171    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
172        kernel::greater_equal(lhs, rhs, BT::dtype())
173    }
174
175    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
176        let dtype = lhs.dtype;
177        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
178    }
179
180    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
181        kernel::lower(lhs, rhs, BT::dtype())
182    }
183
184    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
185        let dtype = lhs.dtype;
186        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
187    }
188
189    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
190        kernel::lower_equal(lhs, rhs, BT::dtype())
191    }
192
193    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
194        let dtype = lhs.dtype;
195        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
196    }
197
198    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
199        numeric::add(lhs, rhs)
200    }
201
202    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
203        let dtype = lhs.dtype;
204        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
205    }
206
207    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
208        numeric::sub(lhs, rhs)
209    }
210
211    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
212        let dtype = lhs.dtype;
213        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
214    }
215
216    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
217        numeric::mul(lhs, rhs)
218    }
219
220    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
221        let dtype = lhs.dtype;
222        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
223    }
224
225    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
226        numeric::div(lhs, rhs)
227    }
228
229    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
230        let dtype = lhs.dtype;
231        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
232    }
233
234    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
235        numeric::remainder(lhs, rhs)
236    }
237
238    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
239        let dtype = lhs.dtype;
240        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
241    }
242
243    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
244        let dtype = dtype.into();
245        numeric::zeros(device.clone(), shape, dtype)
246    }
247
248    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
249        let dtype = dtype.into();
250        numeric::ones(device.clone(), shape, dtype)
251    }
252
253    fn int_full(
254        shape: Shape,
255        fill_value: IntElem<Self>,
256        device: &Device<Self>,
257        dtype: IntDType,
258    ) -> IntTensor<Self> {
259        let dtype: DType = dtype.into();
260        let client = R::client(device);
261        numeric::full_device_dtype(
262            client,
263            shape,
264            device.clone(),
265            InputScalar::new(fill_value, dtype),
266            dtype,
267        )
268    }
269
270    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
271        reduce::sum_fallback(tensor, Default::default()).unwrap()
272    }
273
274    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
275        reduce::reduce_dim(
276            tensor,
277            None,
278            dim,
279            Default::default(),
280            ReduceOperationConfig::Sum,
281        )
282        .unwrap()
283    }
284
285    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
286        reduce::reduce(
287            tensor,
288            None,
289            Default::default(),
290            ReduceOperationConfig::Prod,
291        )
292        .unwrap()
293    }
294
295    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
296        reduce::reduce_dim(
297            tensor,
298            None,
299            dim,
300            Default::default(),
301            ReduceOperationConfig::Prod,
302        )
303        .unwrap()
304    }
305
306    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
307        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
308    }
309
310    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
311        reduce::reduce_dim(
312            tensor,
313            None,
314            dim,
315            Default::default(),
316            ReduceOperationConfig::Max,
317        )
318        .unwrap()
319    }
320
321    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
322        reduce::reduce(
323            tensor,
324            None,
325            Default::default(),
326            ReduceOperationConfig::MaxAbs,
327        )
328        .unwrap()
329    }
330
331    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
332        reduce::reduce_dim(
333            tensor,
334            None,
335            dim,
336            Default::default(),
337            ReduceOperationConfig::MaxAbs,
338        )
339        .unwrap()
340    }
341
342    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
343        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
344    }
345
346    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
347        reduce::reduce_dim(
348            tensor,
349            None,
350            dim,
351            Default::default(),
352            ReduceOperationConfig::Min,
353        )
354        .unwrap()
355    }
356
357    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
358        reduce::reduce_dim(
359            tensor,
360            None,
361            dim,
362            Default::default(),
363            ReduceOperationConfig::Mean,
364        )
365        .unwrap()
366    }
367
368    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
369        numeric::cumsum(tensor, dim)
370    }
371
372    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
373        numeric::cumprod(tensor, dim)
374    }
375
376    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
377        numeric::cummin(tensor, dim)
378    }
379
380    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
381        numeric::cummax(tensor, dim)
382    }
383
384    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
385        let dtype = tensor.dtype;
386        reduce::reduce_dim(
387            tensor,
388            Some(dtype),
389            dim,
390            Default::default(),
391            ReduceOperationConfig::ArgMax,
392        )
393        .unwrap()
394    }
395
396    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
397        let dtype = tensor.dtype;
398        reduce::reduce_dim(
399            tensor,
400            Some(dtype),
401            dim,
402            Default::default(),
403            ReduceOperationConfig::ArgMin,
404        )
405        .unwrap()
406    }
407
408    fn int_clamp(
409        tensor: IntTensor<Self>,
410        min: IntElem<Self>,
411        max: IntElem<Self>,
412    ) -> IntTensor<Self> {
413        let dtype = tensor.dtype;
414        kernel::clamp(
415            tensor,
416            InputScalar::new(min, dtype),
417            InputScalar::new(max, dtype),
418        )
419    }
420
421    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
422        struct Abs;
423
424        #[cube]
425        impl<N: Numeric> NumericUnaryOp<N> for Abs {
426            type Options = ();
427
428            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
429                Line::abs(input)
430            }
431        }
432
433        impl NumericUnaryOpFamily for Abs {
434            type Options = ();
435            type Unary<N: Numeric> = Self;
436        }
437
438        launch_unary_numeric::<R, Abs, _>(tensor, |_| ())
439    }
440
441    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
442        kernel::cast(tensor, F::dtype())
443    }
444
445    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
446        tensor.strides.swap(dim1, dim2);
447        tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
448
449        tensor
450    }
451
452    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
453        kernel::repeat_dim(tensor, dim, times)
454    }
455
456    fn int_random(
457        shape: Shape,
458        distribution: Distribution,
459        device: &Device<Self>,
460    ) -> IntTensor<Self> {
461        let dtype = IntElem::<Self>::dtype();
462        match distribution {
463            Distribution::Default => random_uniform(shape, device, 0., 255., dtype),
464            Distribution::Uniform(low, high) => {
465                random_uniform(shape, device, low.elem(), high.elem(), dtype)
466            }
467            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
468            Distribution::Normal(mean, std) => {
469                random_normal(shape, device, mean.elem(), std.elem(), dtype)
470            }
471        }
472    }
473
474    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
475        permute(tensor, axes)
476    }
477
478    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
479        expand(tensor, shape)
480    }
481
482    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
483        kernel::flip(tensor, axes, BT::dtype())
484    }
485
486    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
487        numeric::bitwise_and(lhs, rhs)
488    }
489
490    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
491        let dtype = lhs.dtype;
492        numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype))
493    }
494
495    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
496        numeric::bitwise_or(lhs, rhs)
497    }
498
499    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
500        let dtype = lhs.dtype;
501        numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype))
502    }
503
504    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
505        numeric::bitwise_xor(lhs, rhs)
506    }
507
508    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
509        let dtype = lhs.dtype;
510        numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype))
511    }
512
513    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
514        unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
515    }
516
517    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
518        launch_binop_int::<R, kernel::BitwiseShlOp>(lhs, rhs)
519    }
520
521    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
522        let dtype = lhs.dtype;
523        launch_scalar_binop_int::<R, BitwiseShlOp>(lhs, InputScalar::new(rhs, dtype))
524    }
525
526    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
527        launch_binop_int::<R, BitwiseShrOp>(lhs, rhs)
528    }
529
530    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
531        let dtype = lhs.dtype;
532        launch_scalar_binop_int::<R, BitwiseShrOp>(lhs, InputScalar::new(rhs, dtype))
533    }
534
535    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
536        kernel::cast(tensor, dtype.into())
537    }
538
539    fn int_unfold(
540        tensor: FloatTensor<Self>,
541        dim: usize,
542        size: usize,
543        step: usize,
544    ) -> FloatTensor<Self> {
545        unfold(tensor, dim, size, step)
546    }
547}