1use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, TensorMetadata};
8
9use burn_backend::ElementConversion;
10
11use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
13use crate::{NdArrayDevice, SEED};
14use crate::{SharedArray, element::QuantElement};
15use crate::{cat_with_dtype, execute_with_float_dtype};
16use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
17use crate::{element::IntNdArrayElement, execute_with_int_dtype};
18
19use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
21use burn_backend::{DType, Shape, TensorData, backend::Backend};
22
23impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
24 for NdArray<E, I, Q>
25where
26 NdArrayTensor: From<SharedArray<E>>,
27 NdArrayTensor: From<SharedArray<I>>,
28{
29 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
30 if data.dtype.is_int() || data.dtype.is_uint() {
31 NdArrayTensor::from_data(data)
32 } else {
33 unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
34 }
35 }
36
37 async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
38 Ok(tensor.into_data())
39 }
40
41 fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
42 tensor
43 }
44
45 fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
46 execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
47 }
48
49 fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
50 execute_with_int_dtype!(tensor, |array| NdArrayOps::slice(array, slices))
51 }
52
53 fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
54 NdArrayDevice::Cpu
55 }
56
57 fn int_empty(
58 shape: Shape,
59 device: &<NdArray<E> as Backend>::Device,
60 dtype: IntDType,
61 ) -> NdArrayTensor {
62 Self::int_zeros(shape, device, dtype)
63 }
64
65 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
66 execute_with_int_dtype!((lhs, rhs), matmul)
67 }
68
69 fn int_mask_where(
70 tensor: NdArrayTensor,
71 mask: NdArrayTensor,
72 source: NdArrayTensor,
73 ) -> NdArrayTensor {
74 execute_with_int_dtype!((tensor, source), |tensor, source| {
75 NdArrayOps::mask_where(tensor, mask.bool(), source)
76 })
77 }
78
79 fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: I) -> NdArrayTensor {
80 execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
81 array,
82 mask.bool(),
83 value.elem()
84 ))
85 }
86
87 fn int_slice_assign(
88 tensor: NdArrayTensor,
89 slices: &[burn_backend::Slice],
90 value: NdArrayTensor,
91 ) -> NdArrayTensor {
92 execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
93 tensor, slices, value
94 ))
95 }
96
97 fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
98 cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
99 }
100
101 fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
102 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
103 }
104
105 fn int_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
106 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
107 }
108
109 fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
110 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
111 }
112
113 fn int_greater_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
114 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
115 }
116
117 fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
118 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
119 }
120
121 fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
122 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
123 array,
124 rhs.elem()
125 ))
126 }
127
128 fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
129 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
130 }
131
132 fn int_lower_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
133 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
134 }
135
136 fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
137 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
138 }
139
140 fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
141 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
142 array,
143 rhs.elem()
144 ))
145 }
146
147 fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
148 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
149 }
150
151 fn int_add_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
152 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
153 }
154
155 fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
156 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
157 }
158
159 fn int_sub_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
160 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
161 }
162
163 fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
164 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
165 }
166
167 fn int_mul_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
168 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
169 }
170
171 fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
172 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
173 }
174
175 fn int_div_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
176 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
177 }
178
179 fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
180 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
181 }
182
183 fn int_remainder_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
184 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
185 array,
186 rhs.elem()
187 ))
188 }
189
190 fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor {
191 Self::int_mul_scalar(tensor, (-1).elem())
192 }
193
194 fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
195 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
197 array.view()
198 ))
199 }
200
201 fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
202 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
203 }
204
205 fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
206 execute_with_int_dtype!(
208 tensor,
209 E,
210 |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
211 )
212 }
213
214 fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
215 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
216 }
217
218 fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
219 execute_with_int_dtype!(
221 tensor,
222 E,
223 |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
224 )
225 }
226
227 fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
228 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
229 }
230
231 fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
232 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
234 array.view()
235 ))
236 }
237
238 fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
239 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
241 array.view()
242 ))
243 }
244
245 fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
246 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
247 }
248
249 fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
250 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
251 }
252
253 fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
254 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
255 }
256
257 fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
258 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
259 }
260
261 fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
262 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
263 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
264 dim, array, idx_array
265 ))
266 })
267 }
268
269 fn int_scatter_add(
270 dim: usize,
271 tensor: NdArrayTensor,
272 indices: NdArrayTensor,
273 value: NdArrayTensor,
274 ) -> NdArrayTensor {
275 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
276 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
277 dim, tensor, idx_array, value
278 ))
279 })
280 }
281
282 fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
283 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
284 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
285 array, dim, idx_array
286 ))
287 })
288 }
289
290 fn int_select_add(
291 tensor: NdArrayTensor,
292 dim: usize,
293 indices: NdArrayTensor,
294 value: NdArrayTensor,
295 ) -> NdArrayTensor {
296 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
297 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
298 tensor, dim, idx_array, value
299 ))
300 })
301 }
302 fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
303 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
305 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
306 })
307 }
308
309 fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
310 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
312 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
313 })
314 }
315
316 fn int_clamp_min(tensor: NdArrayTensor, min: I) -> NdArrayTensor {
317 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
318 }
319
320 fn int_clamp_max(tensor: NdArrayTensor, max: I) -> NdArrayTensor {
321 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
322 }
323
324 fn int_clamp(tensor: NdArrayTensor, min: I, max: I) -> NdArrayTensor {
325 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
326 array,
327 min.elem(),
328 max.elem()
329 ))
330 }
331
332 fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
333 match tensor.dtype() {
334 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
335 execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
336 I64 => i64, I32 => i32, I16 => i16, I8 => i8
337 ])
338 }
339 DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
341 other => panic!("Unsupported dtype: {other:?}"),
342 }
343 }
344
345 fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
346 execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array
347 .mapv(|a: IntElem| a.elem::<E>())
348 .into_shared())
349 }
350
351 fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
352 execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
353 }
354
355 fn int_random(
356 shape: Shape,
357 distribution: Distribution,
358 device: &NdArrayDevice,
359 ) -> NdArrayTensor {
360 let mut seed = SEED.lock().unwrap();
361 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
362 rng_seeded.clone()
363 } else {
364 get_seeded_rng()
365 };
366
367 let effective_distribution = if distribution == Distribution::Default {
368 Distribution::Uniform(0.0, 255.0) } else {
370 distribution
371 };
372
373 let tensor = Self::int_from_data(
374 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
375 device,
376 );
377 *seed = Some(rng);
378 tensor
379 }
380
381 fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
382 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
383 lhs,
384 rhs,
385 |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
386 ))
387 }
388
389 fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
390 execute_with_int_dtype!(lhs, I, |array| -> NdArrayTensor {
391 execute_with_float_dtype!(rhs, E, |rhs_array| {
392 NdArrayMathOps::elementwise_op(array, rhs_array, |a: &I, b: &E| {
393 (a.elem::<i64>().pow(*b as u32)).elem()
394 })
395 })
396 })
397 }
398
399 fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: f32) -> NdArrayTensor {
400 execute_with_int_dtype!(lhs, I, |array| {
401 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
402 (a.elem::<i64>().pow(rhs as u32)).elem()
403 })
404 })
405 }
406
407 fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
408 execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
409 }
410
411 fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
412 execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
413 }
414
415 fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
416 match tensor.dtype() {
417 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
418 execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
419 I64 => i64, I32 => i32, I16 => i16, I8 => i8
420 ])
421 }
422 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
423 Self::int_greater_elem(tensor, 0.elem())
424 }
425 other => panic!("Unsupported dtype: {other:?}"),
426 }
427 }
428
429 fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
430 execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
431 }
432
433 fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
434 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
435 }
436
437 fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
438 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
439 }
440
441 fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
442 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
443 }
444
445 fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
446 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
447 }
448
449 fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
450 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
451 }
452
453 fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
454 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
455 }
456
457 fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
458 execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
459 }
460
461 fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
462 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
463 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
464 (a.elem::<i64>() << (b.elem::<u32>())).elem()
465 })
466 })
467 }
468
469 fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
470 execute_with_int_dtype!(lhs, I, |array| {
471 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
472 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
473 })
474 })
475 }
476
477 fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
478 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
479 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
480 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
481 })
482 })
483 }
484
485 fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
486 execute_with_int_dtype!(lhs, I, |array| {
487 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
488 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
489 })
490 })
491 }
492
493 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
494 execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
495 }
496
497 fn int_unfold(
498 tensor: IntTensor<Self>,
499 dim: usize,
500 size: usize,
501 step: usize,
502 ) -> IntTensor<Self> {
503 execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
504 }
505}