1use self::unary_basic_int::BasicIntUnaryKind;
2
3use super::{expand, numeric, permute, unfold};
4use crate::kernel::{
5 BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,
6 launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
7};
8use crate::{
9 CubeBackend, CubeRuntime, FloatElement, IntElement,
10 kernel::{
11 self,
12 matmul::{MatmulStrategy, matmul},
13 },
14};
15use crate::{
16 element::BoolElement,
17 kernel::prng::{random_bernoulli, random_normal, random_uniform},
18};
19use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
20use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps};
21use burn_backend::{Distribution, ElementConversion, Shape, TensorData, get_device_settings};
22use burn_backend::{ExecutionError, Scalar};
23use burn_std::{BoolDType, FloatDType};
24use cubecl::frontend::Numeric;
25use cubecl::prelude::*;
26use cubek::reduce::components::instructions::ReduceOperationConfig;
27use std::ops::Range;
28
29impl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>
30where
31 R: CubeRuntime,
32 F: FloatElement,
33 I: IntElement,
34 BT: BoolElement,
35{
36 fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
37 let dtype = dtype.into();
38 super::empty(shape, device, dtype)
39 }
40
41 async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {
42 super::into_data(tensor).await
43 }
44
45 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
46 match data.dtype {
47 DType::I64
48 | DType::I32
49 | DType::I16
50 | DType::I8
51 | DType::U64
52 | DType::U32
53 | DType::U16
54 | DType::U8 => super::from_data(data, device),
55 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
56 }
57 }
58
59 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
60 tensor.device.clone()
61 }
62
63 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
64 super::to_device(tensor, device)
65 }
66
67 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
68 super::reshape(tensor, shape)
69 }
70
71 fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {
72 let all_steps_one = slices.iter().all(|info| info.step == 1);
74
75 if all_steps_one {
76 let simple_ranges: Vec<Range<usize>> = slices
78 .iter()
79 .enumerate()
80 .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
81 .collect();
82
83 kernel::slice(tensor, &simple_ranges)
84 } else {
85 kernel::slice_with_steps(tensor, slices)
87 }
88 }
89
90 fn int_slice_assign(
91 tensor: IntTensor<Self>,
92 ranges: &[Slice],
93 value: IntTensor<Self>,
94 ) -> IntTensor<Self> {
95 kernel::slice_assign(tensor, ranges, value)
96 }
97
98 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
99 let dtype = lhs.dtype;
100 matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()
101 }
102
103 fn int_mask_where(
104 tensor: IntTensor<Self>,
105 mask: BoolTensor<Self>,
106 value: IntTensor<Self>,
107 ) -> IntTensor<Self> {
108 let bool_dtype = mask.dtype;
109 kernel::mask_where_auto(tensor, mask, value, bool_dtype)
110 }
111
112 fn int_mask_fill(
113 tensor: IntTensor<Self>,
114 mask: BoolTensor<Self>,
115 value: Scalar,
116 ) -> IntTensor<Self> {
117 let dtype = tensor.dtype;
118 let bool_dtype = mask.dtype;
119 kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), bool_dtype)
120 }
121
122 fn int_gather(
123 dim: usize,
124 tensor: IntTensor<Self>,
125 indices: IntTensor<Self>,
126 ) -> IntTensor<Self> {
127 kernel::gather(dim, tensor, indices)
128 }
129
130 fn int_scatter_add(
131 dim: usize,
132 tensor: IntTensor<Self>,
133 indices: IntTensor<Self>,
134 value: IntTensor<Self>,
135 ) -> IntTensor<Self> {
136 kernel::scatter(dim, tensor, indices, value, false)
137 }
138
139 fn int_scatter_nd(
140 data: IntTensor<Self>,
141 indices: IntTensor<Self>,
142 values: IntTensor<Self>,
143 reduction: burn_backend::tensor::IndexingUpdateOp,
144 ) -> IntTensor<Self> {
145 kernel::scatter_nd(data, indices, values, reduction)
146 }
147
148 fn int_gather_nd(data: IntTensor<Self>, indices: IntTensor<Self>) -> IntTensor<Self> {
149 kernel::gather_nd(data, indices)
150 }
151
152 fn int_select(
153 tensor: IntTensor<Self>,
154 dim: usize,
155 indices: IntTensor<Self>,
156 ) -> IntTensor<Self> {
157 kernel::select(tensor, dim, indices)
158 }
159
160 fn int_select_add(
161 tensor: IntTensor<Self>,
162 dim: usize,
163 indices: IntTensor<Self>,
164 value: IntTensor<Self>,
165 ) -> IntTensor<Self> {
166 kernel::select_assign(tensor, dim, indices, value, false)
167 }
168
169 fn int_equal(
170 lhs: IntTensor<Self>,
171 rhs: IntTensor<Self>,
172 out_dtype: BoolDType,
173 ) -> BoolTensor<Self> {
174 kernel::equal(lhs, rhs, out_dtype.into())
175 }
176
177 fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
178 let dtype = lhs.dtype;
179 kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
180 }
181
182 fn int_greater(
183 lhs: IntTensor<Self>,
184 rhs: IntTensor<Self>,
185 out_dtype: BoolDType,
186 ) -> BoolTensor<Self> {
187 kernel::greater(lhs, rhs, out_dtype.into())
188 }
189
190 fn int_greater_elem(
191 lhs: IntTensor<Self>,
192 rhs: Scalar,
193 out_dtype: BoolDType,
194 ) -> BoolTensor<Self> {
195 let dtype = lhs.dtype;
196 kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
197 }
198
199 fn int_greater_equal(
200 lhs: IntTensor<Self>,
201 rhs: IntTensor<Self>,
202 out_dtype: BoolDType,
203 ) -> BoolTensor<Self> {
204 kernel::greater_equal(lhs, rhs, out_dtype.into())
205 }
206
207 fn int_greater_equal_elem(
208 lhs: IntTensor<Self>,
209 rhs: Scalar,
210 out_dtype: BoolDType,
211 ) -> BoolTensor<Self> {
212 let dtype = lhs.dtype;
213 kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
214 }
215
216 fn int_lower(
217 lhs: IntTensor<Self>,
218 rhs: IntTensor<Self>,
219 out_dtype: BoolDType,
220 ) -> BoolTensor<Self> {
221 kernel::lower(lhs, rhs, out_dtype.into())
222 }
223
224 fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar, out_dtype: BoolDType) -> BoolTensor<Self> {
225 let dtype = lhs.dtype;
226 kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
227 }
228
229 fn int_lower_equal(
230 lhs: IntTensor<Self>,
231 rhs: IntTensor<Self>,
232 out_dtype: BoolDType,
233 ) -> BoolTensor<Self> {
234 kernel::lower_equal(lhs, rhs, out_dtype.into())
235 }
236
237 fn int_lower_equal_elem(
238 lhs: IntTensor<Self>,
239 rhs: Scalar,
240 out_dtype: BoolDType,
241 ) -> BoolTensor<Self> {
242 let dtype = lhs.dtype;
243 kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), out_dtype.into())
244 }
245
246 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
247 numeric::add(lhs, rhs)
248 }
249
250 fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
251 let dtype = lhs.dtype;
252 numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))
253 }
254
255 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
256 numeric::sub(lhs, rhs)
257 }
258
259 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
260 let dtype = lhs.dtype;
261 numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))
262 }
263
264 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
265 numeric::mul(lhs, rhs)
266 }
267
268 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
269 let dtype = lhs.dtype;
270 numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))
271 }
272
273 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
274 numeric::div(lhs, rhs)
275 }
276
277 fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
278 let dtype = lhs.dtype;
279 numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))
280 }
281
282 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
283 numeric::remainder(lhs, rhs)
284 }
285
286 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
287 let dtype = lhs.dtype;
288 numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))
289 }
290
291 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
292 let dtype = dtype.into();
293 numeric::zeros(device.clone(), shape, dtype)
294 }
295
296 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
297 let dtype = dtype.into();
298 numeric::ones(device.clone(), shape, dtype)
299 }
300
301 fn int_full(
302 shape: Shape,
303 fill_value: Scalar,
304 device: &Device<Self>,
305 dtype: IntDType,
306 ) -> IntTensor<Self> {
307 let dtype: DType = dtype.into();
308 let client = R::client(device);
309 numeric::full_device_dtype(
310 client,
311 shape,
312 device.clone(),
313 InputScalar::new(fill_value, dtype),
314 dtype,
315 )
316 }
317
318 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
319 reduce::sum_fallback(tensor, Default::default()).unwrap()
320 }
321
322 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
323 reduce::reduce_dim(
324 tensor,
325 None,
326 dim,
327 Default::default(),
328 ReduceOperationConfig::Sum,
329 )
330 .unwrap()
331 }
332
333 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
334 reduce::reduce(
335 tensor,
336 None,
337 Default::default(),
338 ReduceOperationConfig::Prod,
339 )
340 .unwrap()
341 }
342
343 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
344 reduce::reduce_dim(
345 tensor,
346 None,
347 dim,
348 Default::default(),
349 ReduceOperationConfig::Prod,
350 )
351 .unwrap()
352 }
353
354 fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {
355 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()
356 }
357
358 fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
359 reduce::reduce_dim(
360 tensor,
361 None,
362 dim,
363 Default::default(),
364 ReduceOperationConfig::Max,
365 )
366 .unwrap()
367 }
368
369 fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
370 reduce::reduce_dim(
371 tensor,
372 None,
373 dim,
374 Default::default(),
375 ReduceOperationConfig::TopK(k),
376 )
377 .unwrap()
378 }
379
380 fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
381 reduce::reduce(
382 tensor,
383 None,
384 Default::default(),
385 ReduceOperationConfig::MaxAbs,
386 )
387 .unwrap()
388 }
389
390 fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
391 reduce::reduce_dim(
392 tensor,
393 None,
394 dim,
395 Default::default(),
396 ReduceOperationConfig::MaxAbs,
397 )
398 .unwrap()
399 }
400
401 fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {
402 reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()
403 }
404
405 fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
406 reduce::reduce_dim(
407 tensor,
408 None,
409 dim,
410 Default::default(),
411 ReduceOperationConfig::Min,
412 )
413 .unwrap()
414 }
415
416 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
417 reduce::reduce_dim(
418 tensor,
419 None,
420 dim,
421 Default::default(),
422 ReduceOperationConfig::Mean,
423 )
424 .unwrap()
425 }
426
427 fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
428 numeric::cumsum(tensor, dim)
429 }
430
431 fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
432 numeric::cumprod(tensor, dim)
433 }
434
435 fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
436 numeric::cummin(tensor, dim)
437 }
438
439 fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
440 numeric::cummax(tensor, dim)
441 }
442
443 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
444 let dtype = tensor.dtype;
445 reduce::reduce_dim(
446 tensor,
447 Some(dtype),
448 dim,
449 Default::default(),
450 ReduceOperationConfig::ArgMax,
451 )
452 .unwrap()
453 }
454
455 fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
456 let dtype = tensor.dtype;
457 reduce::reduce_dim(
458 tensor,
459 Some(dtype),
460 dim,
461 Default::default(),
462 ReduceOperationConfig::ArgTopK(k),
463 )
464 .unwrap()
465 }
466
467 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
468 let dtype = tensor.dtype;
469 reduce::reduce_dim(
470 tensor,
471 Some(dtype),
472 dim,
473 Default::default(),
474 ReduceOperationConfig::ArgMin,
475 )
476 .unwrap()
477 }
478
479 fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {
480 let dtype = tensor.dtype;
481 kernel::clamp(
482 tensor,
483 InputScalar::new(min, dtype),
484 InputScalar::new(max, dtype),
485 )
486 }
487
488 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
489 struct Abs;
490
491 #[cube]
492 impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Abs {
493 type Options = ();
494
495 fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
496 Vector::abs(input)
497 }
498 }
499
500 impl NumericUnaryOpFamily for Abs {
501 type Options = ();
502 type Unary<T: Numeric, N: Size> = Self;
503 }
504
505 launch_unary_numeric::<R, Abs, _>(tensor, |_| ())
506 }
507
508 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
509 unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::Sign)
510 }
511
512 fn int_into_float(tensor: IntTensor<Self>, out_dtype: FloatDType) -> FloatTensor<Self> {
513 kernel::cast(tensor, out_dtype.into())
514 }
515
516 fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
517 tensor.meta.swap(dim1, dim2);
518
519 tensor
520 }
521
522 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
523 kernel::repeat_dim(tensor, dim, times)
524 }
525
526 fn int_random(
527 shape: Shape,
528 distribution: Distribution,
529 device: &Device<Self>,
530 dtype: IntDType,
531 ) -> IntTensor<Self> {
532 let dtype = dtype.into();
533 match distribution {
534 Distribution::Default => random_uniform(shape, device, 0., 255., dtype),
535 Distribution::Uniform(low, high) => {
536 random_uniform(shape, device, low.elem(), high.elem(), dtype)
537 }
538 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),
539 Distribution::Normal(mean, std) => {
540 random_normal(shape, device, mean.elem(), std.elem(), dtype)
541 }
542 }
543 }
544
545 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
546 permute(tensor, axes)
547 }
548
549 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
550 expand(tensor, shape)
551 }
552
553 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
554 let bool_dtype = get_device_settings::<Self>(&tensor.device).bool_dtype;
555 kernel::flip(tensor, axes, bool_dtype.into())
556 }
557
558 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
559 numeric::bitwise_and(lhs, rhs)
560 }
561
562 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
563 let dtype = lhs.dtype;
564 numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype))
565 }
566
567 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
568 numeric::bitwise_or(lhs, rhs)
569 }
570
571 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
572 let dtype = lhs.dtype;
573 numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype))
574 }
575
576 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
577 numeric::bitwise_xor(lhs, rhs)
578 }
579
580 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
581 let dtype = lhs.dtype;
582 numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype))
583 }
584
585 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
586 unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::BitwiseNot)
587 }
588
589 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
590 launch_binop_int::<R, kernel::BitwiseShlOp>(lhs, rhs)
591 }
592
593 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
594 let dtype = lhs.dtype;
595 launch_scalar_binop_int::<R, BitwiseShlOp>(lhs, InputScalar::new(rhs, dtype))
596 }
597
598 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
599 launch_binop_int::<R, BitwiseShrOp>(lhs, rhs)
600 }
601
602 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
603 let dtype = lhs.dtype;
604 launch_scalar_binop_int::<R, BitwiseShrOp>(lhs, InputScalar::new(rhs, dtype))
605 }
606
607 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
608 kernel::cast(tensor, dtype.into())
609 }
610
611 fn int_unfold(
612 tensor: FloatTensor<Self>,
613 dim: usize,
614 size: usize,
615 step: usize,
616 ) -> FloatTensor<Self> {
617 unfold(tensor, dim, size, step)
618 }
619
620 }