1use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_tensor::{Distribution, ops::IntTensor};
5use burn_tensor::{IntDType, ops::IntTensorOps};
6use burn_tensor::{TensorMetadata, ops::FloatTensor};
7
8use burn_tensor::ElementConversion;
9
10use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
12use crate::{NdArrayDevice, SEED};
13use crate::{SharedArray, element::QuantElement};
14use crate::{cat_with_dtype, execute_with_float_dtype};
15use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
16use crate::{element::IntNdArrayElement, execute_with_int_dtype};
17
18use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
20use burn_tensor::{DType, Shape, TensorData, backend::Backend};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
23 for NdArray<E, I, Q>
24where
25 NdArrayTensor: From<SharedArray<E>>,
26 NdArrayTensor: From<SharedArray<I>>,
27{
28 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
29 if data.dtype.is_int() || data.dtype.is_uint() {
30 NdArrayTensor::from_data(data)
31 } else {
32 unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
33 }
34 }
35
36 async fn int_into_data(tensor: NdArrayTensor) -> TensorData {
37 tensor.into_data()
38 }
39
40 fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
41 tensor
42 }
43
44 fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
45 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
46 }
47
48 fn int_slice(tensor: NdArrayTensor, slices: &[burn_tensor::Slice]) -> NdArrayTensor {
49 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
50 }
51
52 fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
53 NdArrayDevice::Cpu
54 }
55
56 fn int_empty(
57 shape: Shape,
58 device: &<NdArray<E> as Backend>::Device,
59 dtype: IntDType,
60 ) -> NdArrayTensor {
61 Self::int_zeros(shape, device, dtype)
62 }
63
64 fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
65 execute_with_int_dtype!((lhs, rhs), matmul)
66 }
67
68 fn int_mask_where(
69 tensor: NdArrayTensor,
70 mask: NdArrayTensor,
71 source: NdArrayTensor,
72 ) -> NdArrayTensor {
73 execute_with_int_dtype!((tensor, source), |tensor, source| {
74 NdArrayMathOps::mask_where(tensor, mask.bool(), source)
75 })
76 }
77
78 fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: I) -> NdArrayTensor {
79 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
80 tensor,
81 mask.bool(),
82 value.elem()
83 ))
84 }
85
86 fn int_slice_assign(
87 tensor: NdArrayTensor,
88 slices: &[burn_tensor::Slice],
89 value: NdArrayTensor,
90 ) -> NdArrayTensor {
91 execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
92 tensor, slices, value
93 ))
94 }
95
96 fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
97 cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
98 }
99
100 fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
101 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
102 }
103
104 fn int_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
105 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::equal_elem(lhs, rhs.elem()))
106 }
107
108 fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
109 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
110 }
111
112 fn int_greater_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
113 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::greater_elem(lhs, rhs.elem()))
114 }
115
116 fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
117 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
118 }
119
120 fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
121 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::greater_equal_elem(
122 lhs,
123 rhs.elem()
124 ))
125 }
126
127 fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
128 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
129 }
130
131 fn int_lower_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
132 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::lower_elem(lhs, rhs.elem()))
133 }
134
135 fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
136 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
137 }
138
139 fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
140 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::lower_equal_elem(lhs, rhs.elem()))
141 }
142
143 fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
144 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
145 }
146
147 fn int_add_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
148 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
149 }
150
151 fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
152 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
153 }
154
155 fn int_sub_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
156 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
157 }
158
159 fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
160 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
161 }
162
163 fn int_mul_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
164 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
165 }
166
167 fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
168 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
169 }
170
171 fn int_div_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
172 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
173 }
174
175 fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
176 execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
177 }
178
179 fn int_remainder_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
180 execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
181 }
182
183 fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor {
184 Self::int_mul_scalar(tensor, (-1).elem())
185 }
186
187 fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
188 execute_with_int_dtype!(tensor, NdArrayMathOps::sum)
189 }
190
191 fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
192 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
193 }
194
195 fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
196 execute_with_int_dtype!(tensor, NdArrayMathOps::prod)
197 }
198
199 fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
200 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
201 }
202
203 fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
204 execute_with_int_dtype!(tensor, NdArrayMathOps::mean)
205 }
206
207 fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
208 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
209 }
210
211 fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
212 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
213 }
214
215 fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
216 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
217 }
218
219 fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
220 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
221 }
222
223 fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
224 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
225 }
226
227 fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
228 execute_with_int_dtype!(tensor, E, |tensor: SharedArray<E>| -> NdArrayTensor {
229 execute_with_int_dtype!(indices, |indices| NdArrayMathOps::gather(
230 dim, tensor, indices
231 ))
232 })
233 }
234
235 fn int_scatter(
236 dim: usize,
237 tensor: NdArrayTensor,
238 indices: NdArrayTensor,
239 value: NdArrayTensor,
240 ) -> NdArrayTensor {
241 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
242 execute_with_int_dtype!(indices, |indices| NdArrayMathOps::<I>::scatter(
243 dim, tensor, indices, value
244 ))
245 })
246 }
247
248 fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
249 execute_with_int_dtype!(tensor, E, |tensor: SharedArray<E>| -> NdArrayTensor {
250 execute_with_int_dtype!(indices, |indices| NdArrayMathOps::select(
251 tensor, dim, indices
252 ))
253 })
254 }
255
256 fn int_select_assign(
257 tensor: NdArrayTensor,
258 dim: usize,
259 indices: NdArrayTensor,
260 value: NdArrayTensor,
261 ) -> NdArrayTensor {
262 execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
263 execute_with_int_dtype!(indices, |indices| NdArrayMathOps::<I>::select_assign(
264 tensor, dim, indices, value
265 ))
266 })
267 }
268 fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
269 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
270 }
271
272 fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
273 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
274 }
275
276 fn int_clamp_min(tensor: NdArrayTensor, min: I) -> NdArrayTensor {
277 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
278 tensor,
279 min.elem()
280 ))
281 }
282
283 fn int_clamp_max(tensor: NdArrayTensor, max: I) -> NdArrayTensor {
284 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
285 tensor,
286 max.elem()
287 ))
288 }
289
290 fn int_clamp(tensor: NdArrayTensor, min: I, max: I) -> NdArrayTensor {
291 execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
292 tensor,
293 min.elem(),
294 max.elem()
295 ))
296 }
297
298 fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
299 match tensor.dtype() {
300 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
301 execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
302 I64 => i64, I32 => i32, I16 => i16, I8 => i8
303 ])
304 }
305 DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
307 other => panic!("Unsupported dtype: {other:?}"),
308 }
309 }
310
311 fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
312 execute_with_int_dtype!(tensor, I, |t: SharedArray<I>| t
313 .mapv(|a| a.elem::<E>())
314 .into_shared())
315 }
316
317 fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
318 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
319 }
320
321 fn int_random(
322 shape: Shape,
323 distribution: Distribution,
324 device: &NdArrayDevice,
325 ) -> NdArrayTensor {
326 let mut seed = SEED.lock().unwrap();
327 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
328 rng_seeded.clone()
329 } else {
330 get_seeded_rng()
331 };
332
333 let effective_distribution = if distribution == Distribution::Default {
334 Distribution::Uniform(0.0, 255.0) } else {
336 distribution
337 };
338
339 let tensor = Self::int_from_data(
340 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
341 device,
342 );
343 *seed = Some(rng);
344 tensor
345 }
346
347 fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
348 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
349 lhs,
350 rhs,
351 |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
352 ))
353 }
354
355 fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
356 execute_with_int_dtype!(lhs, I, |lhs| -> NdArrayTensor {
357 execute_with_float_dtype!(rhs, E, |rhs| {
358 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| {
359 (a.elem::<i64>().pow(*b as u32)).elem()
360 })
361 })
362 })
363 }
364
365 fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: f32) -> NdArrayTensor {
366 execute_with_int_dtype!(lhs, I, |lhs| {
367 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
368 (a.elem::<i64>().pow(rhs as u32)).elem()
369 })
370 })
371 }
372
373 fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
374 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
375 }
376
377 fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
378 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
379 }
380
381 fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
382 match tensor.dtype() {
383 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
384 execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
385 I64 => i64, I32 => i32, I16 => i16, I8 => i8
386 ])
387 }
388 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
389 Self::int_greater_elem(tensor, 0.elem())
390 }
391 other => panic!("Unsupported dtype: {other:?}"),
392 }
393 }
394
395 fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
396 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
397 }
398
399 fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
400 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
401 }
402
403 fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
404 execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitand_scalar(lhs, rhs.elem()))
405 }
406
407 fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
408 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
409 }
410
411 fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
412 execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitor_scalar(lhs, rhs.elem()))
413 }
414
415 fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
416 execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
417 }
418
419 fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
420 execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitxor_scalar(lhs, rhs.elem()))
421 }
422
423 fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
424 execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
425 }
426
427 fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
428 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
429 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
430 (a.elem::<i64>() << (b.elem::<u32>())).elem()
431 })
432 })
433 }
434
435 fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
436 execute_with_int_dtype!(lhs, I, |lhs| {
437 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
438 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
439 })
440 })
441 }
442
443 fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
444 execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
445 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
446 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
447 })
448 })
449 }
450
451 fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
452 execute_with_int_dtype!(lhs, I, |lhs| {
453 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
454 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
455 })
456 })
457 }
458
459 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
460 execute_with_int_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
461 }
462
463 fn int_unfold(
464 tensor: IntTensor<Self>,
465 dim: usize,
466 size: usize,
467 step: usize,
468 ) -> IntTensor<Self> {
469 execute_with_int_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
470 }
471}