1use alloc::vec;
3use alloc::vec::Vec;
4use burn_common::rand::get_seeded_rng;
5use burn_tensor::ops::FloatTensor;
6use burn_tensor::ops::IntTensorOps;
7use burn_tensor::Distribution;
8
9use burn_tensor::ElementConversion;
10use core::ops::Range;
11use ndarray::IntoDimension;
12use ndarray::Zip;
13
14use crate::element::FloatNdArrayElement;
16use crate::element::IntNdArrayElement;
17use crate::element::QuantElement;
18use crate::execute_with_float_dtype;
19use crate::new_tensor_float;
20use crate::{tensor::NdArrayTensor, NdArray};
21use crate::{NdArrayDevice, SEED};
22
23use burn_tensor::{backend::Backend, Shape, TensorData};
25
26use super::{NdArrayMathOps, NdArrayOps};
27
28impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
29 for NdArray<E, I, Q>
30{
31 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<I> {
32 NdArrayTensor::from_data(data)
33 }
34
35 async fn int_into_data(tensor: NdArrayTensor<I>) -> TensorData {
36 NdArrayOps::into_data(tensor)
37 }
38
39 fn int_to_device(tensor: NdArrayTensor<I>, _device: &NdArrayDevice) -> NdArrayTensor<I> {
40 tensor
41 }
42
43 fn int_reshape(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
44 NdArrayOps::reshape(tensor, shape)
45 }
46
47 fn int_slice(tensor: NdArrayTensor<I>, ranges: &[Range<usize>]) -> NdArrayTensor<I> {
48 NdArrayOps::slice(tensor, ranges)
49 }
50
51 fn int_device(_tensor: &NdArrayTensor<I>) -> <NdArray<E> as Backend>::Device {
52 NdArrayDevice::Cpu
53 }
54
55 fn int_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
56 let values = vec![0; shape.num_elements()];
57 NdArrayTensor::from_data(TensorData::new(values, shape))
58 }
59
60 fn int_mask_where(
61 tensor: NdArrayTensor<I>,
62 mask: NdArrayTensor<bool>,
63 source: NdArrayTensor<I>,
64 ) -> NdArrayTensor<I> {
65 NdArrayMathOps::mask_where(tensor, mask, source)
66 }
67
68 fn int_mask_fill(
69 tensor: NdArrayTensor<I>,
70 mask: NdArrayTensor<bool>,
71 value: I,
72 ) -> NdArrayTensor<I> {
73 NdArrayMathOps::mask_fill(tensor, mask, value)
74 }
75
76 fn int_slice_assign(
77 tensor: NdArrayTensor<I>,
78 ranges: &[Range<usize>],
79 value: NdArrayTensor<I>,
80 ) -> NdArrayTensor<I> {
81 NdArrayOps::slice_assign(tensor, ranges, value)
82 }
83
84 fn int_cat(tensors: Vec<NdArrayTensor<I>>, dim: usize) -> NdArrayTensor<I> {
85 NdArrayOps::cat(tensors, dim)
86 }
87
88 fn int_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
89 let output = Zip::from(&lhs.array)
90 .and(&rhs.array)
91 .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
92 .into_shared();
93 NdArrayTensor::new(output)
94 }
95
96 fn int_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
97 let array = lhs.array.mapv(|a| a == rhs).into_shared();
98 NdArrayTensor { array }
99 }
100
101 fn int_greater(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
102 let tensor = Self::int_sub(lhs, rhs);
103 Self::int_greater_elem(tensor, 0.elem())
104 }
105
106 fn int_greater_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
107 let array = lhs.array.mapv(|a| a > rhs).into_shared();
108 NdArrayTensor::new(array)
109 }
110
111 fn int_greater_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
112 let tensor = Self::int_sub(lhs, rhs);
113 Self::int_greater_equal_elem(tensor, 0.elem())
114 }
115
116 fn int_greater_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
117 let array = lhs.array.mapv(|a| a >= rhs).into_shared();
118 NdArrayTensor::new(array)
119 }
120
121 fn int_lower(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
122 let tensor = Self::int_sub(lhs, rhs);
123 Self::int_lower_elem(tensor, 0.elem())
124 }
125
126 fn int_lower_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
127 let array = lhs.array.mapv(|a| a < rhs).into_shared();
128 NdArrayTensor::new(array)
129 }
130
131 fn int_lower_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
132 let tensor = Self::int_sub(lhs, rhs);
133 Self::int_lower_equal_elem(tensor, 0.elem())
134 }
135
136 fn int_lower_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
137 let array = lhs.array.mapv(|a| a <= rhs).into_shared();
138 NdArrayTensor::new(array)
139 }
140
141 fn int_add(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
142 NdArrayMathOps::add(lhs, rhs)
143 }
144
145 fn int_add_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
146 NdArrayMathOps::add_scalar(lhs, rhs)
147 }
148
149 fn int_sub(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
150 NdArrayMathOps::sub(lhs, rhs)
151 }
152
153 fn int_sub_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
154 NdArrayMathOps::sub_scalar(lhs, rhs)
155 }
156
157 fn int_mul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
158 NdArrayMathOps::mul(lhs, rhs)
159 }
160
161 fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
162 NdArrayMathOps::mul_scalar(lhs, rhs)
163 }
164
165 fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
166 NdArrayMathOps::div(lhs, rhs)
167 }
168
169 fn int_div_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
170 NdArrayMathOps::div_scalar(lhs, rhs)
171 }
172
173 fn int_remainder(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
174 NdArrayMathOps::remainder(lhs, rhs)
175 }
176
177 fn int_remainder_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
178 NdArrayMathOps::remainder_scalar(lhs, rhs)
179 }
180
181 fn int_neg(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
182 Self::int_mul_scalar(tensor, (-1).elem())
183 }
184
185 fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
186 Self::int_from_data(TensorData::zeros::<i64, _>(shape), device)
187 }
188
189 fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
190 Self::int_from_data(TensorData::ones::<i64, _>(shape), device)
191 }
192
193 fn int_full(
194 shape: Shape,
195 fill_value: I,
196 device: &<NdArray<E> as Backend>::Device,
197 ) -> NdArrayTensor<I> {
198 Self::int_from_data(TensorData::full(shape, fill_value), device)
199 }
200
201 fn int_sum(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
202 NdArrayMathOps::sum(tensor)
203 }
204
205 fn int_sum_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
206 NdArrayMathOps::sum_dim(tensor, dim)
207 }
208
209 fn int_prod(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
210 NdArrayMathOps::prod(tensor)
211 }
212
213 fn int_prod_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
214 NdArrayMathOps::prod_dim(tensor, dim)
215 }
216
217 fn int_mean(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
218 NdArrayMathOps::mean(tensor)
219 }
220
221 fn int_mean_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
222 NdArrayMathOps::mean_dim(tensor, dim)
223 }
224
225 fn int_gather(
226 dim: usize,
227 tensor: NdArrayTensor<I>,
228 indices: NdArrayTensor<I>,
229 ) -> NdArrayTensor<I> {
230 NdArrayMathOps::gather(dim, tensor, indices)
231 }
232
233 fn int_scatter(
234 dim: usize,
235 tensor: NdArrayTensor<I>,
236 indices: NdArrayTensor<I>,
237 value: NdArrayTensor<I>,
238 ) -> NdArrayTensor<I> {
239 NdArrayMathOps::scatter(dim, tensor, indices, value)
240 }
241
242 fn int_select(
243 tensor: NdArrayTensor<I>,
244 dim: usize,
245 indices: NdArrayTensor<I>,
246 ) -> NdArrayTensor<I> {
247 NdArrayMathOps::select(tensor, dim, indices)
248 }
249
250 fn int_select_assign(
251 tensor: NdArrayTensor<I>,
252 dim: usize,
253 indices: NdArrayTensor<I>,
254 value: NdArrayTensor<I>,
255 ) -> NdArrayTensor<I> {
256 NdArrayMathOps::select_assign(tensor, dim, indices, value)
257 }
258 fn int_argmax(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
259 NdArrayMathOps::argmax(tensor, dim)
260 }
261
262 fn int_argmin(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
263 NdArrayMathOps::argmin(tensor, dim)
264 }
265
266 fn int_clamp_min(tensor: NdArrayTensor<I>, min: I) -> NdArrayTensor<I> {
267 NdArrayMathOps::clamp_min(tensor, min)
268 }
269
270 fn int_clamp_max(tensor: NdArrayTensor<I>, max: I) -> NdArrayTensor<I> {
271 NdArrayMathOps::clamp_max(tensor, max)
272 }
273
274 fn int_clamp(tensor: NdArrayTensor<I>, min: I, max: I) -> NdArrayTensor<I> {
275 NdArrayMathOps::clamp(tensor, min, max)
276 }
277
278 fn int_abs(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
279 let array = tensor.array.mapv_into(|a| a.int_abs_elem()).into_shared();
280
281 NdArrayTensor::new(array)
282 }
283
284 fn int_into_float(tensor: NdArrayTensor<I>) -> FloatTensor<Self> {
285 new_tensor_float!(NdArrayTensor {
286 array: tensor.array.mapv(|a| a.elem()).into_shared()
287 })
288 }
289
290 fn int_swap_dims(tensor: NdArrayTensor<I>, dim1: usize, dim2: usize) -> NdArrayTensor<I> {
291 NdArrayOps::swap_dims(tensor, dim1, dim2)
292 }
293
294 fn int_random(
295 shape: Shape,
296 distribution: Distribution,
297 device: &NdArrayDevice,
298 ) -> NdArrayTensor<I> {
299 let mut seed = SEED.lock().unwrap();
300 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
301 rng_seeded.clone()
302 } else {
303 get_seeded_rng()
304 };
305
306 let effective_distribution = if distribution == Distribution::Default {
307 Distribution::Uniform(0.0, 255.0) } else {
309 distribution
310 };
311
312 let tensor = Self::int_from_data(
313 TensorData::random::<i64, _, _>(shape, effective_distribution, &mut rng),
314 device,
315 );
316 *seed = Some(rng);
317 tensor
318 }
319
320 fn int_powi(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
321 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
322 (a.elem::<i64>().pow(b.elem::<u32>())).elem()
323 })
324 }
325
326 fn int_powf(lhs: NdArrayTensor<I>, rhs: FloatTensor<Self>) -> NdArrayTensor<I> {
327 execute_with_float_dtype!(rhs => |rhs| {
328 NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| {
329 (a.elem::<i64>().pow(*b as u32)).elem()
330 })
331 })
332 }
333
334 fn int_powf_scalar(lhs: NdArrayTensor<I>, rhs: f32) -> NdArrayTensor<I> {
335 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::<i64>().pow(rhs as u32)).elem())
336 }
337
338 fn int_permute(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
339 let array = tensor.array.permuted_axes(axes.into_dimension());
340 NdArrayTensor { array }
341 }
342
343 fn int_flip(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
344 NdArrayOps::flip(tensor, axes)
345 }
346
347 fn int_sign(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
348 NdArrayMathOps::sign_op(tensor)
349 }
350
351 fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
352 NdArrayOps::expand(tensor, shape)
353 }
354}