1use alloc::vec;
4use alloc::vec::Vec;
5use burn_backend::{
6 DType, Distribution, ExecutionError, FloatDType, Scalar, TensorData, TensorMetadata,
7 ops::{FloatTensorOps, GridSampleOptions, IntTensorOps},
8 tensor::{BoolTensor, Device, FloatTensor, IntTensor},
9};
10use burn_std::{Bytes, IntDType, Shape, Slice, bf16, f16};
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float;
14
15use crate::Layout;
16use num_traits::ToPrimitive;
17
18use crate::ops::binary::{BinaryOp, binary_op, scalar_op};
19use crate::ops::matmul;
20use crate::ops::unary;
21use crate::{Flex, FlexTensor};
22
23impl FloatTensorOps<Flex> for Flex {
24 fn float_from_data(data: TensorData, _device: &Device<Flex>) -> FloatTensor<Flex> {
25 FlexTensor::from_data(data)
26 }
27
28 fn float_random(
29 shape: Shape,
30 distribution: Distribution,
31 _device: &Device<Flex>,
32 dtype: FloatDType,
33 ) -> FloatTensor<Flex> {
34 let mut seed = crate::backend::SEED.lock().unwrap();
35 let mut rng = seed.take().unwrap_or_else(crate::backend::get_seeded_rng);
36 let data = match dtype {
37 FloatDType::F64 => TensorData::random::<f64, _, _>(shape, distribution, &mut rng),
38 FloatDType::F32 | FloatDType::Flex32 => {
39 TensorData::random::<f32, _, _>(shape, distribution, &mut rng)
40 }
41 FloatDType::F16 => TensorData::random::<f16, _, _>(shape, distribution, &mut rng),
42 FloatDType::BF16 => TensorData::random::<bf16, _, _>(shape, distribution, &mut rng),
43 };
44 *seed = Some(rng);
45 FlexTensor::from_data(data)
46 }
47
48 async fn float_into_data(tensor: FloatTensor<Flex>) -> Result<TensorData, ExecutionError> {
49 Ok(tensor.into_data())
50 }
51
52 fn float_device(_tensor: &FloatTensor<Flex>) -> Device<Flex> {
53 Default::default()
55 }
56
57 fn float_to_device(tensor: FloatTensor<Flex>, _device: &Device<Flex>) -> FloatTensor<Flex> {
58 tensor
60 }
61
62 fn float_detach(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
63 tensor
64 }
65
66 fn float_into_int(tensor: FloatTensor<Flex>, out_dtype: burn_std::IntDType) -> IntTensor<Flex> {
67 let tensor = tensor.to_contiguous();
68 let shape = tensor.layout().shape().clone();
69 let src = tensor.dtype();
70 let out_dt = DType::from(out_dtype);
71
72 macro_rules! read_floats {
74 (|$x:ident| $conv:expr) => {
75 match src {
76 DType::F32 => tensor
77 .storage::<f32>()
78 .iter()
79 .map(|v| {
80 let $x = *v as f64;
81 $conv
82 })
83 .collect(),
84 DType::F64 => tensor
85 .storage::<f64>()
86 .iter()
87 .map(|v| {
88 let $x = *v;
89 $conv
90 })
91 .collect(),
92 DType::F16 => tensor
93 .storage::<f16>()
94 .iter()
95 .map(|v| {
96 let $x = f32::from(*v) as f64;
97 $conv
98 })
99 .collect(),
100 DType::BF16 => tensor
101 .storage::<bf16>()
102 .iter()
103 .map(|v| {
104 let $x = f32::from(*v) as f64;
105 $conv
106 })
107 .collect(),
108 _ => panic!("float_into_int: unsupported source dtype {:?}", src),
109 }
110 };
111 }
112
113 macro_rules! convert {
114 ($int_ty:ty) => {{
115 let data: Vec<$int_ty> = read_floats!(|x| x as $int_ty);
116 FlexTensor::new(Bytes::from_elems(data), Layout::contiguous(shape), out_dt)
117 }};
118 }
119
120 match out_dtype {
121 IntDType::I64 => convert!(i64),
122 IntDType::I32 => convert!(i32),
123 IntDType::I16 => convert!(i16),
124 IntDType::I8 => convert!(i8),
125 IntDType::U64 => convert!(u64),
126 IntDType::U32 => convert!(u32),
127 IntDType::U16 => convert!(u16),
128 IntDType::U8 => convert!(u8),
129 }
130 }
131
132 fn float_empty(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
133 FlexTensor::empty(shape, dtype.into())
134 }
135
136 fn float_add(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
137 binary_op(lhs, rhs, |a, b| a + b, |a, b| a + b, Some(BinaryOp::Add))
138 }
139
140 fn float_add_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
141 let rhs_val = rhs.to_f64().unwrap();
142 scalar_op(lhs, rhs_val, |a, b| a + b, |a, b| a + b)
143 }
144
145 fn float_sub(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
146 binary_op(lhs, rhs, |a, b| a - b, |a, b| a - b, Some(BinaryOp::Sub))
147 }
148
149 fn float_sub_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
150 let rhs_val = rhs.to_f64().unwrap();
151 scalar_op(lhs, rhs_val, |a, b| a - b, |a, b| a - b)
152 }
153
154 fn float_mul(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
155 binary_op(lhs, rhs, |a, b| a * b, |a, b| a * b, Some(BinaryOp::Mul))
156 }
157
158 fn float_mul_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
159 let rhs_val = rhs.to_f64().unwrap();
160 scalar_op(lhs, rhs_val, |a, b| a * b, |a, b| a * b)
161 }
162
163 fn float_div(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
164 binary_op(lhs, rhs, |a, b| a / b, |a, b| a / b, Some(BinaryOp::Div))
165 }
166
167 fn float_div_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
168 let rhs_val = rhs.to_f64().unwrap();
169 scalar_op(lhs, rhs_val, |a, b| a / b, |a, b| a / b)
170 }
171
172 fn float_remainder(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
173 binary_op(
175 lhs,
176 rhs,
177 |a, b| ((a % b) + b) % b,
178 |a, b| ((a % b) + b) % b,
179 None,
180 )
181 }
182
183 fn float_remainder_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
184 let rhs_val = rhs.to_f64().unwrap();
185 scalar_op(
187 lhs,
188 rhs_val,
189 |a, b| ((a % b) + b) % b,
190 |a, b| ((a % b) + b) % b,
191 )
192 }
193
194 fn float_matmul(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
195 matmul::matmul(lhs, rhs)
196 }
197
198 fn float_cross(
199 lhs: FloatTensor<Flex>,
200 rhs: FloatTensor<Flex>,
201 dim: usize,
202 ) -> FloatTensor<Flex> {
203 let shape = lhs.layout().shape();
204 let ndims = shape.num_dims();
205 assert_eq!(
206 shape[dim], 3,
207 "cross product requires dimension {} to have size 3, got {}",
208 dim, shape[dim]
209 );
210
211 let make_slices = |idx: usize| -> alloc::vec::Vec<Slice> {
213 (0..ndims)
214 .map(|d| {
215 if d == dim {
216 Slice::new(idx as isize, Some((idx + 1) as isize), 1)
217 } else {
218 Slice::new(0, None, 1)
219 }
220 })
221 .collect()
222 };
223
224 let a0 = Self::float_slice(lhs.clone(), &make_slices(0));
227 let a1 = Self::float_slice(lhs.clone(), &make_slices(1));
228 let a2 = Self::float_slice(lhs, &make_slices(2));
229
230 let b0 = Self::float_slice(rhs.clone(), &make_slices(0));
231 let b1 = Self::float_slice(rhs.clone(), &make_slices(1));
232 let b2 = Self::float_slice(rhs, &make_slices(2));
233
234 let c0 = Self::float_sub(
239 Self::float_mul(a1.clone(), b2.clone()),
240 Self::float_mul(a2.clone(), b1.clone()),
241 );
242 let c1 = Self::float_sub(
243 Self::float_mul(a2, b0.clone()),
244 Self::float_mul(a0.clone(), b2),
245 );
246 let c2 = Self::float_sub(Self::float_mul(a0, b1), Self::float_mul(a1, b0));
247
248 Self::float_cat(vec![c0, c1, c2], dim)
250 }
251
252 fn float_recip(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
253 unary::recip(tensor)
254 }
255
256 fn float_swap_dims(tensor: FloatTensor<Flex>, dim1: usize, dim2: usize) -> FloatTensor<Flex> {
257 tensor.transpose(dim1, dim2)
258 }
259
260 fn float_permute(tensor: FloatTensor<Flex>, axes: &[usize]) -> FloatTensor<Flex> {
261 tensor.permute(axes)
262 }
263
264 fn float_flip(tensor: FloatTensor<Flex>, axes: &[usize]) -> FloatTensor<Flex> {
265 crate::ops::flip::flip(tensor, axes)
266 }
267
268 fn float_cat(tensors: Vec<FloatTensor<Flex>>, dim: usize) -> FloatTensor<Flex> {
269 crate::ops::cat::cat(tensors, dim)
270 }
271
272 fn float_reshape(tensor: FloatTensor<Flex>, shape: Shape) -> FloatTensor<Flex> {
273 tensor.reshape(shape)
274 }
275
276 fn float_gather(
277 dim: usize,
278 tensor: FloatTensor<Flex>,
279 indices: IntTensor<Flex>,
280 ) -> FloatTensor<Flex> {
281 match tensor.dtype() {
282 DType::F32 => crate::ops::gather_scatter::gather::<f32>(tensor, dim, indices),
283 DType::F64 => crate::ops::gather_scatter::gather::<f64>(tensor, dim, indices),
284 DType::F16 => crate::ops::gather_scatter::gather::<f16>(tensor, dim, indices),
285 DType::BF16 => crate::ops::gather_scatter::gather::<bf16>(tensor, dim, indices),
286 _ => panic!("float_gather: unsupported dtype {:?}", tensor.dtype()),
287 }
288 }
289
290 fn float_scatter_add(
291 dim: usize,
292 tensor: FloatTensor<Flex>,
293 indices: IntTensor<Flex>,
294 value: FloatTensor<Flex>,
295 ) -> FloatTensor<Flex> {
296 match tensor.dtype() {
297 DType::F32 => {
298 crate::ops::gather_scatter::scatter_add::<f32>(tensor, dim, indices, value)
299 }
300 DType::F64 => {
301 crate::ops::gather_scatter::scatter_add::<f64>(tensor, dim, indices, value)
302 }
303 DType::F16 => {
304 crate::ops::gather_scatter::scatter_add::<f16>(tensor, dim, indices, value)
305 }
306 DType::BF16 => {
307 crate::ops::gather_scatter::scatter_add::<bf16>(tensor, dim, indices, value)
308 }
309 _ => panic!("float_scatter_add: unsupported dtype {:?}", tensor.dtype()),
310 }
311 }
312
313 fn float_scatter_nd(
314 data: FloatTensor<Flex>,
315 indices: IntTensor<Flex>,
316 values: FloatTensor<Flex>,
317 reduction: burn_backend::tensor::IndexingUpdateOp,
318 ) -> FloatTensor<Flex> {
319 match data.dtype() {
320 DType::F32 => {
321 crate::ops::gather_scatter::scatter_nd::<f32>(data, indices, values, reduction)
322 }
323 DType::F64 => {
324 crate::ops::gather_scatter::scatter_nd::<f64>(data, indices, values, reduction)
325 }
326 DType::F16 => {
327 crate::ops::gather_scatter::scatter_nd::<f16>(data, indices, values, reduction)
328 }
329 DType::BF16 => {
330 crate::ops::gather_scatter::scatter_nd::<bf16>(data, indices, values, reduction)
331 }
332 _ => panic!("float_scatter_nd: unsupported dtype {:?}", data.dtype()),
333 }
334 }
335
336 fn float_gather_nd(data: FloatTensor<Flex>, indices: IntTensor<Flex>) -> FloatTensor<Flex> {
337 match data.dtype() {
338 DType::F32 => crate::ops::gather_scatter::gather_nd::<f32>(data, indices),
339 DType::F64 => crate::ops::gather_scatter::gather_nd::<f64>(data, indices),
340 DType::F16 => crate::ops::gather_scatter::gather_nd::<f16>(data, indices),
341 DType::BF16 => crate::ops::gather_scatter::gather_nd::<bf16>(data, indices),
342 _ => panic!("float_gather_nd: unsupported dtype {:?}", data.dtype()),
343 }
344 }
345
346 fn float_select(
347 tensor: FloatTensor<Flex>,
348 dim: usize,
349 indices: IntTensor<Flex>,
350 ) -> FloatTensor<Flex> {
351 match tensor.dtype() {
352 DType::F32 => crate::ops::gather_scatter::select::<f32>(tensor, dim, indices),
353 DType::F64 => crate::ops::gather_scatter::select::<f64>(tensor, dim, indices),
354 DType::F16 => crate::ops::gather_scatter::select::<f16>(tensor, dim, indices),
355 DType::BF16 => crate::ops::gather_scatter::select::<bf16>(tensor, dim, indices),
356 _ => panic!("float_select: unsupported dtype {:?}", tensor.dtype()),
357 }
358 }
359
360 fn float_select_add(
361 tensor: FloatTensor<Flex>,
362 dim: usize,
363 indices: IntTensor<Flex>,
364 value: FloatTensor<Flex>,
365 ) -> FloatTensor<Flex> {
366 match tensor.dtype() {
367 DType::F32 => {
368 crate::ops::gather_scatter::select_add::<f32>(tensor, dim, indices, value)
369 }
370 DType::F64 => {
371 crate::ops::gather_scatter::select_add::<f64>(tensor, dim, indices, value)
372 }
373 DType::F16 => {
374 crate::ops::gather_scatter::select_add::<f16>(tensor, dim, indices, value)
375 }
376 DType::BF16 => {
377 crate::ops::gather_scatter::select_add::<bf16>(tensor, dim, indices, value)
378 }
379 _ => panic!("float_select_add: unsupported dtype {:?}", tensor.dtype()),
380 }
381 }
382
383 fn float_slice(tensor: FloatTensor<Flex>, slices: &[Slice]) -> FloatTensor<Flex> {
384 crate::ops::slice::slice(tensor, slices)
385 }
386
387 fn float_slice_assign(
388 tensor: FloatTensor<Flex>,
389 slices: &[Slice],
390 value: FloatTensor<Flex>,
391 ) -> FloatTensor<Flex> {
392 crate::ops::slice::slice_assign(tensor, slices, value)
393 }
394
395 fn float_mask_where(
396 tensor: FloatTensor<Flex>,
397 mask: BoolTensor<Flex>,
398 value: FloatTensor<Flex>,
399 ) -> FloatTensor<Flex> {
400 match tensor.dtype() {
401 DType::F32 => crate::ops::mask::mask_where_f32(tensor, mask, value),
402 DType::F64 => crate::ops::mask::mask_where_f64(tensor, mask, value),
403 DType::F16 => crate::ops::mask::mask_where_f16(tensor, mask, value),
404 DType::BF16 => crate::ops::mask::mask_where_bf16(tensor, mask, value),
405 dtype => panic!("float_mask_where: unsupported dtype {:?}", dtype),
406 }
407 }
408
409 fn float_mask_fill(
410 tensor: FloatTensor<Flex>,
411 mask: BoolTensor<Flex>,
412 value: Scalar,
413 ) -> FloatTensor<Flex> {
414 match tensor.dtype() {
415 DType::F32 => crate::ops::mask::mask_fill_f32(tensor, mask, value.to_f32().unwrap()),
416 DType::F64 => crate::ops::mask::mask_fill_f64(tensor, mask, value.to_f64().unwrap()),
417 DType::F16 => crate::ops::mask::mask_fill_f16(
418 tensor,
419 mask,
420 f16::from_f64(value.to_f64().unwrap()),
421 ),
422 DType::BF16 => crate::ops::mask::mask_fill_bf16(
423 tensor,
424 mask,
425 bf16::from_f64(value.to_f64().unwrap()),
426 ),
427 dtype => panic!("float_mask_fill: unsupported dtype {:?}", dtype),
428 }
429 }
430
431 fn float_equal(
432 lhs: FloatTensor<Flex>,
433 rhs: FloatTensor<Flex>,
434 out_dtype: burn_std::BoolDType,
435 ) -> BoolTensor<Flex> {
436 crate::ops::comparison::equal(lhs, rhs, out_dtype)
437 }
438
439 fn float_equal_elem(
440 lhs: FloatTensor<Flex>,
441 rhs: Scalar,
442 out_dtype: burn_std::BoolDType,
443 ) -> BoolTensor<Flex> {
444 crate::ops::comparison::equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
445 }
446
447 fn float_greater(
448 lhs: FloatTensor<Flex>,
449 rhs: FloatTensor<Flex>,
450 out_dtype: burn_std::BoolDType,
451 ) -> BoolTensor<Flex> {
452 crate::ops::comparison::greater(lhs, rhs, out_dtype)
453 }
454
455 fn float_greater_elem(
456 lhs: FloatTensor<Flex>,
457 rhs: Scalar,
458 out_dtype: burn_std::BoolDType,
459 ) -> BoolTensor<Flex> {
460 crate::ops::comparison::greater_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
461 }
462
463 fn float_greater_equal(
464 lhs: FloatTensor<Flex>,
465 rhs: FloatTensor<Flex>,
466 out_dtype: burn_std::BoolDType,
467 ) -> BoolTensor<Flex> {
468 crate::ops::comparison::greater_equal(lhs, rhs, out_dtype)
469 }
470
471 fn float_greater_equal_elem(
472 lhs: FloatTensor<Flex>,
473 rhs: Scalar,
474 out_dtype: burn_std::BoolDType,
475 ) -> BoolTensor<Flex> {
476 crate::ops::comparison::greater_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
477 }
478
479 fn float_lower(
480 lhs: FloatTensor<Flex>,
481 rhs: FloatTensor<Flex>,
482 out_dtype: burn_std::BoolDType,
483 ) -> BoolTensor<Flex> {
484 crate::ops::comparison::lower(lhs, rhs, out_dtype)
485 }
486
487 fn float_lower_elem(
488 lhs: FloatTensor<Flex>,
489 rhs: Scalar,
490 out_dtype: burn_std::BoolDType,
491 ) -> BoolTensor<Flex> {
492 crate::ops::comparison::lower_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
493 }
494
495 fn float_lower_equal(
496 lhs: FloatTensor<Flex>,
497 rhs: FloatTensor<Flex>,
498 out_dtype: burn_std::BoolDType,
499 ) -> BoolTensor<Flex> {
500 crate::ops::comparison::lower_equal(lhs, rhs, out_dtype)
501 }
502
503 fn float_lower_equal_elem(
504 lhs: FloatTensor<Flex>,
505 rhs: Scalar,
506 out_dtype: burn_std::BoolDType,
507 ) -> BoolTensor<Flex> {
508 crate::ops::comparison::lower_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
509 }
510
511 fn float_not_equal(
512 lhs: FloatTensor<Flex>,
513 rhs: FloatTensor<Flex>,
514 out_dtype: burn_std::BoolDType,
515 ) -> BoolTensor<Flex> {
516 crate::ops::comparison::not_equal(lhs, rhs, out_dtype)
517 }
518
519 fn float_not_equal_elem(
520 lhs: FloatTensor<Flex>,
521 rhs: Scalar,
522 out_dtype: burn_std::BoolDType,
523 ) -> BoolTensor<Flex> {
524 crate::ops::comparison::not_equal_elem(lhs, rhs.to_f64().unwrap(), out_dtype)
525 }
526
527 fn float_neg(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
528 unary::unary_op(tensor, |x: f32| -x, |x: f64| -x)
529 }
530
531 fn float_clamp(tensor: FloatTensor<Flex>, min: Scalar, max: Scalar) -> FloatTensor<Flex> {
532 let min32 = min.to_f32().unwrap();
533 let max32 = max.to_f32().unwrap();
534 let min64 = min.to_f64().unwrap();
535 let max64 = max.to_f64().unwrap();
536 unary::unary_op(
537 tensor,
538 move |x: f32| x.clamp(min32, max32),
539 move |x: f64| x.clamp(min64, max64),
540 )
541 }
542
543 fn float_clamp_min(tensor: FloatTensor<Flex>, min: Scalar) -> FloatTensor<Flex> {
544 let min32 = min.to_f32().unwrap();
545 let min64 = min.to_f64().unwrap();
546 unary::unary_op(
547 tensor,
548 move |x: f32| x.max(min32),
549 move |x: f64| x.max(min64),
550 )
551 }
552
553 fn float_clamp_max(tensor: FloatTensor<Flex>, max: Scalar) -> FloatTensor<Flex> {
554 let max32 = max.to_f32().unwrap();
555 let max64 = max.to_f64().unwrap();
556 unary::unary_op(
557 tensor,
558 move |x: f32| x.min(max32),
559 move |x: f64| x.min(max64),
560 )
561 }
562
563 fn float_sign(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
564 unary::unary_op(
565 tensor,
566 |x: f32| {
567 if x.is_nan() {
568 x
569 } else if x > 0.0 {
570 1.0
571 } else if x < 0.0 {
572 -1.0
573 } else {
574 0.0
575 }
576 },
577 |x: f64| {
578 if x.is_nan() {
579 x
580 } else if x > 0.0 {
581 1.0
582 } else if x < 0.0 {
583 -1.0
584 } else {
585 0.0
586 }
587 },
588 )
589 }
590
591 fn float_mean(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
592 crate::ops::reduce::mean(tensor)
593 }
594
595 fn float_max(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
596 crate::ops::reduce::max(tensor)
597 }
598
599 fn float_max_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
600 crate::ops::reduce::max_dim(tensor, dim)
601 }
602
603 fn float_min(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
604 crate::ops::reduce::min(tensor)
605 }
606
607 fn float_min_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
608 crate::ops::reduce::min_dim(tensor, dim)
609 }
610
611 fn float_max_dim_with_indices(
612 tensor: FloatTensor<Flex>,
613 dim: usize,
614 indices_dtype: burn_std::IntDType,
615 ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
616 let (values, indices) = crate::ops::reduce::max_dim_with_indices(tensor, dim);
617 if indices.dtype() != DType::from(indices_dtype) {
618 (values, Flex::int_cast(indices, indices_dtype))
619 } else {
620 (values, indices)
621 }
622 }
623
624 fn float_min_dim_with_indices(
625 tensor: FloatTensor<Flex>,
626 dim: usize,
627 indices_dtype: burn_std::IntDType,
628 ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
629 let (values, indices) = crate::ops::reduce::min_dim_with_indices(tensor, dim);
630 if indices.dtype() != DType::from(indices_dtype) {
631 (values, Flex::int_cast(indices, indices_dtype))
632 } else {
633 (values, indices)
634 }
635 }
636
637 fn float_any(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
638 crate::ops::comparison::any_float(tensor, out_dtype)
639 }
640
641 fn float_any_dim(
642 tensor: FloatTensor<Flex>,
643 dim: usize,
644 out_dtype: burn_std::BoolDType,
645 ) -> BoolTensor<Flex> {
646 crate::ops::comparison::any_float_dim(tensor, dim, out_dtype)
647 }
648
649 fn float_all(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
650 crate::ops::comparison::all_float(tensor, out_dtype)
651 }
652
653 fn float_all_dim(
654 tensor: FloatTensor<Flex>,
655 dim: usize,
656 out_dtype: burn_std::BoolDType,
657 ) -> BoolTensor<Flex> {
658 crate::ops::comparison::all_float_dim(tensor, dim, out_dtype)
659 }
660
661 fn float_sum(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
662 crate::ops::reduce::sum(tensor)
663 }
664
665 fn float_sum_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
666 crate::ops::reduce::sum_dim(tensor, dim)
667 }
668
669 fn float_mean_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
670 crate::ops::reduce::mean_dim(tensor, dim)
671 }
672
673 fn float_prod(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
674 crate::ops::reduce::prod(tensor)
675 }
676
677 fn float_prod_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
678 crate::ops::reduce::prod_dim(tensor, dim)
679 }
680
681 fn float_cumsum(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
682 match tensor.dtype() {
683 DType::F32 => crate::ops::cumulative::cumsum_f32(tensor, dim),
684 DType::F64 => crate::ops::cumulative::cumsum_f64(tensor, dim),
685 DType::F16 => {
686 crate::ops::cumulative::cumsum_half(tensor, dim, f16::to_f32, f16::from_f32)
687 }
688 DType::BF16 => {
689 crate::ops::cumulative::cumsum_half(tensor, dim, bf16::to_f32, bf16::from_f32)
690 }
691 _ => panic!("float_cumsum: unsupported dtype {:?}", tensor.dtype()),
692 }
693 }
694
695 fn float_cumprod(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
696 match tensor.dtype() {
697 DType::F32 => crate::ops::cumulative::cumprod_f32(tensor, dim),
698 DType::F64 => crate::ops::cumulative::cumprod_f64(tensor, dim),
699 DType::F16 => {
700 crate::ops::cumulative::cumprod_half(tensor, dim, f16::to_f32, f16::from_f32)
701 }
702 DType::BF16 => {
703 crate::ops::cumulative::cumprod_half(tensor, dim, bf16::to_f32, bf16::from_f32)
704 }
705 _ => panic!("float_cumprod: unsupported dtype {:?}", tensor.dtype()),
706 }
707 }
708
709 fn float_cummin(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
710 match tensor.dtype() {
711 DType::F32 => crate::ops::cumulative::cummin_f32(tensor, dim),
712 DType::F64 => crate::ops::cumulative::cummin_f64(tensor, dim),
713 DType::F16 => {
714 crate::ops::cumulative::cummin_half(tensor, dim, f16::to_f32, f16::from_f32)
715 }
716 DType::BF16 => {
717 crate::ops::cumulative::cummin_half(tensor, dim, bf16::to_f32, bf16::from_f32)
718 }
719 _ => panic!("float_cummin: unsupported dtype {:?}", tensor.dtype()),
720 }
721 }
722
723 fn float_cummax(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
724 match tensor.dtype() {
725 DType::F32 => crate::ops::cumulative::cummax_f32(tensor, dim),
726 DType::F64 => crate::ops::cumulative::cummax_f64(tensor, dim),
727 DType::F16 => {
728 crate::ops::cumulative::cummax_half(tensor, dim, f16::to_f32, f16::from_f32)
729 }
730 DType::BF16 => {
731 crate::ops::cumulative::cummax_half(tensor, dim, bf16::to_f32, bf16::from_f32)
732 }
733 _ => panic!("float_cummax: unsupported dtype {:?}", tensor.dtype()),
734 }
735 }
736
737 fn float_cast(tensor: FloatTensor<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
738 use crate::Layout;
739 use burn_std::{Bytes, bf16, f16};
740
741 let src_dtype = tensor.dtype();
742 let target_dtype = DType::from(dtype);
743
744 if src_dtype == target_dtype {
746 return tensor;
747 }
748
749 let tensor = tensor.to_contiguous();
750 let shape = tensor.layout().shape().clone();
751
752 let f64_values: Vec<f64> = match src_dtype {
754 DType::F32 => {
755 let src: &[f32] = tensor.storage();
756 src.iter().map(|&v| v as f64).collect()
757 }
758 DType::F64 => {
759 let src: &[f64] = tensor.storage();
760 src.to_vec()
761 }
762 DType::F16 => {
763 let src: &[f16] = tensor.storage();
764 src.iter().map(|&v| v.to_f32() as f64).collect()
765 }
766 DType::BF16 => {
767 let src: &[bf16] = tensor.storage();
768 src.iter().map(|&v| v.to_f32() as f64).collect()
769 }
770 _ => panic!("float_cast: unsupported source dtype {:?}", src_dtype),
771 };
772
773 match target_dtype {
775 DType::F32 => {
776 let result: Vec<f32> = f64_values.iter().map(|&v| v as f32).collect();
777 let bytes = Bytes::from_elems(result);
778 FlexTensor::new(bytes, Layout::contiguous(shape), DType::F32)
779 }
780 DType::F64 => {
781 let bytes = Bytes::from_elems(f64_values);
782 FlexTensor::new(bytes, Layout::contiguous(shape), DType::F64)
783 }
784 DType::F16 => {
785 let result: Vec<f16> = f64_values.iter().map(|&v| f16::from_f64(v)).collect();
786 let bytes = Bytes::from_elems(result);
787 FlexTensor::new(bytes, Layout::contiguous(shape), DType::F16)
788 }
789 DType::BF16 => {
790 let result: Vec<bf16> = f64_values.iter().map(|&v| bf16::from_f64(v)).collect();
791 let bytes = Bytes::from_elems(result);
792 FlexTensor::new(bytes, Layout::contiguous(shape), DType::BF16)
793 }
794 _ => panic!("float_cast: unsupported target dtype {:?}", target_dtype),
795 }
796 }
797
798 fn float_exp(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
799 unary::exp(tensor)
800 }
801
802 fn float_log(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
803 unary::log(tensor)
804 }
805
806 fn float_log1p(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
807 unary::log1p(tensor)
808 }
809
810 fn float_powf(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
811 binary_op(lhs, rhs, |a: f32, b| a.powf(b), |a: f64, b| a.powf(b), None)
812 }
813
814 fn float_powf_scalar_impl(tensor: FloatTensor<Flex>, value: Scalar) -> FloatTensor<Flex> {
815 let exp = value.to_f64().unwrap();
816 scalar_op(tensor, exp, |a: f32, b| a.powf(b), |a: f64, b| a.powf(b))
817 }
818
819 fn float_sqrt(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
820 unary::sqrt(tensor)
821 }
822
823 fn float_abs(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
824 unary::abs(tensor)
825 }
826
827 fn float_cos(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
828 unary::cos(tensor)
829 }
830
831 fn float_sin(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
832 unary::sin(tensor)
833 }
834
835 fn float_tan(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
836 unary::tan(tensor)
837 }
838
839 fn float_cosh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
840 unary::cosh(tensor)
841 }
842
843 fn float_sinh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
844 unary::sinh(tensor)
845 }
846
847 fn float_tanh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
848 unary::tanh(tensor)
849 }
850
851 fn float_acos(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
852 unary::acos(tensor)
853 }
854
855 fn float_acosh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
856 unary::acosh(tensor)
857 }
858
859 fn float_asin(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
860 unary::asin(tensor)
861 }
862
863 fn float_asinh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
864 unary::asinh(tensor)
865 }
866
867 fn float_atan(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
868 unary::atan(tensor)
869 }
870
871 fn float_atanh(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
872 unary::atanh(tensor)
873 }
874
875 fn float_atan2(lhs: FloatTensor<Flex>, rhs: FloatTensor<Flex>) -> FloatTensor<Flex> {
876 binary_op(
877 lhs,
878 rhs,
879 |a: f32, b| a.atan2(b),
880 |a: f64, b| a.atan2(b),
881 None,
882 )
883 }
884
885 fn float_round(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
886 unary::round(tensor)
887 }
888
889 fn float_floor(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
890 unary::floor(tensor)
891 }
892
893 fn float_ceil(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
894 unary::ceil(tensor)
895 }
896
897 fn float_trunc(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
898 unary::trunc(tensor)
899 }
900
901 fn float_erf(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
902 unary::erf(tensor)
903 }
904
905 fn float_argmax(
906 tensor: FloatTensor<Flex>,
907 dim: usize,
908 out_dtype: burn_std::IntDType,
909 ) -> IntTensor<Flex> {
910 let result = crate::ops::reduce::argmax(tensor, dim);
911 if result.dtype() != DType::from(out_dtype) {
912 Flex::int_cast(result, out_dtype)
913 } else {
914 result
915 }
916 }
917
918 fn float_argtopk(
919 _tensor: FloatTensor<Flex>,
920 _dim: usize,
921 _k: usize,
922 _out_dtype: burn_std::IntDType,
923 ) -> IntTensor<Flex> {
924 unimplemented!("float_argtopk not implemented for flex")
925 }
926
927 fn float_argmin(
928 tensor: FloatTensor<Flex>,
929 dim: usize,
930 out_dtype: burn_std::IntDType,
931 ) -> IntTensor<Flex> {
932 let result = crate::ops::reduce::argmin(tensor, dim);
933 if result.dtype() != DType::from(out_dtype) {
934 Flex::int_cast(result, out_dtype)
935 } else {
936 result
937 }
938 }
939
940 fn float_expand(tensor: FloatTensor<Flex>, shape: Shape) -> FloatTensor<Flex> {
941 crate::ops::expand::expand(tensor, shape)
942 }
943
944 fn float_unfold(
945 tensor: FloatTensor<Flex>,
946 dim: usize,
947 size: usize,
948 step: usize,
949 ) -> FloatTensor<Flex> {
950 crate::ops::unfold::unfold(tensor, dim, size, step)
952 }
953
954 fn float_grid_sample_2d(
955 tensor: FloatTensor<Flex>,
956 grid: FloatTensor<Flex>,
957 options: GridSampleOptions,
958 ) -> FloatTensor<Flex> {
959 crate::ops::grid_sample::grid_sample_2d(tensor, grid, options)
960 }
961
962 fn float_zeros(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
963 FlexTensor::zeros(shape, dtype.into())
964 }
965
966 fn float_ones(shape: Shape, _device: &Device<Flex>, dtype: FloatDType) -> FloatTensor<Flex> {
967 let dt: burn_backend::DType = dtype.into();
968 match dt {
969 DType::F32 => FlexTensor::filled_typed(shape, dt, 1.0f32),
970 DType::F64 => FlexTensor::filled_typed(shape, dt, 1.0f64),
971 DType::F16 => FlexTensor::filled_typed(shape, dt, f16::ONE),
972 DType::BF16 => FlexTensor::filled_typed(shape, dt, bf16::ONE),
973 _ => unreachable!(),
974 }
975 }
976
977 fn float_full(
978 shape: Shape,
979 fill_value: Scalar,
980 _device: &Device<Flex>,
981 dtype: FloatDType,
982 ) -> FloatTensor<Flex> {
983 let dt: burn_backend::DType = dtype.into();
984 match dt {
985 DType::F32 => FlexTensor::filled_typed(shape, dt, fill_value.to_f32().unwrap()),
986 DType::F64 => FlexTensor::filled_typed(shape, dt, fill_value.to_f64().unwrap()),
987 DType::F16 => {
988 FlexTensor::filled_typed(shape, dt, f16::from_f32(fill_value.to_f32().unwrap()))
989 }
990 DType::BF16 => {
991 FlexTensor::filled_typed(shape, dt, bf16::from_f32(fill_value.to_f32().unwrap()))
992 }
993 _ => unreachable!(),
994 }
995 }
996
997 fn float_transpose(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
998 let ndims = tensor.layout().num_dims();
999 if ndims < 2 {
1000 return tensor;
1001 }
1002 tensor.transpose(ndims - 2, ndims - 1)
1003 }
1004
1005 fn float_repeat_dim(tensor: FloatTensor<Flex>, dim: usize, times: usize) -> FloatTensor<Flex> {
1006 crate::ops::repeat_dim::repeat_dim(tensor, dim, times)
1007 }
1008
1009 fn float_sort(tensor: FloatTensor<Flex>, dim: usize, descending: bool) -> FloatTensor<Flex> {
1010 crate::ops::sort::sort(tensor, dim, descending)
1011 }
1012
1013 fn float_sort_with_indices(
1014 tensor: FloatTensor<Flex>,
1015 dim: usize,
1016 descending: bool,
1017 indices_dtype: burn_std::IntDType,
1018 ) -> (FloatTensor<Flex>, IntTensor<Flex>) {
1019 let (values, indices) = crate::ops::sort::sort_with_indices(tensor, dim, descending);
1020 let indices = if indices.dtype() != DType::from(indices_dtype) {
1021 Flex::int_cast(indices, indices_dtype)
1022 } else {
1023 indices
1024 };
1025 (values, indices)
1026 }
1027
1028 fn float_argsort(
1029 tensor: FloatTensor<Flex>,
1030 dim: usize,
1031 descending: bool,
1032 out_dtype: burn_std::IntDType,
1033 ) -> IntTensor<Flex> {
1034 let indices = crate::ops::sort::argsort(tensor, dim, descending);
1035 if indices.dtype() != DType::from(out_dtype) {
1036 Flex::int_cast(indices, out_dtype)
1037 } else {
1038 indices
1039 }
1040 }
1041
1042 fn float_powi(lhs: FloatTensor<Flex>, rhs: IntTensor<Flex>) -> FloatTensor<Flex> {
1043 let dtype = lhs.dtype();
1044 Self::float_powf(lhs, Flex::int_into_float(rhs, dtype.into()))
1045 }
1046
1047 fn float_powi_scalar(lhs: FloatTensor<Flex>, rhs: Scalar) -> FloatTensor<Flex> {
1048 match rhs.to_i64().unwrap() {
1049 0 => Self::float_ones(lhs.shape(), &Default::default(), lhs.dtype().into()),
1050 1 => lhs,
1051 2 => Self::float_mul(lhs.clone(), lhs),
1052 -1 => Self::float_recip(lhs),
1053 -2 => Self::float_recip(Self::float_mul(lhs.clone(), lhs)),
1054 _ => Self::float_powf_scalar_impl(lhs, rhs),
1055 }
1056 }
1057
1058 fn float_powf_scalar(tensor: FloatTensor<Flex>, value: Scalar) -> FloatTensor<Flex> {
1059 if let Some(exp) = value.try_as_integer() {
1060 Self::float_powi_scalar(tensor, exp)
1061 } else {
1062 Self::float_powf_scalar_impl(tensor, value)
1063 }
1064 }
1065
1066 fn float_max_abs(tensor: FloatTensor<Flex>) -> FloatTensor<Flex> {
1067 let abs = unary::abs(tensor);
1068 crate::ops::reduce::max(abs)
1069 }
1070
1071 fn float_max_abs_dim(tensor: FloatTensor<Flex>, dim: usize) -> FloatTensor<Flex> {
1072 let abs = unary::abs(tensor);
1073 crate::ops::reduce::max_dim(abs, dim)
1074 }
1075
1076 fn float_is_nan(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
1077 unary::float_predicate(tensor, out_dtype, |x: f32| x.is_nan(), |x: f64| x.is_nan())
1078 }
1079
1080 fn float_is_inf(tensor: FloatTensor<Flex>, out_dtype: burn_std::BoolDType) -> BoolTensor<Flex> {
1081 unary::float_predicate(
1082 tensor,
1083 out_dtype,
1084 |x: f32| x.is_infinite(),
1085 |x: f64| x.is_infinite(),
1086 )
1087 }
1088}
1089
1090#[cfg(test)]
1099mod tests {
1100 use burn_backend::TensorData;
1101
1102 use crate::Flex;
1103
1104 #[test]
1105 fn test_float_into_int_i32() {
1106 use burn_backend::ops::FloatTensorOps;
1107 use burn_std::IntDType;
1108
1109 let t = crate::FlexTensor::from_data(TensorData::from([1.5f32, -2.7, 0.0, 255.9]));
1110 let result = Flex::float_into_int(t, IntDType::I32);
1111 assert_eq!(result.dtype(), burn_backend::DType::I32);
1112 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1113 assert_eq!(data, vec![1, -2, 0, 255]);
1114 }
1115
1116 #[test]
1117 fn test_float_into_int_u8() {
1118 use burn_backend::ops::FloatTensorOps;
1119 use burn_std::IntDType;
1120
1121 let t = crate::FlexTensor::from_data(TensorData::from([0.0f32, 1.9, 127.5, 255.0]));
1122 let result = Flex::float_into_int(t, IntDType::U8);
1123 assert_eq!(result.dtype(), burn_backend::DType::U8);
1124 let data: Vec<u8> = result.into_data().to_vec().unwrap();
1125 assert_eq!(data, vec![0, 1, 127, 255]);
1126 }
1127
1128 #[test]
1129 fn test_float_argmax_i32_out_dtype() {
1130 use burn_backend::ops::FloatTensorOps;
1131 use burn_std::IntDType;
1132
1133 let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 3.0, 2.0]]));
1134 let result = Flex::float_argmax(t, 1, IntDType::I32);
1135 assert_eq!(result.dtype(), burn_backend::DType::I32);
1136 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1137 assert_eq!(data, vec![1]);
1138 }
1139
1140 #[test]
1141 fn test_float_argmin_i32_out_dtype() {
1142 use burn_backend::ops::FloatTensorOps;
1143 use burn_std::IntDType;
1144
1145 let t = crate::FlexTensor::from_data(TensorData::from([[3.0f32, 1.0, 2.0]]));
1146 let result = Flex::float_argmin(t, 1, IntDType::I32);
1147 assert_eq!(result.dtype(), burn_backend::DType::I32);
1148 let data: Vec<i32> = result.into_data().to_vec().unwrap();
1149 assert_eq!(data, vec![1]);
1150 }
1151
1152 #[test]
1153 fn test_float_argmax_i64_out_dtype() {
1154 use burn_backend::ops::FloatTensorOps;
1155 use burn_std::IntDType;
1156
1157 let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 3.0, 2.0]]));
1158 let result = Flex::float_argmax(t, 1, IntDType::I64);
1159 assert_eq!(result.dtype(), burn_backend::DType::I64);
1160 let data: Vec<i64> = result.into_data().to_vec().unwrap();
1161 assert_eq!(data, vec![1]);
1162 }
1163
1164 #[test]
1165 fn test_float_max_dim_with_indices_i32() {
1166 use burn_backend::ops::FloatTensorOps;
1167 use burn_std::IntDType;
1168
1169 let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 5.0], [3.0, 2.0]]));
1170 let (values, indices) = Flex::float_max_dim_with_indices(t, 1, IntDType::I32);
1171 assert_eq!(indices.dtype(), burn_backend::DType::I32);
1172 let idx: Vec<i32> = indices.into_data().to_vec().unwrap();
1173 assert_eq!(idx, vec![1, 0]);
1174 let vals: Vec<f32> = values.into_data().to_vec().unwrap();
1175 assert_eq!(vals, vec![5.0, 3.0]);
1176 }
1177
1178 #[test]
1179 fn test_float_min_dim_with_indices_i32() {
1180 use burn_backend::ops::FloatTensorOps;
1181 use burn_std::IntDType;
1182
1183 let t = crate::FlexTensor::from_data(TensorData::from([[1.0f32, 5.0], [3.0, 2.0]]));
1184 let (values, indices) = Flex::float_min_dim_with_indices(t, 1, IntDType::I32);
1185 assert_eq!(indices.dtype(), burn_backend::DType::I32);
1186 let idx: Vec<i32> = indices.into_data().to_vec().unwrap();
1187 assert_eq!(idx, vec![0, 1]);
1188 let vals: Vec<f32> = values.into_data().to_vec().unwrap();
1189 assert_eq!(vals, vec![1.0, 2.0]);
1190 }
1191
1192 #[test]
1193 fn test_float_random_f64() {
1194 use burn_backend::{DType, FloatDType, ops::FloatTensorOps};
1195
1196 let shape = burn_std::Shape::from(vec![100]);
1197 let dist = burn_backend::Distribution::Uniform(0.0, 1.0);
1198 let device = crate::FlexDevice;
1199 let t = Flex::float_random(shape, dist, &device, FloatDType::F64);
1200 assert_eq!(t.dtype(), DType::F64);
1201 let data: Vec<f64> = t.into_data().to_vec().unwrap();
1202 assert!(data.iter().all(|&v| (0.0..=1.0).contains(&v)));
1203 }
1204
1205 #[test]
1206 fn test_float_random_f16() {
1207 use burn_backend::{DType, FloatDType, ops::FloatTensorOps};
1208
1209 let shape = burn_std::Shape::from(vec![100]);
1210 let dist = burn_backend::Distribution::Uniform(0.0, 1.0);
1211 let device = crate::FlexDevice;
1212 let t = Flex::float_random(shape, dist, &device, FloatDType::F16);
1213 assert_eq!(t.dtype(), DType::F16);
1214 }
1215}