Skip to main content

burn_flex/ops/
int.rs

1//! Int tensor operations for the Flex backend.
2
3use alloc::vec::Vec;
4use burn_backend::{
5    DType, Distribution, ExecutionError, FloatDType, Scalar, TensorData, TensorMetadata,
6    ops::IntTensorOps,
7    tensor::{BoolTensor, Device, FloatTensor, IntTensor},
8};
9use burn_std::{Bytes, IntDType, Shape, Slice, bf16, f16};
10use num_traits::ToPrimitive;
11
12use crate::Layout;
13use crate::ops::binary::{binary_op_typed, int_binary_op, int_scalar_op, scalar_op_typed};
14use crate::{Flex, FlexTensor, ops::matmul};
15
16/// Convert a Scalar to (i64, u64) pair for the given dtype.
17/// Only the matching type's conversion is validated; the other gets a dummy 0.
18fn scalar_to_int_pair(dtype: DType, rhs: &Scalar) -> (i64, u64) {
19    if dtype == DType::U64 {
20        (0, rhs.to_u64().unwrap())
21    } else {
22        (rhs.to_i64().unwrap(), 0)
23    }
24}
25
26impl IntTensorOps<Flex> for Flex {
27    fn int_from_data(data: TensorData, _device: &Device<Flex>) -> IntTensor<Flex> {
28        FlexTensor::from_data(data)
29    }
30
31    async fn int_into_data(tensor: IntTensor<Flex>) -> Result<TensorData, ExecutionError> {
32        Ok(tensor.into_data())
33    }
34
35    fn int_device(_tensor: &IntTensor<Flex>) -> Device<Flex> {
36        Default::default()
37    }
38
39    fn int_to_device(tensor: IntTensor<Flex>, _device: &Device<Flex>) -> IntTensor<Flex> {
40        tensor
41    }
42
43    fn int_cat(tensors: Vec<IntTensor<Flex>>, dim: usize) -> IntTensor<Flex> {
44        crate::ops::cat::cat(tensors, dim)
45    }
46
47    fn int_reshape(tensor: IntTensor<Flex>, shape: Shape) -> IntTensor<Flex> {
48        tensor.reshape(shape)
49    }
50
51    fn int_slice(tensor: IntTensor<Flex>, slices: &[Slice]) -> IntTensor<Flex> {
52        crate::ops::slice::slice(tensor, slices)
53    }
54
55    fn int_empty(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
56        FlexTensor::empty(shape, dtype.into())
57    }
58
59    fn int_mask_where(
60        tensor: IntTensor<Flex>,
61        mask: BoolTensor<Flex>,
62        value: IntTensor<Flex>,
63    ) -> IntTensor<Flex> {
64        debug_assert_eq!(
65            tensor.dtype(),
66            value.dtype(),
67            "int_mask_where: dtype mismatch"
68        );
69        match tensor.dtype() {
70            DType::I64 => crate::ops::mask::mask_where::<i64>(tensor, mask, value),
71            DType::I32 => crate::ops::mask::mask_where::<i32>(tensor, mask, value),
72            DType::I16 => crate::ops::mask::mask_where::<i16>(tensor, mask, value),
73            DType::I8 => crate::ops::mask::mask_where::<i8>(tensor, mask, value),
74            DType::U64 => crate::ops::mask::mask_where::<u64>(tensor, mask, value),
75            DType::U32 => crate::ops::mask::mask_where::<u32>(tensor, mask, value),
76            DType::U16 => crate::ops::mask::mask_where::<u16>(tensor, mask, value),
77            DType::U8 => crate::ops::mask::mask_where::<u8>(tensor, mask, value),
78            dt => panic!("int_mask_where: unsupported dtype {:?}", dt),
79        }
80    }
81
82    fn int_mask_fill(
83        tensor: IntTensor<Flex>,
84        mask: BoolTensor<Flex>,
85        value: Scalar,
86    ) -> IntTensor<Flex> {
87        match tensor.dtype() {
88            DType::I64 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap()),
89            DType::I32 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i32),
90            DType::I16 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i16),
91            DType::I8 => crate::ops::mask::mask_fill(tensor, mask, value.to_i64().unwrap() as i8),
92            DType::U64 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap()),
93            DType::U32 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u32),
94            DType::U16 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u16),
95            DType::U8 => crate::ops::mask::mask_fill(tensor, mask, value.to_u64().unwrap() as u8),
96            dt => panic!("int_mask_fill: unsupported dtype {:?}", dt),
97        }
98    }
99
100    fn int_slice_assign(
101        tensor: IntTensor<Flex>,
102        slices: &[Slice],
103        value: IntTensor<Flex>,
104    ) -> IntTensor<Flex> {
105        crate::ops::slice::slice_assign(tensor, slices, value)
106    }
107
108    /// Gather ints along `dim` at the given indices.
109    ///
110    /// The `tensor` dispatches on its own int dtype (I8/I16/I32/I64 signed or
111    /// U8/U16/U32/U64 unsigned). The `indices` tensor may be any of those
112    /// widths too - it's normalised to `isize` by the shared `read_indices`
113    /// helper in `ops::gather_scatter` before the kernel runs, so callers are
114    /// not required to pre-convert to I64.
115    fn int_gather(
116        dim: usize,
117        tensor: IntTensor<Flex>,
118        indices: IntTensor<Flex>,
119    ) -> IntTensor<Flex> {
120        match tensor.dtype() {
121            DType::I64 => crate::ops::gather_scatter::gather::<i64>(tensor, dim, indices),
122            DType::I32 => crate::ops::gather_scatter::gather::<i32>(tensor, dim, indices),
123            DType::I16 => crate::ops::gather_scatter::gather::<i16>(tensor, dim, indices),
124            DType::I8 => crate::ops::gather_scatter::gather::<i8>(tensor, dim, indices),
125            DType::U64 => crate::ops::gather_scatter::gather::<u64>(tensor, dim, indices),
126            DType::U32 => crate::ops::gather_scatter::gather::<u32>(tensor, dim, indices),
127            DType::U16 => crate::ops::gather_scatter::gather::<u16>(tensor, dim, indices),
128            DType::U8 => crate::ops::gather_scatter::gather::<u8>(tensor, dim, indices),
129            dt => panic!("int_gather: unsupported dtype {:?}", dt),
130        }
131    }
132
133    /// Scatter-add int values at the given indices along `dim`.
134    ///
135    /// `tensor` and `value` must share the same int dtype; `indices` may be
136    /// any supported int width. See [`int_gather`](Self::int_gather) for the
137    /// full index-width policy.
138    fn int_scatter_add(
139        dim: usize,
140        tensor: IntTensor<Flex>,
141        indices: IntTensor<Flex>,
142        value: IntTensor<Flex>,
143    ) -> IntTensor<Flex> {
144        debug_assert_eq!(
145            tensor.dtype(),
146            value.dtype(),
147            "int_scatter_add: dtype mismatch"
148        );
149        match tensor.dtype() {
150            DType::I64 => {
151                crate::ops::gather_scatter::scatter_add::<i64>(tensor, dim, indices, value)
152            }
153            DType::I32 => {
154                crate::ops::gather_scatter::scatter_add::<i32>(tensor, dim, indices, value)
155            }
156            DType::I16 => {
157                crate::ops::gather_scatter::scatter_add::<i16>(tensor, dim, indices, value)
158            }
159            DType::I8 => crate::ops::gather_scatter::scatter_add::<i8>(tensor, dim, indices, value),
160            DType::U64 => {
161                crate::ops::gather_scatter::scatter_add::<u64>(tensor, dim, indices, value)
162            }
163            DType::U32 => {
164                crate::ops::gather_scatter::scatter_add::<u32>(tensor, dim, indices, value)
165            }
166            DType::U16 => {
167                crate::ops::gather_scatter::scatter_add::<u16>(tensor, dim, indices, value)
168            }
169            DType::U8 => crate::ops::gather_scatter::scatter_add::<u8>(tensor, dim, indices, value),
170            dt => panic!("int_scatter_add: unsupported dtype {:?}", dt),
171        }
172    }
173
174    fn int_scatter_nd(
175        data: IntTensor<Flex>,
176        indices: IntTensor<Flex>,
177        values: IntTensor<Flex>,
178        reduction: burn_backend::tensor::IndexingUpdateOp,
179    ) -> IntTensor<Flex> {
180        match data.dtype() {
181            DType::I64 => {
182                crate::ops::gather_scatter::scatter_nd::<i64>(data, indices, values, reduction)
183            }
184            DType::I32 => {
185                crate::ops::gather_scatter::scatter_nd::<i32>(data, indices, values, reduction)
186            }
187            DType::I16 => {
188                crate::ops::gather_scatter::scatter_nd::<i16>(data, indices, values, reduction)
189            }
190            DType::I8 => {
191                crate::ops::gather_scatter::scatter_nd::<i8>(data, indices, values, reduction)
192            }
193            DType::U64 => {
194                crate::ops::gather_scatter::scatter_nd::<u64>(data, indices, values, reduction)
195            }
196            DType::U32 => {
197                crate::ops::gather_scatter::scatter_nd::<u32>(data, indices, values, reduction)
198            }
199            DType::U16 => {
200                crate::ops::gather_scatter::scatter_nd::<u16>(data, indices, values, reduction)
201            }
202            DType::U8 => {
203                crate::ops::gather_scatter::scatter_nd::<u8>(data, indices, values, reduction)
204            }
205            dt => panic!("int_scatter_nd: unsupported dtype {:?}", dt),
206        }
207    }
208
209    fn int_gather_nd(data: IntTensor<Flex>, indices: IntTensor<Flex>) -> IntTensor<Flex> {
210        match data.dtype() {
211            DType::I64 => crate::ops::gather_scatter::gather_nd::<i64>(data, indices),
212            DType::I32 => crate::ops::gather_scatter::gather_nd::<i32>(data, indices),
213            DType::I16 => crate::ops::gather_scatter::gather_nd::<i16>(data, indices),
214            DType::I8 => crate::ops::gather_scatter::gather_nd::<i8>(data, indices),
215            DType::U64 => crate::ops::gather_scatter::gather_nd::<u64>(data, indices),
216            DType::U32 => crate::ops::gather_scatter::gather_nd::<u32>(data, indices),
217            DType::U16 => crate::ops::gather_scatter::gather_nd::<u16>(data, indices),
218            DType::U8 => crate::ops::gather_scatter::gather_nd::<u8>(data, indices),
219            dt => panic!("int_gather_nd: unsupported dtype {:?}", dt),
220        }
221    }
222
223    /// Select ints along `dim` by a 1D index tensor.
224    ///
225    /// The `indices` tensor may be any supported int width. See
226    /// [`int_gather`](Self::int_gather) for the full index-width policy.
227    fn int_select(
228        tensor: IntTensor<Flex>,
229        dim: usize,
230        indices: IntTensor<Flex>,
231    ) -> IntTensor<Flex> {
232        match tensor.dtype() {
233            DType::I64 => crate::ops::gather_scatter::select::<i64>(tensor, dim, indices),
234            DType::I32 => crate::ops::gather_scatter::select::<i32>(tensor, dim, indices),
235            DType::I16 => crate::ops::gather_scatter::select::<i16>(tensor, dim, indices),
236            DType::I8 => crate::ops::gather_scatter::select::<i8>(tensor, dim, indices),
237            DType::U64 => crate::ops::gather_scatter::select::<u64>(tensor, dim, indices),
238            DType::U32 => crate::ops::gather_scatter::select::<u32>(tensor, dim, indices),
239            DType::U16 => crate::ops::gather_scatter::select::<u16>(tensor, dim, indices),
240            DType::U8 => crate::ops::gather_scatter::select::<u8>(tensor, dim, indices),
241            dt => panic!("int_select: unsupported dtype {:?}", dt),
242        }
243    }
244
245    /// Select-add int values at a 1D index tensor along `dim`.
246    ///
247    /// `tensor` and `value` must share the same int dtype; `indices` may be
248    /// any supported int width. See [`int_gather`](Self::int_gather) for the
249    /// full index-width policy.
250    fn int_select_add(
251        tensor: IntTensor<Flex>,
252        dim: usize,
253        indices: IntTensor<Flex>,
254        value: IntTensor<Flex>,
255    ) -> IntTensor<Flex> {
256        debug_assert_eq!(
257            tensor.dtype(),
258            value.dtype(),
259            "int_select_add: dtype mismatch"
260        );
261        match tensor.dtype() {
262            DType::I64 => {
263                crate::ops::gather_scatter::select_add::<i64>(tensor, dim, indices, value)
264            }
265            DType::I32 => {
266                crate::ops::gather_scatter::select_add::<i32>(tensor, dim, indices, value)
267            }
268            DType::I16 => {
269                crate::ops::gather_scatter::select_add::<i16>(tensor, dim, indices, value)
270            }
271            DType::I8 => crate::ops::gather_scatter::select_add::<i8>(tensor, dim, indices, value),
272            DType::U64 => {
273                crate::ops::gather_scatter::select_add::<u64>(tensor, dim, indices, value)
274            }
275            DType::U32 => {
276                crate::ops::gather_scatter::select_add::<u32>(tensor, dim, indices, value)
277            }
278            DType::U16 => {
279                crate::ops::gather_scatter::select_add::<u16>(tensor, dim, indices, value)
280            }
281            DType::U8 => crate::ops::gather_scatter::select_add::<u8>(tensor, dim, indices, value),
282            dt => panic!("int_select_add: unsupported dtype {:?}", dt),
283        }
284    }
285
286    fn int_equal(
287        lhs: IntTensor<Flex>,
288        rhs: IntTensor<Flex>,
289        out_dtype: burn_std::BoolDType,
290    ) -> BoolTensor<Flex> {
291        crate::ops::comparison::int_equal(lhs, rhs, out_dtype)
292    }
293
294    fn int_equal_elem(
295        lhs: IntTensor<Flex>,
296        rhs: Scalar,
297        out_dtype: burn_std::BoolDType,
298    ) -> BoolTensor<Flex> {
299        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
300        crate::ops::comparison::int_equal_elem(lhs, i, u, out_dtype)
301    }
302
303    fn int_greater(
304        lhs: IntTensor<Flex>,
305        rhs: IntTensor<Flex>,
306        out_dtype: burn_std::BoolDType,
307    ) -> BoolTensor<Flex> {
308        crate::ops::comparison::int_greater(lhs, rhs, out_dtype)
309    }
310
311    fn int_greater_elem(
312        lhs: IntTensor<Flex>,
313        rhs: Scalar,
314        out_dtype: burn_std::BoolDType,
315    ) -> BoolTensor<Flex> {
316        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
317        crate::ops::comparison::int_greater_elem(lhs, i, u, out_dtype)
318    }
319
320    fn int_greater_equal(
321        lhs: IntTensor<Flex>,
322        rhs: IntTensor<Flex>,
323        out_dtype: burn_std::BoolDType,
324    ) -> BoolTensor<Flex> {
325        crate::ops::comparison::int_greater_equal(lhs, rhs, out_dtype)
326    }
327
328    fn int_greater_equal_elem(
329        lhs: IntTensor<Flex>,
330        rhs: Scalar,
331        out_dtype: burn_std::BoolDType,
332    ) -> BoolTensor<Flex> {
333        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
334        crate::ops::comparison::int_greater_equal_elem(lhs, i, u, out_dtype)
335    }
336
337    fn int_lower(
338        lhs: IntTensor<Flex>,
339        rhs: IntTensor<Flex>,
340        out_dtype: burn_std::BoolDType,
341    ) -> BoolTensor<Flex> {
342        crate::ops::comparison::int_lower(lhs, rhs, out_dtype)
343    }
344
345    fn int_lower_elem(
346        lhs: IntTensor<Flex>,
347        rhs: Scalar,
348        out_dtype: burn_std::BoolDType,
349    ) -> BoolTensor<Flex> {
350        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
351        crate::ops::comparison::int_lower_elem(lhs, i, u, out_dtype)
352    }
353
354    fn int_lower_equal(
355        lhs: IntTensor<Flex>,
356        rhs: IntTensor<Flex>,
357        out_dtype: burn_std::BoolDType,
358    ) -> BoolTensor<Flex> {
359        crate::ops::comparison::int_lower_equal(lhs, rhs, out_dtype)
360    }
361
362    fn int_lower_equal_elem(
363        lhs: IntTensor<Flex>,
364        rhs: Scalar,
365        out_dtype: burn_std::BoolDType,
366    ) -> BoolTensor<Flex> {
367        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
368        crate::ops::comparison::int_lower_equal_elem(lhs, i, u, out_dtype)
369    }
370
371    fn int_add(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
372        int_binary_op(lhs, rhs, |a, b| a + b)
373    }
374
375    fn int_add_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
376        if lhs.dtype() == DType::U64 {
377            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
378                a.wrapping_add(b)
379            });
380        }
381        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a + b)
382    }
383
384    fn int_sub(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
385        int_binary_op(lhs, rhs, |a, b| a - b)
386    }
387
388    fn int_sub_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
389        if lhs.dtype() == DType::U64 {
390            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
391                a.wrapping_sub(b)
392            });
393        }
394        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a - b)
395    }
396
397    fn int_mul(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
398        int_binary_op(lhs, rhs, |a, b| a * b)
399    }
400
401    fn int_mul_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
402        if lhs.dtype() == DType::U64 {
403            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| {
404                a.wrapping_mul(b)
405            });
406        }
407        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a * b)
408    }
409
410    fn int_div(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
411        // U64 values > i64::MAX produce wrong results through i64 cast
412        if lhs.dtype() == DType::U64 {
413            let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
414            return binary_op_typed(lhs, &rhs, |a: u64, b: u64| a / b);
415        }
416        int_binary_op(lhs, rhs, |a, b| a / b)
417    }
418
419    fn int_div_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
420        if lhs.dtype() == DType::U64 {
421            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a / b);
422        }
423        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a / b)
424    }
425
426    fn int_remainder(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
427        // U64 values > i64::MAX produce wrong results through i64 cast
428        if lhs.dtype() == DType::U64 {
429            let (lhs, rhs) = crate::ops::expand::broadcast_binary(lhs, rhs);
430            return binary_op_typed(lhs, &rhs, |a: u64, b: u64| a % b);
431        }
432        // Python/PyTorch-style remainder: result has same sign as divisor
433        int_binary_op(lhs, rhs, |a, b| ((a % b) + b) % b)
434    }
435
436    fn int_remainder_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
437        if lhs.dtype() == DType::U64 {
438            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a % b);
439        }
440        // Python/PyTorch-style remainder: result has same sign as divisor
441        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| ((a % b) + b) % b)
442    }
443
444    // Precision limits: i64/u64 > 2^24 for f32/f16/bf16, > 2^53 for f64.
445    fn int_into_float(
446        tensor: IntTensor<Flex>,
447        out_dtype: burn_std::FloatDType,
448    ) -> FloatTensor<Flex> {
449        let tensor = tensor.to_contiguous();
450        let shape = tensor.layout().shape().clone();
451        let src = tensor.dtype();
452        let out_dt = DType::from(out_dtype);
453
454        // Read source ints, applying conversion per-element.
455        // Each arm binds `$x` to the native int value; `$conv` must work for all int types.
456        macro_rules! read_ints {
457            (|$x:ident| $conv:expr) => {
458                match src {
459                    DType::I64 => tensor.storage::<i64>().iter().map(|&$x| $conv).collect(),
460                    DType::I32 => tensor.storage::<i32>().iter().map(|&$x| $conv).collect(),
461                    DType::I16 => tensor.storage::<i16>().iter().map(|&$x| $conv).collect(),
462                    DType::I8 => tensor.storage::<i8>().iter().map(|&$x| $conv).collect(),
463                    DType::U64 => tensor.storage::<u64>().iter().map(|&$x| $conv).collect(),
464                    DType::U32 => tensor.storage::<u32>().iter().map(|&$x| $conv).collect(),
465                    DType::U16 => tensor.storage::<u16>().iter().map(|&$x| $conv).collect(),
466                    DType::U8 => tensor.storage::<u8>().iter().map(|&$x| $conv).collect(),
467                    _ => panic!("int_into_float: unsupported source dtype {:?}", src),
468                }
469            };
470        }
471
472        match out_dtype {
473            FloatDType::F64 => {
474                let data: Vec<f64> = read_ints!(|x| x as f64);
475                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
476            }
477            FloatDType::F32 | FloatDType::Flex32 => {
478                let data: Vec<f32> = read_ints!(|x| x as f32);
479                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
480            }
481            FloatDType::F16 => {
482                let data: Vec<f16> = read_ints!(|x| f16::from_f32(x as f32));
483                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
484            }
485            FloatDType::BF16 => {
486                let data: Vec<bf16> = read_ints!(|x| bf16::from_f32(x as f32));
487                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
488            }
489        }
490    }
491
492    fn int_swap_dims(tensor: IntTensor<Flex>, dim1: usize, dim2: usize) -> IntTensor<Flex> {
493        tensor.transpose(dim1, dim2)
494    }
495
496    fn int_permute(tensor: IntTensor<Flex>, axes: &[usize]) -> IntTensor<Flex> {
497        tensor.permute(axes)
498    }
499
500    fn int_flip(tensor: IntTensor<Flex>, axes: &[usize]) -> IntTensor<Flex> {
501        crate::ops::flip::flip(tensor, axes)
502    }
503
504    fn int_random(
505        shape: Shape,
506        distribution: Distribution,
507        _device: &Device<Flex>,
508        dtype: IntDType,
509    ) -> IntTensor<Flex> {
510        let mut seed = crate::backend::SEED.lock().unwrap();
511        let mut rng = seed.take().unwrap_or_else(crate::backend::get_seeded_rng);
512        let data = match dtype {
513            IntDType::I64 => TensorData::random::<i64, _, _>(shape, distribution, &mut rng),
514            IntDType::I32 => TensorData::random::<i32, _, _>(shape, distribution, &mut rng),
515            IntDType::I16 => TensorData::random::<i16, _, _>(shape, distribution, &mut rng),
516            IntDType::I8 => TensorData::random::<i8, _, _>(shape, distribution, &mut rng),
517            IntDType::U64 => TensorData::random::<u64, _, _>(shape, distribution, &mut rng),
518            IntDType::U32 => TensorData::random::<u32, _, _>(shape, distribution, &mut rng),
519            IntDType::U16 => TensorData::random::<u16, _, _>(shape, distribution, &mut rng),
520            IntDType::U8 => TensorData::random::<u8, _, _>(shape, distribution, &mut rng),
521        };
522        *seed = Some(rng);
523        FlexTensor::from_data(data)
524    }
525
526    fn int_expand(tensor: IntTensor<Flex>, shape: Shape) -> IntTensor<Flex> {
527        crate::ops::expand::expand(tensor, shape)
528    }
529
530    fn int_matmul(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
531        matmul::int_matmul(lhs, rhs)
532    }
533
534    fn int_sum(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
535        crate::ops::reduce::sum(tensor)
536    }
537
538    fn int_sum_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
539        crate::ops::reduce::sum_dim(tensor, dim)
540    }
541
542    fn int_prod(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
543        crate::ops::reduce::prod(tensor)
544    }
545
546    fn int_prod_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
547        crate::ops::reduce::prod_dim(tensor, dim)
548    }
549
550    fn int_mean_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
551        crate::ops::reduce::mean_dim(tensor, dim)
552    }
553
554    fn int_cumsum(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
555        match tensor.dtype() {
556            DType::I64 => crate::ops::cumulative::cumsum::<i64>(tensor, dim),
557            DType::I32 => crate::ops::cumulative::cumsum::<i32>(tensor, dim),
558            DType::I16 => crate::ops::cumulative::cumsum::<i16>(tensor, dim),
559            DType::I8 => crate::ops::cumulative::cumsum::<i8>(tensor, dim),
560            DType::U64 => crate::ops::cumulative::cumsum::<u64>(tensor, dim),
561            DType::U32 => crate::ops::cumulative::cumsum::<u32>(tensor, dim),
562            DType::U16 => crate::ops::cumulative::cumsum::<u16>(tensor, dim),
563            DType::U8 => crate::ops::cumulative::cumsum::<u8>(tensor, dim),
564            dt => panic!("int_cumsum: unsupported dtype {:?}", dt),
565        }
566    }
567
568    fn int_cumprod(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
569        match tensor.dtype() {
570            DType::I64 => crate::ops::cumulative::cumprod::<i64>(tensor, dim),
571            DType::I32 => crate::ops::cumulative::cumprod::<i32>(tensor, dim),
572            DType::I16 => crate::ops::cumulative::cumprod::<i16>(tensor, dim),
573            DType::I8 => crate::ops::cumulative::cumprod::<i8>(tensor, dim),
574            DType::U64 => crate::ops::cumulative::cumprod::<u64>(tensor, dim),
575            DType::U32 => crate::ops::cumulative::cumprod::<u32>(tensor, dim),
576            DType::U16 => crate::ops::cumulative::cumprod::<u16>(tensor, dim),
577            DType::U8 => crate::ops::cumulative::cumprod::<u8>(tensor, dim),
578            dt => panic!("int_cumprod: unsupported dtype {:?}", dt),
579        }
580    }
581
582    fn int_cummin(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
583        match tensor.dtype() {
584            DType::I64 => crate::ops::cumulative::cummin::<i64>(tensor, dim),
585            DType::I32 => crate::ops::cumulative::cummin::<i32>(tensor, dim),
586            DType::I16 => crate::ops::cumulative::cummin::<i16>(tensor, dim),
587            DType::I8 => crate::ops::cumulative::cummin::<i8>(tensor, dim),
588            DType::U64 => crate::ops::cumulative::cummin::<u64>(tensor, dim),
589            DType::U32 => crate::ops::cumulative::cummin::<u32>(tensor, dim),
590            DType::U16 => crate::ops::cumulative::cummin::<u16>(tensor, dim),
591            DType::U8 => crate::ops::cumulative::cummin::<u8>(tensor, dim),
592            dt => panic!("int_cummin: unsupported dtype {:?}", dt),
593        }
594    }
595
596    fn int_cummax(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
597        match tensor.dtype() {
598            DType::I64 => crate::ops::cumulative::cummax::<i64>(tensor, dim),
599            DType::I32 => crate::ops::cumulative::cummax::<i32>(tensor, dim),
600            DType::I16 => crate::ops::cumulative::cummax::<i16>(tensor, dim),
601            DType::I8 => crate::ops::cumulative::cummax::<i8>(tensor, dim),
602            DType::U64 => crate::ops::cumulative::cummax::<u64>(tensor, dim),
603            DType::U32 => crate::ops::cumulative::cummax::<u32>(tensor, dim),
604            DType::U16 => crate::ops::cumulative::cummax::<u16>(tensor, dim),
605            DType::U8 => crate::ops::cumulative::cummax::<u8>(tensor, dim),
606            dt => panic!("int_cummax: unsupported dtype {:?}", dt),
607        }
608    }
609
610    fn int_argmax(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
611        crate::ops::reduce::argmax(tensor, dim)
612    }
613
614    fn int_argtopk(_tensor: IntTensor<Flex>, _dim: usize, _k: usize) -> IntTensor<Flex> {
615        panic!("argtopk not implemented for flex")
616    }
617
618    fn int_argmin(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
619        crate::ops::reduce::argmin(tensor, dim)
620    }
621
622    fn int_abs(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
623        crate::ops::unary::int_abs(tensor)
624    }
625
626    fn bitwise_and(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
627        int_binary_op(lhs, rhs, |a, b| a & b)
628    }
629
630    fn bitwise_and_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
631        if lhs.dtype() == DType::U64 {
632            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a & b);
633        }
634        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a & b)
635    }
636
637    fn bitwise_or(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
638        int_binary_op(lhs, rhs, |a, b| a | b)
639    }
640
641    fn bitwise_or_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
642        if lhs.dtype() == DType::U64 {
643            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a | b);
644        }
645        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a | b)
646    }
647
648    fn bitwise_xor(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
649        int_binary_op(lhs, rhs, |a, b| a ^ b)
650    }
651
652    fn bitwise_xor_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
653        if lhs.dtype() == DType::U64 {
654            return scalar_op_typed(lhs, rhs.to_u64().unwrap(), |a: u64, b: u64| a ^ b);
655        }
656        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a ^ b)
657    }
658
659    fn bitwise_not(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
660        // Use scalar op with dummy value, only applying NOT to lhs
661        int_scalar_op(tensor, 0, |a, _| !a)
662    }
663
664    // Shift amounts masked to type width via wrapping_shl/wrapping_shr.
665    fn bitwise_left_shift(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
666        int_binary_op(lhs, rhs, |a, b| a.wrapping_shl(b as u32))
667    }
668
669    fn bitwise_left_shift_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
670        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a.wrapping_shl(b as u32))
671    }
672
673    fn bitwise_right_shift(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
674        int_binary_op(lhs, rhs, |a, b| a.wrapping_shr(b as u32))
675    }
676
677    fn bitwise_right_shift_scalar(lhs: IntTensor<Flex>, rhs: Scalar) -> IntTensor<Flex> {
678        int_scalar_op(lhs, rhs.to_i64().unwrap(), |a, b| a.wrapping_shr(b as u32))
679    }
680
681    fn int_cast(tensor: IntTensor<Flex>, dtype: IntDType) -> IntTensor<Flex> {
682        let target_dtype: DType = dtype.into();
683
684        // If already the target dtype, return as-is
685        if tensor.dtype() == target_dtype {
686            return tensor;
687        }
688
689        // Make contiguous for easier iteration
690        let tensor = tensor.to_contiguous();
691        let shape = tensor.layout().shape().clone();
692
693        // Helper macro to convert between types
694        macro_rules! cast_impl {
695            ($src_type:ty, $dst_type:ty, $dst_dtype:expr) => {{
696                let src: &[$src_type] = tensor.storage();
697                let dst: Vec<$dst_type> = src.iter().map(|&x| x as $dst_type).collect();
698                FlexTensor::new(
699                    Bytes::from_elems(dst),
700                    Layout::contiguous(shape),
701                    $dst_dtype,
702                )
703            }};
704        }
705
706        // Match source dtype to target dtype
707        match (tensor.dtype(), target_dtype) {
708            // From I64
709            (DType::I64, DType::I32) => cast_impl!(i64, i32, DType::I32),
710            (DType::I64, DType::I16) => cast_impl!(i64, i16, DType::I16),
711            (DType::I64, DType::I8) => cast_impl!(i64, i8, DType::I8),
712            (DType::I64, DType::U64) => cast_impl!(i64, u64, DType::U64),
713            (DType::I64, DType::U32) => cast_impl!(i64, u32, DType::U32),
714            (DType::I64, DType::U16) => cast_impl!(i64, u16, DType::U16),
715            (DType::I64, DType::U8) => cast_impl!(i64, u8, DType::U8),
716
717            // From I32
718            (DType::I32, DType::I64) => cast_impl!(i32, i64, DType::I64),
719            (DType::I32, DType::I16) => cast_impl!(i32, i16, DType::I16),
720            (DType::I32, DType::I8) => cast_impl!(i32, i8, DType::I8),
721            (DType::I32, DType::U64) => cast_impl!(i32, u64, DType::U64),
722            (DType::I32, DType::U32) => cast_impl!(i32, u32, DType::U32),
723            (DType::I32, DType::U16) => cast_impl!(i32, u16, DType::U16),
724            (DType::I32, DType::U8) => cast_impl!(i32, u8, DType::U8),
725
726            // From I16
727            (DType::I16, DType::I64) => cast_impl!(i16, i64, DType::I64),
728            (DType::I16, DType::I32) => cast_impl!(i16, i32, DType::I32),
729            (DType::I16, DType::I8) => cast_impl!(i16, i8, DType::I8),
730            (DType::I16, DType::U64) => cast_impl!(i16, u64, DType::U64),
731            (DType::I16, DType::U32) => cast_impl!(i16, u32, DType::U32),
732            (DType::I16, DType::U16) => cast_impl!(i16, u16, DType::U16),
733            (DType::I16, DType::U8) => cast_impl!(i16, u8, DType::U8),
734
735            // From I8
736            (DType::I8, DType::I64) => cast_impl!(i8, i64, DType::I64),
737            (DType::I8, DType::I32) => cast_impl!(i8, i32, DType::I32),
738            (DType::I8, DType::I16) => cast_impl!(i8, i16, DType::I16),
739            (DType::I8, DType::U64) => cast_impl!(i8, u64, DType::U64),
740            (DType::I8, DType::U32) => cast_impl!(i8, u32, DType::U32),
741            (DType::I8, DType::U16) => cast_impl!(i8, u16, DType::U16),
742            (DType::I8, DType::U8) => cast_impl!(i8, u8, DType::U8),
743
744            // From U64
745            (DType::U64, DType::I64) => cast_impl!(u64, i64, DType::I64),
746            (DType::U64, DType::I32) => cast_impl!(u64, i32, DType::I32),
747            (DType::U64, DType::I16) => cast_impl!(u64, i16, DType::I16),
748            (DType::U64, DType::I8) => cast_impl!(u64, i8, DType::I8),
749            (DType::U64, DType::U32) => cast_impl!(u64, u32, DType::U32),
750            (DType::U64, DType::U16) => cast_impl!(u64, u16, DType::U16),
751            (DType::U64, DType::U8) => cast_impl!(u64, u8, DType::U8),
752
753            // From U32
754            (DType::U32, DType::I64) => cast_impl!(u32, i64, DType::I64),
755            (DType::U32, DType::I32) => cast_impl!(u32, i32, DType::I32),
756            (DType::U32, DType::I16) => cast_impl!(u32, i16, DType::I16),
757            (DType::U32, DType::I8) => cast_impl!(u32, i8, DType::I8),
758            (DType::U32, DType::U64) => cast_impl!(u32, u64, DType::U64),
759            (DType::U32, DType::U16) => cast_impl!(u32, u16, DType::U16),
760            (DType::U32, DType::U8) => cast_impl!(u32, u8, DType::U8),
761
762            // From U16
763            (DType::U16, DType::I64) => cast_impl!(u16, i64, DType::I64),
764            (DType::U16, DType::I32) => cast_impl!(u16, i32, DType::I32),
765            (DType::U16, DType::I16) => cast_impl!(u16, i16, DType::I16),
766            (DType::U16, DType::I8) => cast_impl!(u16, i8, DType::I8),
767            (DType::U16, DType::U64) => cast_impl!(u16, u64, DType::U64),
768            (DType::U16, DType::U32) => cast_impl!(u16, u32, DType::U32),
769            (DType::U16, DType::U8) => cast_impl!(u16, u8, DType::U8),
770
771            // From U8
772            (DType::U8, DType::I64) => cast_impl!(u8, i64, DType::I64),
773            (DType::U8, DType::I32) => cast_impl!(u8, i32, DType::I32),
774            (DType::U8, DType::I16) => cast_impl!(u8, i16, DType::I16),
775            (DType::U8, DType::I8) => cast_impl!(u8, i8, DType::I8),
776            (DType::U8, DType::U64) => cast_impl!(u8, u64, DType::U64),
777            (DType::U8, DType::U32) => cast_impl!(u8, u32, DType::U32),
778            (DType::U8, DType::U16) => cast_impl!(u8, u16, DType::U16),
779
780            _ => panic!(
781                "int_cast: unsupported conversion from {:?} to {:?}",
782                tensor.dtype(),
783                target_dtype
784            ),
785        }
786    }
787
788    fn int_unfold(
789        tensor: IntTensor<Flex>,
790        dim: usize,
791        size: usize,
792        step: usize,
793    ) -> IntTensor<Flex> {
794        crate::ops::unfold::unfold_int(tensor, dim, size, step)
795    }
796
797    fn int_neg(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
798        int_scalar_op(tensor, 0i64, |a, _| a.wrapping_neg())
799    }
800
801    fn int_clamp(tensor: IntTensor<Flex>, min: Scalar, max: Scalar) -> IntTensor<Flex> {
802        if tensor.dtype() == DType::U64 {
803            let min_val = min.to_u64().unwrap();
804            let max_val = max.to_u64().unwrap();
805            return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.clamp(min_val, max_val));
806        }
807        let min_val = min.to_i64().unwrap();
808        let max_val = max.to_i64().unwrap();
809        int_scalar_op(tensor, 0i64, move |x, _| x.clamp(min_val, max_val))
810    }
811
812    fn int_clamp_min(tensor: IntTensor<Flex>, min: Scalar) -> IntTensor<Flex> {
813        if tensor.dtype() == DType::U64 {
814            let min_val = min.to_u64().unwrap();
815            return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.max(min_val));
816        }
817        let min_val = min.to_i64().unwrap();
818        int_scalar_op(tensor, 0i64, move |x, _| x.max(min_val))
819    }
820
821    fn int_clamp_max(tensor: IntTensor<Flex>, max: Scalar) -> IntTensor<Flex> {
822        if tensor.dtype() == DType::U64 {
823            let max_val = max.to_u64().unwrap();
824            return scalar_op_typed(tensor, 0u64, move |x: u64, _| x.min(max_val));
825        }
826        let max_val = max.to_i64().unwrap();
827        int_scalar_op(tensor, 0i64, move |x, _| x.min(max_val))
828    }
829
830    fn int_sign(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
831        if tensor.dtype() == DType::U64 {
832            return scalar_op_typed(tensor, 0u64, |x: u64, _| if x > 0 { 1 } else { 0 });
833        }
834        int_scalar_op(tensor, 0i64, |x, _| {
835            if x > 0 {
836                1
837            } else if x < 0 {
838                -1
839            } else {
840                0
841            }
842        })
843    }
844
845    fn int_mean(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
846        let n = tensor.layout().num_elements();
847        assert!(n > 0, "int_mean: cannot take mean of empty tensor");
848        let dtype = tensor.dtype();
849        let sum_result = crate::ops::reduce::sum(tensor);
850        // Compute in i64 to avoid truncation of n for small int types
851        macro_rules! compute_mean {
852            ($ty:ty) => {{
853                let data: &[$ty] = sum_result.storage();
854                let mean_val = (data[0] as i64 / n as i64) as $ty;
855                FlexTensor::new(
856                    Bytes::from_elems(alloc::vec![mean_val]),
857                    Layout::contiguous(Shape::from(alloc::vec![1])),
858                    dtype,
859                )
860            }};
861        }
862        match dtype {
863            DType::I64 => compute_mean!(i64),
864            DType::I32 => compute_mean!(i32),
865            DType::I16 => compute_mean!(i16),
866            DType::I8 => compute_mean!(i8),
867            other => panic!("int_mean: unsupported dtype {:?}", other),
868        }
869    }
870
871    fn int_max(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
872        crate::ops::reduce::max(tensor)
873    }
874
875    fn int_max_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
876        crate::ops::reduce::max_dim(tensor, dim)
877    }
878
879    fn int_min(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
880        crate::ops::reduce::min(tensor)
881    }
882
883    fn int_min_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
884        crate::ops::reduce::min_dim(tensor, dim)
885    }
886
887    fn int_max_dim_with_indices(
888        tensor: IntTensor<Flex>,
889        dim: usize,
890    ) -> (IntTensor<Flex>, IntTensor<Flex>) {
891        crate::ops::reduce::max_dim_with_indices(tensor, dim)
892    }
893
894    fn int_min_dim_with_indices(
895        tensor: IntTensor<Flex>,
896        dim: usize,
897    ) -> (IntTensor<Flex>, IntTensor<Flex>) {
898        crate::ops::reduce::min_dim_with_indices(tensor, dim)
899    }
900
901    fn int_any(tensor: IntTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
902        crate::ops::comparison::any_int(tensor, out_dtype)
903    }
904
905    fn int_any_dim(
906        tensor: IntTensor<Flex>,
907        dim: usize,
908        out_dtype: burn_std::BoolDType,
909    ) -> BoolTensor<Flex> {
910        crate::ops::comparison::any_int_dim(tensor, dim, out_dtype)
911    }
912
913    fn int_all(tensor: IntTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
914        crate::ops::comparison::all_int(tensor, out_dtype)
915    }
916
917    fn int_all_dim(
918        tensor: IntTensor<Flex>,
919        dim: usize,
920        out_dtype: burn_std::BoolDType,
921    ) -> BoolTensor<Flex> {
922        crate::ops::comparison::all_int_dim(tensor, dim, out_dtype)
923    }
924
925    fn int_powi(lhs: IntTensor<Flex>, rhs: IntTensor<Flex>) -> IntTensor<Flex> {
926        int_binary_op(lhs, rhs, |a, b| a.wrapping_pow(b as u32))
927    }
928
929    fn int_zeros(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
930        FlexTensor::zeros(shape, dtype.into())
931    }
932
933    fn int_ones(shape: Shape, _device: &Device<Flex>, dtype: IntDType) -> IntTensor<Flex> {
934        let dt: DType = dtype.into();
935        match dt {
936            DType::I64 => FlexTensor::filled_typed(shape, dt, 1i64),
937            DType::I32 => FlexTensor::filled_typed(shape, dt, 1i32),
938            DType::I16 => FlexTensor::filled_typed(shape, dt, 1i16),
939            DType::I8 => FlexTensor::filled_typed(shape, dt, 1i8),
940            DType::U64 => FlexTensor::filled_typed(shape, dt, 1u64),
941            DType::U32 => FlexTensor::filled_typed(shape, dt, 1u32),
942            DType::U16 => FlexTensor::filled_typed(shape, dt, 1u16),
943            DType::U8 => FlexTensor::filled_typed(shape, dt, 1u8),
944            _ => unreachable!(),
945        }
946    }
947
948    fn int_full(
949        shape: Shape,
950        fill_value: burn_backend::Scalar,
951        _device: &Device<Flex>,
952        dtype: IntDType,
953    ) -> IntTensor<Flex> {
954        let dt: DType = dtype.into();
955        let v = fill_value.to_i64().unwrap();
956        match dt {
957            DType::I64 => FlexTensor::filled_typed(shape, dt, v),
958            DType::I32 => FlexTensor::filled_typed(shape, dt, v as i32),
959            DType::I16 => FlexTensor::filled_typed(shape, dt, v as i16),
960            DType::I8 => FlexTensor::filled_typed(shape, dt, v as i8),
961            DType::U64 => FlexTensor::filled_typed(shape, dt, v as u64),
962            DType::U32 => FlexTensor::filled_typed(shape, dt, v as u32),
963            DType::U16 => FlexTensor::filled_typed(shape, dt, v as u16),
964            DType::U8 => FlexTensor::filled_typed(shape, dt, v as u8),
965            _ => unreachable!(),
966        }
967    }
968
969    fn int_transpose(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
970        let ndims = tensor.layout().num_dims();
971        if ndims < 2 {
972            return tensor;
973        }
974        tensor.transpose(ndims - 2, ndims - 1)
975    }
976
977    fn int_repeat_dim(tensor: IntTensor<Flex>, dim: usize, times: usize) -> IntTensor<Flex> {
978        crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
979    }
980
981    fn int_not_equal(
982        lhs: IntTensor<Flex>,
983        rhs: IntTensor<Flex>,
984        out_dtype: burn_std::BoolDType,
985    ) -> BoolTensor<Flex> {
986        crate::ops::comparison::int_not_equal(lhs, rhs, out_dtype)
987    }
988
989    fn int_not_equal_elem(
990        lhs: IntTensor<Flex>,
991        rhs: burn_backend::Scalar,
992        out_dtype: burn_std::BoolDType,
993    ) -> BoolTensor<Flex> {
994        let (i, u) = scalar_to_int_pair(lhs.dtype(), &rhs);
995        crate::ops::comparison::int_not_equal_elem(lhs, i, u, out_dtype)
996    }
997
998    fn int_sort(tensor: IntTensor<Flex>, dim: usize, descending: bool) -> IntTensor<Flex> {
999        crate::ops::sort::sort(tensor, dim, descending)
1000    }
1001
1002    fn int_sort_with_indices(
1003        tensor: IntTensor<Flex>,
1004        dim: usize,
1005        descending: bool,
1006    ) -> (IntTensor<Flex>, IntTensor<Flex>) {
1007        crate::ops::sort::sort_with_indices(tensor, dim, descending)
1008    }
1009
1010    fn int_argsort(tensor: IntTensor<Flex>, dim: usize, descending: bool) -> IntTensor<Flex> {
1011        crate::ops::sort::argsort(tensor, dim, descending)
1012    }
1013
1014    fn int_powi_scalar(lhs: IntTensor<Flex>, rhs: burn_backend::Scalar) -> IntTensor<Flex> {
1015        use num_traits::ToPrimitive;
1016        match rhs.to_i64().unwrap() {
1017            0 => Self::int_ones(lhs.shape(), &Default::default(), lhs.dtype().into()),
1018            1 => lhs,
1019            2 => Self::int_mul(lhs.clone(), lhs),
1020            _ => Self::int_powi_scalar_impl(lhs, rhs),
1021        }
1022    }
1023
1024    fn int_powi_scalar_impl(lhs: IntTensor<Flex>, rhs: burn_backend::Scalar) -> IntTensor<Flex> {
1025        use num_traits::ToPrimitive;
1026        let exp = rhs.to_i64().unwrap() as u32;
1027        if lhs.dtype() == DType::U64 {
1028            return scalar_op_typed(lhs, exp as u64, move |x: u64, _| x.wrapping_pow(exp));
1029        }
1030        int_scalar_op(lhs, exp as i64, move |x, _| x.wrapping_pow(exp))
1031    }
1032
1033    fn int_max_abs(tensor: IntTensor<Flex>) -> IntTensor<Flex> {
1034        let abs = Self::int_abs(tensor);
1035        crate::ops::reduce::max(abs)
1036    }
1037
1038    fn int_max_abs_dim(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
1039        let abs = Self::int_abs(tensor);
1040        crate::ops::reduce::max_dim(abs, dim)
1041    }
1042
1043    fn int_arange(
1044        range: core::ops::Range<i64>,
1045        _device: &Device<Flex>,
1046        dtype: IntDType,
1047    ) -> IntTensor<Flex> {
1048        Self::int_arange_step(range, 1, &Default::default(), dtype)
1049    }
1050
1051    fn int_arange_step(
1052        range: core::ops::Range<i64>,
1053        step: usize,
1054        _device: &Device<Flex>,
1055        dtype: IntDType,
1056    ) -> IntTensor<Flex> {
1057        let dt: DType = dtype.into();
1058
1059        macro_rules! arange_typed {
1060            ($ty:ty) => {{
1061                let data: Vec<$ty> = range.step_by(step).map(|v| v as $ty).collect();
1062                let shape = Shape::from(alloc::vec![data.len()]);
1063                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), dt)
1064            }};
1065        }
1066
1067        match dt {
1068            DType::I64 => arange_typed!(i64),
1069            DType::I32 => arange_typed!(i32),
1070            DType::I16 => arange_typed!(i16),
1071            DType::I8 => arange_typed!(i8),
1072            DType::U64 => arange_typed!(u64),
1073            DType::U32 => arange_typed!(u32),
1074            DType::U16 => arange_typed!(u16),
1075            DType::U8 => arange_typed!(u8),
1076            _ => unreachable!(),
1077        }
1078    }
1079}
1080
1081// Tests kept here exercise flex-specific behavior: dtype storage
1082// selection for every int width (I16/I32/U8/U16/U32/I64/U64), and edge
1083// cases of the dtype-specific kernels (u64 wrap, i64::MIN abs/neg, bit
1084// shift at width). Plain int arithmetic, scalar ops, bool->int cast
1085// smokes, and negative-stride (flipped/transposed) variants have been
1086// migrated to burn-backend-tests so they run against every backend.
1087// When adding new tests, keep them here only if they probe flex dtype
1088// storage; otherwise add them to
1089// crates/burn-backend-tests/tests/tensor/int/ops/.
1090#[cfg(test)]
1091mod tests {
1092    use alloc::vec;
1093    use burn_backend::TensorData;
1094    use burn_backend::ops::IntTensorOps;
1095
1096    use crate::Flex;
1097    use crate::FlexTensor;
1098
1099    #[test]
1100    fn test_u64_div_large_values() {
1101        let a = FlexTensor::from_data(TensorData::new(vec![u64::MAX], [1]));
1102        let b = FlexTensor::from_data(TensorData::new(vec![2u64], [1]));
1103        let result = Flex::int_div(a, b);
1104        let values: Vec<u64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1105        assert_eq!(values[0], u64::MAX / 2);
1106    }
1107
1108    #[test]
1109    fn test_u64_remainder_large_values() {
1110        let a = FlexTensor::from_data(TensorData::new(vec![u64::MAX], [1]));
1111        let b = FlexTensor::from_data(TensorData::new(vec![2u64], [1]));
1112        let result = Flex::int_remainder(a, b);
1113        let values: Vec<u64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1114        assert_eq!(values[0], u64::MAX % 2);
1115    }
1116
1117    #[test]
1118    fn test_int_abs_min_value() {
1119        // i64::MIN.abs() panics in debug; wrapping_abs returns MIN (matches PyTorch)
1120        let a = FlexTensor::from_data(TensorData::new(vec![i64::MIN], [1]));
1121        let result = Flex::int_abs(a);
1122        let values: Vec<i64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1123        assert_eq!(values[0], i64::MIN.wrapping_abs());
1124    }
1125
1126    #[test]
1127    fn test_int_neg_min_value() {
1128        // i64::MIN negation panics in debug; wrapping_neg returns MIN (matches PyTorch)
1129        let a = FlexTensor::from_data(TensorData::new(vec![i64::MIN], [1]));
1130        let result = Flex::int_neg(a);
1131        let values: Vec<i64> = bytemuck::cast_slice(&result.into_data().bytes).to_vec();
1132        assert_eq!(values[0], i64::MIN.wrapping_neg());
1133    }
1134
1135    #[test]
1136    fn test_int_shift_large_amount() {
1137        // Shift by >= bit width panics without wrapping; should not crash
1138        let a = FlexTensor::from_data(TensorData::new(vec![1i64], [1]));
1139        let b = FlexTensor::from_data(TensorData::new(vec![64i64], [1]));
1140        let _left = Flex::bitwise_left_shift(a.clone(), b.clone());
1141        let _right = Flex::bitwise_right_shift(a, b);
1142    }
1143
1144    #[test]
1145    fn test_int_into_float_f64() {
1146        use burn_backend::ops::IntTensorOps;
1147        use burn_std::FloatDType;
1148
1149        let t = FlexTensor::from_data(TensorData::new(vec![1i64, 2, -3], [3]));
1150        let result = Flex::int_into_float(t, FloatDType::F64);
1151        assert_eq!(result.dtype(), burn_backend::DType::F64);
1152        let data: Vec<f64> = result.into_data().to_vec().unwrap();
1153        assert_eq!(data, vec![1.0f64, 2.0, -3.0]);
1154    }
1155
1156    #[test]
1157    fn test_u64_add_scalar_large() {
1158        let t = FlexTensor::from_data(TensorData::new(vec![1u64, 2, 3], [3]));
1159        let big: u64 = (i64::MAX as u64) + 100;
1160        let result = Flex::int_add_scalar(t, burn_backend::Scalar::from(big));
1161        let data: Vec<u64> = result.into_data().to_vec().unwrap();
1162        assert_eq!(data, vec![big + 1, big + 2, big + 3]);
1163    }
1164
1165    #[test]
1166    fn test_u64_greater_elem_large() {
1167        let big: u64 = (i64::MAX as u64) + 100;
1168        let t = FlexTensor::from_data(TensorData::new(vec![big, big + 1, big - 1], [3]));
1169        let result = Flex::int_greater_elem(
1170            t,
1171            burn_backend::Scalar::from(big),
1172            burn_std::BoolStore::Native,
1173        );
1174        let data: Vec<bool> = result.into_data().to_vec().unwrap();
1175        assert_eq!(data, vec![false, true, false]);
1176    }
1177
1178    #[test]
1179    fn test_int_mask_fill_i32() {
1180        let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1181        let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
1182        let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(0i64));
1183        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1184        assert_eq!(data, vec![0, 2, 0, 4]);
1185    }
1186
1187    #[test]
1188    fn test_int_mask_fill_i16() {
1189        let t = FlexTensor::from_data(TensorData::new(vec![10i16, 20, 30, 40], [4]));
1190        let mask = FlexTensor::from_data(TensorData::new(vec![false, true, false, true], [4]));
1191        let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(-1i64));
1192        let data: Vec<i16> = result.into_data().to_vec().unwrap();
1193        assert_eq!(data, vec![10, -1, 30, -1]);
1194    }
1195
1196    #[test]
1197    fn test_int_mask_fill_u8() {
1198        let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1199        let mask = FlexTensor::from_data(TensorData::new(vec![true, true, false, false], [4]));
1200        let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(255i64));
1201        let data: Vec<u8> = result.into_data().to_vec().unwrap();
1202        assert_eq!(data, vec![255, 255, 3, 4]);
1203    }
1204
1205    #[test]
1206    fn test_int_mask_fill_u32() {
1207        let t = FlexTensor::from_data(TensorData::new(vec![100u32, 200, 300], [3]));
1208        let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true], [3]));
1209        let result = Flex::int_mask_fill(t, mask, burn_backend::Scalar::from(0i64));
1210        let data: Vec<u32> = result.into_data().to_vec().unwrap();
1211        assert_eq!(data, vec![0, 200, 0]);
1212    }
1213
1214    #[test]
1215    fn test_int_mask_where_i32() {
1216        let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1217        let mask = FlexTensor::from_data(TensorData::new(vec![true, false, true, false], [4]));
1218        let v = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30, 40], [4]));
1219        let result = Flex::int_mask_where(t, mask, v);
1220        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1221        assert_eq!(data, vec![10, 2, 30, 4]);
1222    }
1223
1224    #[test]
1225    fn test_int_mask_where_u8() {
1226        let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1227        let mask = FlexTensor::from_data(TensorData::new(vec![false, true, false, true], [4]));
1228        let v = FlexTensor::from_data(TensorData::new(vec![10u8, 20, 30, 40], [4]));
1229        let result = Flex::int_mask_where(t, mask, v);
1230        let data: Vec<u8> = result.into_data().to_vec().unwrap();
1231        assert_eq!(data, vec![1, 20, 3, 40]);
1232    }
1233
1234    #[test]
1235    fn test_int_gather_i32() {
1236        let t = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30, 40, 50, 60], [2, 3]));
1237        let indices = FlexTensor::from_data(TensorData::new(vec![2i64, 0, 1, 2], [2, 2]));
1238        let result = Flex::int_gather(1, t, indices);
1239        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1240        assert_eq!(data, vec![30, 10, 50, 60]);
1241    }
1242
1243    #[test]
1244    fn test_int_select_u16() {
1245        let t = FlexTensor::from_data(TensorData::new(vec![10u16, 20, 30, 40, 50, 60], [2, 3]));
1246        let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 1], [2]));
1247        let result = Flex::int_select(t, 1, indices);
1248        let data: Vec<u16> = result.into_data().to_vec().unwrap();
1249        assert_eq!(data, vec![10, 20, 40, 50]);
1250    }
1251
1252    #[test]
1253    fn test_int_cumsum_i32() {
1254        let t = FlexTensor::from_data(TensorData::new(vec![1i32, 2, 3, 4], [4]));
1255        let result = Flex::int_cumsum(t, 0);
1256        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1257        assert_eq!(data, vec![1, 3, 6, 10]);
1258    }
1259
1260    #[test]
1261    fn test_int_cumprod_u8() {
1262        let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3, 4], [4]));
1263        let result = Flex::int_cumprod(t, 0);
1264        let data: Vec<u8> = result.into_data().to_vec().unwrap();
1265        assert_eq!(data, vec![1, 2, 6, 24]);
1266    }
1267
1268    #[test]
1269    fn test_int_cummin_i32() {
1270        let t = FlexTensor::from_data(TensorData::new(vec![3i32, 1, 4, 1, 5], [5]));
1271        let result = Flex::int_cummin(t, 0);
1272        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1273        assert_eq!(data, vec![3, 1, 1, 1, 1]);
1274    }
1275
1276    #[test]
1277    fn test_int_cummax_u16() {
1278        let t = FlexTensor::from_data(TensorData::new(vec![3u16, 1, 4, 1, 5], [5]));
1279        let result = Flex::int_cummax(t, 0);
1280        let data: Vec<u16> = result.into_data().to_vec().unwrap();
1281        assert_eq!(data, vec![3, 3, 4, 4, 5]);
1282    }
1283
1284    #[test]
1285    fn test_int_scatter_add_i32() {
1286        let t = FlexTensor::from_data(TensorData::new(vec![0i32, 0, 0], [1, 3]));
1287        let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 2, 1], [1, 3]));
1288        let values = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30], [1, 3]));
1289        let result = Flex::int_scatter_add(1, t, indices, values);
1290        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1291        assert_eq!(data, vec![10, 30, 20]);
1292    }
1293
1294    #[test]
1295    fn test_int_select_add_u8() {
1296        let t = FlexTensor::from_data(TensorData::new(vec![1u8, 2, 3], [3]));
1297        let indices = FlexTensor::from_data(TensorData::new(vec![0i64, 2], [2]));
1298        let values = FlexTensor::from_data(TensorData::new(vec![10u8, 20], [2]));
1299        let result = Flex::int_select_add(t, 0, indices, values);
1300        let data: Vec<u8> = result.into_data().to_vec().unwrap();
1301        assert_eq!(data, vec![11, 2, 23]);
1302    }
1303
1304    #[test]
1305    fn test_int_random_i32() {
1306        use burn_backend::{DType, Distribution, ops::IntTensorOps};
1307        use burn_std::{IntDType, Shape};
1308
1309        let shape = Shape::from(vec![100]);
1310        let dist = Distribution::Uniform(0.0, 10.0);
1311        let device = crate::FlexDevice;
1312        let t = Flex::int_random(shape, dist, &device, IntDType::I32);
1313        assert_eq!(t.dtype(), DType::I32);
1314        let data: Vec<i32> = t.into_data().to_vec().unwrap();
1315        assert!(data.iter().all(|&v| (0..=10).contains(&v)));
1316    }
1317
1318    #[test]
1319    fn test_int_random_u8() {
1320        use burn_backend::{DType, Distribution, ops::IntTensorOps};
1321        use burn_std::{IntDType, Shape};
1322
1323        let shape = Shape::from(vec![50]);
1324        let dist = Distribution::Uniform(0.0, 100.0);
1325        let device = crate::FlexDevice;
1326        let t = Flex::int_random(shape, dist, &device, IntDType::U8);
1327        assert_eq!(t.dtype(), DType::U8);
1328    }
1329
1330    #[test]
1331    fn test_int_mean_i32() {
1332        use burn_backend::{DType, ops::IntTensorOps};
1333
1334        let t = FlexTensor::from_data(TensorData::new(vec![10i32, 20, 30], [3]));
1335        let result = Flex::int_mean(t);
1336        assert_eq!(result.dtype(), DType::I32);
1337        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1338        assert_eq!(data, vec![20]); // (10 + 20 + 30) / 3 = 20
1339    }
1340}