Skip to main content

burn_ndarray/ops/
int_tensor.rs

1// Language
2use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_backend::backend::ExecutionError;
5use burn_backend::ops::IntTensorOps;
6use burn_backend::tensor::{FloatTensor, IntTensor};
7use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};
8
9use burn_backend::ElementConversion;
10
11// Current crate
12use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
13use crate::{NdArrayDevice, SEED, slice};
14use crate::{SharedArray, element::QuantElement};
15use crate::{cat_with_dtype, execute_with_float_dtype};
16use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
17use crate::{element::IntNdArrayElement, execute_with_int_dtype};
18
19// Workspace crates
20use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
21use burn_backend::{DType, Shape, TensorData, backend::Backend};
22
23impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
24    for NdArray<E, I, Q>
25where
26    NdArrayTensor: From<SharedArray<E>>,
27    NdArrayTensor: From<SharedArray<I>>,
28{
29    fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
30        if data.dtype.is_int() || data.dtype.is_uint() {
31            NdArrayTensor::from_data(data)
32        } else {
33            unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
34        }
35    }
36
37    async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {
38        Ok(tensor.into_data())
39    }
40
41    fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
42        tensor
43    }
44
45    fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
46        execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))
47    }
48
49    fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {
50        slice!(tensor, slices)
51    }
52
53    fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
54        NdArrayDevice::Cpu
55    }
56
57    fn int_empty(
58        shape: Shape,
59        device: &<NdArray<E> as Backend>::Device,
60        dtype: IntDType,
61    ) -> NdArrayTensor {
62        Self::int_zeros(shape, device, dtype)
63    }
64
65    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
66        execute_with_int_dtype!((lhs, rhs), matmul)
67    }
68
69    fn int_mask_where(
70        tensor: NdArrayTensor,
71        mask: NdArrayTensor,
72        source: NdArrayTensor,
73    ) -> NdArrayTensor {
74        execute_with_int_dtype!((tensor, source), |tensor, source| {
75            NdArrayOps::mask_where(tensor, mask.bool(), source)
76        })
77    }
78
79    fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {
80        execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(
81            array,
82            mask.bool(),
83            value.elem()
84        ))
85    }
86
87    fn int_slice_assign(
88        tensor: NdArrayTensor,
89        slices: &[burn_backend::Slice],
90        value: NdArrayTensor,
91    ) -> NdArrayTensor {
92        execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
93            tensor, slices, value
94        ))
95    }
96
97    fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
98        cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
99    }
100
101    fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
102        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
103    }
104
105    fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
106        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))
107    }
108
109    fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
110        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
111    }
112
113    fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
114        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))
115    }
116
117    fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
118        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
119    }
120
121    fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
122        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(
123            array,
124            rhs.elem()
125        ))
126    }
127
128    fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
129        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
130    }
131
132    fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
133        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))
134    }
135
136    fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
137        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
138    }
139
140    fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
141        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(
142            array,
143            rhs.elem()
144        ))
145    }
146
147    fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
148        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
149    }
150
151    fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
152        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))
153    }
154
155    fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
156        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
157    }
158
159    fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
160        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))
161    }
162
163    fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
164        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
165    }
166
167    fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
168        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))
169    }
170
171    fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
172        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
173    }
174
175    fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
176        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))
177    }
178
179    fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
180        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
181    }
182
183    fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
184        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(
185            array,
186            rhs.elem()
187        ))
188    }
189
190    fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
191        // Use view() for zero-copy on borrowed storage
192        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(
193            array.view()
194        ))
195    }
196
197    fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
198        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))
199    }
200
201    fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
202        // Use view() for zero-copy on borrowed storage
203        execute_with_int_dtype!(
204            tensor,
205            E,
206            |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())
207        )
208    }
209
210    fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
211        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))
212    }
213
214    fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
215        // Use view() for zero-copy on borrowed storage
216        execute_with_int_dtype!(
217            tensor,
218            E,
219            |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())
220        )
221    }
222
223    fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
224        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))
225    }
226
227    fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {
228        // Use view() for zero-copy on borrowed storage
229        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(
230            array.view()
231        ))
232    }
233
234    fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {
235        // Use view() for zero-copy on borrowed storage
236        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(
237            array.view()
238        ))
239    }
240
241    fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
242        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))
243    }
244
245    fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
246        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))
247    }
248
249    fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
250        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))
251    }
252
253    fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
254        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))
255    }
256
257    fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
258        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
259            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(
260                dim, array, idx_array
261            ))
262        })
263    }
264
265    fn int_scatter_add(
266        dim: usize,
267        tensor: NdArrayTensor,
268        indices: NdArrayTensor,
269        value: NdArrayTensor,
270    ) -> NdArrayTensor {
271        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
272            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(
273                dim, tensor, idx_array, value
274            ))
275        })
276    }
277
278    fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
279        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {
280            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(
281                array, dim, idx_array
282            ))
283        })
284    }
285
286    fn int_select_add(
287        tensor: NdArrayTensor,
288        dim: usize,
289        indices: NdArrayTensor,
290        value: NdArrayTensor,
291    ) -> NdArrayTensor {
292        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
293            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(
294                tensor, dim, idx_array, value
295            ))
296        })
297    }
298    fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
299        // Use view() for zero-copy on borrowed storage
300        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
301            NdArrayMathOps::argmax_view::<I>(array.view(), dim)
302        })
303    }
304
305    fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
306        // Use view() for zero-copy on borrowed storage
307        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
308            NdArrayMathOps::argmin_view::<I>(array.view(), dim)
309        })
310    }
311
312    fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {
313        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))
314    }
315
316    fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {
317        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))
318    }
319
320    fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {
321        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(
322            array,
323            min.elem(),
324            max.elem()
325        ))
326    }
327
328    fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
329        match tensor.dtype() {
330            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
331                execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
332                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
333                ])
334            }
335            // Already unsigned
336            DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
337            other => panic!("Unsupported dtype: {other:?}"),
338        }
339    }
340
341    fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
342        execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array
343            .mapv(|a: IntElem| a.elem::<E>())
344            .into_shared())
345    }
346
347    fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
348        execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))
349    }
350
351    fn int_random(
352        shape: Shape,
353        distribution: Distribution,
354        device: &NdArrayDevice,
355    ) -> NdArrayTensor {
356        let mut seed = SEED.lock().unwrap();
357        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
358            rng_seeded.clone()
359        } else {
360            get_seeded_rng()
361        };
362
363        let effective_distribution = if distribution == Distribution::Default {
364            Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
365        } else {
366            distribution
367        };
368
369        let tensor = Self::int_from_data(
370            TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
371            device,
372        );
373        *seed = Some(rng);
374        tensor
375    }
376
377    fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
378        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
379            lhs,
380            rhs,
381            |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
382        ))
383    }
384
385    fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
386        execute_with_int_dtype!(lhs, I, |array| -> NdArrayTensor {
387            execute_with_float_dtype!(rhs, E, |rhs_array| {
388                NdArrayMathOps::elementwise_op(array, rhs_array, |a: &I, b: &E| {
389                    (a.elem::<i64>().pow(*b as u32)).elem()
390                })
391            })
392        })
393    }
394
395    fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
396        execute_with_int_dtype!(lhs, I, |array| {
397            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
398                (a.elem::<i64>().pow(rhs.elem())).elem()
399            })
400        })
401    }
402
403    fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
404        execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))
405    }
406
407    fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
408        execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))
409    }
410
411    fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
412        match tensor.dtype() {
413            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
414                execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
415                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
416                ])
417            }
418            DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
419                Self::int_greater_elem(tensor, 0.into())
420            }
421            other => panic!("Unsupported dtype: {other:?}"),
422        }
423    }
424
425    fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
426        execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))
427    }
428
429    fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
430        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
431    }
432
433    fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
434        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))
435    }
436
437    fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
438        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
439    }
440
441    fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
442        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))
443    }
444
445    fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
446        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
447    }
448
449    fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
450        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))
451    }
452
453    fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
454        execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
455    }
456
457    fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
458        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
459            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
460                (a.elem::<i64>() << (b.elem::<u32>())).elem()
461            })
462        })
463    }
464
465    fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
466        execute_with_int_dtype!(lhs, I, |array| {
467            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
468                (a.elem::<i64>() << rhs.elem::<u32>()).elem()
469            })
470        })
471    }
472
473    fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
474        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
475            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
476                (a.elem::<i64>() >> (b.elem::<u32>())).elem()
477            })
478        })
479    }
480
481    fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {
482        execute_with_int_dtype!(lhs, I, |array| {
483            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {
484                (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
485            })
486        })
487    }
488
489    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
490        execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))
491    }
492
493    fn int_unfold(
494        tensor: IntTensor<Self>,
495        dim: usize,
496        size: usize,
497        step: usize,
498    ) -> IntTensor<Self> {
499        execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))
500    }
501}