Skip to main content

burn_flex/ops/
float.rs

1//! Float tensor operations for the Flex backend.
2
3use alloc::vec;
4use alloc::vec::Vec;
5use burn_backend::{
6    DType, Distribution, ExecutionError, FloatDType, Scalar, TensorData, TensorMetadata,
7    ops::{FloatTensorOps, GridSampleOptions, IntTensorOps},
8    tensor::{BoolTensor, Device, FloatTensor, IntTensor},
9};
10use burn_std::{Bytes, IntDType, Shape, Slice, bf16, f16};
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float;
14
15use crate::Layout;
16use num_traits::ToPrimitive;
17
18use crate::ops::binary::{BinaryOp, binary_op, scalar_op};
19use crate::ops::matmul;
20use crate::ops::unary;
21use crate::{Flex, FlexTensor};
22
23impl FloatTensorOps<Flex> for Flex {
24    fn float_from_data(data: TensorData, _device: &Device<Flex>) -> FloatTensor<Flex> {
25        FlexTensor::from_data(data)
26    }
27
28    fn float_random(
29        shape: Shape,
30        distribution: Distribution,
31        _device: &Device<Flex>,
32        dtype: FloatDType,
33    ) -> FloatTensor<Flex> {
34        let mut seed = crate::backend::SEED.lock().unwrap();
35        let mut rng = seed.take().unwrap_or_else(crate::backend::get_seeded_rng);
36        let data = match dtype {
37            FloatDType::F64 => TensorData::random::<f64, _, _>(shape, distribution, &mut rng),
38            FloatDType::F32 | FloatDType::Flex32 => {
39                TensorData::random::<f32, _, _>(shape, distribution, &mut rng)
40            }
41            FloatDType::F16 => TensorData::random::<f16, _, _>(shape, distribution, &mut rng),
42            FloatDType::BF16 => TensorData::random::<bf16, _, _>(shape, distribution, &mut rng),
43        };
44        *seed = Some(rng);
45        FlexTensor::from_data(data)
46    }
47
48    async fn float_into_data(tensor: FloatTensor<Flex>) -> Result<TensorData, ExecutionError> {
49        Ok(tensor.into_data())
50    }
51
52    fn float_device(_tensor: &FloatTensor<Flex>) -> Device<Flex> {
53        // CPU backend: all tensors are on the default device
54        Default::default()
55    }
56
57    fn float_to_device(tensor: FloatTensor<Flex>, _device: &Device<Flex>) -> FloatTensor<Flex> {
58        // CPU backend: no-op, tensors are always on CPU
59        tensor
60    }
61
62    fn float_detach(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
63        tensor
64    }
65
66    fn float_into_int(tensor: FloatTensor<Flex>, out_dtype: burn_std::IntDType) -> IntTensor<Flex> {
67        let tensor = tensor.to_contiguous();
68        let shape = tensor.layout().shape().clone();
69        let src = tensor.dtype();
70        let out_dt = DType::from(out_dtype);
71
72        // Read source floats as f64 (lossless for f32/f16/bf16).
73        macro_rules! read_floats {
74            (|$x:ident| $conv:expr) => {
75                match src {
76                    DType::F32 => tensor
77                        .storage::<f32>()
78                        .iter()
79                        .map(|v| {
80                            let $x = *v as f64;
81                            $conv
82                        })
83                        .collect(),
84                    DType::F64 => tensor
85                        .storage::<f64>()
86                        .iter()
87                        .map(|v| {
88                            let $x = *v;
89                            $conv
90                        })
91                        .collect(),
92                    DType::F16 => tensor
93                        .storage::<f16>()
94                        .iter()
95                        .map(|v| {
96                            let $x = f32::from(*v) as f64;
97                            $conv
98                        })
99                        .collect(),
100                    DType::BF16 => tensor
101                        .storage::<bf16>()
102                        .iter()
103                        .map(|v| {
104                            let $x = f32::from(*v) as f64;
105                            $conv
106                        })
107                        .collect(),
108                    _ => panic!("float_into_int: unsupported source dtype {:?}", src),
109                }
110            };
111        }
112
113        macro_rules! convert {
114            ($int_ty:ty) => {{
115                let data: Vec<$int_ty> = read_floats!(|x| x as $int_ty);
116                FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
117            }};
118        }
119
120        match out_dtype {
121            IntDType::I64 => convert!(i64),
122            IntDType::I32 => convert!(i32),
123            IntDType::I16 => convert!(i16),
124            IntDType::I8 => convert!(i8),
125            IntDType::U64 => convert!(u64),
126            IntDType::U32 => convert!(u32),
127            IntDType::U16 => convert!(u16),
128            IntDType::U8 => convert!(u8),
129        }
130    }
131
132    fn float_empty(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
133        FlexTensor::empty(shape, dtype.into())
134    }
135
136    fn float_add(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
137        binary_op(lhs, rhs, |a, b| a + b, |a, b| a + b, Some(BinaryOp::Add))
138    }
139
140    fn float_add_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
141        let rhs_val = rhs.to_f64().unwrap();
142        scalar_op(lhs, rhs_val, |a, b| a + b, |a, b| a + b)
143    }
144
145    fn float_sub(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
146        binary_op(lhs, rhs, |a, b| a - b, |a, b| a - b, Some(BinaryOp::Sub))
147    }
148
149    fn float_sub_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
150        let rhs_val = rhs.to_f64().unwrap();
151        scalar_op(lhs, rhs_val, |a, b| a - b, |a, b| a - b)
152    }
153
154    fn float_mul(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
155        binary_op(lhs, rhs, |a, b| a * b, |a, b| a * b, Some(BinaryOp::Mul))
156    }
157
158    fn float_mul_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
159        let rhs_val = rhs.to_f64().unwrap();
160        scalar_op(lhs, rhs_val, |a, b| a * b, |a, b| a * b)
161    }
162
163    fn float_div(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
164        binary_op(lhs, rhs, |a, b| a / b, |a, b| a / b, Some(BinaryOp::Div))
165    }
166
167    fn float_div_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
168        let rhs_val = rhs.to_f64().unwrap();
169        scalar_op(lhs, rhs_val, |a, b| a / b, |a, b| a / b)
170    }
171
172    fn float_remainder(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
173        // Python/PyTorch-style remainder: result has same sign as divisor
174        binary_op(
175            lhs,
176            rhs,
177            |a, b| ((a % b) + b) % b,
178            |a, b| ((a % b) + b) % b,
179            None,
180        )
181    }
182
183    fn float_remainder_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
184        let rhs_val = rhs.to_f64().unwrap();
185        // Python/PyTorch-style remainder: result has same sign as divisor
186        scalar_op(
187            lhs,
188            rhs_val,
189            |a, b| ((a % b) + b) % b,
190            |a, b| ((a % b) + b) % b,
191        )
192    }
193
194    fn float_matmul(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
195        matmul::matmul(lhs, rhs)
196    }
197
198    fn float_cross(
199        lhs: FloatTensor<Flex>,
200        rhs: FloatTensor<Flex>,
201        dim: usize,
202    ) -> FloatTensor<Flex> {
203        let shape = lhs.layout().shape();
204        let ndims = shape.num_dims();
205        assert_eq!(
206            shape[dim], 3,
207            "cross product requires dimension {} to have size 3, got {}",
208            dim, shape[dim]
209        );
210
211        // Helper to create slices that select index `idx` along `dim`
212        let make_slices = |idx: usize| -> alloc::vec::Vec<Slice> {
213            (0..ndims)
214                .map(|d| {
215                    if d == dim {
216                        Slice::new(idx as isize, Some((idx + 1) as isize), 1)
217                    } else {
218                        Slice::new(0, None, 1)
219                    }
220                })
221                .collect()
222        };
223
224        // Extract components along the dimension
225        // a = [a0, a1, a2], b = [b0, b1, b2]
226        let a0 = Self::float_slice(lhs.clone(), &make_slices(0));
227        let a1 = Self::float_slice(lhs.clone(), &make_slices(1));
228        let a2 = Self::float_slice(lhs, &make_slices(2));
229
230        let b0 = Self::float_slice(rhs.clone(), &make_slices(0));
231        let b1 = Self::float_slice(rhs.clone(), &make_slices(1));
232        let b2 = Self::float_slice(rhs, &make_slices(2));
233
234        // Cross product: c = a × b
235        // c0 = a1*b2 - a2*b1
236        // c1 = a2*b0 - a0*b2
237        // c2 = a0*b1 - a1*b0
238        let c0 = Self::float_sub(
239            Self::float_mul(a1.clone(), b2.clone()),
240            Self::float_mul(a2.clone(), b1.clone()),
241        );
242        let c1 = Self::float_sub(
243            Self::float_mul(a2, b0.clone()),
244            Self::float_mul(a0.clone(), b2),
245        );
246        let c2 = Self::float_sub(Self::float_mul(a0, b1), Self::float_mul(a1, b0));
247
248        // Concatenate along the dimension
249        Self::float_cat(vec![c0, c1, c2], dim)
250    }
251
252    fn float_recip(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
253        unary::recip(tensor)
254    }
255
256    fn float_swap_dims(tensor: FloatTensor<Flex>, dim1: usize, dim2: usize) -> FloatTensor<Flex> {
257        tensor.transpose(dim1, dim2)
258    }
259
260    fn float_permute(tensor: FloatTensor<Flex>, axes: &[usize]) -> FloatTensor<Flex> {
261        tensor.permute(axes)
262    }
263
264    fn float_flip(tensor: FloatTensor<Flex>, axes: &[usize]) -> FloatTensor<Flex> {
265        crate::ops::flip::flip(tensor, axes)
266    }
267
268    fn float_cat(tensors: Vec<FloatTensor<Flex>>, dim: usize) -> FloatTensor<Flex> {
269        crate::ops::cat::cat(tensors, dim)
270    }
271
272    fn float_reshape(tensor: FloatTensor<Flex>, shape: Shape) -> FloatTensor<Flex> {
273        tensor.reshape(shape)
274    }
275
276    fn float_gather(
277        dim: usize,
278        tensor: FloatTensor<Flex>,
279        indices: IntTensor<Flex>,
280    ) -> FloatTensor<Flex> {
281        match tensor.dtype() {
282            DType::F32 => crate::ops::gather_scatter::gather::<f32>(tensor, dim, indices),
283            DType::F64 => crate::ops::gather_scatter::gather::<f64>(tensor, dim, indices),
284            DType::F16 => crate::ops::gather_scatter::gather::<f16>(tensor, dim, indices),
285            DType::BF16 => crate::ops::gather_scatter::gather::<bf16>(tensor, dim, indices),
286            _ => panic!("float_gather: unsupported dtype {:?}", tensor.dtype()),
287        }
288    }
289
290    fn float_scatter_add(
291        dim: usize,
292        tensor: FloatTensor<Flex>,
293        indices: IntTensor<Flex>,
294        value: FloatTensor<Flex>,
295    ) -> FloatTensor<Flex> {
296        match tensor.dtype() {
297            DType::F32 => {
298                crate::ops::gather_scatter::scatter_add::<f32>(tensor, dim, indices, value)
299            }
300            DType::F64 => {
301                crate::ops::gather_scatter::scatter_add::<f64>(tensor, dim, indices, value)
302            }
303            DType::F16 => {
304                crate::ops::gather_scatter::scatter_add::<f16>(tensor, dim, indices, value)
305            }
306            DType::BF16 => {
307                crate::ops::gather_scatter::scatter_add::<bf16>(tensor, dim, indices, value)
308            }
309            _ => panic!("float_scatter_add: unsupported dtype {:?}", tensor.dtype()),
310        }
311    }
312
313    fn float_scatter_nd(
314        data: FloatTensor<Flex>,
315        indices: IntTensor<Flex>,
316        values: FloatTensor<Flex>,
317        reduction: burn_backend::tensor::IndexingUpdateOp,
318    ) -> FloatTensor<Flex> {
319        match data.dtype() {
320            DType::F32 => {
321                crate::ops::gather_scatter::scatter_nd::<f32>(data, indices, values, reduction)
322            }
323            DType::F64 => {
324                crate::ops::gather_scatter::scatter_nd::<f64>(data, indices, values, reduction)
325            }
326            DType::F16 => {
327                crate::ops::gather_scatter::scatter_nd::<f16>(data, indices, values, reduction)
328            }
329            DType::BF16 => {
330                crate::ops::gather_scatter::scatter_nd::<bf16>(data, indices, values, reduction)
331            }
332            _ => panic!("float_scatter_nd: unsupported dtype {:?}", data.dtype()),
333        }
334    }
335
336    fn float_gather_nd(data: FloatTensor<Flex>, indices: IntTensor<Flex>) -> FloatTensor<Flex> {
337        match data.dtype() {
338            DType::F32 => crate::ops::gather_scatter::gather_nd::<f32>(data, indices),
339            DType::F64 => crate::ops::gather_scatter::gather_nd::<f64>(data, indices),
340            DType::F16 => crate::ops::gather_scatter::gather_nd::<f16>(data, indices),
341            DType::BF16 => crate::ops::gather_scatter::gather_nd::<bf16>(data, indices),
342            _ => panic!("float_gather_nd: unsupported dtype {:?}", data.dtype()),
343        }
344    }
345
346    fn float_select(
347        tensor: FloatTensor<Flex>,
348        dim: usize,
349        indices: IntTensor<Flex>,
350    ) -> FloatTensor<Flex> {
351        match tensor.dtype() {
352            DType::F32 => crate::ops::gather_scatter::select::<f32>(tensor, dim, indices),
353            DType::F64 => crate::ops::gather_scatter::select::<f64>(tensor, dim, indices),
354            DType::F16 => crate::ops::gather_scatter::select::<f16>(tensor, dim, indices),
355            DType::BF16 => crate::ops::gather_scatter::select::<bf16>(tensor, dim, indices),
356            _ => panic!("float_select: unsupported dtype {:?}", tensor.dtype()),
357        }
358    }
359
360    fn float_select_add(
361        tensor: FloatTensor<Flex>,
362        dim: usize,
363        indices: IntTensor<Flex>,
364        value: FloatTensor<Flex>,
365    ) -> FloatTensor<Flex> {
366        match tensor.dtype() {
367            DType::F32 => {
368                crate::ops::gather_scatter::select_add::<f32>(tensor, dim, indices, value)
369            }
370            DType::F64 => {
371                crate::ops::gather_scatter::select_add::<f64>(tensor, dim, indices, value)
372            }
373            DType::F16 => {
374                crate::ops::gather_scatter::select_add::<f16>(tensor, dim, indices, value)
375            }
376            DType::BF16 => {
377                crate::ops::gather_scatter::select_add::<bf16>(tensor, dim, indices, value)
378            }
379            _ => panic!("float_select_add: unsupported dtype {:?}", tensor.dtype()),
380        }
381    }
382
383    fn float_slice(tensor: FloatTensor<Flex>, slices: &[Slice]) -> FloatTensor<Flex> {
384        crate::ops::slice::slice(tensor, slices)
385    }
386
387    fn float_slice_assign(
388        tensor: FloatTensor<Flex>,
389        slices: &[Slice],
390        value: FloatTensor<Flex>,
391    ) -> FloatTensor<Flex> {
392        crate::ops::slice::slice_assign(tensor, slices, value)
393    }
394
395    fn float_mask_where(
396        tensor: FloatTensor<Flex>,
397        mask: BoolTensor<Flex>,
398        value: FloatTensor<Flex>,
399    ) -> FloatTensor<Flex> {
400        match tensor.dtype() {
401            DType::F32 => crate::ops::mask::mask_where_f32(tensor, mask, value),
402            DType::F64 => crate::ops::mask::mask_where_f64(tensor, mask, value),
403            DType::F16 => crate::ops::mask::mask_where_f16(tensor, mask, value),
404            DType::BF16 => crate::ops::mask::mask_where_bf16(tensor, mask, value),
405            dtype => panic!("float_mask_where: unsupported dtype {:?}", dtype),
406        }
407    }
408
409    fn float_mask_fill(
410        tensor: FloatTensor<Flex>,
411        mask: BoolTensor<Flex>,
412        value: Scalar,
413    ) -> FloatTensor<Flex> {
414        match tensor.dtype() {
415            DType::F32 => crate::ops::mask::mask_fill_f32(tensor, mask, value.to_f32().unwrap()),
416            DType::F64 => crate::ops::mask::mask_fill_f64(tensor, mask, value.to_f64().unwrap()),
417            DType::F16 => crate::ops::mask::mask_fill_f16(
418                tensor,
419                mask,
420                f16::from_f64(value.to_f64().unwrap()),
421            ),
422            DType::BF16 => crate::ops::mask::mask_fill_bf16(
423                tensor,
424                mask,
425                bf16::from_f64(value.to_f64().unwrap()),
426            ),
427            dtype => panic!("float_mask_fill: unsupported dtype {:?}", dtype),
428        }
429    }
430
431    fn float_equal(
432        lhs: FloatTensor<Flex>,
433        rhs: FloatTensor<Flex>,
434        out_dtype: burn_std::BoolDType,
435    ) -> BoolTensor<Flex> {
436        crate::ops::comparison::equal(lhs, rhs, out_dtype)
437    }
438
439    fn float_equal_elem(
440        lhs: FloatTensor<Flex>,
441        rhs: Scalar,
442        out_dtype: burn_std::BoolDType,
443    ) -> BoolTensor<Flex> {
444        crate::ops::comparison::equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
445    }
446
447    fn float_greater(
448        lhs: FloatTensor<Flex>,
449        rhs: FloatTensor<Flex>,
450        out_dtype: burn_std::BoolDType,
451    ) -> BoolTensor<Flex> {
452        crate::ops::comparison::greater(lhs, rhs, out_dtype)
453    }
454
455    fn float_greater_elem(
456        lhs: FloatTensor<Flex>,
457        rhs: Scalar,
458        out_dtype: burn_std::BoolDType,
459    ) -> BoolTensor<Flex> {
460        crate::ops::comparison::greater_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
461    }
462
463    fn float_greater_equal(
464        lhs: FloatTensor<Flex>,
465        rhs: FloatTensor<Flex>,
466        out_dtype: burn_std::BoolDType,
467    ) -> BoolTensor<Flex> {
468        crate::ops::comparison::greater_equal(lhs, rhs, out_dtype)
469    }
470
471    fn float_greater_equal_elem(
472        lhs: FloatTensor<Flex>,
473        rhs: Scalar,
474        out_dtype: burn_std::BoolDType,
475    ) -> BoolTensor<Flex> {
476        crate::ops::comparison::greater_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
477    }
478
479    fn float_lower(
480        lhs: FloatTensor<Flex>,
481        rhs: FloatTensor<Flex>,
482        out_dtype: burn_std::BoolDType,
483    ) -> BoolTensor<Flex> {
484        crate::ops::comparison::lower(lhs, rhs, out_dtype)
485    }
486
487    fn float_lower_elem(
488        lhs: FloatTensor<Flex>,
489        rhs: Scalar,
490        out_dtype: burn_std::BoolDType,
491    ) -> BoolTensor<Flex> {
492        crate::ops::comparison::lower_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
493    }
494
495    fn float_lower_equal(
496        lhs: FloatTensor<Flex>,
497        rhs: FloatTensor<Flex>,
498        out_dtype: burn_std::BoolDType,
499    ) -> BoolTensor<Flex> {
500        crate::ops::comparison::lower_equal(lhs, rhs, out_dtype)
501    }
502
503    fn float_lower_equal_elem(
504        lhs: FloatTensor<Flex>,
505        rhs: Scalar,
506        out_dtype: burn_std::BoolDType,
507    ) -> BoolTensor<Flex> {
508        crate::ops::comparison::lower_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
509    }
510
511    fn float_not_equal(
512        lhs: FloatTensor<Flex>,
513        rhs: FloatTensor<Flex>,
514        out_dtype: burn_std::BoolDType,
515    ) -> BoolTensor<Flex> {
516        crate::ops::comparison::not_equal(lhs, rhs, out_dtype)
517    }
518
519    fn float_not_equal_elem(
520        lhs: FloatTensor<Flex>,
521        rhs: Scalar,
522        out_dtype: burn_std::BoolDType,
523    ) -> BoolTensor<Flex> {
524        crate::ops::comparison::not_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
525    }
526
527    fn float_neg(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
528        unary::unary_op(tensor, |x: f32| -x, |x: f64| -x)
529    }
530
531    fn float_clamp(tensor: FloatTensor<Flex>, min: Scalar, max: Scalar) -> FloatTensor<Flex> {
532        let min32 = min.to_f32().unwrap();
533        let max32 = max.to_f32().unwrap();
534        let min64 = min.to_f64().unwrap();
535        let max64 = max.to_f64().unwrap();
536        unary::unary_op(
537            tensor,
538            move |x: f32| x.clamp(min32, max32),
539            move |x: f64| x.clamp(min64, max64),
540        )
541    }
542
543    fn float_clamp_min(tensor: FloatTensor<Flex>, min: Scalar) -> FloatTensor<Flex> {
544        let min32 = min.to_f32().unwrap();
545        let min64 = min.to_f64().unwrap();
546        unary::unary_op(
547            tensor,
548            move |x: f32| x.max(min32),
549            move |x: f64| x.max(min64),
550        )
551    }
552
553    fn float_clamp_max(tensor: FloatTensor<Flex>, max: Scalar) -> FloatTensor<Flex> {
554        let max32 = max.to_f32().unwrap();
555        let max64 = max.to_f64().unwrap();
556        unary::unary_op(
557            tensor,
558            move |x: f32| x.min(max32),
559            move |x: f64| x.min(max64),
560        )
561    }
562
563    fn float_sign(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
564        unary::unary_op(
565            tensor,
566            |x: f32| {
567                if x.is_nan() {
568                    x
569                } else if x > 0.0 {
570                    1.0
571                } else if x < 0.0 {
572                    -1.0
573                } else {
574                    0.0
575                }
576            },
577            |x: f64| {
578                if x.is_nan() {
579                    x
580                } else if x > 0.0 {
581                    1.0
582                } else if x < 0.0 {
583                    -1.0
584                } else {
585                    0.0
586                }
587            },
588        )
589    }
590
591    fn float_mean(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
592        crate::ops::reduce::mean(tensor)
593    }
594
595    fn float_max(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
596        crate::ops::reduce::max(tensor)
597    }
598
599    fn float_max_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
600        crate::ops::reduce::max_dim(tensor, dim)
601    }
602
603    fn float_min(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
604        crate::ops::reduce::min(tensor)
605    }
606
607    fn float_min_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
608        crate::ops::reduce::min_dim(tensor, dim)
609    }
610
611    fn float_max_dim_with_indices(
612        tensor: FloatTensor<Flex>,
613        dim: usize,
614        indices_dtype: burn_std::IntDType,
615    ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
616        let (values, indices) = crate::ops::reduce::max_dim_with_indices(tensor, dim);
617        if indices.dtype() != DType::from(indices_dtype) {
618            (values, Flex::int_cast(indices, indices_dtype))
619        } else {
620            (values, indices)
621        }
622    }
623
624    fn float_min_dim_with_indices(
625        tensor: FloatTensor<Flex>,
626        dim: usize,
627        indices_dtype: burn_std::IntDType,
628    ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
629        let (values, indices) = crate::ops::reduce::min_dim_with_indices(tensor, dim);
630        if indices.dtype() != DType::from(indices_dtype) {
631            (values, Flex::int_cast(indices, indices_dtype))
632        } else {
633            (values, indices)
634        }
635    }
636
637    fn float_any(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
638        crate::ops::comparison::any_float(tensor, out_dtype)
639    }
640
641    fn float_any_dim(
642        tensor: FloatTensor<Flex>,
643        dim: usize,
644        out_dtype: burn_std::BoolDType,
645    ) -> BoolTensor<Flex> {
646        crate::ops::comparison::any_float_dim(tensor, dim, out_dtype)
647    }
648
649    fn float_all(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
650        crate::ops::comparison::all_float(tensor, out_dtype)
651    }
652
653    fn float_all_dim(
654        tensor: FloatTensor<Flex>,
655        dim: usize,
656        out_dtype: burn_std::BoolDType,
657    ) -> BoolTensor<Flex> {
658        crate::ops::comparison::all_float_dim(tensor, dim, out_dtype)
659    }
660
661    fn float_sum(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
662        crate::ops::reduce::sum(tensor)
663    }
664
665    fn float_sum_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
666        crate::ops::reduce::sum_dim(tensor, dim)
667    }
668
669    fn float_mean_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
670        crate::ops::reduce::mean_dim(tensor, dim)
671    }
672
673    fn float_prod(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
674        crate::ops::reduce::prod(tensor)
675    }
676
677    fn float_prod_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
678        crate::ops::reduce::prod_dim(tensor, dim)
679    }
680
681    fn float_cumsum(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
682        match tensor.dtype() {
683            DType::F32 => crate::ops::cumulative::cumsum_f32(tensor, dim),
684            DType::F64 => crate::ops::cumulative::cumsum_f64(tensor, dim),
685            DType::F16 => {
686                crate::ops::cumulative::cumsum_half(tensor, dim, f16::to_f32, f16::from_f32)
687            }
688            DType::BF16 => {
689                crate::ops::cumulative::cumsum_half(tensor, dim, bf16::to_f32, bf16::from_f32)
690            }
691            _ => panic!("float_cumsum: unsupported dtype {:?}", tensor.dtype()),
692        }
693    }
694
695    fn float_cumprod(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
696        match tensor.dtype() {
697            DType::F32 => crate::ops::cumulative::cumprod_f32(tensor, dim),
698            DType::F64 => crate::ops::cumulative::cumprod_f64(tensor, dim),
699            DType::F16 => {
700                crate::ops::cumulative::cumprod_half(tensor, dim, f16::to_f32, f16::from_f32)
701            }
702            DType::BF16 => {
703                crate::ops::cumulative::cumprod_half(tensor, dim, bf16::to_f32, bf16::from_f32)
704            }
705            _ => panic!("float_cumprod: unsupported dtype {:?}", tensor.dtype()),
706        }
707    }
708
709    fn float_cummin(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
710        match tensor.dtype() {
711            DType::F32 => crate::ops::cumulative::cummin_f32(tensor, dim),
712            DType::F64 => crate::ops::cumulative::cummin_f64(tensor, dim),
713            DType::F16 => {
714                crate::ops::cumulative::cummin_half(tensor, dim, f16::to_f32, f16::from_f32)
715            }
716            DType::BF16 => {
717                crate::ops::cumulative::cummin_half(tensor, dim, bf16::to_f32, bf16::from_f32)
718            }
719            _ => panic!("float_cummin: unsupported dtype {:?}", tensor.dtype()),
720        }
721    }
722
723    fn float_cummax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
724        match tensor.dtype() {
725            DType::F32 => crate::ops::cumulative::cummax_f32(tensor, dim),
726            DType::F64 => crate::ops::cumulative::cummax_f64(tensor, dim),
727            DType::F16 => {
728                crate::ops::cumulative::cummax_half(tensor, dim, f16::to_f32, f16::from_f32)
729            }
730            DType::BF16 => {
731                crate::ops::cumulative::cummax_half(tensor, dim, bf16::to_f32, bf16::from_f32)
732            }
733            _ => panic!("float_cummax: unsupported dtype {:?}", tensor.dtype()),
734        }
735    }
736
737    fn float_cast(tensor: FloatTensor<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
738        use crate::Layout;
739        use burn_std::{Bytes, bf16, f16};
740
741        let src_dtype = tensor.dtype();
742        let target_dtype = DType::from(dtype);
743
744        // No-op if already the same dtype
745        if src_dtype == target_dtype {
746            return tensor;
747        }
748
749        let tensor = tensor.to_contiguous();
750        let shape = tensor.layout().shape().clone();
751
752        // Convert to f64 intermediate, then to target
753        let f64_values: Vec<f64> = match src_dtype {
754            DType::F32 => {
755                let src: &[f32] = tensor.storage();
756                src.iter().map(|&v| v as f64).collect()
757            }
758            DType::F64 => {
759                let src: &[f64] = tensor.storage();
760                src.to_vec()
761            }
762            DType::F16 => {
763                let src: &[f16] = tensor.storage();
764                src.iter().map(|&v| v.to_f32() as f64).collect()
765            }
766            DType::BF16 => {
767                let src: &[bf16] = tensor.storage();
768                src.iter().map(|&v| v.to_f32() as f64).collect()
769            }
770            _ => panic!("float_cast: unsupported source dtype {:?}", src_dtype),
771        };
772
773        // Convert from f64 to target dtype
774        match target_dtype {
775            DType::F32 => {
776                let result: Vec<f32> = f64_values.iter().map(|&v| v as f32).collect();
777                let bytes = Bytes::from_elems(result);
778                FlexTensor::new(bytes, Layout::contiguous(shape), DType::F32)
779            }
780            DType::F64 => {
781                let bytes = Bytes::from_elems(f64_values);
782                FlexTensor::new(bytes, Layout::contiguous(shape), DType::F64)
783            }
784            DType::F16 => {
785                let result: Vec<f16> = f64_values.iter().map(|&v| f16::from_f64(v)).collect();
786                let bytes = Bytes::from_elems(result);
787                FlexTensor::new(bytes, Layout::contiguous(shape), DType::F16)
788            }
789            DType::BF16 => {
790                let result: Vec<bf16> = f64_values.iter().map(|&v| bf16::from_f64(v)).collect();
791                let bytes = Bytes::from_elems(result);
792                FlexTensor::new(bytes, Layout::contiguous(shape), DType::BF16)
793            }
794            _ => panic!("float_cast: unsupported target dtype {:?}", target_dtype),
795        }
796    }
797
798    fn float_exp(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
799        unary::exp(tensor)
800    }
801
802    fn float_log(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
803        unary::log(tensor)
804    }
805
806    fn float_log1p(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
807        unary::log1p(tensor)
808    }
809
810    fn float_powf(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
811        binary_op(lhs, rhs, |a: f32, b| a.powf(b), |a: f64, b| a.powf(b), None)
812    }
813
814    fn float_powf_scalar_impl(tensor: FloatTensor<Flex>, value: Scalar) -> FloatTensor<Flex> {
815        let exp = value.to_f64().unwrap();
816        scalar_op(tensor, exp, |a: f32, b| a.powf(b), |a: f64, b| a.powf(b))
817    }
818
819    fn float_sqrt(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
820        unary::sqrt(tensor)
821    }
822
823    fn float_abs(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
824        unary::abs(tensor)
825    }
826
827    fn float_cos(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
828        unary::cos(tensor)
829    }
830
831    fn float_sin(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
832        unary::sin(tensor)
833    }
834
835    fn float_tan(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
836        unary::tan(tensor)
837    }
838
839    fn float_cosh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
840        unary::cosh(tensor)
841    }
842
843    fn float_sinh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
844        unary::sinh(tensor)
845    }
846
847    fn float_tanh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
848        unary::tanh(tensor)
849    }
850
851    fn float_acos(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
852        unary::acos(tensor)
853    }
854
855    fn float_acosh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
856        unary::acosh(tensor)
857    }
858
859    fn float_asin(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
860        unary::asin(tensor)
861    }
862
863    fn float_asinh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
864        unary::asinh(tensor)
865    }
866
867    fn float_atan(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
868        unary::atan(tensor)
869    }
870
871    fn float_atanh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
872        unary::atanh(tensor)
873    }
874
875    fn float_atan2(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
876        binary_op(
877            lhs,
878            rhs,
879            |a: f32, b| a.atan2(b),
880            |a: f64, b| a.atan2(b),
881            None,
882        )
883    }
884
885    fn float_round(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
886        unary::round(tensor)
887    }
888
889    fn float_floor(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
890        unary::floor(tensor)
891    }
892
893    fn float_ceil(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
894        unary::ceil(tensor)
895    }
896
897    fn float_trunc(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
898        unary::trunc(tensor)
899    }
900
901    fn float_erf(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
902        unary::erf(tensor)
903    }
904
905    fn float_argmax(
906        tensor: FloatTensor<Flex>,
907        dim: usize,
908        out_dtype: burn_std::IntDType,
909    ) -> IntTensor<Flex> {
910        let result = crate::ops::reduce::argmax(tensor, dim);
911        if result.dtype() != DType::from(out_dtype) {
912            Flex::int_cast(result, out_dtype)
913        } else {
914            result
915        }
916    }
917
918    fn float_argtopk(
919        _tensor: FloatTensor<Flex>,
920        _dim: usize,
921        _k: usize,
922        _out_dtype: burn_std::IntDType,
923    ) -> IntTensor<Flex> {
924        unimplemented!("float_argtopk not implemented for flex")
925    }
926
927    fn float_argmin(
928        tensor: FloatTensor<Flex>,
929        dim: usize,
930        out_dtype: burn_std::IntDType,
931    ) -> IntTensor<Flex> {
932        let result = crate::ops::reduce::argmin(tensor, dim);
933        if result.dtype() != DType::from(out_dtype) {
934            Flex::int_cast(result, out_dtype)
935        } else {
936            result
937        }
938    }
939
940    fn float_expand(tensor: FloatTensor<Flex>, shape: Shape) -> FloatTensor<Flex> {
941        crate::ops::expand::expand(tensor, shape)
942    }
943
944    fn float_unfold(
945        tensor: FloatTensor<Flex>,
946        dim: usize,
947        size: usize,
948        step: usize,
949    ) -> FloatTensor<Flex> {
950        // unfold is now type-agnostic (zero-copy strided view)
951        crate::ops::unfold::unfold(tensor, dim, size, step)
952    }
953
954    fn float_grid_sample_2d(
955        tensor: FloatTensor<Flex>,
956        grid: FloatTensor<Flex>,
957        options: GridSampleOptions,
958    ) -> FloatTensor<Flex> {
959        crate::ops::grid_sample::grid_sample_2d(tensor, grid, options)
960    }
961
962    fn float_zeros(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
963        FlexTensor::zeros(shape, dtype.into())
964    }
965
966    fn float_ones(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
967        let dt: burn_backend::DType = dtype.into();
968        match dt {
969            DType::F32 => FlexTensor::filled_typed(shape, dt, 1.0f32),
970            DType::F64 => FlexTensor::filled_typed(shape, dt, 1.0f64),
971            DType::F16 => FlexTensor::filled_typed(shape, dt, f16::ONE),
972            DType::BF16 => FlexTensor::filled_typed(shape, dt, bf16::ONE),
973            _ => unreachable!(),
974        }
975    }
976
977    fn float_full(
978        shape: Shape,
979        fill_value: Scalar,
980        _device: &Device<Flex>,
981        dtype: FloatDType,
982    ) -> FloatTensor<Flex> {
983        let dt: burn_backend::DType = dtype.into();
984        match dt {
985            DType::F32 => FlexTensor::filled_typed(shape, dt, fill_value.to_f32().unwrap()),
986            DType::F64 => FlexTensor::filled_typed(shape, dt, fill_value.to_f64().unwrap()),
987            DType::F16 => {
988                FlexTensor::filled_typed(shape, dt, f16::from_f32(fill_value.to_f32().unwrap()))
989            }
990            DType::BF16 => {
991                FlexTensor::filled_typed(shape, dt, bf16::from_f32(fill_value.to_f32().unwrap()))
992            }
993            _ => unreachable!(),
994        }
995    }
996
997    fn float_transpose(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
998        let ndims = tensor.layout().num_dims();
999        if ndims < 2 {
1000            return tensor;
1001        }
1002        tensor.transpose(ndims - 2, ndims - 1)
1003    }
1004
1005    fn float_repeat_dim(tensor: FloatTensor<Flex>, dim: usize, times: usize) -> FloatTensor<Flex> {
1006        crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
1007    }
1008
1009    fn float_sort(tensor: FloatTensor<Flex>, dim: usize, descending: bool) -> FloatTensor<Flex> {
1010        crate::ops::sort::sort(tensor, dim, descending)
1011    }
1012
1013    fn float_sort_with_indices(
1014        tensor: FloatTensor<Flex>,
1015        dim: usize,
1016        descending: bool,
1017        indices_dtype: burn_std::IntDType,
1018    ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
1019        let (values, indices) = crate::ops::sort::sort_with_indices(tensor, dim, descending);
1020        let indices = if indices.dtype() != DType::from(indices_dtype) {
1021            Flex::int_cast(indices, indices_dtype)
1022        } else {
1023            indices
1024        };
1025        (values, indices)
1026    }
1027
1028    fn float_argsort(
1029        tensor: FloatTensor<Flex>,
1030        dim: usize,
1031        descending: bool,
1032        out_dtype: burn_std::IntDType,
1033    ) -> IntTensor<Flex> {
1034        let indices = crate::ops::sort::argsort(tensor, dim, descending);
1035        if indices.dtype() != DType::from(out_dtype) {
1036            Flex::int_cast(indices, out_dtype)
1037        } else {
1038            indices
1039        }
1040    }
1041
1042    fn float_powi(lhs: FloatTensor<Flex>, rhs: IntTensor<Flex>) -> FloatTensor<Flex> {
1043        let dtype = lhs.dtype();
1044        Self::float_powf(lhs, Flex::int_into_float(rhs, dtype.into()))
1045    }
1046
1047    fn float_powi_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
1048        match rhs.to_i64().unwrap() {
1049            0 => Self::float_ones(lhs.shape(), &Default::default(), lhs.dtype().into()),
1050            1 => lhs,
1051            2 => Self::float_mul(lhs.clone(), lhs),
1052            -1 => Self::float_recip(lhs),
1053            -2 => Self::float_recip(Self::float_mul(lhs.clone(), lhs)),
1054            _ => Self::float_powf_scalar_impl(lhs, rhs),
1055        }
1056    }
1057
1058    fn float_powf_scalar(tensor: FloatTensor<Flex>, value: Scalar) -> FloatTensor<Flex> {
1059        if let Some(exp) = value.try_as_integer() {
1060            Self::float_powi_scalar(tensor, exp)
1061        } else {
1062            Self::float_powf_scalar_impl(tensor, value)
1063        }
1064    }
1065
1066    fn float_max_abs(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
1067        let abs = unary::abs(tensor);
1068        crate::ops::reduce::max(abs)
1069    }
1070
1071    fn float_max_abs_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
1072        let abs = unary::abs(tensor);
1073        crate::ops::reduce::max_dim(abs, dim)
1074    }
1075
1076    fn float_is_nan(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
1077        unary::float_predicate(tensor, out_dtype, |x: f32| x.is_nan(), |x: f64| x.is_nan())
1078    }
1079
1080    fn float_is_inf(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
1081        unary::float_predicate(
1082            tensor,
1083            out_dtype,
1084            |x: f32| x.is_infinite(),
1085            |x: f64| x.is_infinite(),
1086        )
1087    }
1088}
1089
1090// Tests kept here exercise flex-specific behavior: direct `Flex::`
1091// backend-op calls with explicit IntDType/FloatDType to pin dtype storage
1092// selection (U8/I32/I64, F16/F64). Plain arithmetic, math, cast, cross,
1093// unfold, and random smoke tests have been dropped in favor of the
1094// equivalent coverage in burn-backend-tests, which exercises every backend.
1095// When adding new tests, keep them here only if they probe flex dtype
1096// storage or flex internals; otherwise add them to
1097// crates/burn-backend-tests/tests/tensor/float/ops/.
1098#[cfg(test)]
1099mod tests {
1100    use burn_backend::TensorData;
1101
1102    use crate::Flex;
1103
1104    #[test]
1105    fn test_float_into_int_i32() {
1106        use burn_backend::ops::FloatTensorOps;
1107        use burn_std::IntDType;
1108
1109        let t = crate::FlexTensor::from_data(TensorData::from([1.5f32, -2.7, 0.0, 255.9]));
1110        let result = Flex::float_into_int(t, IntDType::I32);
1111        assert_eq!(result.dtype(), burn_backend::DType::I32);
1112        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1113        assert_eq!(data, vec![1, -2, 0, 255]);
1114    }
1115
1116    #[test]
1117    fn test_float_into_int_u8() {
1118        use burn_backend::ops::FloatTensorOps;
1119        use burn_std::IntDType;
1120
1121        let t = crate::FlexTensor::from_data(TensorData::from([0.0f32, 1.9, 127.5, 255.0]));
1122        let result = Flex::float_into_int(t, IntDType::U8);
1123        assert_eq!(result.dtype(), burn_backend::DType::U8);
1124        let data: Vec<u8> = result.into_data().to_vec().unwrap();
1125        assert_eq!(data, vec![0, 1, 127, 255]);
1126    }
1127
1128    #[test]
1129    fn test_float_argmax_i32_out_dtype() {
1130        use burn_backend::ops::FloatTensorOps;
1131        use burn_std::IntDType;
1132
1133        let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 3.0, 2.0]]));
1134        let result = Flex::float_argmax(t, 1, IntDType::I32);
1135        assert_eq!(result.dtype(), burn_backend::DType::I32);
1136        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1137        assert_eq!(data, vec![1]);
1138    }
1139
1140    #[test]
1141    fn test_float_argmin_i32_out_dtype() {
1142        use burn_backend::ops::FloatTensorOps;
1143        use burn_std::IntDType;
1144
1145        let t = crate::FlexTensor::from_data(TensorData::from([[3.0f32, 1.0, 2.0]]));
1146        let result = Flex::float_argmin(t, 1, IntDType::I32);
1147        assert_eq!(result.dtype(), burn_backend::DType::I32);
1148        let data: Vec<i32> = result.into_data().to_vec().unwrap();
1149        assert_eq!(data, vec![1]);
1150    }
1151
1152    #[test]
1153    fn test_float_argmax_i64_out_dtype() {
1154        use burn_backend::ops::FloatTensorOps;
1155        use burn_std::IntDType;
1156
1157        let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 3.0, 2.0]]));
1158        let result = Flex::float_argmax(t, 1, IntDType::I64);
1159        assert_eq!(result.dtype(), burn_backend::DType::I64);
1160        let data: Vec<i64> = result.into_data().to_vec().unwrap();
1161        assert_eq!(data, vec![1]);
1162    }
1163
1164    #[test]
1165    fn test_float_max_dim_with_indices_i32() {
1166        use burn_backend::ops::FloatTensorOps;
1167        use burn_std::IntDType;
1168
1169        let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 5.0], [3.0, 2.0]]));
1170        let (values, indices) = Flex::float_max_dim_with_indices(t, 1, IntDType::I32);
1171        assert_eq!(indices.dtype(), burn_backend::DType::I32);
1172        let idx: Vec<i32> = indices.into_data().to_vec().unwrap();
1173        assert_eq!(idx, vec![1, 0]);
1174        let vals: Vec<f32> = values.into_data().to_vec().unwrap();
1175        assert_eq!(vals, vec![5.0, 3.0]);
1176    }
1177
1178    #[test]
1179    fn test_float_min_dim_with_indices_i32() {
1180        use burn_backend::ops::FloatTensorOps;
1181        use burn_std::IntDType;
1182
1183        let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 5.0], [3.0, 2.0]]));
1184        let (values, indices) = Flex::float_min_dim_with_indices(t, 1, IntDType::I32);
1185        assert_eq!(indices.dtype(), burn_backend::DType::I32);
1186        let idx: Vec<i32> = indices.into_data().to_vec().unwrap();
1187        assert_eq!(idx, vec![0, 1]);
1188        let vals: Vec<f32> = values.into_data().to_vec().unwrap();
1189        assert_eq!(vals, vec![1.0, 2.0]);
1190    }
1191
1192    #[test]
1193    fn test_float_random_f64() {
1194        use burn_backend::{DType, FloatDType, ops::FloatTensorOps};
1195
1196        let shape = burn_std::Shape::from(vec![100]);
1197        let dist = burn_backend::Distribution::Uniform(0.0, 1.0);
1198        let device = crate::FlexDevice;
1199        let t = Flex::float_random(shape, dist, &device, FloatDType::F64);
1200        assert_eq!(t.dtype(), DType::F64);
1201        let data: Vec<f64> = t.into_data().to_vec().unwrap();
1202        assert!(data.iter().all(|&v| (0.0..=1.0).contains(&v)));
1203    }
1204
1205    #[test]
1206    fn test_float_random_f16() {
1207        use burn_backend::{DType, FloatDType, ops::FloatTensorOps};
1208
1209        let shape = burn_std::Shape::from(vec![100]);
1210        let dist = burn_backend::Distribution::Uniform(0.0, 1.0);
1211        let device = crate::FlexDevice;
1212        let t = Flex::float_random(shape, dist, &device, FloatDType::F16);
1213        assert_eq!(t.dtype(), DType::F16);
1214    }
1215}