burn_ndarray/ops/
int_tensor.rs

1// Language
2use alloc::vec;
3use alloc::vec::Vec;
4use burn_common::rand::get_seeded_rng;
5use burn_tensor::ops::FloatTensor;
6use burn_tensor::ops::IntTensorOps;
7use burn_tensor::Distribution;
8
9use burn_tensor::ElementConversion;
10use core::ops::Range;
11use ndarray::IntoDimension;
12use ndarray::Zip;
13
14// Current crate
15use crate::element::FloatNdArrayElement;
16use crate::element::IntNdArrayElement;
17use crate::element::QuantElement;
18use crate::execute_with_float_dtype;
19use crate::new_tensor_float;
20use crate::{tensor::NdArrayTensor, NdArray};
21use crate::{NdArrayDevice, SEED};
22
23// Workspace crates
24use burn_tensor::{backend::Backend, Shape, TensorData};
25
26use super::{NdArrayMathOps, NdArrayOps};
27
28impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
29    for NdArray<E, I, Q>
30{
31    fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<I> {
32        NdArrayTensor::from_data(data)
33    }
34
35    async fn int_into_data(tensor: NdArrayTensor<I>) -> TensorData {
36        NdArrayOps::into_data(tensor)
37    }
38
39    fn int_to_device(tensor: NdArrayTensor<I>, _device: &NdArrayDevice) -> NdArrayTensor<I> {
40        tensor
41    }
42
43    fn int_reshape(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
44        NdArrayOps::reshape(tensor, shape)
45    }
46
47    fn int_slice(tensor: NdArrayTensor<I>, ranges: &[Range<usize>]) -> NdArrayTensor<I> {
48        NdArrayOps::slice(tensor, ranges)
49    }
50
51    fn int_device(_tensor: &NdArrayTensor<I>) -> <NdArray<E> as Backend>::Device {
52        NdArrayDevice::Cpu
53    }
54
55    fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
56        let values = vec![0; shape.num_elements()];
57        NdArrayTensor::from_data(TensorData::new(values, shape))
58    }
59
60    fn int_mask_where(
61        tensor: NdArrayTensor<I>,
62        mask: NdArrayTensor<bool>,
63        source: NdArrayTensor<I>,
64    ) -> NdArrayTensor<I> {
65        NdArrayMathOps::mask_where(tensor, mask, source)
66    }
67
68    fn int_mask_fill(
69        tensor: NdArrayTensor<I>,
70        mask: NdArrayTensor<bool>,
71        value: I,
72    ) -> NdArrayTensor<I> {
73        NdArrayMathOps::mask_fill(tensor, mask, value)
74    }
75
76    fn int_slice_assign(
77        tensor: NdArrayTensor<I>,
78        ranges: &[Range<usize>],
79        value: NdArrayTensor<I>,
80    ) -> NdArrayTensor<I> {
81        NdArrayOps::slice_assign(tensor, ranges, value)
82    }
83
84    fn int_cat(tensors: Vec<NdArrayTensor<I>>, dim: usize) -> NdArrayTensor<I> {
85        NdArrayOps::cat(tensors, dim)
86    }
87
88    fn int_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
89        let output = Zip::from(&lhs.array)
90            .and(&rhs.array)
91            .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
92            .into_shared();
93        NdArrayTensor::new(output)
94    }
95
96    fn int_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
97        let array = lhs.array.mapv(|a| a == rhs).into_shared();
98        NdArrayTensor { array }
99    }
100
101    fn int_greater(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
102        let tensor = Self::int_sub(lhs, rhs);
103        Self::int_greater_elem(tensor, 0.elem())
104    }
105
106    fn int_greater_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
107        let array = lhs.array.mapv(|a| a > rhs).into_shared();
108        NdArrayTensor::new(array)
109    }
110
111    fn int_greater_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
112        let tensor = Self::int_sub(lhs, rhs);
113        Self::int_greater_equal_elem(tensor, 0.elem())
114    }
115
116    fn int_greater_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
117        let array = lhs.array.mapv(|a| a >= rhs).into_shared();
118        NdArrayTensor::new(array)
119    }
120
121    fn int_lower(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
122        let tensor = Self::int_sub(lhs, rhs);
123        Self::int_lower_elem(tensor, 0.elem())
124    }
125
126    fn int_lower_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
127        let array = lhs.array.mapv(|a| a < rhs).into_shared();
128        NdArrayTensor::new(array)
129    }
130
131    fn int_lower_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
132        let tensor = Self::int_sub(lhs, rhs);
133        Self::int_lower_equal_elem(tensor, 0.elem())
134    }
135
136    fn int_lower_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
137        let array = lhs.array.mapv(|a| a <= rhs).into_shared();
138        NdArrayTensor::new(array)
139    }
140
141    fn int_add(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
142        NdArrayMathOps::add(lhs, rhs)
143    }
144
145    fn int_add_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
146        NdArrayMathOps::add_scalar(lhs, rhs)
147    }
148
149    fn int_sub(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
150        NdArrayMathOps::sub(lhs, rhs)
151    }
152
153    fn int_sub_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
154        NdArrayMathOps::sub_scalar(lhs, rhs)
155    }
156
157    fn int_mul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
158        NdArrayMathOps::mul(lhs, rhs)
159    }
160
161    fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
162        NdArrayMathOps::mul_scalar(lhs, rhs)
163    }
164
165    fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
166        NdArrayMathOps::div(lhs, rhs)
167    }
168
169    fn int_div_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
170        NdArrayMathOps::div_scalar(lhs, rhs)
171    }
172
173    fn int_remainder(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
174        NdArrayMathOps::remainder(lhs, rhs)
175    }
176
177    fn int_remainder_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
178        NdArrayMathOps::remainder_scalar(lhs, rhs)
179    }
180
181    fn int_neg(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
182        Self::int_mul_scalar(tensor, (-1).elem())
183    }
184
185    fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
186        Self::int_from_data(TensorData::zeros::<i64, _>(shape), device)
187    }
188
189    fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
190        Self::int_from_data(TensorData::ones::<i64, _>(shape), device)
191    }
192
193    fn int_full(
194        shape: Shape,
195        fill_value: I,
196        device: &<NdArray<E> as Backend>::Device,
197    ) -> NdArrayTensor<I> {
198        Self::int_from_data(TensorData::full(shape, fill_value), device)
199    }
200
201    fn int_sum(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
202        NdArrayMathOps::sum(tensor)
203    }
204
205    fn int_sum_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
206        NdArrayMathOps::sum_dim(tensor, dim)
207    }
208
209    fn int_prod(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
210        NdArrayMathOps::prod(tensor)
211    }
212
213    fn int_prod_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
214        NdArrayMathOps::prod_dim(tensor, dim)
215    }
216
217    fn int_mean(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
218        NdArrayMathOps::mean(tensor)
219    }
220
221    fn int_mean_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
222        NdArrayMathOps::mean_dim(tensor, dim)
223    }
224
225    fn int_gather(
226        dim: usize,
227        tensor: NdArrayTensor<I>,
228        indices: NdArrayTensor<I>,
229    ) -> NdArrayTensor<I> {
230        NdArrayMathOps::gather(dim, tensor, indices)
231    }
232
233    fn int_scatter(
234        dim: usize,
235        tensor: NdArrayTensor<I>,
236        indices: NdArrayTensor<I>,
237        value: NdArrayTensor<I>,
238    ) -> NdArrayTensor<I> {
239        NdArrayMathOps::scatter(dim, tensor, indices, value)
240    }
241
242    fn int_select(
243        tensor: NdArrayTensor<I>,
244        dim: usize,
245        indices: NdArrayTensor<I>,
246    ) -> NdArrayTensor<I> {
247        NdArrayMathOps::select(tensor, dim, indices)
248    }
249
250    fn int_select_assign(
251        tensor: NdArrayTensor<I>,
252        dim: usize,
253        indices: NdArrayTensor<I>,
254        value: NdArrayTensor<I>,
255    ) -> NdArrayTensor<I> {
256        NdArrayMathOps::select_assign(tensor, dim, indices, value)
257    }
258    fn int_argmax(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
259        NdArrayMathOps::argmax(tensor, dim)
260    }
261
262    fn int_argmin(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
263        NdArrayMathOps::argmin(tensor, dim)
264    }
265
266    fn int_clamp_min(tensor: NdArrayTensor<I>, min: I) -> NdArrayTensor<I> {
267        NdArrayMathOps::clamp_min(tensor, min)
268    }
269
270    fn int_clamp_max(tensor: NdArrayTensor<I>, max: I) -> NdArrayTensor<I> {
271        NdArrayMathOps::clamp_max(tensor, max)
272    }
273
274    fn int_clamp(tensor: NdArrayTensor<I>, min: I, max: I) -> NdArrayTensor<I> {
275        NdArrayMathOps::clamp(tensor, min, max)
276    }
277
278    fn int_abs(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
279        let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared();
280
281        NdArrayTensor::new(array)
282    }
283
284    fn int_into_float(tensor: NdArrayTensor<I>) -> FloatTensor<Self> {
285        new_tensor_float!(NdArrayTensor {
286            array: tensor.array.mapv(|a| a.elem()).into_shared()
287        })
288    }
289
290    fn int_swap_dims(tensor: NdArrayTensor<I>, dim1: usize, dim2: usize) -> NdArrayTensor<I> {
291        NdArrayOps::swap_dims(tensor, dim1, dim2)
292    }
293
294    fn int_random(
295        shape: Shape,
296        distribution: Distribution,
297        device: &NdArrayDevice,
298    ) -> NdArrayTensor<I> {
299        let mut seed = SEED.lock().unwrap();
300        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
301            rng_seeded.clone()
302        } else {
303            get_seeded_rng()
304        };
305
306        let effective_distribution = if distribution == Distribution::Default {
307            Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant
308        } else {
309            distribution
310        };
311
312        let tensor = Self::int_from_data(
313            TensorData::random::<i64, _, _>(shape, effective_distribution, &mut rng),
314            device,
315        );
316        *seed = Some(rng);
317        tensor
318    }
319
320    fn int_powi(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
321        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
322            (a.elem::<i64>().pow(b.elem::<u32>())).elem()
323        })
324    }
325
326    fn int_powf(lhs: NdArrayTensor<I>, rhs: FloatTensor<Self>) -> NdArrayTensor<I> {
327        execute_with_float_dtype!(rhs => |rhs| {
328            NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| {
329                (a.elem::<i64>().pow(*b as u32)).elem()
330            })
331        })
332    }
333
334    fn int_powf_scalar(lhs: NdArrayTensor<I>, rhs: f32) -> NdArrayTensor<I> {
335        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::<i64>().pow(rhs as u32)).elem())
336    }
337
338    fn int_permute(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
339        let array = tensor.array.permuted_axes(axes.into_dimension());
340        NdArrayTensor { array }
341    }
342
343    fn int_flip(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
344        NdArrayOps::flip(tensor, axes)
345    }
346
347    fn int_sign(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
348        NdArrayMathOps::sign_op(tensor)
349    }
350
351    fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
352        NdArrayOps::expand(tensor, shape)
353    }
354}