1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute};
4use crate::kernel::{
5 BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6 launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel};
9use crate::{
10 element::BoolElement,
11 kernel::prng::{random_bernoulli, random_normal, random_uniform},
12};
13use burn_tensor::DType;
14use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
15use burn_tensor::{Distribution, ElementConversion, Shape, TensorData, ops::IntTensorOps};
16use cubecl::frontend::Numeric;
17use cubecl::prelude::*;
18use cubecl::reduce::instructions::ReduceFnConfig;
19use std::ops::Range;
20
21impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
22where
23 R: CubeRuntime,
24 F: FloatElement,
25 I: IntElement,
26 BT: BoolElement,
27{
28 fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
29 super::empty::<R, I>(shape, device)
30 }
31
32 async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
33 super::into_data::<R, I>(tensor).await
34 }
35
36 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
37 match data.dtype {
38 DType::I64 | DType::I32 | DType::I16 | DType::I8 | DType::U32 => {
39 super::from_data::<R>(data, device)
40 }
41 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
42 }
43 }
44
45 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
46 tensor.device.clone()
47 }
48
49 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
50 super::to_device(tensor, device)
51 }
52
53 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
54 super::reshape(tensor, shape)
55 }
56
57 fn int_slice(tensor: IntTensor<Self>, ranges: &[Range<usize>]) -> IntTensor<Self> {
58 kernel::slice::<R, I>(tensor, ranges)
59 }
60
61 fn int_slice_assign(
62 tensor: IntTensor<Self>,
63 ranges: &[Range<usize>],
64 value: IntTensor<Self>,
65 ) -> IntTensor<Self> {
66 kernel::slice_assign::<R, I>(tensor, ranges, value)
67 }
68
69 fn int_mask_where(
70 tensor: IntTensor<Self>,
71 mask: BoolTensor<Self>,
72 value: IntTensor<Self>,
73 ) -> IntTensor<Self> {
74 kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
75 }
76
77 fn int_mask_fill(
78 tensor: IntTensor<Self>,
79 mask: BoolTensor<Self>,
80 value: IntElem<Self>,
81 ) -> IntTensor<Self> {
82 kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value)
83 }
84
85 fn int_gather(
86 dim: usize,
87 tensor: IntTensor<Self>,
88 indices: IntTensor<Self>,
89 ) -> IntTensor<Self> {
90 kernel::gather::<R, I, I>(dim, tensor, indices)
91 }
92
93 fn int_scatter(
94 dim: usize,
95 tensor: IntTensor<Self>,
96 indices: IntTensor<Self>,
97 value: IntTensor<Self>,
98 ) -> IntTensor<Self> {
99 kernel::scatter::<R, I, I>(dim, tensor, indices, value)
100 }
101
102 fn int_select(
103 tensor: IntTensor<Self>,
104 dim: usize,
105 indices: IntTensor<Self>,
106 ) -> IntTensor<Self> {
107 kernel::select::<R, I, I>(tensor, dim, indices)
108 }
109
110 fn int_select_assign(
111 tensor: IntTensor<Self>,
112 dim: usize,
113 indices: IntTensor<Self>,
114 value: IntTensor<Self>,
115 ) -> IntTensor<Self> {
116 kernel::select_assign::<R, I, I>(tensor, dim, indices, value)
117 }
118
119 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
120 kernel::equal::<R, I, BT>(lhs, rhs)
121 }
122
123 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
124 kernel::equal_elem::<R, I, BT>(lhs, rhs)
125 }
126
127 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
128 kernel::greater::<R, I, BT>(lhs, rhs)
129 }
130
131 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
132 kernel::greater_elem::<R, I, BT>(lhs, rhs)
133 }
134
135 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
136 kernel::greater_equal::<R, I, BT>(lhs, rhs)
137 }
138
139 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
140 kernel::greater_equal_elem::<R, I, BT>(lhs, rhs)
141 }
142
143 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
144 kernel::lower::<R, I, BT>(lhs, rhs)
145 }
146
147 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
148 kernel::lower_elem::<R, I, BT>(lhs, rhs)
149 }
150
151 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
152 kernel::lower_equal::<R, I, BT>(lhs, rhs)
153 }
154
155 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
156 kernel::lower_equal_elem::<R, I, BT>(lhs, rhs)
157 }
158
159 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
160 numeric::add::<R, I>(lhs, rhs)
161 }
162
163 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
164 numeric::add_scalar::<R, I>(lhs, rhs)
165 }
166
167 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
168 numeric::sub::<R, I>(lhs, rhs)
169 }
170
171 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
172 numeric::sub_scalar::<R, I>(lhs, rhs)
173 }
174
175 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
176 numeric::mul::<R, I>(lhs, rhs)
177 }
178
179 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
180 numeric::mul_scalar::<R, I>(lhs, rhs)
181 }
182
183 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
184 numeric::div::<R, I>(lhs, rhs)
185 }
186
187 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
188 numeric::div_scalar::<R, I>(lhs, rhs)
189 }
190
191 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
192 numeric::remainder::<R, I>(lhs, rhs)
193 }
194
195 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
196 numeric::remainder_scalar::<R, I>(lhs, rhs)
197 }
198
199 fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
200 numeric::zeros::<R, I>(shape, device)
201 }
202
203 fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
204 numeric::ones::<R, I>(shape, device)
205 }
206
207 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
208 reduce::sum::<R, I>(tensor, Default::default()).unwrap()
209 }
210
211 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
212 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Sum).unwrap()
213 }
214
215 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
216 reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Prod).unwrap()
217 }
218
219 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
220 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Prod)
221 .unwrap()
222 }
223
224 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
225 reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Max).unwrap()
226 }
227
228 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
229 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Max).unwrap()
230 }
231
232 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
233 reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::MaxAbs).unwrap()
234 }
235
236 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
237 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::MaxAbs)
238 .unwrap()
239 }
240
241 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
242 reduce::reduce::<R, I, I>(tensor, Default::default(), ReduceFnConfig::Min).unwrap()
243 }
244
245 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
246 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Min).unwrap()
247 }
248
249 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
250 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::Mean)
251 .unwrap()
252 }
253
254 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
255 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMax)
256 .unwrap()
257 }
258
259 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
260 reduce::reduce_dim::<R, I, I>(tensor, dim, Default::default(), ReduceFnConfig::ArgMin)
261 .unwrap()
262 }
263
264 fn int_clamp(
265 tensor: IntTensor<Self>,
266 min: IntElem<Self>,
267 max: IntElem<Self>,
268 ) -> IntTensor<Self> {
269 kernel::clamp::<R, I>(tensor, min, max)
270 }
271
272 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
273 struct Abs;
274
275 #[cube]
276 impl<N: Numeric> NumericUnaryOp<N> for Abs {
277 type Options = ();
278
279 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
280 Line::abs(input)
281 }
282 }
283
284 impl NumericUnaryOpFamily for Abs {
285 type Options<N: Numeric> = ();
286 type Unary<N: Numeric> = Self;
287 }
288
289 launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
290 }
291
292 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
293 kernel::cast::<R, I, F>(tensor)
294 }
295
296 fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
297 tensor.strides.swap(dim1, dim2);
298 tensor.shape.dims.swap(dim1, dim2);
299
300 tensor
301 }
302
303 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
304 kernel::repeat_dim::<R, I>(tensor, dim, times)
305 }
306
307 fn int_random(
308 shape: Shape,
309 distribution: Distribution,
310 device: &Device<Self>,
311 ) -> IntTensor<Self> {
312 let float_tensor = match distribution {
313 Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 255.elem()),
314 Distribution::Uniform(low, high) => {
315 random_uniform(shape, device, low.elem::<F>(), high.elem())
316 }
317 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
318 Distribution::Normal(mean, std) => {
319 random_normal(shape, device, mean.elem::<F>(), std.elem())
320 }
321 };
322
323 kernel::cast::<R, F, I>(float_tensor)
324 }
325
326 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
327 permute(tensor, axes)
328 }
329
330 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
331 expand(tensor, shape)
332 }
333
334 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
335 kernel::flip::<R, I, BT>(tensor, axes)
336 }
337
338 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
339 numeric::bitwise_and::<R, I>(lhs, rhs)
340 }
341
342 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
343 numeric::bitwise_and_scalar::<R, I>(lhs, rhs)
344 }
345
346 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
347 numeric::bitwise_or::<R, I>(lhs, rhs)
348 }
349
350 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
351 numeric::bitwise_or_scalar(lhs, rhs)
352 }
353
354 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
355 numeric::bitwise_xor::<R, I>(lhs, rhs)
356 }
357
358 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
359 numeric::bitwise_xor_scalar(lhs, rhs)
360 }
361
362 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
363 unary_basic_int::launch::<R, _, I>(tensor, |_| &BasicIntUnaryKind::BitwiseNot)
364 }
365
366 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
367 launch_binop_int::<R, I, kernel::BitwiseShlOp>(lhs, rhs)
368 }
369
370 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
371 launch_scalar_binop_int::<R, I, BitwiseShlOp>(lhs, rhs)
372 }
373
374 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
375 launch_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
376 }
377
378 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
379 launch_scalar_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
380 }
381}