1use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
8
9use burn_backend::ElementConversion;
10
11use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
13use crate::{NdArrayDevice, SEED, slice};
14use crate::{SharedArray, element::QuantElement};
15use crate::{cat_with_dtype, execute_with_float_dtype};
16use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
17use crate::{element::IntNdArrayElement, execute_with_int_dtype};
18
19use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
21use burn_backend::{DType, Shape, TensorData, backend::Backend};
22
23impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
24 for NdArray<E, I, Q>
25where
26 NdArrayTensor: From<SharedArray<E>>,
27 NdArrayTensor: From<SharedArray<I>>,
28{
29 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
30 if data.dtype.is_int() || data.dtype.is_uint() {
31 NdArrayTensor::from_data(data)
32 } else {
33 unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
34 }
35 }
36
37 async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
38 Ok(tensor.into_data())
39 }
40
41 fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
42 tensor
43 }
44
45 fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
46 execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
47 }
48
49 fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
50 slice!(tensor, slices)
51 }
52
53 fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
54 NdArrayDevice::Cpu
55 }
56
57 fn int_empty(
58 shape: Shape,
59 device: &<NdArray<E> as Backend>::Device,
60 dtype: IntDType,
61 ) -> NdArrayTensor {
62 Self::int_zeros(shape, device, dtype)
63 }
64
65 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
66 execute_with_int_dtype!((lhs, rhs), matmul)
67 }
68
69 fn int_mask_where(
70 tensor: NdArrayTensor,
71 mask: NdArrayTensor,
72 source: NdArrayTensor,
73 ) -> NdArrayTensor {
74 execute_with_int_dtype!((tensor, source), |tensor, source| {
75 NdArrayOps::mask_where(tensor, mask.bool(), source)
76 })
77 }
78
79 fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
80 execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
81 array,
82 mask.bool(),
83 value.elem()
84 ))
85 }
86
87 fn int_slice_assign(
88 tensor: NdArrayTensor,
89 slices: &[burn_backend::Slice],
90 value: NdArrayTensor,
91 ) -> NdArrayTensor {
92 execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
93 tensor, slices, value
94 ))
95 }
96
97 fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
98 cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
99 }
100
101 fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
102 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
103 }
104
105 fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
106 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
107 }
108
109 fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
110 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
111 }
112
113 fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
114 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
115 }
116
117 fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
118 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
119 }
120
121 fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
122 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
123 array,
124 rhs.elem()
125 ))
126 }
127
128 fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
129 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
130 }
131
132 fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
133 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
134 }
135
136 fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
137 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
138 }
139
140 fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
141 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
142 array,
143 rhs.elem()
144 ))
145 }
146
147 fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
148 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
149 }
150
151 fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
152 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
153 }
154
155 fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
156 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
157 }
158
159 fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
160 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
161 }
162
163 fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
164 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
165 }
166
167 fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
168 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
169 }
170
171 fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
172 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
173 }
174
175 fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
176 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
177 }
178
179 fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
180 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
181 }
182
183 fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
184 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
185 array,
186 rhs.elem()
187 ))
188 }
189
190 fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
191 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
193 array.view()
194 ))
195 }
196
197 fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
198 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
199 }
200
201 fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
202 execute_with_int_dtype!(
204 tensor,
205 E,
206 |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
207 )
208 }
209
210 fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
211 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
212 }
213
214 fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
215 execute_with_int_dtype!(
217 tensor,
218 E,
219 |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
220 )
221 }
222
223 fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
224 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
225 }
226
227 fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
228 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
230 array.view()
231 ))
232 }
233
234 fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
235 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
237 array.view()
238 ))
239 }
240
241 fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
242 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
243 }
244
245 fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
246 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
247 }
248
249 fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
250 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
251 }
252
253 fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
254 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
255 }
256
257 fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
258 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
259 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
260 dim, array, idx_array
261 ))
262 })
263 }
264
265 fn int_scatter_add(
266 dim: usize,
267 tensor: NdArrayTensor,
268 indices: NdArrayTensor,
269 value: NdArrayTensor,
270 ) -> NdArrayTensor {
271 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
272 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
273 dim, tensor, idx_array, value
274 ))
275 })
276 }
277
278 fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
279 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
280 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
281 array, dim, idx_array
282 ))
283 })
284 }
285
286 fn int_select_add(
287 tensor: NdArrayTensor,
288 dim: usize,
289 indices: NdArrayTensor,
290 value: NdArrayTensor,
291 ) -> NdArrayTensor {
292 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
293 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
294 tensor, dim, idx_array, value
295 ))
296 })
297 }
298 fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
299 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
301 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
302 })
303 }
304
305 fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
306 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
308 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
309 })
310 }
311
312 fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
313 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
314 }
315
316 fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
317 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
318 }
319
320 fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
321 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
322 array,
323 min.elem(),
324 max.elem()
325 ))
326 }
327
328 fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
329 match tensor.dtype() {
330 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
331 execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
332 I64 => i64, I32 => i32, I16 => i16, I8 => i8
333 ])
334 }
335 DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
337 other => panic!("Unsupported dtype: {other:?}"),
338 }
339 }
340
341 fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
342 execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array
343 .mapv(|a: IntElem| a.elem::<E>())
344 .into_shared())
345 }
346
347 fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
348 execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
349 }
350
351 fn int_random(
352 shape: Shape,
353 distribution: Distribution,
354 device: &NdArrayDevice,
355 ) -> NdArrayTensor {
356 let mut seed = SEED.lock().unwrap();
357 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
358 rng_seeded.clone()
359 } else {
360 get_seeded_rng()
361 };
362
363 let effective_distribution = if distribution == Distribution::Default {
364 Distribution::Uniform(0.0, 255.0) } else {
366 distribution
367 };
368
369 let tensor = Self::int_from_data(
370 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
371 device,
372 );
373 *seed = Some(rng);
374 tensor
375 }
376
377 fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
378 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
379 lhs,
380 rhs,
381 |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
382 ))
383 }
384
385 fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
386 execute_with_int_dtype!(lhs, I, |array| -> NdArrayTensor {
387 execute_with_float_dtype!(rhs, E, |rhs_array| {
388 NdArrayMathOps::elementwise_op(array, rhs_array, |a: &I, b: &E| {
389 (a.elem::<i64>().pow(*b as u32)).elem()
390 })
391 })
392 })
393 }
394
395 fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
396 execute_with_int_dtype!(lhs, I, |array| {
397 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
398 (a.elem::<i64>().pow(rhs.elem())).elem()
399 })
400 })
401 }
402
403 fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
404 execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
405 }
406
407 fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
408 execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
409 }
410
411 fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
412 match tensor.dtype() {
413 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
414 execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
415 I64 => i64, I32 => i32, I16 => i16, I8 => i8
416 ])
417 }
418 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
419 Self::int_greater_elem(tensor, 0.into())
420 }
421 other => panic!("Unsupported dtype: {other:?}"),
422 }
423 }
424
425 fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
426 execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
427 }
428
429 fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
430 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
431 }
432
433 fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
434 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
435 }
436
437 fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
438 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
439 }
440
441 fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
442 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
443 }
444
445 fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
446 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
447 }
448
449 fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
450 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
451 }
452
453 fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
454 execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
455 }
456
457 fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
458 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
459 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
460 (a.elem::<i64>() << (b.elem::<u32>())).elem()
461 })
462 })
463 }
464
465 fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
466 execute_with_int_dtype!(lhs, I, |array| {
467 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
468 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
469 })
470 })
471 }
472
473 fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
474 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
475 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
476 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
477 })
478 })
479 }
480
481 fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
482 execute_with_int_dtype!(lhs, I, |array| {
483 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
484 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
485 })
486 })
487 }
488
489 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
490 execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
491 }
492
493 fn int_unfold(
494 tensor: IntTensor<Self>,
495 dim: usize,
496 size: usize,
497 step: usize,
498 ) -> IntTensor<Self> {
499 execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
500 }
501}