1use alloc::vec::Vec;
4use burn_backend::{
5 DType, Distribution, ExecutionError, FloatDType, Scalar, TensorData, TensorMetadata,
6 ops::IntTensorOps,
7 tensor::{BoolTensor, Device, FloatTensor, IntTensor},
8};
9use burn_std::{Bytes, IntDType, Shape, Slice, bf16, f16};
10use num_traits::ToPrimitive;
11
12use crate::Layout;
13use crate::ops::binary::{binary_op_typed, int_binary_op, int_scalar_op, scalar_op_typed};
14use crate::{Flex, FlexTensor, ops::matmul};
15
16fn scalar_to_int_pair(dtype: DType, rhs: &Scalar) -> (i64, u64) {
19 if dtype == DType::U64 {
20 (0, rhs.to_u64().unwrap())
21 } else {
22 (rhs.to_i64().unwrap(), 0)
23 }
24}
25
26impl IntTensorOps<Flex> for Flex {
27 fn int_from_data(data: TensorData, _device: &Device<Flex>) -> IntTensor<Flex> {
28 FlexTensor::from_data(data)
29 }
30
31 async fn int_into_data(tensor: IntTensor<Flex>) -> Result<TensorData, ExecutionError> {
32 Ok(tensor.into_data())
33 }
34
35 fn int_device(_tensor: &IntTensor<Flex>) -> Device<Flex> {
36 Default::default()
37 }
38
39 fn int_to_device(tensor: IntTensor<Flex>, _device: &Device<Flex>) -> IntTensor<Flex> {
40 tensor
41 }
42
43 fn int_cat(tensors: Vec<IntTensor<Flex>>, dim: usize) -> IntTensor<Flex> {
44 crate::ops::cat::cat(tensors, dim)
45 }
46
47 fn int_reshape(tensor: IntTensor<Flex>, shape: Shape) -> IntTensor<Flex> {
48 tensor.reshape(shape)
49 }
50
51 fn int_slice(tensor: IntTensor<Flex>, slices: &[Slice]) -> IntTensor<Flex> {
52 crate::ops::slice::slice(tensor, slices)
53 }
54
55 fn int_empty(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
56 FlexTensor::empty(shape, dtype.into())
57 }
58
59 fn int_mask_where(
60 tensor: IntTensor<Flex>,
61 mask: BoolTensor<Flex>,
62 value: IntTensor<Flex>,
63 ) -> IntTensor<Flex> {
64 debug_assert_eq!(
65 tensor.dtype(),
66 value.dtype(),
67 "int_mask_where: dtype mismatch"
68 );
69 match tensor.dtype() {
70 DType::I64 => crate::ops::mask::mask_where::<i64>(tensor, mask, value),
71 DType::I32 => crate::ops::mask::mask_where::<i32>(tensor, mask, value),
72 DType::I16 => crate::ops::mask::mask_where::<i16>(tensor, mask, value),
73 DType::I8 => crate::ops::mask::mask_where::<i8>(tensor, mask, value),
74 DType::U64 => crate::ops::mask::mask_where::<u64>(tensor, mask, value),
75 DType::U32 => crate::ops::mask::mask_where::<u32>(tensor, mask, value),
76 DType::U16 => crate::ops::mask::mask_where::<u16>(tensor, mask, value),
77 DType::U8 => crate::ops::mask::mask_where::<u8>(tensor, mask, value),
78 dt => panic!("int_mask_where: unsupported dtype {:?}", dt),
79 }
80 }
81
82 fn int_mask_fill(
83 tensor: IntTensor<Flex>,
84 mask: BoolTensor<Flex>,
85 value: Scalar,
86 ) -> IntTensor<Flex> {
87 match tensor.dtype() {
88 DType::I64 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap()),
89 DType::I32 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i32),
90 DType::I16 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i16),
91 DType::I8 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i8),
92 DType::U64 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap()),
93 DType::U32 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u32),
94 DType::U16 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u16),
95 DType::U8 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u8),
96 dt => panic!("int_mask_fill: unsupported dtype {:?}", dt),
97 }
98 }
99
100 fn int_slice_assign(
101 tensor: IntTensor<Flex>,
102 slices: &[Slice],
103 value: IntTensor<Flex>,
104 ) -> IntTensor<Flex> {
105 crate::ops::slice::slice_assign(tensor, slices, value)
106 }
107
108 fn int_gather(
116 dim: usize,
117 tensor: IntTensor<Flex>,
118 indices: IntTensor<Flex>,
119 ) -> IntTensor<Flex> {
120 match tensor.dtype() {
121 DType::I64 => crate::ops::gather_scatter::gather::<i64>(tensor, dim, indices),
122 DType::I32 => crate::ops::gather_scatter::gather::<i32>(tensor, dim, indices),
123 DType::I16 => crate::ops::gather_scatter::gather::<i16>(tensor, dim, indices),
124 DType::I8 => crate::ops::gather_scatter::gather::<i8>(tensor, dim, indices),
125 DType::U64 => crate::ops::gather_scatter::gather::<u64>(tensor, dim, indices),
126 DType::U32 => crate::ops::gather_scatter::gather::<u32>(tensor, dim, indices),
127 DType::U16 => crate::ops::gather_scatter::gather::<u16>(tensor, dim, indices),
128 DType::U8 => crate::ops::gather_scatter::gather::<u8>(tensor, dim, indices),
129 dt => panic!("int_gather: unsupported dtype {:?}", dt),
130 }
131 }
132
133 fn int_scatter_add(
139 dim: usize,
140 tensor: IntTensor<Flex>,
141 indices: IntTensor<Flex>,
142 value: IntTensor<Flex>,
143 ) -> IntTensor<Flex> {
144 debug_assert_eq!(
145 tensor.dtype(),
146 value.dtype(),
147 "int_scatter_add: dtype mismatch"
148 );
149 match tensor.dtype() {
150 DType::I64 => {
151 crate::ops::gather_scatter::scatter_add::<i64>(tensor, dim, indices, value)
152 }
153 DType::I32 => {
154 crate::ops::gather_scatter::scatter_add::<i32>(tensor, dim, indices, value)
155 }
156 DType::I16 => {
157 crate::ops::gather_scatter::scatter_add::<i16>(tensor, dim, indices, value)
158 }
159 DType::I8 => crate::ops::gather_scatter::scatter_add::<i8>(tensor, dim, indices, value),
160 DType::U64 => {
161 crate::ops::gather_scatter::scatter_add::<u64>(tensor, dim, indices, value)
162 }
163 DType::U32 => {
164 crate::ops::gather_scatter::scatter_add::<u32>(tensor, dim, indices, value)
165 }
166 DType::U16 => {
167 crate::ops::gather_scatter::scatter_add::<u16>(tensor, dim, indices, value)
168 }
169 DType::U8 => crate::ops::gather_scatter::scatter_add::<u8>(tensor, dim, indices, value),
170 dt => panic!("int_scatter_add: unsupported dtype {:?}", dt),
171 }
172 }
173
174 fn int_scatter_nd(
175 data: IntTensor<Flex>,
176 indices: IntTensor<Flex>,
177 values: IntTensor<Flex>,
178 reduction: burn_backend::tensor::IndexingUpdateOp,
179 ) -> IntTensor<Flex> {
180 match data.dtype() {
181 DType::I64 => {
182 crate::ops::gather_scatter::scatter_nd::<i64>(data, indices, values, reduction)
183 }
184 DType::I32 => {
185 crate::ops::gather_scatter::scatter_nd::<i32>(data, indices, values, reduction)
186 }
187 DType::I16 => {
188 crate::ops::gather_scatter::scatter_nd::<i16>(data, indices, values, reduction)
189 }
190 DType::I8 => {
191 crate::ops::gather_scatter::scatter_nd::<i8>(data, indices, values, reduction)
192 }
193 DType::U64 => {
194 crate::ops::gather_scatter::scatter_nd::<u64>(data, indices, values, reduction)
195 }
196 DType::U32 => {
197 crate::ops::gather_scatter::scatter_nd::<u32>(data, indices, values, reduction)
198 }
199 DType::U16 => {
200 crate::ops::gather_scatter::scatter_nd::<u16>(data, indices, values, reduction)
201 }
202 DType::U8 => {
203 crate::ops::gather_scatter::scatter_nd::<u8>(data, indices, values, reduction)
204 }
205 dt => panic!("int_scatter_nd: unsupported dtype {:?}", dt),
206 }
207 }
208
209 fn int_gather_nd(data: IntTensor<Flex>, indices: IntTensor<Flex>) -> IntTensor<Flex> {
210 match data.dtype() {
211 DType::I64 => crate::ops::gather_scatter::gather_nd::<i64>(data, indices),
212 DType::I32 => crate::ops::gather_scatter::gather_nd::<i32>(data, indices),
213 DType::I16 => crate::ops::gather_scatter::gather_nd::<i16>(data, indices),
214 DType::I8 => crate::ops::gather_scatter::gather_nd::<i8>(data, indices),
215 DType::U64 => crate::ops::gather_scatter::gather_nd::<u64>(data, indices),
216 DType::U32 => crate::ops::gather_scatter::gather_nd::<u32>(data, indices),
217 DType::U16 => crate::ops::gather_scatter::gather_nd::<u16>(data, indices),
218 DType::U8 => crate::ops::gather_scatter::gather_nd::<u8>(data, indices),
219 dt => panic!("int_gather_nd: unsupported dtype {:?}", dt),
220 }
221 }
222
223 fn int_select(
228 tensor: IntTensor<Flex>,
229 dim: usize,
230 indices: IntTensor<Flex>,
231 ) -> IntTensor<Flex> {
232 match tensor.dtype() {
233 DType::I64 => crate::ops::gather_scatter::select::<i64>(tensor, dim, indices),
234 DType::I32 => crate::ops::gather_scatter::select::<i32>(tensor, dim, indices),
235 DType::I16 => crate::ops::gather_scatter::select::<i16>(tensor, dim, indices),
236 DType::I8 => crate::ops::gather_scatter::select::<i8>(tensor, dim, indices),
237 DType::U64 => crate::ops::gather_scatter::select::<u64>(tensor, dim, indices),
238 DType::U32 => crate::ops::gather_scatter::select::<u32>(tensor, dim, indices),
239 DType::U16 => crate::ops::gather_scatter::select::<u16>(tensor, dim, indices),
240 DType::U8 => crate::ops::gather_scatter::select::<u8>(tensor, dim, indices),
241 dt => panic!("int_select: unsupported dtype {:?}", dt),
242 }
243 }
244
245 fn int_select_add(
251 tensor: IntTensor<Flex>,
252 dim: usize,
253 indices: IntTensor<Flex>,
254 value: IntTensor<Flex>,
255 ) -> IntTensor<Flex> {
256 debug_assert_eq!(
257 tensor.dtype(),
258 value.dtype(),
259 "int_select_add: dtype mismatch"
260 );
261 match tensor.dtype() {
262 DType::I64 => {
263 crate::ops::gather_scatter::select_add::<i64>(tensor, dim, indices, value)
264 }
265 DType::I32 => {
266 crate::ops::gather_scatter::select_add::<i32>(tensor, dim, indices, value)
267 }
268 DType::I16 => {
269 crate::ops::gather_scatter::select_add::<i16>(tensor, dim, indices, value)
270 }
271 DType::I8 => crate::ops::gather_scatter::select_add::<i8>(tensor, dim, indices, value),
272 DType::U64 => {
273 crate::ops::gather_scatter::select_add::<u64>(tensor, dim, indices, value)
274 }
275 DType::U32 => {
276 crate::ops::gather_scatter::select_add::<u32>(tensor, dim, indices, value)
277 }
278 DType::U16 => {
279 crate::ops::gather_scatter::select_add::<u16>(tensor, dim, indices, value)
280 }
281 DType::U8 => crate::ops::gather_scatter::select_add::<u8>(tensor, dim, indices, value),
282 dt => panic!("int_select_add: unsupported dtype {:?}", dt),
283 }
284 }
285
286 fn int_equal(
287 lhs: IntTensor<Flex>,
288 rhs: IntTensor<Flex>,
289 out_dtype: burn_std::BoolDType,
290 ) -> BoolTensor<Flex> {
291 crate::ops::comparison::int_equal(lhs, rhs, out_dtype)
292 }
293
294 fn int_equal_elem(
295 lhs: IntTensor<Flex>,
296 rhs: Scalar,
297 out_dtype: burn_std::BoolDType,
298 ) -> BoolTensor<Flex> {
299 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
300 crate::ops::comparison::int_equal_elem(lhs, i, u, out_dtype)
301 }
302
303 fn int_greater(
304 lhs: IntTensor<Flex>,
305 rhs: IntTensor<Flex>,
306 out_dtype: burn_std::BoolDType,
307 ) -> BoolTensor<Flex> {
308 crate::ops::comparison::int_greater(lhs, rhs, out_dtype)
309 }
310
311 fn int_greater_elem(
312 lhs: IntTensor<Flex>,
313 rhs: Scalar,
314 out_dtype: burn_std::BoolDType,
315 ) -> BoolTensor<Flex> {
316 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
317 crate::ops::comparison::int_greater_elem(lhs, i, u, out_dtype)
318 }
319
320 fn int_greater_equal(
321 lhs: IntTensor<Flex>,
322 rhs: IntTensor<Flex>,
323 out_dtype: burn_std::BoolDType,
324 ) -> BoolTensor<Flex> {
325 crate::ops::comparison::int_greater_equal(lhs, rhs, out_dtype)
326 }
327
328 fn int_greater_equal_elem(
329 lhs: IntTensor<Flex>,
330 rhs: Scalar,
331 out_dtype: burn_std::BoolDType,
332 ) -> BoolTensor<Flex> {
333 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
334 crate::ops::comparison::int_greater_equal_elem(lhs, i, u, out_dtype)
335 }
336
337 fn int_lower(
338 lhs: IntTensor<Flex>,
339 rhs: IntTensor<Flex>,
340 out_dtype: burn_std::BoolDType,
341 ) -> BoolTensor<Flex> {
342 crate::ops::comparison::int_lower(lhs, rhs, out_dtype)
343 }
344
345 fn int_lower_elem(
346 lhs: IntTensor<Flex>,
347 rhs: Scalar,
348 out_dtype: burn_std::BoolDType,
349 ) -> BoolTensor<Flex> {
350 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
351 crate::ops::comparison::int_lower_elem(lhs, i, u, out_dtype)
352 }
353
354 fn int_lower_equal(
355 lhs: IntTensor<Flex>,
356 rhs: IntTensor<Flex>,
357 out_dtype: burn_std::BoolDType,
358 ) -> BoolTensor<Flex> {
359 crate::ops::comparison::int_lower_equal(lhs, rhs, out_dtype)
360 }
361
362 fn int_lower_equal_elem(
363 lhs: IntTensor<Flex>,
364 rhs: Scalar,
365 out_dtype: burn_std::BoolDType,
366 ) -> BoolTensor<Flex> {
367 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
368 crate::ops::comparison::int_lower_equal_elem(lhs, i, u, out_dtype)
369 }
370
371 fn int_add(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
372 int_binary_op(lhs, rhs, |a, b| a + b)
373 }
374
375 fn int_add_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
376 if lhs.dtype() == DType::U64 {
377 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
378 a.wrapping_add(b)
379 });
380 }
381 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a + b)
382 }
383
384 fn int_sub(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
385 int_binary_op(lhs, rhs, |a, b| a - b)
386 }
387
388 fn int_sub_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
389 if lhs.dtype() == DType::U64 {
390 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
391 a.wrapping_sub(b)
392 });
393 }
394 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a - b)
395 }
396
397 fn int_mul(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
398 int_binary_op(lhs, rhs, |a, b| a * b)
399 }
400
401 fn int_mul_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
402 if lhs.dtype() == DType::U64 {
403 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
404 a.wrapping_mul(b)
405 });
406 }
407 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a * b)
408 }
409
410 fn int_div(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
411 if lhs.dtype() == DType::U64 {
413 let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
414 return binary_op_typed(lhs, &rhs, |a: u64, b: u64| a / b);
415 }
416 int_binary_op(lhs, rhs, |a, b| a / b)
417 }
418
419 fn int_div_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
420 if lhs.dtype() == DType::U64 {
421 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a / b);
422 }
423 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a / b)
424 }
425
426 fn int_remainder(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
427 if lhs.dtype() == DType::U64 {
429 let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
430 return binary_op_typed(lhs, &rhs, |a: u64, b: u64| a % b);
431 }
432 int_binary_op(lhs, rhs, |a, b| ((a % b) + b) % b)
434 }
435
436 fn int_remainder_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
437 if lhs.dtype() == DType::U64 {
438 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a % b);
439 }
440 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| ((a % b) + b) % b)
442 }
443
444 fn int_into_float(
446 tensor: IntTensor<Flex>,
447 out_dtype: burn_std::FloatDType,
448 ) -> FloatTensor<Flex> {
449 let tensor = tensor.to_contiguous();
450 let shape = tensor.layout().shape().clone();
451 let src = tensor.dtype();
452 let out_dt = DType::from(out_dtype);
453
454 macro_rules! read_ints {
457 (|$x:ident| $conv:expr) => {
458 match src {
459 DType::I64 => tensor.storage::<i64>().iter().map(|&$x| $conv).collect(),
460 DType::I32 => tensor.storage::<i32>().iter().map(|&$x| $conv).collect(),
461 DType::I16 => tensor.storage::<i16>().iter().map(|&$x| $conv).collect(),
462 DType::I8 => tensor.storage::<i8>().iter().map(|&$x| $conv).collect(),
463 DType::U64 => tensor.storage::<u64>().iter().map(|&$x| $conv).collect(),
464 DType::U32 => tensor.storage::<u32>().iter().map(|&$x| $conv).collect(),
465 DType::U16 => tensor.storage::<u16>().iter().map(|&$x| $conv).collect(),
466 DType::U8 => tensor.storage::<u8>().iter().map(|&$x| $conv).collect(),
467 _ => panic!("int_into_float: unsupported source dtype {:?}", src),
468 }
469 };
470 }
471
472 match out_dtype {
473 FloatDType::F64 => {
474 let data: Vec<f64> = read_ints!(|x| x as f64);
475 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
476 }
477 FloatDType::F32 | FloatDType::Flex32 => {
478 let data: Vec<f32> = read_ints!(|x| x as f32);
479 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
480 }
481 FloatDType::F16 => {
482 let data: Vec<f16> = read_ints!(|x| f16::from_f32(x as f32));
483 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
484 }
485 FloatDType::BF16 => {
486 let data: Vec<bf16> = read_ints!(|x| bf16::from_f32(x as f32));
487 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
488 }
489 }
490 }
491
492 fn int_swap_dims(tensor: IntTensor<Flex>, dim1: usize, dim2: usize) -> IntTensor<Flex> {
493 tensor.transpose(dim1, dim2)
494 }
495
496 fn int_permute(tensor: IntTensor<Flex>, axes: &[usize]) -> IntTensor<Flex> {
497 tensor.permute(axes)
498 }
499
500 fn int_flip(tensor: IntTensor<Flex>, axes: &[usize]) -> IntTensor<Flex> {
501 crate::ops::flip::flip(tensor, axes)
502 }
503
504 fn int_random(
505 shape: Shape,
506 distribution: Distribution,
507 _device: &Device<Flex>,
508 dtype: IntDType,
509 ) -> IntTensor<Flex> {
510 let mut seed = crate::backend::SEED.lock().unwrap();
511 let mut rng = seed.take().unwrap_or_else(crate::backend::get_seeded_rng);
512 let data = match dtype {
513 IntDType::I64 => TensorData::random::<i64, _, _>(shape, distribution, &mut rng),
514 IntDType::I32 => TensorData::random::<i32, _, _>(shape, distribution, &mut rng),
515 IntDType::I16 => TensorData::random::<i16, _, _>(shape, distribution, &mut rng),
516 IntDType::I8 => TensorData::random::<i8, _, _>(shape, distribution, &mut rng),
517 IntDType::U64 => TensorData::random::<u64, _, _>(shape, distribution, &mut rng),
518 IntDType::U32 => TensorData::random::<u32, _, _>(shape, distribution, &mut rng),
519 IntDType::U16 => TensorData::random::<u16, _, _>(shape, distribution, &mut rng),
520 IntDType::U8 => TensorData::random::<u8, _, _>(shape, distribution, &mut rng),
521 };
522 *seed = Some(rng);
523 FlexTensor::from_data(data)
524 }
525
526 fn int_expand(tensor: IntTensor<Flex>, shape: Shape) -> IntTensor<Flex> {
527 crate::ops::expand::expand(tensor, shape)
528 }
529
530 fn int_matmul(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
531 matmul::int_matmul(lhs, rhs)
532 }
533
534 fn int_sum(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
535 crate::ops::reduce::sum(tensor)
536 }
537
538 fn int_sum_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
539 crate::ops::reduce::sum_dim(tensor, dim)
540 }
541
542 fn int_prod(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
543 crate::ops::reduce::prod(tensor)
544 }
545
546 fn int_prod_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
547 crate::ops::reduce::prod_dim(tensor, dim)
548 }
549
550 fn int_mean_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
551 crate::ops::reduce::mean_dim(tensor, dim)
552 }
553
554 fn int_cumsum(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
555 match tensor.dtype() {
556 DType::I64 => crate::ops::cumulative::cumsum::<i64>(tensor, dim),
557 DType::I32 => crate::ops::cumulative::cumsum::<i32>(tensor, dim),
558 DType::I16 => crate::ops::cumulative::cumsum::<i16>(tensor, dim),
559 DType::I8 => crate::ops::cumulative::cumsum::<i8>(tensor, dim),
560 DType::U64 => crate::ops::cumulative::cumsum::<u64>(tensor, dim),
561 DType::U32 => crate::ops::cumulative::cumsum::<u32>(tensor, dim),
562 DType::U16 => crate::ops::cumulative::cumsum::<u16>(tensor, dim),
563 DType::U8 => crate::ops::cumulative::cumsum::<u8>(tensor, dim),
564 dt => panic!("int_cumsum: unsupported dtype {:?}", dt),
565 }
566 }
567
568 fn int_cumprod(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
569 match tensor.dtype() {
570 DType::I64 => crate::ops::cumulative::cumprod::<i64>(tensor, dim),
571 DType::I32 => crate::ops::cumulative::cumprod::<i32>(tensor, dim),
572 DType::I16 => crate::ops::cumulative::cumprod::<i16>(tensor, dim),
573 DType::I8 => crate::ops::cumulative::cumprod::<i8>(tensor, dim),
574 DType::U64 => crate::ops::cumulative::cumprod::<u64>(tensor, dim),
575 DType::U32 => crate::ops::cumulative::cumprod::<u32>(tensor, dim),
576 DType::U16 => crate::ops::cumulative::cumprod::<u16>(tensor, dim),
577 DType::U8 => crate::ops::cumulative::cumprod::<u8>(tensor, dim),
578 dt => panic!("int_cumprod: unsupported dtype {:?}", dt),
579 }
580 }
581
582 fn int_cummin(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
583 match tensor.dtype() {
584 DType::I64 => crate::ops::cumulative::cummin::<i64>(tensor, dim),
585 DType::I32 => crate::ops::cumulative::cummin::<i32>(tensor, dim),
586 DType::I16 => crate::ops::cumulative::cummin::<i16>(tensor, dim),
587 DType::I8 => crate::ops::cumulative::cummin::<i8>(tensor, dim),
588 DType::U64 => crate::ops::cumulative::cummin::<u64>(tensor, dim),
589 DType::U32 => crate::ops::cumulative::cummin::<u32>(tensor, dim),
590 DType::U16 => crate::ops::cumulative::cummin::<u16>(tensor, dim),
591 DType::U8 => crate::ops::cumulative::cummin::<u8>(tensor, dim),
592 dt => panic!("int_cummin: unsupported dtype {:?}", dt),
593 }
594 }
595
596 fn int_cummax(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
597 match tensor.dtype() {
598 DType::I64 => crate::ops::cumulative::cummax::<i64>(tensor, dim),
599 DType::I32 => crate::ops::cumulative::cummax::<i32>(tensor, dim),
600 DType::I16 => crate::ops::cumulative::cummax::<i16>(tensor, dim),
601 DType::I8 => crate::ops::cumulative::cummax::<i8>(tensor, dim),
602 DType::U64 => crate::ops::cumulative::cummax::<u64>(tensor, dim),
603 DType::U32 => crate::ops::cumulative::cummax::<u32>(tensor, dim),
604 DType::U16 => crate::ops::cumulative::cummax::<u16>(tensor, dim),
605 DType::U8 => crate::ops::cumulative::cummax::<u8>(tensor, dim),
606 dt => panic!("int_cummax: unsupported dtype {:?}", dt),
607 }
608 }
609
610 fn int_argmax(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
611 crate::ops::reduce::argmax(tensor, dim)
612 }
613
614 fn int_argtopk(_tensor: IntTensor<Flex>, _dim: usize, _k: usize) -> IntTensor<Flex> {
615 panic!("argtopk not implemented for flex")
616 }
617
618 fn int_argmin(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
619 crate::ops::reduce::argmin(tensor, dim)
620 }
621
622 fn int_abs(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
623 crate::ops::unary::int_abs(tensor)
624 }
625
626 fn bitwise_and(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
627 int_binary_op(lhs, rhs, |a, b| a & b)
628 }
629
630 fn bitwise_and_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
631 if lhs.dtype() == DType::U64 {
632 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a & b);
633 }
634 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a & b)
635 }
636
637 fn bitwise_or(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
638 int_binary_op(lhs, rhs, |a, b| a | b)
639 }
640
641 fn bitwise_or_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
642 if lhs.dtype() == DType::U64 {
643 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a | b);
644 }
645 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a | b)
646 }
647
648 fn bitwise_xor(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
649 int_binary_op(lhs, rhs, |a, b| a ^ b)
650 }
651
652 fn bitwise_xor_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
653 if lhs.dtype() == DType::U64 {
654 return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a ^ b);
655 }
656 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a ^ b)
657 }
658
659 fn bitwise_not(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
660 int_scalar_op(tensor, 0, |a, _| !a)
662 }
663
664 fn bitwise_left_shift(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
666 int_binary_op(lhs, rhs, |a, b| a.wrapping_shl(b as u32))
667 }
668
669 fn bitwise_left_shift_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
670 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a.wrapping_shl(b as u32))
671 }
672
673 fn bitwise_right_shift(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
674 int_binary_op(lhs, rhs, |a, b| a.wrapping_shr(b as u32))
675 }
676
677 fn bitwise_right_shift_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
678 int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a.wrapping_shr(b as u32))
679 }
680
681 fn int_cast(tensor: IntTensor<Flex>, dtype: IntDType) -> IntTensor<Flex> {
682 let target_dtype: DType = dtype.into();
683
684 if tensor.dtype() == target_dtype {
686 return tensor;
687 }
688
689 let tensor = tensor.to_contiguous();
691 let shape = tensor.layout().shape().clone();
692
693 macro_rules! cast_impl {
695 ($src_type:ty, $dst_type:ty, $dst_dtype:expr) => {{
696 let src: &[$src_type] = tensor.storage();
697 let dst: Vec<$dst_type> = src.iter().map(|&x| x as $dst_type).collect();
698 FlexTensor::new(
699 Bytes::from_elems(dst),
700 Layout::contiguous(shape),
701 $dst_dtype,
702 )
703 }};
704 }
705
706 match (tensor.dtype(), target_dtype) {
708 (DType::I64, DType::I32) => cast_impl!(i64, i32, DType::I32),
710 (DType::I64, DType::I16) => cast_impl!(i64, i16, DType::I16),
711 (DType::I64, DType::I8) => cast_impl!(i64, i8, DType::I8),
712 (DType::I64, DType::U64) => cast_impl!(i64, u64, DType::U64),
713 (DType::I64, DType::U32) => cast_impl!(i64, u32, DType::U32),
714 (DType::I64, DType::U16) => cast_impl!(i64, u16, DType::U16),
715 (DType::I64, DType::U8) => cast_impl!(i64, u8, DType::U8),
716
717 (DType::I32, DType::I64) => cast_impl!(i32, i64, DType::I64),
719 (DType::I32, DType::I16) => cast_impl!(i32, i16, DType::I16),
720 (DType::I32, DType::I8) => cast_impl!(i32, i8, DType::I8),
721 (DType::I32, DType::U64) => cast_impl!(i32, u64, DType::U64),
722 (DType::I32, DType::U32) => cast_impl!(i32, u32, DType::U32),
723 (DType::I32, DType::U16) => cast_impl!(i32, u16, DType::U16),
724 (DType::I32, DType::U8) => cast_impl!(i32, u8, DType::U8),
725
726 (DType::I16, DType::I64) => cast_impl!(i16, i64, DType::I64),
728 (DType::I16, DType::I32) => cast_impl!(i16, i32, DType::I32),
729 (DType::I16, DType::I8) => cast_impl!(i16, i8, DType::I8),
730 (DType::I16, DType::U64) => cast_impl!(i16, u64, DType::U64),
731 (DType::I16, DType::U32) => cast_impl!(i16, u32, DType::U32),
732 (DType::I16, DType::U16) => cast_impl!(i16, u16, DType::U16),
733 (DType::I16, DType::U8) => cast_impl!(i16, u8, DType::U8),
734
735 (DType::I8, DType::I64) => cast_impl!(i8, i64, DType::I64),
737 (DType::I8, DType::I32) => cast_impl!(i8, i32, DType::I32),
738 (DType::I8, DType::I16) => cast_impl!(i8, i16, DType::I16),
739 (DType::I8, DType::U64) => cast_impl!(i8, u64, DType::U64),
740 (DType::I8, DType::U32) => cast_impl!(i8, u32, DType::U32),
741 (DType::I8, DType::U16) => cast_impl!(i8, u16, DType::U16),
742 (DType::I8, DType::U8) => cast_impl!(i8, u8, DType::U8),
743
744 (DType::U64, DType::I64) => cast_impl!(u64, i64, DType::I64),
746 (DType::U64, DType::I32) => cast_impl!(u64, i32, DType::I32),
747 (DType::U64, DType::I16) => cast_impl!(u64, i16, DType::I16),
748 (DType::U64, DType::I8) => cast_impl!(u64, i8, DType::I8),
749 (DType::U64, DType::U32) => cast_impl!(u64, u32, DType::U32),
750 (DType::U64, DType::U16) => cast_impl!(u64, u16, DType::U16),
751 (DType::U64, DType::U8) => cast_impl!(u64, u8, DType::U8),
752
753 (DType::U32, DType::I64) => cast_impl!(u32, i64, DType::I64),
755 (DType::U32, DType::I32) => cast_impl!(u32, i32, DType::I32),
756 (DType::U32, DType::I16) => cast_impl!(u32, i16, DType::I16),
757 (DType::U32, DType::I8) => cast_impl!(u32, i8, DType::I8),
758 (DType::U32, DType::U64) => cast_impl!(u32, u64, DType::U64),
759 (DType::U32, DType::U16) => cast_impl!(u32, u16, DType::U16),
760 (DType::U32, DType::U8) => cast_impl!(u32, u8, DType::U8),
761
762 (DType::U16, DType::I64) => cast_impl!(u16, i64, DType::I64),
764 (DType::U16, DType::I32) => cast_impl!(u16, i32, DType::I32),
765 (DType::U16, DType::I16) => cast_impl!(u16, i16, DType::I16),
766 (DType::U16, DType::I8) => cast_impl!(u16, i8, DType::I8),
767 (DType::U16, DType::U64) => cast_impl!(u16, u64, DType::U64),
768 (DType::U16, DType::U32) => cast_impl!(u16, u32, DType::U32),
769 (DType::U16, DType::U8) => cast_impl!(u16, u8, DType::U8),
770
771 (DType::U8, DType::I64) => cast_impl!(u8, i64, DType::I64),
773 (DType::U8, DType::I32) => cast_impl!(u8, i32, DType::I32),
774 (DType::U8, DType::I16) => cast_impl!(u8, i16, DType::I16),
775 (DType::U8, DType::I8) => cast_impl!(u8, i8, DType::I8),
776 (DType::U8, DType::U64) => cast_impl!(u8, u64, DType::U64),
777 (DType::U8, DType::U32) => cast_impl!(u8, u32, DType::U32),
778 (DType::U8, DType::U16) => cast_impl!(u8, u16, DType::U16),
779
780 _ => panic!(
781 "int_cast: unsupported conversion from {:?} to {:?}",
782 tensor.dtype(),
783 target_dtype
784 ),
785 }
786 }
787
788 fn int_unfold(
789 tensor: IntTensor<Flex>,
790 dim: usize,
791 size: usize,
792 step: usize,
793 ) -> IntTensor<Flex> {
794 crate::ops::unfold::unfold_int(tensor, dim, size, step)
795 }
796
797 fn int_neg(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
798 int_scalar_op(tensor, 0i64, |a, _| a.wrapping_neg())
799 }
800
801 fn int_clamp(tensor: IntTensor<Flex>, min: Scalar, max: Scalar) -> IntTensor<Flex> {
802 if tensor.dtype() == DType::U64 {
803 let min_val = min.to_u64().unwrap();
804 let max_val = max.to_u64().unwrap();
805 return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.clamp(min_val, max_val));
806 }
807 let min_val = min.to_i64().unwrap();
808 let max_val = max.to_i64().unwrap();
809 int_scalar_op(tensor, 0i64, move |x, _| x.clamp(min_val, max_val))
810 }
811
812 fn int_clamp_min(tensor: IntTensor<Flex>, min: Scalar) -> IntTensor<Flex> {
813 if tensor.dtype() == DType::U64 {
814 let min_val = min.to_u64().unwrap();
815 return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.max(min_val));
816 }
817 let min_val = min.to_i64().unwrap();
818 int_scalar_op(tensor, 0i64, move |x, _| x.max(min_val))
819 }
820
821 fn int_clamp_max(tensor: IntTensor<Flex>, max: Scalar) -> IntTensor<Flex> {
822 if tensor.dtype() == DType::U64 {
823 let max_val = max.to_u64().unwrap();
824 return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.min(max_val));
825 }
826 let max_val = max.to_i64().unwrap();
827 int_scalar_op(tensor, 0i64, move |x, _| x.min(max_val))
828 }
829
830 fn int_sign(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
831 if tensor.dtype() == DType::U64 {
832 return scalar_op_typed(tensor, 0u64, |x: u64, _| if x > 0 { 1 } else { 0 });
833 }
834 int_scalar_op(tensor, 0i64, |x, _| {
835 if x > 0 {
836 1
837 } else if x < 0 {
838 -1
839 } else {
840 0
841 }
842 })
843 }
844
845 fn int_mean(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
846 let n = tensor.layout().num_elements();
847 assert!(n > 0, "int_mean: cannot take mean of empty tensor");
848 let dtype = tensor.dtype();
849 let sum_result = crate::ops::reduce::sum(tensor);
850 macro_rules! compute_mean {
852 ($ty:ty) => {{
853 let data: &[$ty] = sum_result.storage();
854 let mean_val = (data[0] as i64 / n as i64) as $ty;
855 FlexTensor::new(
856 Bytes::from_elems(alloc::vec![mean_val]),
857 Layout::contiguous(Shape::from(alloc::vec![1])),
858 dtype,
859 )
860 }};
861 }
862 match dtype {
863 DType::I64 => compute_mean!(i64),
864 DType::I32 => compute_mean!(i32),
865 DType::I16 => compute_mean!(i16),
866 DType::I8 => compute_mean!(i8),
867 other => panic!("int_mean: unsupported dtype {:?}", other),
868 }
869 }
870
871 fn int_max(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
872 crate::ops::reduce::max(tensor)
873 }
874
875 fn int_max_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
876 crate::ops::reduce::max_dim(tensor, dim)
877 }
878
879 fn int_min(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
880 crate::ops::reduce::min(tensor)
881 }
882
883 fn int_min_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
884 crate::ops::reduce::min_dim(tensor, dim)
885 }
886
887 fn int_max_dim_with_indices(
888 tensor: IntTensor<Flex>,
889 dim: usize,
890 ) -> (IntTensor<Flex>, IntTensor<Flex>) {
891 crate::ops::reduce::max_dim_with_indices(tensor, dim)
892 }
893
894 fn int_min_dim_with_indices(
895 tensor: IntTensor<Flex>,
896 dim: usize,
897 ) -> (IntTensor<Flex>, IntTensor<Flex>) {
898 crate::ops::reduce::min_dim_with_indices(tensor, dim)
899 }
900
901 fn int_any(tensor: IntTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
902 crate::ops::comparison::any_int(tensor, out_dtype)
903 }
904
905 fn int_any_dim(
906 tensor: IntTensor<Flex>,
907 dim: usize,
908 out_dtype: burn_std::BoolDType,
909 ) -> BoolTensor<Flex> {
910 crate::ops::comparison::any_int_dim(tensor, dim, out_dtype)
911 }
912
913 fn int_all(tensor: IntTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
914 crate::ops::comparison::all_int(tensor, out_dtype)
915 }
916
917 fn int_all_dim(
918 tensor: IntTensor<Flex>,
919 dim: usize,
920 out_dtype: burn_std::BoolDType,
921 ) -> BoolTensor<Flex> {
922 crate::ops::comparison::all_int_dim(tensor, dim, out_dtype)
923 }
924
925 fn int_powi(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
926 int_binary_op(lhs, rhs, |a, b| a.wrapping_pow(b as u32))
927 }
928
929 fn int_zeros(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
930 FlexTensor::zeros(shape, dtype.into())
931 }
932
933 fn int_ones(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
934 let dt: DType = dtype.into();
935 match dt {
936 DType::I64 => FlexTensor::filled_typed(shape, dt, 1i64),
937 DType::I32 => FlexTensor::filled_typed(shape, dt, 1i32),
938 DType::I16 => FlexTensor::filled_typed(shape, dt, 1i16),
939 DType::I8 => FlexTensor::filled_typed(shape, dt, 1i8),
940 DType::U64 => FlexTensor::filled_typed(shape, dt, 1u64),
941 DType::U32 => FlexTensor::filled_typed(shape, dt, 1u32),
942 DType::U16 => FlexTensor::filled_typed(shape, dt, 1u16),
943 DType::U8 => FlexTensor::filled_typed(shape, dt, 1u8),
944 _ => unreachable!(),
945 }
946 }
947
948 fn int_full(
949 shape: Shape,
950 fill_value: burn_backend::Scalar,
951 _device: &Device<Flex>,
952 dtype: IntDType,
953 ) -> IntTensor<Flex> {
954 let dt: DType = dtype.into();
955 let v = fill_value.to_i64().unwrap();
956 match dt {
957 DType::I64 => FlexTensor::filled_typed(shape, dt, v),
958 DType::I32 => FlexTensor::filled_typed(shape, dt, v as i32),
959 DType::I16 => FlexTensor::filled_typed(shape, dt, v as i16),
960 DType::I8 => FlexTensor::filled_typed(shape, dt, v as i8),
961 DType::U64 => FlexTensor::filled_typed(shape, dt, v as u64),
962 DType::U32 => FlexTensor::filled_typed(shape, dt, v as u32),
963 DType::U16 => FlexTensor::filled_typed(shape, dt, v as u16),
964 DType::U8 => FlexTensor::filled_typed(shape, dt, v as u8),
965 _ => unreachable!(),
966 }
967 }
968
969 fn int_transpose(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
970 let ndims = tensor.layout().num_dims();
971 if ndims < 2 {
972 return tensor;
973 }
974 tensor.transpose(ndims - 2, ndims - 1)
975 }
976
977 fn int_repeat_dim(tensor: IntTensor<Flex>, dim: usize, times: usize) -> IntTensor<Flex> {
978 crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
979 }
980
981 fn int_not_equal(
982 lhs: IntTensor<Flex>,
983 rhs: IntTensor<Flex>,
984 out_dtype: burn_std::BoolDType,
985 ) -> BoolTensor<Flex> {
986 crate::ops::comparison::int_not_equal(lhs, rhs, out_dtype)
987 }
988
989 fn int_not_equal_elem(
990 lhs: IntTensor<Flex>,
991 rhs: burn_backend::Scalar,
992 out_dtype: burn_std::BoolDType,
993 ) -> BoolTensor<Flex> {
994 let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
995 crate::ops::comparison::int_not_equal_elem(lhs, i, u, out_dtype)
996 }
997
998 fn int_sort(tensor: IntTensor<Flex>, dim: usize, descending: bool) -> IntTensor<Flex> {
999 crate::ops::sort::sort(tensor, dim, descending)
1000 }
1001
1002 fn int_sort_with_indices(
1003 tensor: IntTensor<Flex>,
1004 dim: usize,
1005 descending: bool,
1006 ) -> (IntTensor<Flex>, IntTensor<Flex>) {
1007 crate::ops::sort::sort_with_indices(tensor, dim, descending)
1008 }
1009
1010 fn int_argsort(tensor: IntTensor<Flex>, dim: usize, descending: bool) -> IntTensor<Flex> {
1011 crate::ops::sort::argsort(tensor, dim, descending)
1012 }
1013
1014 fn int_powi_scalar(lhs: IntTensor<Flex>, rhs: burn_backend::Scalar) -> IntTensor<Flex> {
1015 use num_traits::ToPrimitive;
1016 match rhs.to_i64().unwrap() {
1017 0 => Self::int_ones(lhs.shape(), &Default::default(), lhs.dtype().into()),
1018 1 => lhs,
1019 2 => Self::int_mul(lhs.clone(), lhs),
1020 _ => Self::int_powi_scalar_impl(lhs, rhs),
1021 }
1022 }
1023
1024 fn int_powi_scalar_impl(lhs: IntTensor<Flex>, rhs: burn_backend::Scalar) -> IntTensor<Flex> {
1025 use num_traits::ToPrimitive;
1026 let exp = rhs.to_i64().unwrap() as u32;
1027 if lhs.dtype() == DType::U64 {
1028 return scalar_op_typed(lhs, exp as u64, move |x: u64, _| x.wrapping_pow(exp));
1029 }
1030 int_scalar_op(lhs, exp as i64, move |x, _| x.wrapping_pow(exp))
1031 }
1032
1033 fn int_max_abs(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
1034 let abs = Self::int_abs(tensor);
1035 crate::ops::reduce::max(abs)
1036 }
1037
1038 fn int_max_abs_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
1039 let abs = Self::int_abs(tensor);
1040 crate::ops::reduce::max_dim(abs, dim)
1041 }
1042
1043 fn int_arange(
1044 range: core::ops::Range<i64>,
1045 _device: &Device<Flex>,
1046 dtype: IntDType,
1047 ) -> IntTensor<Flex> {
1048 Self::int_arange_step(range, 1, &Default::default(), dtype)
1049 }
1050
1051 fn int_arange_step(
1052 range: core::ops::Range<i64>,
1053 step: usize,
1054 _device: &Device<Flex>,
1055 dtype: IntDType,
1056 ) -> IntTensor<Flex> {
1057 let dt: DType = dtype.into();
1058
1059 macro_rules! arange_typed {
1060 ($ty:ty) => {{
1061 let data: Vec<$ty> = range.step_by(step).map(|v| v as $ty).collect();
1062 let shape = Shape::from(alloc::vec![data.len()]);
1063 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), dt)
1064 }};
1065 }
1066
1067 match dt {
1068 DType::I64 => arange_typed!(i64),
1069 DType::I32 => arange_typed!(i32),
1070 DType::I16 => arange_typed!(i16),
1071 DType::I8 => arange_typed!(i8),
1072 DType::U64 => arange_typed!(u64),
1073 DType::U32 => arange_typed!(u32),
1074 DType::U16 => arange_typed!(u16),
1075 DType::U8 => arange_typed!(u8),
1076 _ => unreachable!(),
1077 }
1078 }
1079}
1080
1081#[cfg(test)]
1091mod tests {
1092 use alloc::vec;
1093 use burn_backend::TensorData;
1094 use burn_backend::ops::IntTensorOps;
1095
1096 use crate::Flex;
1097 use crate::FlexTensor;
1098
1099 #[test]
1100 fn test_u64_div_large_values() {
1101 let a = FlexTensor::from_data(TensorData::new(vec![u64::MAX], [1]));
1102 let b = FlexTensor::from_data(TensorData::new(vec![2u64], [1]));
1103 let result = Flex::int_div(a, b);
1104 let values: Vec<u64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1105 assert_eq!(values[0], u64::MAX / 2);
1106 }
1107
1108 #[test]
1109 fn test_u64_remainder_large_values() {
1110 let a = FlexTensor::from_data(TensorData::new(vec![u64::MAX], [1]));
1111 let b = FlexTensor::from_data(TensorData::new(vec![2u64], [1]));
1112 let result = Flex::int_remainder(a, b);
1113 let values: Vec<u64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1114 assert_eq!(values[0], u64::MAX % 2);
1115 }
1116
1117 #[test]
1118 fn test_int_abs_min_value() {
1119 let a = FlexTensor::from_data(TensorData::new(vec![i64::MIN], [1]));
1121 let result = Flex::int_abs(a);
1122 let values: Vec<i64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1123 assert_eq!(values[0], i64::MIN.wrapping_abs());
1124 }
1125
1126 #[test]
1127 fn test_int_neg_min_value() {
1128 let a = FlexTensor::from_data(TensorData::new(vec![i64::MIN], [1]));
1130 let result = Flex::int_neg(a);
1131 let values: Vec<i64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1132 assert_eq!(values[0], i64::MIN.wrapping_neg());
1133 }
1134
1135 #[test]
1136 fn test_int_shift_large_amount() {
1137 let a = FlexTensor::from_data(TensorData::new(vec![1i64], [1]));
1139 let b = FlexTensor::from_data(TensorData::new(vec![64i64], [1]));
1140 let _left = Flex::bitwise_left_shift(a.clone(), b.clone());
1141 let _right = Flex::bitwise_right_shift(a, b);
1142 }
1143
1144 #[test]
1145 fn test_int_into_float_f64() {
1146 use burn_backend::ops::IntTensorOps;
1147 use burn_std::FloatDType;
1148
1149 let t = FlexTensor::from_data(TensorData::new(vec![1i64, 2, -3], [3]));
1150 let result = Flex::int_into_float(t, FloatDType::F64);
1151 assert_eq!(result.dtype(), burn_backend::DType::F64);
1152 let data: Vec<f64> = result.into_data().to_vec().unwrap();
1153 assert_eq!(data, vec![1.0f64, 2.0, -3.0]);
1154 }
1155
1156 #[test]
1157 fn test_u64_add_scalar_large() {
1158 let t = FlexTensor::from_data(TensorData::new(vec![1u64, 2, 3], [3]));
1159 let big: u64 = (i64::MAX as u64) + 100;
1160 let result = Flex::int_add_scalar(t, burn_backend::Scalar::from(big));
1161 let data: Vec<u64> = result.into_data().to_vec().unwrap();
1162 assert_eq!(data, vec![big + 1, big + 2, big + 3]);
1163 }
1164
1165 #[test]
1166 fn test_u64_greater_elem_large() {
1167 let big: u64 = (i64::MAX as u64) + 100;
1168 let t = FlexTensor::from_data(TensorData::new(vec![big, big + 1, big - 1], [3]));
1169 let result = Flex::int_greater_elem(
1170 t,
1171 burn_backend::Scalar::from(big),
1172 burn_std::BoolStore::Native,
1173 );
1174 let data: Vec<bool> = result.into_data().to_vec().unwrap();
1175 assert_eq!(data, vec![false, true, false]);
1176 }
1177
1178 #[test]
1179 fn test_int_mask_fill_i32() {
1180 let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1181 let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
1182 let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(0i64));
1183 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1184 assert_eq!(data, vec![0, 2, 0, 4]);
1185 }
1186
1187 #[test]
1188 fn test_int_mask_fill_i16() {
1189 let t = FlexTensor::from_data(TensorData::new(vec![10i16, 20, 30, 40], [4]));
1190 let mask = FlexTensor::from_data(TensorData::new(vec![false, true, false, true], [4]));
1191 let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(-1i64));
1192 let data: Vec<i16> = result.into_data().to_vec().unwrap();
1193 assert_eq!(data, vec![10, -1, 30, -1]);
1194 }
1195
1196 #[test]
1197 fn test_int_mask_fill_u8() {
1198 let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1199 let mask = FlexTensor::from_data(TensorData::new(vec![true, true, false, false], [4]));
1200 let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(255i64));
1201 let data: Vec<u8> = result.into_data().to_vec().unwrap();
1202 assert_eq!(data, vec![255, 255, 3, 4]);
1203 }
1204
1205 #[test]
1206 fn test_int_mask_fill_u32() {
1207 let t = FlexTensor::from_data(TensorData::new(vec![100u32, 200, 300], [3]));
1208 let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true], [3]));
1209 let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(0i64));
1210 let data: Vec<u32> = result.into_data().to_vec().unwrap();
1211 assert_eq!(data, vec![0, 200, 0]);
1212 }
1213
1214 #[test]
1215 fn test_int_mask_where_i32() {
1216 let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1217 let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
1218 let v = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30, 40], [4]));
1219 let result = Flex::int_mask_where(t, mask, v);
1220 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1221 assert_eq!(data, vec![10, 2, 30, 4]);
1222 }
1223
1224 #[test]
1225 fn test_int_mask_where_u8() {
1226 let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1227 let mask = FlexTensor::from_data(TensorData::new(vec![false, true, false, true], [4]));
1228 let v = FlexTensor::from_data(TensorData::new(vec![10u8, 20, 30, 40], [4]));
1229 let result = Flex::int_mask_where(t, mask, v);
1230 let data: Vec<u8> = result.into_data().to_vec().unwrap();
1231 assert_eq!(data, vec![1, 20, 3, 40]);
1232 }
1233
1234 #[test]
1235 fn test_int_gather_i32() {
1236 let t = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30, 40, 50, 60], [2, 3]));
1237 let indices = FlexTensor::from_data(TensorData::new(vec![2i64, 0, 1, 2], [2, 2]));
1238 let result = Flex::int_gather(1, t, indices);
1239 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1240 assert_eq!(data, vec![30, 10, 50, 60]);
1241 }
1242
1243 #[test]
1244 fn test_int_select_u16() {
1245 let t = FlexTensor::from_data(TensorData::new(vec![10u16, 20, 30, 40, 50, 60], [2, 3]));
1246 let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 1], [2]));
1247 let result = Flex::int_select(t, 1, indices);
1248 let data: Vec<u16> = result.into_data().to_vec().unwrap();
1249 assert_eq!(data, vec![10, 20, 40, 50]);
1250 }
1251
1252 #[test]
1253 fn test_int_cumsum_i32() {
1254 let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1255 let result = Flex::int_cumsum(t, 0);
1256 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1257 assert_eq!(data, vec![1, 3, 6, 10]);
1258 }
1259
1260 #[test]
1261 fn test_int_cumprod_u8() {
1262 let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1263 let result = Flex::int_cumprod(t, 0);
1264 let data: Vec<u8> = result.into_data().to_vec().unwrap();
1265 assert_eq!(data, vec![1, 2, 6, 24]);
1266 }
1267
1268 #[test]
1269 fn test_int_cummin_i32() {
1270 let t = FlexTensor::from_data(TensorData::new(vec![3i32, 1, 4, 1, 5], [5]));
1271 let result = Flex::int_cummin(t, 0);
1272 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1273 assert_eq!(data, vec![3, 1, 1, 1, 1]);
1274 }
1275
1276 #[test]
1277 fn test_int_cummax_u16() {
1278 let t = FlexTensor::from_data(TensorData::new(vec![3u16, 1, 4, 1, 5], [5]));
1279 let result = Flex::int_cummax(t, 0);
1280 let data: Vec<u16> = result.into_data().to_vec().unwrap();
1281 assert_eq!(data, vec![3, 3, 4, 4, 5]);
1282 }
1283
1284 #[test]
1285 fn test_int_scatter_add_i32() {
1286 let t = FlexTensor::from_data(TensorData::new(vec![0i32, 0, 0], [1, 3]));
1287 let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 2, 1], [1, 3]));
1288 let values = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30], [1, 3]));
1289 let result = Flex::int_scatter_add(1, t, indices, values);
1290 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1291 assert_eq!(data, vec![10, 30, 20]);
1292 }
1293
1294 #[test]
1295 fn test_int_select_add_u8() {
1296 let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3], [3]));
1297 let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 2], [2]));
1298 let values = FlexTensor::from_data(TensorData::new(vec![10u8, 20], [2]));
1299 let result = Flex::int_select_add(t, 0, indices, values);
1300 let data: Vec<u8> = result.into_data().to_vec().unwrap();
1301 assert_eq!(data, vec![11, 2, 23]);
1302 }
1303
1304 #[test]
1305 fn test_int_random_i32() {
1306 use burn_backend::{DType, Distribution, ops::IntTensorOps};
1307 use burn_std::{IntDType, Shape};
1308
1309 let shape = Shape::from(vec![100]);
1310 let dist = Distribution::Uniform(0.0, 10.0);
1311 let device = crate::FlexDevice;
1312 let t = Flex::int_random(shape, dist, &device, IntDType::I32);
1313 assert_eq!(t.dtype(), DType::I32);
1314 let data: Vec<i32> = t.into_data().to_vec().unwrap();
1315 assert!(data.iter().all(|&v| (0..=10).contains(&v)));
1316 }
1317
1318 #[test]
1319 fn test_int_random_u8() {
1320 use burn_backend::{DType, Distribution, ops::IntTensorOps};
1321 use burn_std::{IntDType, Shape};
1322
1323 let shape = Shape::from(vec![50]);
1324 let dist = Distribution::Uniform(0.0, 100.0);
1325 let device = crate::FlexDevice;
1326 let t = Flex::int_random(shape, dist, &device, IntDType::U8);
1327 assert_eq!(t.dtype(), DType::U8);
1328 }
1329
1330 #[test]
1331 fn test_int_mean_i32() {
1332 use burn_backend::{DType, ops::IntTensorOps};
1333
1334 let t = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30], [3]));
1335 let result = Flex::int_mean(t);
1336 assert_eq!(result.dtype(), DType::I32);
1337 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1338 assert_eq!(data, vec![20]); }
1340}