1use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
8
9use burn_backend::ElementConversion;
10use burn_std::{BoolDType, FloatDType};
11
12use crate::{ExpElement, NdArrayDevice, SEED, execute_with_int_out_dtype, slice};
14use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
15use crate::{SharedArray, element::QuantElement};
16use crate::{cat_with_dtype, execute_with_float_out_dtype};
17use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
18use crate::{element::IntNdArrayElement, execute_with_int_dtype};
19
20use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
22use burn_backend::{DType, Shape, TensorData};
23
24impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
25 for NdArray<E, I, Q>
26where
27 NdArrayTensor: From<SharedArray<E>>,
28 NdArrayTensor: From<SharedArray<I>>,
29{
30 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
31 if data.dtype.is_int() || data.dtype.is_uint() {
32 NdArrayTensor::from_data(data)
33 } else {
34 unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
35 }
36 }
37
38 async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
39 Ok(tensor.into_data())
40 }
41
42 fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
43 tensor
44 }
45
46 fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
47 execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
48 }
49
50 fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
51 slice!(tensor, slices)
52 }
53
54 fn int_device(_tensor: &NdArrayTensor) -> NdArrayDevice {
55 NdArrayDevice::Cpu
56 }
57
58 fn int_empty(shape: Shape, device: &NdArrayDevice, dtype: IntDType) -> NdArrayTensor {
59 Self::int_zeros(shape, device, dtype)
60 }
61
62 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
63 execute_with_int_dtype!((lhs, rhs), matmul)
64 }
65
66 fn int_mask_where(
67 tensor: NdArrayTensor,
68 mask: NdArrayTensor,
69 source: NdArrayTensor,
70 ) -> NdArrayTensor {
71 execute_with_int_dtype!((tensor, source), |tensor, source| {
72 NdArrayOps::mask_where(tensor, mask.bool(), source)
73 })
74 }
75
76 fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
77 execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
78 array,
79 mask.bool(),
80 value.elem()
81 ))
82 }
83
84 fn int_slice_assign(
85 tensor: NdArrayTensor,
86 slices: &[burn_backend::Slice],
87 value: NdArrayTensor,
88 ) -> NdArrayTensor {
89 execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
90 tensor, slices, value
91 ))
92 }
93
94 fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
95 cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
96 }
97
98 fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
99 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
100 }
101
102 fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
103 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
104 }
105
106 fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
107 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
108 }
109
110 fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
111 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
112 }
113
114 fn int_greater_equal(
115 lhs: NdArrayTensor,
116 rhs: NdArrayTensor,
117 _out_dtype: BoolDType,
118 ) -> NdArrayTensor {
119 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
120 }
121
122 fn int_greater_equal_elem(
123 lhs: NdArrayTensor,
124 rhs: Scalar,
125 _out_dtype: BoolDType,
126 ) -> NdArrayTensor {
127 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
128 array,
129 rhs.elem()
130 ))
131 }
132
133 fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
134 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
135 }
136
137 fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
138 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
139 }
140
141 fn int_lower_equal(
142 lhs: NdArrayTensor,
143 rhs: NdArrayTensor,
144 _out_dtype: BoolDType,
145 ) -> NdArrayTensor {
146 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
147 }
148
149 fn int_lower_equal_elem(
150 lhs: NdArrayTensor,
151 rhs: Scalar,
152 _out_dtype: BoolDType,
153 ) -> NdArrayTensor {
154 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
155 array,
156 rhs.elem()
157 ))
158 }
159
160 fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
161 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
162 }
163
164 fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
165 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
166 }
167
168 fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
169 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
170 }
171
172 fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
173 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
174 }
175
176 fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
177 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
178 }
179
180 fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
181 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
182 }
183
184 fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
185 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
186 }
187
188 fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
189 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
190 }
191
192 fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
193 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
194 }
195
196 fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
197 execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
198 array,
199 rhs.elem()
200 ))
201 }
202
203 fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
204 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
206 array.view()
207 ))
208 }
209
210 fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
211 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
212 }
213
214 fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
215 execute_with_int_dtype!(
217 tensor,
218 E,
219 |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
220 )
221 }
222
223 fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
224 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
225 }
226
227 fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
228 execute_with_int_dtype!(
230 tensor,
231 E,
232 |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
233 )
234 }
235
236 fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
237 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
238 }
239
240 fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
241 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
243 array.view()
244 ))
245 }
246
247 fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
248 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
250 array.view()
251 ))
252 }
253
254 fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
255 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
256 }
257
258 fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
259 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
260 }
261
262 fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
263 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
264 }
265
266 fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
267 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
268 }
269
270 fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
271 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
272 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
273 dim, array, idx_array
274 ))
275 })
276 }
277
278 fn int_scatter_add(
279 dim: usize,
280 tensor: NdArrayTensor,
281 indices: NdArrayTensor,
282 value: NdArrayTensor,
283 ) -> NdArrayTensor {
284 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
285 execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
286 dim, tensor, idx_array, value
287 ))
288 })
289 }
290
291 fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
292 execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
293 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
294 array, dim, idx_array
295 ))
296 })
297 }
298
299 fn int_select_add(
300 tensor: NdArrayTensor,
301 dim: usize,
302 indices: NdArrayTensor,
303 value: NdArrayTensor,
304 ) -> NdArrayTensor {
305 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
306 execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
307 tensor, dim, idx_array, value
308 ))
309 })
310 }
311 fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
312 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
314 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
315 })
316 }
317
318 fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
319 execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
321 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
322 })
323 }
324
325 fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
326 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
327 }
328
329 fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
330 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
331 }
332
333 fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
334 execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
335 array,
336 min.elem(),
337 max.elem()
338 ))
339 }
340
341 fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
342 match tensor.dtype() {
343 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
344 execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
345 I64 => i64, I32 => i32, I16 => i16, I8 => i8
346 ])
347 }
348 DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
350 other => panic!("Unsupported dtype: {other:?}"),
351 }
352 }
353
354 fn int_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
355 execute_with_float_out_dtype!(out_dtype, F, {
356 execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| {
357 array.mapv(|a: IntElem| a.elem::<F>()).into_shared()
358 })
359 })
360 }
361
362 fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
363 execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
364 }
365
366 fn int_random(
367 shape: Shape,
368 distribution: Distribution,
369 device: &NdArrayDevice,
370 dtype: IntDType,
371 ) -> NdArrayTensor {
372 let mut seed = SEED.lock().unwrap();
373 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
374
375 let effective_distribution = if distribution == Distribution::Default {
376 Distribution::Uniform(0.0, 255.0) } else {
378 distribution
379 };
380
381 let tensor = execute_with_int_out_dtype!(
382 dtype,
383 I,
384 Self::int_from_data(
385 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
386 device,
387 )
388 );
389 *seed = Some(rng);
390 tensor
391 }
392
393 fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
394 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
395 lhs,
396 rhs,
397 |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
398 ))
399 }
400
401 fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
402 execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
403 }
404
405 fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
406 execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
407 }
408
409 fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
410 match tensor.dtype() {
411 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
412 execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
413 I64 => i64, I32 => i32, I16 => i16, I8 => i8
414 ])
415 }
416 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
417 Self::int_greater_elem(tensor, 0.into(), BoolDType::Native)
418 }
419 other => panic!("Unsupported dtype: {other:?}"),
420 }
421 }
422
423 fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
424 execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
425 }
426
427 fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
428 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
429 }
430
431 fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
432 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
433 }
434
435 fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
436 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
437 }
438
439 fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
440 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
441 }
442
443 fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
444 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
445 }
446
447 fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
448 execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
449 }
450
451 fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
452 execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
453 }
454
455 fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
456 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
457 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
458 (a.elem::<i64>() << (b.elem::<u32>())).elem()
459 })
460 })
461 }
462
463 fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
464 execute_with_int_dtype!(lhs, I, |array| {
465 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
466 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
467 })
468 })
469 }
470
471 fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
472 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
473 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
474 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
475 })
476 })
477 }
478
479 fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
480 execute_with_int_dtype!(lhs, I, |array| {
481 NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
482 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
483 })
484 })
485 }
486
487 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
488 execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
489 }
490
491 fn int_unfold(
492 tensor: IntTensor<Self>,
493 dim: usize,
494 size: usize,
495 step: usize,
496 ) -> IntTensor<Self> {
497 execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
498 }
499
500 fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
501 execute_with_int_dtype!(lhs, I, |array| {
502 NdArrayMathOps::elementwise_op_scalar(array, |a: I| a.powi_elem(rhs.elem()))
503 })
504 }
505}