Skip to main content

burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7
8// Current crate
9use super::{
10    NdArrayMathOps, NdArrayOps,
11    matmul::{cross, matmul},
12};
13use crate::{
14    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED, slice};
17use crate::{
18    SharedArray,
19    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23// Workspace crates
24use crate::rand::get_seeded_rng;
25use burn_backend::{Distribution, FloatDType, Scalar};
26use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37    x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43    if (x - x.floor()) == 0.5 {
44        (x * 0.5).round() * 2.0
45    } else {
46        x.round()
47    }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51    for NdArray<E, I, Q>
52where
53    NdArrayTensor: From<SharedArray<E>>,
54    NdArrayTensor: From<SharedArray<I>>,
55{
56    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57        NdArrayTensor::from_data(data)
58    }
59
60    fn float_random(
61        shape: Shape,
62        distribution: Distribution,
63        device: &NdArrayDevice,
64    ) -> FloatTensor<Self> {
65        let mut seed = SEED.lock().unwrap();
66        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67            rng_seeded.clone()
68        } else {
69            get_seeded_rng()
70        };
71        let tensor = Self::float_from_data(
72            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73            device,
74        );
75        *seed = Some(rng);
76        tensor
77    }
78
79    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80        Ok(tensor.into_data())
81    }
82
83    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84        NdArrayDevice::Cpu
85    }
86
87    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88        tensor
89    }
90
91    fn float_empty(
92        shape: Shape,
93        device: &<NdArray<E> as Backend>::Device,
94        dtype: FloatDType,
95    ) -> FloatTensor<Self> {
96        Self::float_zeros(shape, device, dtype)
97    }
98
99    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101    }
102
103    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
104        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
105            NdArrayMathOps::add_scalar(array, rhs.elem())
106        })
107    }
108
109    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
110        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
111    }
112
113    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
114        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
115            NdArrayMathOps::sub_scalar(array, rhs.elem())
116        })
117    }
118
119    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
120        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
121    }
122
123    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
124        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
125            NdArrayMathOps::mul_scalar(array, rhs.elem())
126        })
127    }
128
129    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
131    }
132
133    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
134        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
135            NdArrayMathOps::div_scalar(array, rhs.elem())
136        })
137    }
138
139    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
141    }
142
143    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
144        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
145            NdArrayMathOps::remainder_scalar(array, rhs.elem())
146        })
147    }
148
149    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150        execute_with_float_dtype!((lhs, rhs), matmul)
151    }
152
153    fn float_cross(
154        lhs: FloatTensor<Self>,
155        rhs: FloatTensor<Self>,
156        dim: usize,
157    ) -> FloatTensor<Self> {
158        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
159    }
160
161    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
162        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
163            NdArrayMathOps::recip(array)
164        })
165    }
166
167    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
168        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
169            NdArrayOps::swap_dims(array, dim1, dim2)
170        })
171    }
172
173    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
174        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
175            NdArrayOps::reshape(array, shape)
176        })
177    }
178
179    fn float_gather(
180        dim: usize,
181        tensor: FloatTensor<Self>,
182        indices: NdArrayTensor,
183    ) -> FloatTensor<Self> {
184        execute_with_int_dtype!(
185            indices,
186            IntElem,
187            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
188                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
189                    NdArrayOps::gather(dim, array, idx_array)
190                })
191            }
192        )
193    }
194
195    fn float_scatter_add(
196        dim: usize,
197        tensor: FloatTensor<Self>,
198        indices: NdArrayTensor,
199        value: FloatTensor<Self>,
200    ) -> FloatTensor<Self> {
201        execute_with_int_dtype!(
202            indices,
203            IntElem,
204            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
205                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
206                    dim, tensor, idx_array, value
207                ))
208            }
209        )
210    }
211
212    fn float_select(
213        tensor: FloatTensor<Self>,
214        dim: usize,
215        indices: NdArrayTensor,
216    ) -> FloatTensor<Self> {
217        execute_with_int_dtype!(
218            indices,
219            IntElem,
220            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
222                    NdArrayMathOps::select(array, dim, idx_array)
223                })
224            }
225        )
226    }
227
228    fn float_select_add(
229        tensor: FloatTensor<Self>,
230        dim: usize,
231        indices: NdArrayTensor,
232        value: FloatTensor<Self>,
233    ) -> FloatTensor<Self> {
234        execute_with_int_dtype!(
235            indices,
236            IntElem,
237            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
238                execute_with_float_dtype!((tensor, value), |tensor, value| {
239                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
240                })
241            }
242        )
243    }
244
245    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
246        slice!(tensor, slices)
247    }
248
249    fn float_slice_assign(
250        tensor: FloatTensor<Self>,
251        slices: &[burn_backend::Slice],
252        value: FloatTensor<Self>,
253    ) -> FloatTensor<Self> {
254        execute_with_float_dtype!((tensor, value), |tensor, value| {
255            NdArrayOps::slice_assign(tensor, slices, value)
256        })
257    }
258
259    fn float_mask_where(
260        tensor: FloatTensor<Self>,
261        mask: NdArrayTensor,
262        value: FloatTensor<Self>,
263    ) -> FloatTensor<Self> {
264        execute_with_float_dtype!((tensor, value), |tensor, value| {
265            NdArrayOps::mask_where(tensor, mask.bool(), value)
266        })
267    }
268
269    fn float_mask_fill(
270        tensor: FloatTensor<Self>,
271        mask: NdArrayTensor,
272        value: Scalar,
273    ) -> FloatTensor<Self> {
274        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
275            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
276        })
277    }
278
279    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
280        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
281    }
282
283    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
284        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
285            NdArrayMathOps::equal_elem(array, rhs.elem())
286        })
287    }
288
289    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
290        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
291    }
292
293    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
294        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
295            NdArrayMathOps::greater_elem(array, rhs.elem())
296        })
297    }
298
299    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
300        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
301            NdArrayMathOps::greater_equal(lhs, rhs)
302        })
303    }
304
305    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
306        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
307            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
308        })
309    }
310
311    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
312        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
313    }
314
315    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
316        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
317            NdArrayMathOps::lower_elem(array, rhs.elem())
318        })
319    }
320
321    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
322        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
323            NdArrayMathOps::lower_equal(lhs, rhs)
324        })
325    }
326
327    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
328        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
329            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
330        })
331    }
332
333    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
334        tensor
335    }
336
337    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
338        // Use view() for zero-copy on borrowed storage
339        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
340            NdArrayMathOps::mean_view(array.view())
341        })
342    }
343
344    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
345        // Use view() for zero-copy on borrowed storage
346        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
347            NdArrayMathOps::sum_view(array.view())
348        })
349    }
350
351    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
352        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
353            NdArrayMathOps::mean_dim(array, dim)
354        })
355    }
356
357    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
358        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
359            NdArrayMathOps::cumsum(array, dim)
360        })
361    }
362
363    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
364        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
365            NdArrayMathOps::cumprod(array, dim)
366        })
367    }
368
369    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
371            NdArrayMathOps::cummin(array, dim)
372        })
373    }
374
375    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
376        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
377            NdArrayMathOps::cummax(array, dim)
378        })
379    }
380
381    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
382        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383            NdArrayMathOps::sum_dim(array, dim)
384        })
385    }
386
387    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
388        // Use view() for zero-copy on borrowed storage
389        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
390            NdArrayMathOps::argmax_view::<I>(array.view(), dim)
391        })
392    }
393
394    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
395        // Use view() for zero-copy on borrowed storage
396        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
397            NdArrayMathOps::argmin_view::<I>(array.view(), dim)
398        })
399    }
400
401    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
403            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
404        })
405    }
406
407    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
408        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
409            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
410        })
411    }
412
413    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
414        // Use view() for zero-copy on borrowed storage
415        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
416            NdArrayMathOps::prod_view(array.view())
417        })
418    }
419
420    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
421        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422            NdArrayMathOps::prod_dim(array, dim)
423        })
424    }
425
426    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
427        // Use view() for zero-copy on borrowed storage
428        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
429            NdArrayMathOps::max_view(array.view())
430        })
431    }
432
433    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
434        // Use view() for zero-copy on borrowed storage
435        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
436            NdArrayMathOps::min_view(array.view())
437        })
438    }
439
440    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
441        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
442            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
443        })
444    }
445
446    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
447        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
448            array
449                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
450                .into_shared()
451        })
452    }
453
454    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
456            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
457        })
458    }
459
460    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
462            NdArrayMathOps::abs(array)
463        })
464    }
465
466    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
468            array
469                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
470                .into_shared()
471        })
472    }
473
474    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
475        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
476            array
477                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
478                .into_shared()
479        })
480    }
481
482    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
484            array
485                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
486                .into_shared()
487        })
488    }
489
490    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
492            array
493                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
494                .into_shared()
495        })
496    }
497
498    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
500            array
501                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
502                .into_shared()
503        })
504    }
505
506    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
507        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
508            array
509                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
510                .into_shared()
511        })
512    }
513
514    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
516            array
517                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
518                .into_shared()
519        })
520    }
521
522    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
524            array
525                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
526                .into_shared()
527        })
528    }
529
530    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
531        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
532            array
533                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
534                .into_shared()
535        })
536    }
537
538    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
539        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
540            array
541                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
542                .into_shared()
543        })
544    }
545
546    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
547        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
548            array
549                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
550                .into_shared()
551        })
552    }
553
554    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
555        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
556            array
557                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
558                .into_shared()
559        })
560    }
561
562    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
563        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
564            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
565        })
566    }
567
568    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
570            array
571                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
572                .into_shared()
573        })
574    }
575
576    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
577        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
578            array
579                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
580                .into_shared()
581        })
582    }
583
584    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
585        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
586            array
587                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
588                .into_shared()
589        })
590    }
591
592    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
593        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
594            array
595                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
596                .into_shared()
597        })
598    }
599
600    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
601        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
602            array
603                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
604                .into_shared()
605        })
606    }
607
608    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
609        cat_with_dtype!(tensors, dim, [F64, F32])
610    }
611
612    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
613        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
614            NdArrayMathOps::clamp_min(array, min.elem())
615        })
616    }
617
618    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
619        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
620            NdArrayMathOps::clamp_max(array, max.elem())
621        })
622    }
623
624    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
625        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
626            NdArrayMathOps::clamp(array, min.elem(), max.elem())
627        })
628    }
629
630    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
631        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
632            array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
633        })
634    }
635
636    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
637        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
638            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
639        })
640    }
641
642    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
643        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
644            NdArrayOps::permute(array, axes)
645        })
646    }
647
648    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
649        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
650            NdArrayOps::flip(array, axes)
651        })
652    }
653
654    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
655        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
656            NdArrayMathOps::sign_op(array)
657        })
658    }
659
660    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
661        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
662            NdArrayOps::expand(array, shape)
663        })
664    }
665
666    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
667        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
668            cast_to_dtype(array, dtype.into())
669        })
670    }
671
672    fn float_grid_sample_2d(
673        tensor: FloatTensor<Self>,
674        grid: FloatTensor<Self>,
675        options: GridSampleOptions,
676    ) -> FloatTensor<Self> {
677        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
678            tensor, grid, options
679        ))
680    }
681
682    fn float_unfold(
683        tensor: FloatTensor<Self>,
684        dim: usize,
685        size: usize,
686        step: usize,
687    ) -> FloatTensor<Self> {
688        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
689            NdArrayOps::unfold(array, dim, size, step)
690        })
691    }
692}