1use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
8
9use burn_backend::ElementConversion;
10use burn_std::{BoolDType, FloatDType};
11
12use crate::{ExpElement, NdArrayDevice, SEED, execute_with_int_out_dtype, slice};
14use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
15use crate::{SharedArray, element::QuantElement};
16use crate::{cat_with_dtype, execute_with_float_out_dtype};
17use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
18use crate::{element::IntNdArrayElement, execute_with_int_dtype};
19
20use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
22use burn_backend::{DType, Shape, TensorData, backend::Backend};
23
24impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
25 for NdArray<E, I, Q>
26where
27 NdArrayTensor: From<SharedArray<E>>,
28 NdArrayTensor: From<SharedArray<I>>,
29{
30 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
31 if data.dtype.is_int() || data.dtype.is_uint() {
32 NdArrayTensor::from_data(data)
33 } else {
34 unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
35 }
36 }
37
38 async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
39 Ok(tensor.into_data())
40 }
41
42 fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
43 tensor
44 }
45
46 fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
47 execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
48 }
49
50 fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
51 slice!(tensor, slices)
52 }
53
54 fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
55 NdArrayDevice::Cpu
56 }
57
58 fn int_empty(
59 shape: Shape,
60 device: &<NdArray<E> as Backend>::Device,
61 dtype: IntDType,
62 ) -> NdArrayTensor {
63 Self::int_zeros(shape, device, dtype)
64 }
65
66 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
67 execute_with_int_dtype!((lhs, rhs), matmul)
68 }
69
70 fn int_mask_where(
71 tensor: NdArrayTensor,
72 mask: NdArrayTensor,
73 source: NdArrayTensor,
74 ) -> NdArrayTensor {
75 execute_with_int_dtype!((tensor, source), |tensor, source| {
76 NdArrayOps::mask_where(tensor, mask.bool(), source)
77 })
78 }
79
80 fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
81 execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
82 array,
83 mask.bool(),
84 value.elem()
85 ))
86 }
87
88 fn int_slice_assign(
89 tensor: NdArrayTensor,
90 slices: &[burn_backend::Slice],
91 value: NdArrayTensor,
92 ) -> NdArrayTensor {
93 execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
94 tensor, slices, value
95 ))
96 }
97
98 fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
99 cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
100 }
101
102 fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
103 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
104 }
105
106 fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
107 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
108 }
109
110 fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
111 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
112 }
113
114 fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
115 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
116 }
117
118 fn int_greater_equal(
119 lhs: NdArrayTensor,
120 rhs: NdArrayTensor,
121 _out_dtype: BoolDType,
122 ) -> NdArrayTensor {
123 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
124 }
125
126 fn int_greater_equal_elem(
127 lhs: NdArrayTensor,
128 rhs: Scalar,
129 _out_dtype: BoolDType,
130 ) -> NdArrayTensor {
131 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
132 array,
133 rhs.elem()
134 ))
135 }
136
137 fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
138 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
139 }
140
141 fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
142 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
143 }
144
145 fn int_lower_equal(
146 lhs: NdArrayTensor,
147 rhs: NdArrayTensor,
148 _out_dtype: BoolDType,
149 ) -> NdArrayTensor {
150 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
151 }
152
153 fn int_lower_equal_elem(
154 lhs: NdArrayTensor,
155 rhs: Scalar,
156 _out_dtype: BoolDType,
157 ) -> NdArrayTensor {
158 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
159 array,
160 rhs.elem()
161 ))
162 }
163
164 fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
165 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
166 }
167
168 fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
169 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
170 }
171
172 fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
173 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
174 }
175
176 fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
177 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
178 }
179
180 fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
181 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
182 }
183
184 fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
185 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
186 }
187
188 fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
189 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
190 }
191
192 fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
193 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
194 }
195
196 fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
197 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
198 }
199
200 fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
201 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
202 array,
203 rhs.elem()
204 ))
205 }
206
207 fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
208 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
210 array.view()
211 ))
212 }
213
214 fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
215 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
216 }
217
218 fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
219 execute_with_int_dtype!(
221 tensor,
222 E,
223 |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
224 )
225 }
226
227 fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
228 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
229 }
230
231 fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
232 execute_with_int_dtype!(
234 tensor,
235 E,
236 |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
237 )
238 }
239
240 fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
241 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
242 }
243
244 fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
245 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
247 array.view()
248 ))
249 }
250
251 fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
252 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
254 array.view()
255 ))
256 }
257
258 fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
259 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
260 }
261
262 fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
263 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
264 }
265
266 fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
267 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
268 }
269
270 fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
271 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
272 }
273
274 fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
275 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
276 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
277 dim, array, idx_array
278 ))
279 })
280 }
281
282 fn int_scatter_add(
283 dim: usize,
284 tensor: NdArrayTensor,
285 indices: NdArrayTensor,
286 value: NdArrayTensor,
287 ) -> NdArrayTensor {
288 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
289 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
290 dim, tensor, idx_array, value
291 ))
292 })
293 }
294
295 fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
296 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
297 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
298 array, dim, idx_array
299 ))
300 })
301 }
302
303 fn int_select_add(
304 tensor: NdArrayTensor,
305 dim: usize,
306 indices: NdArrayTensor,
307 value: NdArrayTensor,
308 ) -> NdArrayTensor {
309 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
310 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
311 tensor, dim, idx_array, value
312 ))
313 })
314 }
315 fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
316 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
318 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
319 })
320 }
321
322 fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
323 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
325 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
326 })
327 }
328
329 fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
330 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
331 }
332
333 fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
334 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
335 }
336
337 fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
338 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
339 array,
340 min.elem(),
341 max.elem()
342 ))
343 }
344
345 fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
346 match tensor.dtype() {
347 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
348 execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
349 I64 => i64, I32 => i32, I16 => i16, I8 => i8
350 ])
351 }
352 DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
354 other => panic!("Unsupported dtype: {other:?}"),
355 }
356 }
357
358 fn int_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
359 execute_with_float_out_dtype!(out_dtype, F, {
360 execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| {
361 array.mapv(|a: IntElem| a.elem::<F>()).into_shared()
362 })
363 })
364 }
365
366 fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
367 execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
368 }
369
370 fn int_random(
371 shape: Shape,
372 distribution: Distribution,
373 device: &NdArrayDevice,
374 dtype: IntDType,
375 ) -> NdArrayTensor {
376 let mut seed = SEED.lock().unwrap();
377 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
378
379 let effective_distribution = if distribution == Distribution::Default {
380 Distribution::Uniform(0.0, 255.0) } else {
382 distribution
383 };
384
385 let tensor = execute_with_int_out_dtype!(
386 dtype,
387 I,
388 Self::int_from_data(
389 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
390 device,
391 )
392 );
393 *seed = Some(rng);
394 tensor
395 }
396
397 fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
398 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
399 lhs,
400 rhs,
401 |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
402 ))
403 }
404
405 fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
406 execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
407 }
408
409 fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
410 execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
411 }
412
413 fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
414 match tensor.dtype() {
415 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
416 execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
417 I64 => i64, I32 => i32, I16 => i16, I8 => i8
418 ])
419 }
420 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
421 Self::int_greater_elem(tensor, 0.into(), BoolDType::Native)
422 }
423 other => panic!("Unsupported dtype: {other:?}"),
424 }
425 }
426
427 fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
428 execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
429 }
430
431 fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
432 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
433 }
434
435 fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
436 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
437 }
438
439 fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
440 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
441 }
442
443 fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
444 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
445 }
446
447 fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
448 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
449 }
450
451 fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
452 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
453 }
454
455 fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
456 execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
457 }
458
459 fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
460 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
461 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
462 (a.elem::<i64>() << (b.elem::<u32>())).elem()
463 })
464 })
465 }
466
467 fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
468 execute_with_int_dtype!(lhs, I, |array| {
469 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
470 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
471 })
472 })
473 }
474
475 fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
476 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
477 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
478 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
479 })
480 })
481 }
482
483 fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
484 execute_with_int_dtype!(lhs, I, |array| {
485 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
486 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
487 })
488 })
489 }
490
491 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
492 execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
493 }
494
495 fn int_unfold(
496 tensor: IntTensor<Self>,
497 dim: usize,
498 size: usize,
499 step: usize,
500 ) -> IntTensor<Self> {
501 execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
502 }
503
504 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
505 execute_with_int_dtype!(lhs, I, |array| {
506 NdArrayMathOps::elementwise_op_scalar(array, |a: I| a.powi_elem(rhs.elem()))
507 })
508 }
509}