Skip to main content

burn_ndarray/ops/
int_tensor.rs

1// Language
2use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
8
9use burn_backend::ElementConversion;
10use burn_std::{BoolDType, FloatDType};
11
12// Current crate
13use crate::{ExpElement, NdArrayDevice, SEED, execute_with_int_out_dtype, slice};
14use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
15use crate::{SharedArray, element::QuantElement};
16use crate::{cat_with_dtype, execute_with_float_out_dtype};
17use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
18use crate::{element::IntNdArrayElement, execute_with_int_dtype};
19
20// Workspace crates
21use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
22use burn_backend::{DType, Shape, TensorData, backend::Backend};
23
24impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
25    for NdArray<E, I, Q>
26where
27    NdArrayTensor: From<SharedArray<E>>,
28    NdArrayTensor: From<SharedArray<I>>,
29{
30    fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
31        if data.dtype.is_int() || data.dtype.is_uint() {
32            NdArrayTensor::from_data(data)
33        } else {
34            unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
35        }
36    }
37
38    async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
39        Ok(tensor.into_data())
40    }
41
42    fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
43        tensor
44    }
45
46    fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
47        execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
48    }
49
50    fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
51        slice!(tensor, slices)
52    }
53
54    fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
55        NdArrayDevice::Cpu
56    }
57
58    fn int_empty(
59        shape: Shape,
60        device: &<NdArray<E> as Backend>::Device,
61        dtype: IntDType,
62    ) -> NdArrayTensor {
63        Self::int_zeros(shape, device, dtype)
64    }
65
66    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
67        execute_with_int_dtype!((lhs, rhs), matmul)
68    }
69
70    fn int_mask_where(
71        tensor: NdArrayTensor,
72        mask: NdArrayTensor,
73        source: NdArrayTensor,
74    ) -> NdArrayTensor {
75        execute_with_int_dtype!((tensor, source), |tensor, source| {
76            NdArrayOps::mask_where(tensor, mask.bool(), source)
77        })
78    }
79
80    fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
81        execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
82            array,
83            mask.bool(),
84            value.elem()
85        ))
86    }
87
88    fn int_slice_assign(
89        tensor: NdArrayTensor,
90        slices: &[burn_backend::Slice],
91        value: NdArrayTensor,
92    ) -> NdArrayTensor {
93        execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
94            tensor, slices, value
95        ))
96    }
97
98    fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
99        cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
100    }
101
102    fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
103        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
104    }
105
106    fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
107        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
108    }
109
110    fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
111        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
112    }
113
114    fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
115        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
116    }
117
118    fn int_greater_equal(
119        lhs: NdArrayTensor,
120        rhs: NdArrayTensor,
121        _out_dtype: BoolDType,
122    ) -> NdArrayTensor {
123        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
124    }
125
126    fn int_greater_equal_elem(
127        lhs: NdArrayTensor,
128        rhs: Scalar,
129        _out_dtype: BoolDType,
130    ) -> NdArrayTensor {
131        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
132            array,
133            rhs.elem()
134        ))
135    }
136
137    fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor {
138        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
139    }
140
141    fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor {
142        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
143    }
144
145    fn int_lower_equal(
146        lhs: NdArrayTensor,
147        rhs: NdArrayTensor,
148        _out_dtype: BoolDType,
149    ) -> NdArrayTensor {
150        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
151    }
152
153    fn int_lower_equal_elem(
154        lhs: NdArrayTensor,
155        rhs: Scalar,
156        _out_dtype: BoolDType,
157    ) -> NdArrayTensor {
158        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
159            array,
160            rhs.elem()
161        ))
162    }
163
164    fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
165        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
166    }
167
168    fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
169        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
170    }
171
172    fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
173        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
174    }
175
176    fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
177        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
178    }
179
180    fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
181        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
182    }
183
184    fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
185        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
186    }
187
188    fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
189        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
190    }
191
192    fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
193        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
194    }
195
196    fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
197        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
198    }
199
200    fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
201        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
202            array,
203            rhs.elem()
204        ))
205    }
206
207    fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
208        // Use view() for zero-copy on borrowed storage
209        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
210            array.view()
211        ))
212    }
213
214    fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
215        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
216    }
217
218    fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
219        // Use view() for zero-copy on borrowed storage
220        execute_with_int_dtype!(
221            tensor,
222            E,
223            |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
224        )
225    }
226
227    fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
228        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
229    }
230
231    fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
232        // Use view() for zero-copy on borrowed storage
233        execute_with_int_dtype!(
234            tensor,
235            E,
236            |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
237        )
238    }
239
240    fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
241        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
242    }
243
244    fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
245        // Use view() for zero-copy on borrowed storage
246        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
247            array.view()
248        ))
249    }
250
251    fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
252        // Use view() for zero-copy on borrowed storage
253        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
254            array.view()
255        ))
256    }
257
258    fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
259        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
260    }
261
262    fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
263        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
264    }
265
266    fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
267        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
268    }
269
270    fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
271        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
272    }
273
274    fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
275        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
276            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
277                dim, array, idx_array
278            ))
279        })
280    }
281
282    fn int_scatter_add(
283        dim: usize,
284        tensor: NdArrayTensor,
285        indices: NdArrayTensor,
286        value: NdArrayTensor,
287    ) -> NdArrayTensor {
288        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
289            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
290                dim, tensor, idx_array, value
291            ))
292        })
293    }
294
295    fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
296        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
297            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
298                array, dim, idx_array
299            ))
300        })
301    }
302
303    fn int_select_add(
304        tensor: NdArrayTensor,
305        dim: usize,
306        indices: NdArrayTensor,
307        value: NdArrayTensor,
308    ) -> NdArrayTensor {
309        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
310            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
311                tensor, dim, idx_array, value
312            ))
313        })
314    }
315    fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
316        // Use view() for zero-copy on borrowed storage
317        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
318            NdArrayMathOps::argmax_view::<I>(array.view(), dim)
319        })
320    }
321
322    fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
323        // Use view() for zero-copy on borrowed storage
324        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
325            NdArrayMathOps::argmin_view::<I>(array.view(), dim)
326        })
327    }
328
329    fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
330        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
331    }
332
333    fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
334        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
335    }
336
337    fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
338        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
339            array,
340            min.elem(),
341            max.elem()
342        ))
343    }
344
345    fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
346        match tensor.dtype() {
347            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
348                execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
349                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
350                ])
351            }
352            // Already unsigned
353            DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
354            other => panic!("Unsupported dtype: {other:?}"),
355        }
356    }
357
358    fn int_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor<Self> {
359        execute_with_float_out_dtype!(out_dtype, F, {
360            execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| {
361                array.mapv(|a: IntElem| a.elem::<F>()).into_shared()
362            })
363        })
364    }
365
366    fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
367        execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
368    }
369
370    fn int_random(
371        shape: Shape,
372        distribution: Distribution,
373        device: &NdArrayDevice,
374        dtype: IntDType,
375    ) -> NdArrayTensor {
376        let mut seed = SEED.lock().unwrap();
377        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
378
379        let effective_distribution = if distribution == Distribution::Default {
380            Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
381        } else {
382            distribution
383        };
384
385        let tensor = execute_with_int_out_dtype!(
386            dtype,
387            I,
388            Self::int_from_data(
389                TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
390                device,
391            )
392        );
393        *seed = Some(rng);
394        tensor
395    }
396
397    fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
398        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
399            lhs,
400            rhs,
401            |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
402        ))
403    }
404
405    fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
406        execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
407    }
408
409    fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
410        execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
411    }
412
413    fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
414        match tensor.dtype() {
415            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
416                execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
417                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
418                ])
419            }
420            DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
421                Self::int_greater_elem(tensor, 0.into(), BoolDType::Native)
422            }
423            other => panic!("Unsupported dtype: {other:?}"),
424        }
425    }
426
427    fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
428        execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
429    }
430
431    fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
432        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
433    }
434
435    fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
436        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
437    }
438
439    fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
440        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
441    }
442
443    fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
444        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
445    }
446
447    fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
448        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
449    }
450
451    fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
452        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
453    }
454
455    fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
456        execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
457    }
458
459    fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
460        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
461            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
462                (a.elem::<i64>() << (b.elem::<u32>())).elem()
463            })
464        })
465    }
466
467    fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
468        execute_with_int_dtype!(lhs, I, |array| {
469            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
470                (a.elem::<i64>() << rhs.elem::<u32>()).elem()
471            })
472        })
473    }
474
475    fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
476        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
477            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
478                (a.elem::<i64>() >> (b.elem::<u32>())).elem()
479            })
480        })
481    }
482
483    fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
484        execute_with_int_dtype!(lhs, I, |array| {
485            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
486                (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
487            })
488        })
489    }
490
491    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
492        execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
493    }
494
495    fn int_unfold(
496        tensor: IntTensor<Self>,
497        dim: usize,
498        size: usize,
499        step: usize,
500    ) -> IntTensor<Self> {
501        execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
502    }
503
504    fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {
505        execute_with_int_dtype!(lhs, I, |array| {
506            NdArrayMathOps::elementwise_op_scalar(array, |a: I| a.powi_elem(rhs.elem()))
507        })
508    }
509}