burn_cubecl/ops/
int_ops.rs

1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute};
4use crate::kernel::{
5    BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6    launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel};
9use crate::{
10    element::BoolElement,
11    kernel::prng::{random_bernoulli, random_normal, random_uniform},
12};
13use burn_tensor::DType;
14use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
15use burn_tensor::{Distribution, ElementConversion, Shape, TensorData, ops::IntTensorOps};
16use cubecl::frontend::Numeric;
17use cubecl::prelude::*;
18use cubecl::reduce::instructions::ReduceFnConfig;
19use std::ops::Range;
20
21impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
22where
23    R: CubeRuntime,
24    F: FloatElement,
25    I: IntElement,
26    BT: BoolElement,
27{
28    fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
29        super::empty::<R, I>(shape, device)
30    }
31
32    async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
33        super::into_data::<R, I>(tensor).await
34    }
35
36    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
37        match data.dtype {
38            DType::I64 | DType::I32 | DType::I16 | DType::I8 | DType::U32 => {
39                super::from_data::<R>(data, device)
40            }
41            _ => unimplemented!("Unsupported dtype for `int_from_data`"),
42        }
43    }
44
45    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
46        tensor.device.clone()
47    }
48
49    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
50        super::to_device(tensor, device)
51    }
52
53    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
54        super::reshape(tensor, shape)
55    }
56
57    fn int_slice(tensor: IntTensor<Self>, ranges: &[Range<usize>]) -> IntTensor<Self> {
58        kernel::slice::<R, I>(tensor, ranges)
59    }
60
61    fn int_slice_assign(
62        tensor: IntTensor<Self>,
63        ranges: &[Range<usize>],
64        value: IntTensor<Self>,
65    ) -> IntTensor<Self> {
66        kernel::slice_assign::<R, I>(tensor, ranges, value)
67    }
68
69    fn int_mask_where(
70        tensor: IntTensor<Self>,
71        mask: BoolTensor<Self>,
72        value: IntTensor<Self>,
73    ) -> IntTensor<Self> {
74        kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
75    }
76
77    fn int_mask_fill(
78        tensor: IntTensor<Self>,
79        mask: BoolTensor<Self>,
80        value: IntElem<Self>,
81    ) -> IntTensor<Self> {
82        kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value)
83    }
84
85    fn int_gather(
86        dim: usize,
87        tensor: IntTensor<Self>,
88        indices: IntTensor<Self>,
89    ) -> IntTensor<Self> {
90        kernel::gather::<R, I, I>(dim, tensor, indices)
91    }
92
93    fn int_scatter(
94        dim: usize,
95        tensor: IntTensor<Self>,
96        indices: IntTensor<Self>,
97        value: IntTensor<Self>,
98    ) -> IntTensor<Self> {
99        kernel::scatter::<R, I, I>(dim, tensor, indices, value)
100    }
101
102    fn int_select(
103        tensor: IntTensor<Self>,
104        dim: usize,
105        indices: IntTensor<Self>,
106    ) -> IntTensor<Self> {
107        kernel::select::<R, I, I>(tensor, dim, indices)
108    }
109
110    fn int_select_assign(
111        tensor: IntTensor<Self>,
112        dim: usize,
113        indices: IntTensor<Self>,
114        value: IntTensor<Self>,
115    ) -> IntTensor<Self> {
116        kernel::select_assign::<R, I, I>(tensor, dim, indices, value)
117    }
118
119    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
120        kernel::equal::<R, I, BT>(lhs, rhs)
121    }
122
123    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
124        kernel::equal_elem::<R, I, BT>(lhs, rhs)
125    }
126
127    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
128        kernel::greater::<R, I, BT>(lhs, rhs)
129    }
130
131    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
132        kernel::greater_elem::<R, I, BT>(lhs, rhs)
133    }
134
135    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
136        kernel::greater_equal::<R, I, BT>(lhs, rhs)
137    }
138
139    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
140        kernel::greater_equal_elem::<R, I, BT>(lhs, rhs)
141    }
142
143    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
144        kernel::lower::<R, I, BT>(lhs, rhs)
145    }
146
147    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
148        kernel::lower_elem::<R, I, BT>(lhs, rhs)
149    }
150
151    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
152        kernel::lower_equal::<R, I, BT>(lhs, rhs)
153    }
154
155    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
156        kernel::lower_equal_elem::<R, I, BT>(lhs, rhs)
157    }
158
159    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
160        numeric::add::<R, I>(lhs, rhs)
161    }
162
163    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
164        numeric::add_scalar::<R, I>(lhs, rhs)
165    }
166
167    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
168        numeric::sub::<R, I>(lhs, rhs)
169    }
170
171    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
172        numeric::sub_scalar::<R, I>(lhs, rhs)
173    }
174
175    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
176        numeric::mul::<R, I>(lhs, rhs)
177    }
178
179    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
180        numeric::mul_scalar::<R, I>(lhs, rhs)
181    }
182
183    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
184        numeric::div::<R, I>(lhs, rhs)
185    }
186
187    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
188        numeric::div_scalar::<R, I>(lhs, rhs)
189    }
190
191    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
192        numeric::remainder::<R, I>(lhs, rhs)
193    }
194
195    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
196        numeric::remainder_scalar::<R, I>(lhs, rhs)
197    }
198
199    fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
200        numeric::zeros::<R, I>(shape, device)
201    }
202
203    fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
204        numeric::ones::<R, I>(shape, device)
205    }
206
207    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
208        reduce::sum::<R, I>(tensor, Default::default()).unwrap()
209    }
210
211    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
212        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Sum).unwrap()
213    }
214
215    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
216        reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Prod).unwrap()
217    }
218
219    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
220        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Prod)
221            .unwrap()
222    }
223
224    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
225        reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
226    }
227
228    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
229        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Max).unwrap()
230    }
231
232    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
233        reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::MaxAbs).unwrap()
234    }
235
236    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
237        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::MaxAbs)
238            .unwrap()
239    }
240
241    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
242        reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
243    }
244
245    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
246        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Min).unwrap()
247    }
248
249    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
250        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Mean)
251            .unwrap()
252    }
253
254    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
255        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMax)
256            .unwrap()
257    }
258
259    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
260        reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMin)
261            .unwrap()
262    }
263
264    fn int_clamp(
265        tensor: IntTensor<Self>,
266        min: IntElem<Self>,
267        max: IntElem<Self>,
268    ) -> IntTensor<Self> {
269        kernel::clamp::<R, I>(tensor, min, max)
270    }
271
272    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
273        struct Abs;
274
275        #[cube]
276        impl<N: Numeric> NumericUnaryOp<N> for Abs {
277            type Options = ();
278
279            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
280                Line::abs(input)
281            }
282        }
283
284        impl NumericUnaryOpFamily for Abs {
285            type Options<N: Numeric> = ();
286            type Unary<N: Numeric> = Self;
287        }
288
289        launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
290    }
291
292    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
293        kernel::cast::<R, I, F>(tensor)
294    }
295
296    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
297        tensor.strides.swap(dim1, dim2);
298        tensor.shape.dims.swap(dim1, dim2);
299
300        tensor
301    }
302
303    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
304        kernel::repeat_dim::<R, I>(tensor, dim, times)
305    }
306
307    fn int_random(
308        shape: Shape,
309        distribution: Distribution,
310        device: &Device<Self>,
311    ) -> IntTensor<Self> {
312        let float_tensor = match distribution {
313            Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 255.elem()),
314            Distribution::Uniform(low, high) => {
315                random_uniform(shape, device, low.elem::<F>(), high.elem())
316            }
317            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
318            Distribution::Normal(mean, std) => {
319                random_normal(shape, device, mean.elem::<F>(), std.elem())
320            }
321        };
322
323        kernel::cast::<R, F, I>(float_tensor)
324    }
325
326    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
327        permute(tensor, axes)
328    }
329
330    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
331        expand(tensor, shape)
332    }
333
334    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
335        kernel::flip::<R, I, BT>(tensor, axes)
336    }
337
338    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
339        numeric::bitwise_and::<R, I>(lhs, rhs)
340    }
341
342    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
343        numeric::bitwise_and_scalar::<R, I>(lhs, rhs)
344    }
345
346    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
347        numeric::bitwise_or::<R, I>(lhs, rhs)
348    }
349
350    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
351        numeric::bitwise_or_scalar(lhs, rhs)
352    }
353
354    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
355        numeric::bitwise_xor::<R, I>(lhs, rhs)
356    }
357
358    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
359        numeric::bitwise_xor_scalar(lhs, rhs)
360    }
361
362    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
363        unary_basic_int::launch::<R, _, I>(tensor, |_| &BasicIntUnaryKind::BitwiseNot)
364    }
365
366    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
367        launch_binop_int::<R, I, kernel::BitwiseShlOp>(lhs, rhs)
368    }
369
370    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
371        launch_scalar_binop_int::<R, I, BitwiseShlOp>(lhs, rhs)
372    }
373
374    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
375        launch_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
376    }
377
378    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
379        launch_scalar_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
380    }
381}