1use alloc::vec::Vec;
3use burn_common::rand::get_seeded_rng;
4use burn_tensor::Distribution;
5use burn_tensor::ops::FloatTensor;
6use burn_tensor::ops::IntTensorOps;
7
8use burn_tensor::ElementConversion;
9use core::ops::Range;
10use ndarray::IntoDimension;
11
12use crate::element::FloatNdArrayElement;
14use crate::element::IntNdArrayElement;
15use crate::element::QuantElement;
16use crate::execute_with_float_dtype;
17use crate::new_tensor_float;
18use crate::{NdArray, tensor::NdArrayTensor};
19use crate::{NdArrayDevice, SEED};
20
21use burn_tensor::{DType, Shape, TensorData, backend::Backend};
23
24use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};
25
26impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>
27 for NdArray<E, I, Q>
28{
29 fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor<I> {
30 match data.dtype {
31 DType::I64 | DType::I32 => NdArrayTensor::from_data(data),
32 _ => unimplemented!("Unsupported dtype for `int_from_data`"),
33 }
34 }
35
36 async fn int_into_data(tensor: NdArrayTensor<I>) -> TensorData {
37 NdArrayOps::into_data(tensor)
38 }
39
40 fn int_to_device(tensor: NdArrayTensor<I>, _device: &NdArrayDevice) -> NdArrayTensor<I> {
41 tensor
42 }
43
44 fn int_reshape(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
45 NdArrayOps::reshape(tensor, shape)
46 }
47
48 fn int_slice(tensor: NdArrayTensor<I>, ranges: &[Range<usize>]) -> NdArrayTensor<I> {
49 NdArrayOps::slice(tensor, ranges)
50 }
51
52 fn int_device(_tensor: &NdArrayTensor<I>) -> <NdArray<E> as Backend>::Device {
53 NdArrayDevice::Cpu
54 }
55
56 fn int_empty(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
57 Self::int_zeros(shape, device)
58 }
59
60 fn int_mask_where(
61 tensor: NdArrayTensor<I>,
62 mask: NdArrayTensor<bool>,
63 source: NdArrayTensor<I>,
64 ) -> NdArrayTensor<I> {
65 NdArrayMathOps::mask_where(tensor, mask, source)
66 }
67
68 fn int_mask_fill(
69 tensor: NdArrayTensor<I>,
70 mask: NdArrayTensor<bool>,
71 value: I,
72 ) -> NdArrayTensor<I> {
73 NdArrayMathOps::mask_fill(tensor, mask, value)
74 }
75
76 fn int_slice_assign(
77 tensor: NdArrayTensor<I>,
78 ranges: &[Range<usize>],
79 value: NdArrayTensor<I>,
80 ) -> NdArrayTensor<I> {
81 NdArrayOps::slice_assign(tensor, ranges, value)
82 }
83
84 fn int_cat(tensors: Vec<NdArrayTensor<I>>, dim: usize) -> NdArrayTensor<I> {
85 NdArrayOps::cat(tensors, dim)
86 }
87
88 fn int_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
89 NdArrayMathOps::equal(lhs, rhs)
90 }
91
92 fn int_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
93 NdArrayMathOps::equal_elem(lhs, rhs)
94 }
95
96 fn int_greater(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
97 NdArrayMathOps::greater(lhs, rhs)
98 }
99
100 fn int_greater_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
101 NdArrayMathOps::greater_elem(lhs, rhs)
102 }
103
104 fn int_greater_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
105 NdArrayMathOps::greater_equal(lhs, rhs)
106 }
107
108 fn int_greater_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
109 NdArrayMathOps::greater_equal_elem(lhs, rhs)
110 }
111
112 fn int_lower(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
113 NdArrayMathOps::lower(lhs, rhs)
114 }
115
116 fn int_lower_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
117 NdArrayMathOps::lower_elem(lhs, rhs)
118 }
119
120 fn int_lower_equal(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<bool> {
121 NdArrayMathOps::lower_equal(lhs, rhs)
122 }
123
124 fn int_lower_equal_elem(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<bool> {
125 NdArrayMathOps::lower_equal_elem(lhs, rhs)
126 }
127
128 fn int_add(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
129 NdArrayMathOps::add(lhs, rhs)
130 }
131
132 fn int_add_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
133 NdArrayMathOps::add_scalar(lhs, rhs)
134 }
135
136 fn int_sub(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
137 NdArrayMathOps::sub(lhs, rhs)
138 }
139
140 fn int_sub_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
141 NdArrayMathOps::sub_scalar(lhs, rhs)
142 }
143
144 fn int_mul(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
145 NdArrayMathOps::mul(lhs, rhs)
146 }
147
148 fn int_mul_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
149 NdArrayMathOps::mul_scalar(lhs, rhs)
150 }
151
152 fn int_div(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
153 NdArrayMathOps::div(lhs, rhs)
154 }
155
156 fn int_div_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
157 NdArrayMathOps::div_scalar(lhs, rhs)
158 }
159
160 fn int_remainder(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
161 NdArrayMathOps::remainder(lhs, rhs)
162 }
163
164 fn int_remainder_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
165 NdArrayMathOps::remainder_scalar(lhs, rhs)
166 }
167
168 fn int_neg(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
169 Self::int_mul_scalar(tensor, (-1).elem())
170 }
171
172 fn int_zeros(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
173 Self::int_from_data(TensorData::zeros::<I, _>(shape), device)
174 }
175
176 fn int_ones(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor<I> {
177 Self::int_from_data(TensorData::ones::<I, _>(shape), device)
178 }
179
180 fn int_full(
181 shape: Shape,
182 fill_value: I,
183 device: &<NdArray<E> as Backend>::Device,
184 ) -> NdArrayTensor<I> {
185 Self::int_from_data(TensorData::full(shape, fill_value), device)
186 }
187
188 fn int_sum(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
189 NdArrayMathOps::sum(tensor)
190 }
191
192 fn int_sum_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
193 NdArrayMathOps::sum_dim(tensor, dim)
194 }
195
196 fn int_prod(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
197 NdArrayMathOps::prod(tensor)
198 }
199
200 fn int_prod_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
201 NdArrayMathOps::prod_dim(tensor, dim)
202 }
203
204 fn int_mean(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
205 NdArrayMathOps::mean(tensor)
206 }
207
208 fn int_mean_dim(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
209 NdArrayMathOps::mean_dim(tensor, dim)
210 }
211
212 fn int_gather(
213 dim: usize,
214 tensor: NdArrayTensor<I>,
215 indices: NdArrayTensor<I>,
216 ) -> NdArrayTensor<I> {
217 NdArrayMathOps::gather(dim, tensor, indices)
218 }
219
220 fn int_scatter(
221 dim: usize,
222 tensor: NdArrayTensor<I>,
223 indices: NdArrayTensor<I>,
224 value: NdArrayTensor<I>,
225 ) -> NdArrayTensor<I> {
226 NdArrayMathOps::scatter(dim, tensor, indices, value)
227 }
228
229 fn int_select(
230 tensor: NdArrayTensor<I>,
231 dim: usize,
232 indices: NdArrayTensor<I>,
233 ) -> NdArrayTensor<I> {
234 NdArrayMathOps::select(tensor, dim, indices)
235 }
236
237 fn int_select_assign(
238 tensor: NdArrayTensor<I>,
239 dim: usize,
240 indices: NdArrayTensor<I>,
241 value: NdArrayTensor<I>,
242 ) -> NdArrayTensor<I> {
243 NdArrayMathOps::select_assign(tensor, dim, indices, value)
244 }
245 fn int_argmax(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
246 NdArrayMathOps::argmax(tensor, dim)
247 }
248
249 fn int_argmin(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
250 NdArrayMathOps::argmin(tensor, dim)
251 }
252
253 fn int_clamp_min(tensor: NdArrayTensor<I>, min: I) -> NdArrayTensor<I> {
254 NdArrayMathOps::clamp_min(tensor, min)
255 }
256
257 fn int_clamp_max(tensor: NdArrayTensor<I>, max: I) -> NdArrayTensor<I> {
258 NdArrayMathOps::clamp_max(tensor, max)
259 }
260
261 fn int_clamp(tensor: NdArrayTensor<I>, min: I, max: I) -> NdArrayTensor<I> {
262 NdArrayMathOps::clamp(tensor, min, max)
263 }
264
265 fn int_abs(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
266 NdArrayMathOps::abs(tensor)
267 }
268
269 fn int_into_float(tensor: NdArrayTensor<I>) -> FloatTensor<Self> {
270 new_tensor_float!(NdArrayTensor {
271 array: tensor.array.mapv(|a| a.elem()).into_shared()
272 })
273 }
274
275 fn int_swap_dims(tensor: NdArrayTensor<I>, dim1: usize, dim2: usize) -> NdArrayTensor<I> {
276 NdArrayOps::swap_dims(tensor, dim1, dim2)
277 }
278
279 fn int_random(
280 shape: Shape,
281 distribution: Distribution,
282 device: &NdArrayDevice,
283 ) -> NdArrayTensor<I> {
284 let mut seed = SEED.lock().unwrap();
285 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
286 rng_seeded.clone()
287 } else {
288 get_seeded_rng()
289 };
290
291 let effective_distribution = if distribution == Distribution::Default {
292 Distribution::Uniform(0.0, 255.0) } else {
294 distribution
295 };
296
297 let tensor = Self::int_from_data(
298 TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),
299 device,
300 );
301 *seed = Some(rng);
302 tensor
303 }
304
305 fn int_powi(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
306 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
307 (a.elem::<i64>().pow(b.elem::<u32>())).elem()
308 })
309 }
310
311 fn int_powf(lhs: NdArrayTensor<I>, rhs: FloatTensor<Self>) -> NdArrayTensor<I> {
312 execute_with_float_dtype!(rhs => |rhs| {
313 NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| {
314 (a.elem::<i64>().pow(*b as u32)).elem()
315 })
316 })
317 }
318
319 fn int_powf_scalar(lhs: NdArrayTensor<I>, rhs: f32) -> NdArrayTensor<I> {
320 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| (a.elem::<i64>().pow(rhs as u32)).elem())
321 }
322
323 fn int_permute(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
324 let array = tensor.array.permuted_axes(axes.into_dimension());
325 NdArrayTensor { array }
326 }
327
328 fn int_flip(tensor: NdArrayTensor<I>, axes: &[usize]) -> NdArrayTensor<I> {
329 NdArrayOps::flip(tensor, axes)
330 }
331
332 fn int_sign(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
333 NdArrayMathOps::sign_op(tensor)
334 }
335
336 fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
337 NdArrayOps::expand(tensor, shape)
338 }
339
340 fn bitwise_and(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
341 NdArrayBitOps::bitand(lhs, rhs)
342 }
343
344 fn bitwise_and_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
345 NdArrayBitOps::bitand_scalar(lhs, rhs)
346 }
347
348 fn bitwise_or(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
349 NdArrayBitOps::bitor(lhs, rhs)
350 }
351
352 fn bitwise_or_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
353 NdArrayBitOps::bitor_scalar(lhs, rhs)
354 }
355
356 fn bitwise_xor(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
357 NdArrayBitOps::bitxor(lhs, rhs)
358 }
359
360 fn bitwise_xor_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
361 NdArrayBitOps::bitxor_scalar(lhs, rhs)
362 }
363
364 fn bitwise_not(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
365 NdArrayBitOps::bitnot(tensor)
366 }
367
368 fn bitwise_left_shift(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
369 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
370 (a.elem::<i64>() << (b.elem::<u32>())).elem()
371 })
372 }
373
374 fn bitwise_left_shift_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
375 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
376 (a.elem::<i64>() << rhs.elem::<u32>()).elem()
377 })
378 }
379
380 fn bitwise_right_shift(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
381 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
382 (a.elem::<i64>() >> (b.elem::<u32>())).elem()
383 })
384 }
385
386 fn bitwise_right_shift_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
387 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
388 (a.elem::<i64>() >> rhs.elem::<u32>()).elem()
389 })
390 }
391}