Skip to main content

burn_cubecl/ops/
int_tensor.rs

1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::kernel::{
5    BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6    launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{
9    CubeBackend, CubeRuntime, FloatElement, IntElement,
10    kernel::{
11        self,
12        matmul::{MatmulStrategy, matmul},
13    },
14};
15use crate::{
16    element::BoolElement,
17    kernel::prng::{random_bernoulli, random_normal, random_uniform},
18};
19use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
20use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps};
21use burn_backend::{Distribution, ElementConversion, Shape, TensorData, get_device_settings};
22use burn_backend::{ExecutionError, Scalar};
23use burn_std::{BoolDType, FloatDType};
24use cubecl::frontend::Numeric;
25use cubecl::prelude::*;
26use cubek::reduce::components::instructions::ReduceOperationConfig;
27use std::ops::Range;
28
29impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
30where
31    R: CubeRuntime,
32    F: FloatElement,
33    I: IntElement,
34    BT: BoolElement,
35{
36    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
37        let dtype = dtype.into();
38        super::empty(shape, device, dtype)
39    }
40
41    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
42        super::into_data(tensor).await
43    }
44
45    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
46        match data.dtype {
47            DType::I64
48            | DType::I32
49            | DType::I16
50            | DType::I8
51            | DType::U64
52            | DType::U32
53            | DType::U16
54            | DType::U8 => super::from_data(data, device),
55            _ => unimplemented!("Unsupported dtype for `int_from_data`"),
56        }
57    }
58
59    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
60        tensor.device.clone()
61    }
62
63    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
64        super::to_device(tensor, device)
65    }
66
67    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
68        super::reshape(tensor, shape)
69    }
70
71    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
72        // Check if all steps are 1
73        let all_steps_one = slices.iter().all(|info| info.step == 1);
74
75        if all_steps_one {
76            // Use optimized slice for step=1
77            let simple_ranges: Vec<Range<usize>> = slices
78                .iter()
79                .enumerate()
80                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
81                .collect();
82
83            kernel::slice(tensor, &simple_ranges)
84        } else {
85            // Use slice with steps kernel
86            kernel::slice_with_steps(tensor, slices)
87        }
88    }
89
90    fn int_slice_assign(
91        tensor: IntTensor<Self>,
92        ranges: &[Slice],
93        value: IntTensor<Self>,
94    ) -> IntTensor<Self> {
95        kernel::slice_assign(tensor, ranges, value)
96    }
97
98    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
99        let dtype = lhs.dtype;
100        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
101    }
102
103    fn int_mask_where(
104        tensor: IntTensor<Self>,
105        mask: BoolTensor<Self>,
106        value: IntTensor<Self>,
107    ) -> IntTensor<Self> {
108        let bool_dtype = mask.dtype;
109        kernel::mask_where_auto(tensor, mask, value, bool_dtype)
110    }
111
112    fn int_mask_fill(
113        tensor: IntTensor<Self>,
114        mask: BoolTensor<Self>,
115        value: Scalar,
116    ) -> IntTensor<Self> {
117        let dtype = tensor.dtype;
118        let bool_dtype = mask.dtype;
119        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), bool_dtype)
120    }
121
122    fn int_gather(
123        dim: usize,
124        tensor: IntTensor<Self>,
125        indices: IntTensor<Self>,
126    ) -> IntTensor<Self> {
127        kernel::gather(dim, tensor, indices)
128    }
129
130    fn int_scatter_add(
131        dim: usize,
132        tensor: IntTensor<Self>,
133        indices: IntTensor<Self>,
134        value: IntTensor<Self>,
135    ) -> IntTensor<Self> {
136        kernel::scatter(dim, tensor, indices, value, false)
137    }
138
139    fn int_scatter_nd(
140        data: IntTensor<Self>,
141        indices: IntTensor<Self>,
142        values: IntTensor<Self>,
143        reduction: burn_backend::tensor::IndexingUpdateOp,
144    ) -> IntTensor<Self> {
145        kernel::scatter_nd(data, indices, values, reduction)
146    }
147
148    fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
149        kernel::gather_nd(data, indices)
150    }
151
152    fn int_select(
153        tensor: IntTensor<Self>,
154        dim: usize,
155        indices: IntTensor<Self>,
156    ) -> IntTensor<Self> {
157        kernel::select(tensor, dim, indices)
158    }
159
160    fn int_select_add(
161        tensor: IntTensor<Self>,
162        dim: usize,
163        indices: IntTensor<Self>,
164        value: IntTensor<Self>,
165    ) -> IntTensor<Self> {
166        kernel::select_assign(tensor, dim, indices, value, false)
167    }
168
169    fn int_equal(
170        lhs: IntTensor<Self>,
171        rhs: IntTensor<Self>,
172        out_dtype: BoolDType,
173    ) -> BoolTensor<Self> {
174        kernel::equal(lhs, rhs, out_dtype.into())
175    }
176
177    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
178        let dtype = lhs.dtype;
179        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
180    }
181
182    fn int_greater(
183        lhs: IntTensor<Self>,
184        rhs: IntTensor<Self>,
185        out_dtype: BoolDType,
186    ) -> BoolTensor<Self> {
187        kernel::greater(lhs, rhs, out_dtype.into())
188    }
189
190    fn int_greater_elem(
191        lhs: IntTensor<Self>,
192        rhs: Scalar,
193        out_dtype: BoolDType,
194    ) -> BoolTensor<Self> {
195        let dtype = lhs.dtype;
196        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
197    }
198
199    fn int_greater_equal(
200        lhs: IntTensor<Self>,
201        rhs: IntTensor<Self>,
202        out_dtype: BoolDType,
203    ) -> BoolTensor<Self> {
204        kernel::greater_equal(lhs, rhs, out_dtype.into())
205    }
206
207    fn int_greater_equal_elem(
208        lhs: IntTensor<Self>,
209        rhs: Scalar,
210        out_dtype: BoolDType,
211    ) -> BoolTensor<Self> {
212        let dtype = lhs.dtype;
213        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
214    }
215
216    fn int_lower(
217        lhs: IntTensor<Self>,
218        rhs: IntTensor<Self>,
219        out_dtype: BoolDType,
220    ) -> BoolTensor<Self> {
221        kernel::lower(lhs, rhs, out_dtype.into())
222    }
223
224    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
225        let dtype = lhs.dtype;
226        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
227    }
228
229    fn int_lower_equal(
230        lhs: IntTensor<Self>,
231        rhs: IntTensor<Self>,
232        out_dtype: BoolDType,
233    ) -> BoolTensor<Self> {
234        kernel::lower_equal(lhs, rhs, out_dtype.into())
235    }
236
237    fn int_lower_equal_elem(
238        lhs: IntTensor<Self>,
239        rhs: Scalar,
240        out_dtype: BoolDType,
241    ) -> BoolTensor<Self> {
242        let dtype = lhs.dtype;
243        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
244    }
245
246    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
247        numeric::add(lhs, rhs)
248    }
249
250    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
251        let dtype = lhs.dtype;
252        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
253    }
254
255    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
256        numeric::sub(lhs, rhs)
257    }
258
259    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
260        let dtype = lhs.dtype;
261        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
262    }
263
264    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
265        numeric::mul(lhs, rhs)
266    }
267
268    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
269        let dtype = lhs.dtype;
270        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
271    }
272
273    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
274        numeric::div(lhs, rhs)
275    }
276
277    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
278        let dtype = lhs.dtype;
279        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
280    }
281
282    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
283        numeric::remainder(lhs, rhs)
284    }
285
286    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
287        let dtype = lhs.dtype;
288        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
289    }
290
291    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
292        let dtype = dtype.into();
293        numeric::zeros(device.clone(), shape, dtype)
294    }
295
296    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
297        let dtype = dtype.into();
298        numeric::ones(device.clone(), shape, dtype)
299    }
300
301    fn int_full(
302        shape: Shape,
303        fill_value: Scalar,
304        device: &Device<Self>,
305        dtype: IntDType,
306    ) -> IntTensor<Self> {
307        let dtype: DType = dtype.into();
308        let client = R::client(device);
309        numeric::full_device_dtype(
310            client,
311            shape,
312            device.clone(),
313            InputScalar::new(fill_value, dtype),
314            dtype,
315        )
316    }
317
318    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
319        reduce::sum_fallback(tensor, Default::default()).unwrap()
320    }
321
322    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
323        reduce::reduce_dim(
324            tensor,
325            None,
326            dim,
327            Default::default(),
328            ReduceOperationConfig::Sum,
329        )
330        .unwrap()
331    }
332
333    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
334        reduce::reduce(
335            tensor,
336            None,
337            Default::default(),
338            ReduceOperationConfig::Prod,
339        )
340        .unwrap()
341    }
342
343    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
344        reduce::reduce_dim(
345            tensor,
346            None,
347            dim,
348            Default::default(),
349            ReduceOperationConfig::Prod,
350        )
351        .unwrap()
352    }
353
354    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
355        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
356    }
357
358    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
359        reduce::reduce_dim(
360            tensor,
361            None,
362            dim,
363            Default::default(),
364            ReduceOperationConfig::Max,
365        )
366        .unwrap()
367    }
368
369    fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
370        reduce::reduce_dim(
371            tensor,
372            None,
373            dim,
374            Default::default(),
375            ReduceOperationConfig::TopK(k),
376        )
377        .unwrap()
378    }
379
380    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
381        reduce::reduce(
382            tensor,
383            None,
384            Default::default(),
385            ReduceOperationConfig::MaxAbs,
386        )
387        .unwrap()
388    }
389
390    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
391        reduce::reduce_dim(
392            tensor,
393            None,
394            dim,
395            Default::default(),
396            ReduceOperationConfig::MaxAbs,
397        )
398        .unwrap()
399    }
400
401    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
402        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
403    }
404
405    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
406        reduce::reduce_dim(
407            tensor,
408            None,
409            dim,
410            Default::default(),
411            ReduceOperationConfig::Min,
412        )
413        .unwrap()
414    }
415
416    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
417        reduce::reduce_dim(
418            tensor,
419            None,
420            dim,
421            Default::default(),
422            ReduceOperationConfig::Mean,
423        )
424        .unwrap()
425    }
426
427    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
428        numeric::cumsum(tensor, dim)
429    }
430
431    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
432        numeric::cumprod(tensor, dim)
433    }
434
435    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
436        numeric::cummin(tensor, dim)
437    }
438
439    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
440        numeric::cummax(tensor, dim)
441    }
442
443    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
444        let dtype = tensor.dtype;
445        reduce::reduce_dim(
446            tensor,
447            Some(dtype),
448            dim,
449            Default::default(),
450            ReduceOperationConfig::ArgMax,
451        )
452        .unwrap()
453    }
454
455    fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
456        let dtype = tensor.dtype;
457        reduce::reduce_dim(
458            tensor,
459            Some(dtype),
460            dim,
461            Default::default(),
462            ReduceOperationConfig::ArgTopK(k),
463        )
464        .unwrap()
465    }
466
467    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
468        let dtype = tensor.dtype;
469        reduce::reduce_dim(
470            tensor,
471            Some(dtype),
472            dim,
473            Default::default(),
474            ReduceOperationConfig::ArgMin,
475        )
476        .unwrap()
477    }
478
479    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
480        let dtype = tensor.dtype;
481        kernel::clamp(
482            tensor,
483            InputScalar::new(min, dtype),
484            InputScalar::new(max, dtype),
485        )
486    }
487
488    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
489        struct Abs;
490
491        #[cube]
492        impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Abs {
493            type Options = ();
494
495            fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
496                Vector::abs(input)
497            }
498        }
499
500        impl NumericUnaryOpFamily for Abs {
501            type Options = ();
502            type Unary<T: Numeric, N: Size> = Self;
503        }
504
505        launch_unary_numeric::<R, Abs, _>(tensor, |_| ())
506    }
507
508    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
509        unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::Sign)
510    }
511
512    fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
513        kernel::cast(tensor, out_dtype.into())
514    }
515
516    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
517        tensor.meta.swap(dim1, dim2);
518
519        tensor
520    }
521
522    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
523        kernel::repeat_dim(tensor, dim, times)
524    }
525
526    fn int_random(
527        shape: Shape,
528        distribution: Distribution,
529        device: &Device<Self>,
530        dtype: IntDType,
531    ) -> IntTensor<Self> {
532        let dtype = dtype.into();
533        match distribution {
534            Distribution::Default => random_uniform(shape, device, 0., 255., dtype),
535            Distribution::Uniform(low, high) => {
536                random_uniform(shape, device, low.elem(), high.elem(), dtype)
537            }
538            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
539            Distribution::Normal(mean, std) => {
540                random_normal(shape, device, mean.elem(), std.elem(), dtype)
541            }
542        }
543    }
544
545    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
546        permute(tensor, axes)
547    }
548
549    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
550        expand(tensor, shape)
551    }
552
553    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
554        let bool_dtype = get_device_settings::<Self>(&tensor.device).bool_dtype;
555        kernel::flip(tensor, axes, bool_dtype.into())
556    }
557
558    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
559        numeric::bitwise_and(lhs, rhs)
560    }
561
562    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
563        let dtype = lhs.dtype;
564        numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype))
565    }
566
567    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
568        numeric::bitwise_or(lhs, rhs)
569    }
570
571    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
572        let dtype = lhs.dtype;
573        numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype))
574    }
575
576    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
577        numeric::bitwise_xor(lhs, rhs)
578    }
579
580    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
581        let dtype = lhs.dtype;
582        numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype))
583    }
584
585    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
586        unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
587    }
588
589    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
590        launch_binop_int::<R, kernel::BitwiseShlOp>(lhs, rhs)
591    }
592
593    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
594        let dtype = lhs.dtype;
595        launch_scalar_binop_int::<R, BitwiseShlOp>(lhs, InputScalar::new(rhs, dtype))
596    }
597
598    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
599        launch_binop_int::<R, BitwiseShrOp>(lhs, rhs)
600    }
601
602    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
603        let dtype = lhs.dtype;
604        launch_scalar_binop_int::<R, BitwiseShrOp>(lhs, InputScalar::new(rhs, dtype))
605    }
606
607    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
608        kernel::cast(tensor, dtype.into())
609    }
610
611    fn int_unfold(
612        tensor: FloatTensor<Self>,
613        dim: usize,
614        size: usize,
615        step: usize,
616    ) -> FloatTensor<Self> {
617        unfold(tensor, dim, size, step)
618    }
619
620    // TODO
621    // fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
622    //     todo!()
623    // }
624
625    // fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
626    //     todo!()
627    // }
628}