1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::kernel::{
5 BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6 launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{
9 CubeBackend, CubeRuntime, FloatElement, IntElement,
10 kernel::{
11 self,
12 matmul::{MatmulStrategy, matmul},
13 },
14};
15use crate::{
16 element::BoolElement,
17 kernel::prng::{random_bernoulli, random_normal, random_uniform},
18};
19use burn_backend::ExecutionError;
20use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
21use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps};
22use burn_backend::{Distribution, ElementConversion, Shape, TensorData};
23use cubecl::frontend::Numeric;
24use cubecl::prelude::*;
25use cubek::reduce::components::instructions::ReduceOperationConfig;
26use std::ops::Range;
27
28impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
29where
30 R: CubeRuntime,
31 F: FloatElement,
32 I: IntElement,
33 BT: BoolElement,
34{
35 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
36 let dtype = dtype.into();
37 super::empty(shape, device, dtype)
38 }
39
40 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
41 super::into_data(tensor).await
42 }
43
44 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
45 match data.dtype {
46 DType::I64
47 | DType::I32
48 | DType::I16
49 | DType::I8
50 | DType::U64
51 | DType::U32
52 | DType::U16
53 | DType::U8 => super::from_data(data, device),
54 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
55 }
56 }
57
58 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
59 tensor.device.clone()
60 }
61
62 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
63 super::to_device(tensor, device)
64 }
65
66 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
67 super::reshape(tensor, shape)
68 }
69
70 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
71 let all_steps_one = slices.iter().all(|info| info.step == 1);
73
74 if all_steps_one {
75 let simple_ranges: Vec<Range<usize>> = slices
77 .iter()
78 .enumerate()
79 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
80 .collect();
81
82 kernel::slice(tensor, &simple_ranges)
83 } else {
84 kernel::slice_with_steps(tensor, slices)
86 }
87 }
88
89 fn int_slice_assign(
90 tensor: IntTensor<Self>,
91 ranges: &[Slice],
92 value: IntTensor<Self>,
93 ) -> IntTensor<Self> {
94 kernel::slice_assign(tensor, ranges, value)
95 }
96
97 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
98 let dtype = lhs.dtype;
99 matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
100 }
101
102 fn int_mask_where(
103 tensor: IntTensor<Self>,
104 mask: BoolTensor<Self>,
105 value: IntTensor<Self>,
106 ) -> IntTensor<Self> {
107 kernel::mask_where_auto(tensor, mask, value, BT::dtype())
108 }
109
110 fn int_mask_fill(
111 tensor: IntTensor<Self>,
112 mask: BoolTensor<Self>,
113 value: IntElem<Self>,
114 ) -> IntTensor<Self> {
115 let dtype = tensor.dtype;
116 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), BT::dtype())
117 }
118
119 fn int_gather(
120 dim: usize,
121 tensor: IntTensor<Self>,
122 indices: IntTensor<Self>,
123 ) -> IntTensor<Self> {
124 kernel::gather(dim, tensor, indices)
125 }
126
127 fn int_scatter_add(
128 dim: usize,
129 tensor: IntTensor<Self>,
130 indices: IntTensor<Self>,
131 value: IntTensor<Self>,
132 ) -> IntTensor<Self> {
133 kernel::scatter(dim, tensor, indices, value, false)
134 }
135
136 fn int_select(
137 tensor: IntTensor<Self>,
138 dim: usize,
139 indices: IntTensor<Self>,
140 ) -> IntTensor<Self> {
141 kernel::select(tensor, dim, indices)
142 }
143
144 fn int_select_add(
145 tensor: IntTensor<Self>,
146 dim: usize,
147 indices: IntTensor<Self>,
148 value: IntTensor<Self>,
149 ) -> IntTensor<Self> {
150 kernel::select_assign(tensor, dim, indices, value, false)
151 }
152
153 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
154 kernel::equal(lhs, rhs, BT::dtype())
155 }
156
157 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
158 let dtype = lhs.dtype;
159 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
160 }
161
162 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
163 kernel::greater(lhs, rhs, BT::dtype())
164 }
165
166 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
167 let dtype = lhs.dtype;
168 kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
169 }
170
171 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
172 kernel::greater_equal(lhs, rhs, BT::dtype())
173 }
174
175 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
176 let dtype = lhs.dtype;
177 kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
178 }
179
180 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
181 kernel::lower(lhs, rhs, BT::dtype())
182 }
183
184 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
185 let dtype = lhs.dtype;
186 kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
187 }
188
189 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
190 kernel::lower_equal(lhs, rhs, BT::dtype())
191 }
192
193 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
194 let dtype = lhs.dtype;
195 kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), BT::dtype())
196 }
197
198 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
199 numeric::add(lhs, rhs)
200 }
201
202 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
203 let dtype = lhs.dtype;
204 numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
205 }
206
207 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
208 numeric::sub(lhs, rhs)
209 }
210
211 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
212 let dtype = lhs.dtype;
213 numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
214 }
215
216 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
217 numeric::mul(lhs, rhs)
218 }
219
220 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
221 let dtype = lhs.dtype;
222 numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
223 }
224
225 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
226 numeric::div(lhs, rhs)
227 }
228
229 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
230 let dtype = lhs.dtype;
231 numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
232 }
233
234 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
235 numeric::remainder(lhs, rhs)
236 }
237
238 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
239 let dtype = lhs.dtype;
240 numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
241 }
242
243 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
244 let dtype = dtype.into();
245 numeric::zeros(device.clone(), shape, dtype)
246 }
247
248 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
249 let dtype = dtype.into();
250 numeric::ones(device.clone(), shape, dtype)
251 }
252
253 fn int_full(
254 shape: Shape,
255 fill_value: IntElem<Self>,
256 device: &Device<Self>,
257 dtype: IntDType,
258 ) -> IntTensor<Self> {
259 let dtype: DType = dtype.into();
260 let client = R::client(device);
261 numeric::full_device_dtype(
262 client,
263 shape,
264 device.clone(),
265 InputScalar::new(fill_value, dtype),
266 dtype,
267 )
268 }
269
270 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
271 reduce::sum_fallback(tensor, Default::default()).unwrap()
272 }
273
274 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
275 reduce::reduce_dim(
276 tensor,
277 None,
278 dim,
279 Default::default(),
280 ReduceOperationConfig::Sum,
281 )
282 .unwrap()
283 }
284
285 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
286 reduce::reduce(
287 tensor,
288 None,
289 Default::default(),
290 ReduceOperationConfig::Prod,
291 )
292 .unwrap()
293 }
294
295 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
296 reduce::reduce_dim(
297 tensor,
298 None,
299 dim,
300 Default::default(),
301 ReduceOperationConfig::Prod,
302 )
303 .unwrap()
304 }
305
306 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
307 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
308 }
309
310 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
311 reduce::reduce_dim(
312 tensor,
313 None,
314 dim,
315 Default::default(),
316 ReduceOperationConfig::Max,
317 )
318 .unwrap()
319 }
320
321 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
322 reduce::reduce(
323 tensor,
324 None,
325 Default::default(),
326 ReduceOperationConfig::MaxAbs,
327 )
328 .unwrap()
329 }
330
331 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
332 reduce::reduce_dim(
333 tensor,
334 None,
335 dim,
336 Default::default(),
337 ReduceOperationConfig::MaxAbs,
338 )
339 .unwrap()
340 }
341
342 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
343 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
344 }
345
346 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
347 reduce::reduce_dim(
348 tensor,
349 None,
350 dim,
351 Default::default(),
352 ReduceOperationConfig::Min,
353 )
354 .unwrap()
355 }
356
357 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
358 reduce::reduce_dim(
359 tensor,
360 None,
361 dim,
362 Default::default(),
363 ReduceOperationConfig::Mean,
364 )
365 .unwrap()
366 }
367
368 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
369 numeric::cumsum(tensor, dim)
370 }
371
372 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
373 numeric::cumprod(tensor, dim)
374 }
375
376 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
377 numeric::cummin(tensor, dim)
378 }
379
380 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
381 numeric::cummax(tensor, dim)
382 }
383
384 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
385 let dtype = tensor.dtype;
386 reduce::reduce_dim(
387 tensor,
388 Some(dtype),
389 dim,
390 Default::default(),
391 ReduceOperationConfig::ArgMax,
392 )
393 .unwrap()
394 }
395
396 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
397 let dtype = tensor.dtype;
398 reduce::reduce_dim(
399 tensor,
400 Some(dtype),
401 dim,
402 Default::default(),
403 ReduceOperationConfig::ArgMin,
404 )
405 .unwrap()
406 }
407
408 fn int_clamp(
409 tensor: IntTensor<Self>,
410 min: IntElem<Self>,
411 max: IntElem<Self>,
412 ) -> IntTensor<Self> {
413 let dtype = tensor.dtype;
414 kernel::clamp(
415 tensor,
416 InputScalar::new(min, dtype),
417 InputScalar::new(max, dtype),
418 )
419 }
420
421 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
422 struct Abs;
423
424 #[cube]
425 impl<N: Numeric> NumericUnaryOp<N> for Abs {
426 type Options = ();
427
428 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
429 Line::abs(input)
430 }
431 }
432
433 impl NumericUnaryOpFamily for Abs {
434 type Options = ();
435 type Unary<N: Numeric> = Self;
436 }
437
438 launch_unary_numeric::<R, Abs, _>(tensor, |_| ())
439 }
440
441 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
442 kernel::cast(tensor, F::dtype())
443 }
444
445 fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
446 tensor.strides.swap(dim1, dim2);
447 tensor.shape = tensor.shape.swap(dim1, dim2).unwrap();
448
449 tensor
450 }
451
452 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
453 kernel::repeat_dim(tensor, dim, times)
454 }
455
456 fn int_random(
457 shape: Shape,
458 distribution: Distribution,
459 device: &Device<Self>,
460 ) -> IntTensor<Self> {
461 let dtype = IntElem::<Self>::dtype();
462 match distribution {
463 Distribution::Default => random_uniform(shape, device, 0., 255., dtype),
464 Distribution::Uniform(low, high) => {
465 random_uniform(shape, device, low.elem(), high.elem(), dtype)
466 }
467 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
468 Distribution::Normal(mean, std) => {
469 random_normal(shape, device, mean.elem(), std.elem(), dtype)
470 }
471 }
472 }
473
474 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
475 permute(tensor, axes)
476 }
477
478 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
479 expand(tensor, shape)
480 }
481
482 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
483 kernel::flip(tensor, axes, BT::dtype())
484 }
485
486 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
487 numeric::bitwise_and(lhs, rhs)
488 }
489
490 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
491 let dtype = lhs.dtype;
492 numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype))
493 }
494
495 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
496 numeric::bitwise_or(lhs, rhs)
497 }
498
499 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
500 let dtype = lhs.dtype;
501 numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype))
502 }
503
504 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
505 numeric::bitwise_xor(lhs, rhs)
506 }
507
508 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
509 let dtype = lhs.dtype;
510 numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype))
511 }
512
513 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
514 unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
515 }
516
517 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
518 launch_binop_int::<R, kernel::BitwiseShlOp>(lhs, rhs)
519 }
520
521 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
522 let dtype = lhs.dtype;
523 launch_scalar_binop_int::<R, BitwiseShlOp>(lhs, InputScalar::new(rhs, dtype))
524 }
525
526 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
527 launch_binop_int::<R, BitwiseShrOp>(lhs, rhs)
528 }
529
530 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
531 let dtype = lhs.dtype;
532 launch_scalar_binop_int::<R, BitwiseShrOp>(lhs, InputScalar::new(rhs, dtype))
533 }
534
535 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
536 kernel::cast(tensor, dtype.into())
537 }
538
539 fn int_unfold(
540 tensor: FloatTensor<Self>,
541 dim: usize,
542 size: usize,
543 step: usize,
544 ) -> FloatTensor<Self> {
545 unfold(tensor, dim, size, step)
546 }
547}