burn_ndarray/ops/
int_tensor.rs

1// Language
2use crate::rand::get_seeded_rng;
3use alloc::vec::Vec;
4use burn_tensor::{Distribution, ops::IntTensor};
5use burn_tensor::{IntDType, ops::IntTensorOps};
6use burn_tensor::{TensorMetadata, ops::FloatTensor};
7
8use burn_tensor::ElementConversion;
9
10// Current crate
11use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};
12use crate::{NdArrayDevice, SEED};
13use crate::{SharedArray, element::QuantElement};
14use crate::{cat_with_dtype, execute_with_float_dtype};
15use crate::{element::FloatNdArrayElement, ops::matmul::matmul};
16use crate::{element::IntNdArrayElement, execute_with_int_dtype};
17
18// Workspace crates
19use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
20use burn_tensor::{DType, Shape, TensorData, backend::Backend};
21
22impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
23    for NdArray<E, I, Q>
24where
25    NdArrayTensor: From<SharedArray<E>>,
26    NdArrayTensor: From<SharedArray<I>>,
27{
28    fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {
29        if data.dtype.is_int() || data.dtype.is_uint() {
30            NdArrayTensor::from_data(data)
31        } else {
32            unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype)
33        }
34    }
35
36    async fn int_into_data(tensor: NdArrayTensor) -> TensorData {
37        tensor.into_data()
38    }
39
40    fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {
41        tensor
42    }
43
44    fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
45        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
46    }
47
48    fn int_slice(tensor: NdArrayTensor, slices: &[burn_tensor::Slice]) -> NdArrayTensor {
49        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
50    }
51
52    fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {
53        NdArrayDevice::Cpu
54    }
55
56    fn int_empty(
57        shape: Shape,
58        device: &<NdArray<E> as Backend>::Device,
59        dtype: IntDType,
60    ) -> NdArrayTensor {
61        Self::int_zeros(shape, device, dtype)
62    }
63
64    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
65        execute_with_int_dtype!((lhs, rhs), matmul)
66    }
67
68    fn int_mask_where(
69        tensor: NdArrayTensor,
70        mask: NdArrayTensor,
71        source: NdArrayTensor,
72    ) -> NdArrayTensor {
73        execute_with_int_dtype!((tensor, source), |tensor, source| {
74            NdArrayMathOps::mask_where(tensor, mask.bool(), source)
75        })
76    }
77
78    fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: I) -> NdArrayTensor {
79        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
80            tensor,
81            mask.bool(),
82            value.elem()
83        ))
84    }
85
86    fn int_slice_assign(
87        tensor: NdArrayTensor,
88        slices: &[burn_tensor::Slice],
89        value: NdArrayTensor,
90    ) -> NdArrayTensor {
91        execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(
92            tensor, slices, value
93        ))
94    }
95
96    fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {
97        cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])
98    }
99
100    fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
101        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)
102    }
103
104    fn int_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
105        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::equal_elem(lhs, rhs.elem()))
106    }
107
108    fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
109        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)
110    }
111
112    fn int_greater_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
113        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::greater_elem(lhs, rhs.elem()))
114    }
115
116    fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
117        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)
118    }
119
120    fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
121        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::greater_equal_elem(
122            lhs,
123            rhs.elem()
124        ))
125    }
126
127    fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
128        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)
129    }
130
131    fn int_lower_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
132        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::lower_elem(lhs, rhs.elem()))
133    }
134
135    fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
136        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)
137    }
138
139    fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
140        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::lower_equal_elem(lhs, rhs.elem()))
141    }
142
143    fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
144        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)
145    }
146
147    fn int_add_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
148        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
149    }
150
151    fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
152        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)
153    }
154
155    fn int_sub_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
156        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
157    }
158
159    fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
160        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)
161    }
162
163    fn int_mul_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
164        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
165    }
166
167    fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
168        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)
169    }
170
171    fn int_div_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
172        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
173    }
174
175    fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
176        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)
177    }
178
179    fn int_remainder_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
180        execute_with_int_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
181    }
182
183    fn int_neg(tensor: NdArrayTensor) -> NdArrayTensor {
184        Self::int_mul_scalar(tensor, (-1).elem())
185    }
186
187    fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {
188        execute_with_int_dtype!(tensor, NdArrayMathOps::sum)
189    }
190
191    fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
192        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
193    }
194
195    fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {
196        execute_with_int_dtype!(tensor, NdArrayMathOps::prod)
197    }
198
199    fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
200        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
201    }
202
203    fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {
204        execute_with_int_dtype!(tensor, NdArrayMathOps::mean)
205    }
206
207    fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
208        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
209    }
210
211    fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
212        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
213    }
214
215    fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
216        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
217    }
218
219    fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
220        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
221    }
222
223    fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
224        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
225    }
226
227    fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
228        execute_with_int_dtype!(tensor, E, |tensor: SharedArray<E>| -> NdArrayTensor {
229            execute_with_int_dtype!(indices, |indices| NdArrayMathOps::gather(
230                dim, tensor, indices
231            ))
232        })
233    }
234
235    fn int_scatter(
236        dim: usize,
237        tensor: NdArrayTensor,
238        indices: NdArrayTensor,
239        value: NdArrayTensor,
240    ) -> NdArrayTensor {
241        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
242            execute_with_int_dtype!(indices, |indices| NdArrayMathOps::<I>::scatter(
243                dim, tensor, indices, value
244            ))
245        })
246    }
247
248    fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
249        execute_with_int_dtype!(tensor, E, |tensor: SharedArray<E>| -> NdArrayTensor {
250            execute_with_int_dtype!(indices, |indices| NdArrayMathOps::select(
251                tensor, dim, indices
252            ))
253        })
254    }
255
256    fn int_select_assign(
257        tensor: NdArrayTensor,
258        dim: usize,
259        indices: NdArrayTensor,
260        value: NdArrayTensor,
261    ) -> NdArrayTensor {
262        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {
263            execute_with_int_dtype!(indices, |indices| NdArrayMathOps::<I>::select_assign(
264                tensor, dim, indices, value
265            ))
266        })
267    }
268    fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
269        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
270    }
271
272    fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
273        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
274    }
275
276    fn int_clamp_min(tensor: NdArrayTensor, min: I) -> NdArrayTensor {
277        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
278            tensor,
279            min.elem()
280        ))
281    }
282
283    fn int_clamp_max(tensor: NdArrayTensor, max: I) -> NdArrayTensor {
284        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
285            tensor,
286            max.elem()
287        ))
288    }
289
290    fn int_clamp(tensor: NdArrayTensor, min: I, max: I) -> NdArrayTensor {
291        execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
292            tensor,
293            min.elem(),
294            max.elem()
295        ))
296    }
297
298    fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {
299        match tensor.dtype() {
300            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
301                execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [
302                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
303                ])
304            }
305            // Already unsigned
306            DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,
307            other => panic!("Unsupported dtype: {other:?}"),
308        }
309    }
310
311    fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {
312        execute_with_int_dtype!(tensor, I, |t: SharedArray<I>| t
313            .mapv(|a| a.elem::<E>())
314            .into_shared())
315    }
316
317    fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {
318        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
319    }
320
321    fn int_random(
322        shape: Shape,
323        distribution: Distribution,
324        device: &NdArrayDevice,
325    ) -> NdArrayTensor {
326        let mut seed = SEED.lock().unwrap();
327        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
328            rng_seeded.clone()
329        } else {
330            get_seeded_rng()
331        };
332
333        let effective_distribution = if distribution == Distribution::Default {
334            Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
335        } else {
336            distribution
337        };
338
339        let tensor = Self::int_from_data(
340            TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
341            device,
342        );
343        *seed = Some(rng);
344        tensor
345    }
346
347    fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
348        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(
349            lhs,
350            rhs,
351            |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }
352        ))
353    }
354
355    fn int_powf(lhs: NdArrayTensor, rhs: FloatTensor<Self>) -> NdArrayTensor {
356        execute_with_int_dtype!(lhs, I, |lhs| -> NdArrayTensor {
357            execute_with_float_dtype!(rhs, E, |rhs| {
358                NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &E| {
359                    (a.elem::<i64>().pow(*b as u32)).elem()
360                })
361            })
362        })
363    }
364
365    fn int_powf_scalar_impl(lhs: NdArrayTensor, rhs: f32) -> NdArrayTensor {
366        execute_with_int_dtype!(lhs, I, |lhs| {
367            NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
368                (a.elem::<i64>().pow(rhs as u32)).elem()
369            })
370        })
371    }
372
373    fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
374        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
375    }
376
377    fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
378        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
379    }
380
381    fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {
382        match tensor.dtype() {
383            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
384                execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [
385                    I64 => i64, I32 => i32, I16 => i16, I8 => i8
386                ])
387            }
388            DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
389                Self::int_greater_elem(tensor, 0.elem())
390            }
391            other => panic!("Unsupported dtype: {other:?}"),
392        }
393    }
394
395    fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {
396        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
397    }
398
399    fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
400        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)
401    }
402
403    fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
404        execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitand_scalar(lhs, rhs.elem()))
405    }
406
407    fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
408        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)
409    }
410
411    fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
412        execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitor_scalar(lhs, rhs.elem()))
413    }
414
415    fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
416        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)
417    }
418
419    fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
420        execute_with_int_dtype!(lhs, |lhs| NdArrayBitOps::bitxor_scalar(lhs, rhs.elem()))
421    }
422
423    fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {
424        execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)
425    }
426
427    fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
428        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
429            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
430                (a.elem::<i64>() << (b.elem::<u32>())).elem()
431            })
432        })
433    }
434
435    fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
436        execute_with_int_dtype!(lhs, I, |lhs| {
437            NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
438                (a.elem::<i64>() << rhs.elem::<u32>()).elem()
439            })
440        })
441    }
442
443    fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {
444        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {
445            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
446                (a.elem::<i64>() >> (b.elem::<u32>())).elem()
447            })
448        })
449    }
450
451    fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor {
452        execute_with_int_dtype!(lhs, I, |lhs| {
453            NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
454                (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
455            })
456        })
457    }
458
459    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
460        execute_with_int_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
461    }
462
463    fn int_unfold(
464        tensor: IntTensor<Self>,
465        dim: usize,
466        size: usize,
467        step: usize,
468    ) -> IntTensor<Self> {
469        execute_with_int_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
470    }
471}