Skip to main content

burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9// Current crate
10use super::{
11    NdArrayMathOps, NdArrayOps,
12    matmul::{cross, matmul},
13};
14use crate::{
15    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19    SharedArray,
20    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24// Workspace crates
25use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38    x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44    if (x - x.floor()) == 0.5 {
45        (x * 0.5).round() * 2.0
46    } else {
47        x.round()
48    }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52    for NdArray<E, I, Q>
53where
54    NdArrayTensor: From<SharedArray<E>>,
55    NdArrayTensor: From<SharedArray<I>>,
56{
57    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58        NdArrayTensor::from_data(data)
59    }
60
61    fn float_random(
62        shape: Shape,
63        distribution: Distribution,
64        device: &NdArrayDevice,
65        dtype: FloatDType,
66    ) -> FloatTensor<Self> {
67        let mut seed = SEED.lock().unwrap();
68        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69        let tensor = execute_with_float_out_dtype!(
70            dtype,
71            E,
72            Self::float_from_data(
73                TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74                device,
75            )
76        );
77
78        *seed = Some(rng);
79        tensor
80    }
81
82    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83        Ok(tensor.into_data())
84    }
85
86    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87        NdArrayDevice::Cpu
88    }
89
90    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91        tensor
92    }
93
94    fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95        Self::float_zeros(shape, device, dtype)
96    }
97
98    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100    }
101
102    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104            NdArrayMathOps::add_scalar(array, rhs.elem())
105        })
106    }
107
108    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110    }
111
112    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114            NdArrayMathOps::sub_scalar(array, rhs.elem())
115        })
116    }
117
118    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120    }
121
122    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124            NdArrayMathOps::mul_scalar(array, rhs.elem())
125        })
126    }
127
128    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130    }
131
132    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134            NdArrayMathOps::div_scalar(array, rhs.elem())
135        })
136    }
137
138    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140    }
141
142    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144            NdArrayMathOps::remainder_scalar(array, rhs.elem())
145        })
146    }
147
148    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149        execute_with_float_dtype!((lhs, rhs), matmul)
150    }
151
152    fn float_cross(
153        lhs: FloatTensor<Self>,
154        rhs: FloatTensor<Self>,
155        dim: usize,
156    ) -> FloatTensor<Self> {
157        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158    }
159
160    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162            NdArrayMathOps::recip(array)
163        })
164    }
165
166    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168            NdArrayOps::swap_dims(array, dim1, dim2)
169        })
170    }
171
172    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174            NdArrayOps::reshape(array, shape)
175        })
176    }
177
178    fn float_gather(
179        dim: usize,
180        tensor: FloatTensor<Self>,
181        indices: NdArrayTensor,
182    ) -> FloatTensor<Self> {
183        execute_with_int_dtype!(
184            indices,
185            IntElem,
186            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188                    NdArrayOps::gather(dim, array, idx_array)
189                })
190            }
191        )
192    }
193
194    fn float_scatter_add(
195        dim: usize,
196        tensor: FloatTensor<Self>,
197        indices: NdArrayTensor,
198        value: FloatTensor<Self>,
199    ) -> FloatTensor<Self> {
200        execute_with_int_dtype!(
201            indices,
202            IntElem,
203            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205                    dim, tensor, idx_array, value
206                ))
207            }
208        )
209    }
210
211    fn float_select(
212        tensor: FloatTensor<Self>,
213        dim: usize,
214        indices: NdArrayTensor,
215    ) -> FloatTensor<Self> {
216        execute_with_int_dtype!(
217            indices,
218            IntElem,
219            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
220                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
221                    NdArrayMathOps::select(array, dim, idx_array)
222                })
223            }
224        )
225    }
226
227    fn float_select_add(
228        tensor: FloatTensor<Self>,
229        dim: usize,
230        indices: NdArrayTensor,
231        value: FloatTensor<Self>,
232    ) -> FloatTensor<Self> {
233        execute_with_int_dtype!(
234            indices,
235            IntElem,
236            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
237                execute_with_float_dtype!((tensor, value), |tensor, value| {
238                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
239                })
240            }
241        )
242    }
243
244    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
245        slice!(tensor, slices)
246    }
247
248    fn float_slice_assign(
249        tensor: FloatTensor<Self>,
250        slices: &[burn_backend::Slice],
251        value: FloatTensor<Self>,
252    ) -> FloatTensor<Self> {
253        execute_with_float_dtype!((tensor, value), |tensor, value| {
254            NdArrayOps::slice_assign(tensor, slices, value)
255        })
256    }
257
258    fn float_mask_where(
259        tensor: FloatTensor<Self>,
260        mask: NdArrayTensor,
261        value: FloatTensor<Self>,
262    ) -> FloatTensor<Self> {
263        execute_with_float_dtype!((tensor, value), |tensor, value| {
264            NdArrayOps::mask_where(tensor, mask.bool(), value)
265        })
266    }
267
268    fn float_mask_fill(
269        tensor: FloatTensor<Self>,
270        mask: NdArrayTensor,
271        value: Scalar,
272    ) -> FloatTensor<Self> {
273        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
274            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
275        })
276    }
277
278    fn float_equal(
279        lhs: FloatTensor<Self>,
280        rhs: FloatTensor<Self>,
281        _out_dtype: BoolDType,
282    ) -> NdArrayTensor {
283        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
284    }
285
286    fn float_equal_elem(
287        lhs: FloatTensor<Self>,
288        rhs: Scalar,
289        _out_dtype: BoolDType,
290    ) -> NdArrayTensor {
291        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
292            NdArrayMathOps::equal_elem(array, rhs.elem())
293        })
294    }
295
296    fn float_greater(
297        lhs: FloatTensor<Self>,
298        rhs: FloatTensor<Self>,
299        _out_dtype: BoolDType,
300    ) -> NdArrayTensor {
301        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
302    }
303
304    fn float_greater_elem(
305        lhs: FloatTensor<Self>,
306        rhs: Scalar,
307        _out_dtype: BoolDType,
308    ) -> NdArrayTensor {
309        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
310            NdArrayMathOps::greater_elem(array, rhs.elem())
311        })
312    }
313
314    fn float_greater_equal(
315        lhs: FloatTensor<Self>,
316        rhs: FloatTensor<Self>,
317        _out_dtype: BoolDType,
318    ) -> NdArrayTensor {
319        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
320            NdArrayMathOps::greater_equal(lhs, rhs)
321        })
322    }
323
324    fn float_greater_equal_elem(
325        lhs: FloatTensor<Self>,
326        rhs: Scalar,
327        _out_dtype: BoolDType,
328    ) -> NdArrayTensor {
329        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
330            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
331        })
332    }
333
334    fn float_lower(
335        lhs: FloatTensor<Self>,
336        rhs: FloatTensor<Self>,
337        _out_dtype: BoolDType,
338    ) -> NdArrayTensor {
339        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
340    }
341
342    fn float_lower_elem(
343        lhs: FloatTensor<Self>,
344        rhs: Scalar,
345        _out_dtype: BoolDType,
346    ) -> NdArrayTensor {
347        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
348            NdArrayMathOps::lower_elem(array, rhs.elem())
349        })
350    }
351
352    fn float_lower_equal(
353        lhs: FloatTensor<Self>,
354        rhs: FloatTensor<Self>,
355        _out_dtype: BoolDType,
356    ) -> NdArrayTensor {
357        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
358            NdArrayMathOps::lower_equal(lhs, rhs)
359        })
360    }
361
362    fn float_lower_equal_elem(
363        lhs: FloatTensor<Self>,
364        rhs: Scalar,
365        _out_dtype: BoolDType,
366    ) -> NdArrayTensor {
367        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
368            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
369        })
370    }
371
372    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373        tensor
374    }
375
376    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377        // Use view() for zero-copy on borrowed storage
378        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
379            NdArrayMathOps::mean_view(array.view())
380        })
381    }
382
383    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384        // Use view() for zero-copy on borrowed storage
385        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
386            NdArrayMathOps::sum_view(array.view())
387        })
388    }
389
390    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
391        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
392            NdArrayMathOps::mean_dim(array, dim)
393        })
394    }
395
396    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
397        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
398            NdArrayMathOps::cumsum(array, dim)
399        })
400    }
401
402    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
403        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
404            NdArrayMathOps::cumprod(array, dim)
405        })
406    }
407
408    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
409        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
410            NdArrayMathOps::cummin(array, dim)
411        })
412    }
413
414    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
415        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
416            NdArrayMathOps::cummax(array, dim)
417        })
418    }
419
420    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
421        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422            NdArrayMathOps::sum_dim(array, dim)
423        })
424    }
425
426    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
427        // Use view() for zero-copy on borrowed storage
428        execute_with_int_out_dtype!(out_dtype, I, {
429            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
430                NdArrayMathOps::argmax_view::<I>(array.view(), dim)
431            })
432        })
433    }
434
435    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
436        // Use view() for zero-copy on borrowed storage
437        execute_with_int_out_dtype!(out_dtype, I, {
438            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439                NdArrayMathOps::argmin_view::<I>(array.view(), dim)
440            })
441        })
442    }
443
444    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
445        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
446            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
447        })
448    }
449
450    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
452            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
453        })
454    }
455
456    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
457        // Use view() for zero-copy on borrowed storage
458        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459            NdArrayMathOps::prod_view(array.view())
460        })
461    }
462
463    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
464        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
465            NdArrayMathOps::prod_dim(array, dim)
466        })
467    }
468
469    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
470        // Use view() for zero-copy on borrowed storage
471        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
472            NdArrayMathOps::max_view(array.view())
473        })
474    }
475
476    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
477        // Use view() for zero-copy on borrowed storage
478        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
479            NdArrayMathOps::min_view(array.view())
480        })
481    }
482
483    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
484        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
485            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
486        })
487    }
488
489    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
490        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
491            array
492                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
493                .into_shared()
494        })
495    }
496
497    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
498        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
499            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
500        })
501    }
502
503    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
504        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
505            NdArrayMathOps::abs(array)
506        })
507    }
508
509    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
510        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
511            array
512                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
513                .into_shared()
514        })
515    }
516
517    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
518        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
519            array
520                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
521                .into_shared()
522        })
523    }
524
525    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
526        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
527            array
528                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
529                .into_shared()
530        })
531    }
532
533    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
534        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
535            array
536                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
537                .into_shared()
538        })
539    }
540
541    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
542        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
543            array
544                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
545                .into_shared()
546        })
547    }
548
549    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
550        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
551            array
552                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
553                .into_shared()
554        })
555    }
556
557    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
558        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
559            array
560                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
561                .into_shared()
562        })
563    }
564
565    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
566        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
567            array
568                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
569                .into_shared()
570        })
571    }
572
573    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
574        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
575            array
576                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
577                .into_shared()
578        })
579    }
580
581    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
582        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
583            array
584                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
585                .into_shared()
586        })
587    }
588
589    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
590        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
591            array
592                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
593                .into_shared()
594        })
595    }
596
597    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
598        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
599            array
600                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
601                .into_shared()
602        })
603    }
604
605    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
606        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
607            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
608        })
609    }
610
611    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
613            array
614                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
615                .into_shared()
616        })
617    }
618
619    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
620        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
621            array
622                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
623                .into_shared()
624        })
625    }
626
627    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
628        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
629            array
630                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
631                .into_shared()
632        })
633    }
634
635    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
636        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
637            array
638                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
639                .into_shared()
640        })
641    }
642
643    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
644        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
645            array
646                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
647                .into_shared()
648        })
649    }
650
651    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
652        cat_with_dtype!(tensors, dim, [F64, F32])
653    }
654
655    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
656        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
657            NdArrayMathOps::clamp_min(array, min.elem())
658        })
659    }
660
661    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
662        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
663            NdArrayMathOps::clamp_max(array, max.elem())
664        })
665    }
666
667    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
668        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
669            NdArrayMathOps::clamp(array, min.elem(), max.elem())
670        })
671    }
672
673    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
674        execute_with_int_out_dtype!(out_dtype, I, {
675            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
676                array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
677            })
678        })
679    }
680
681    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
682        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
683            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
684        })
685    }
686
687    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
688        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
689            NdArrayOps::permute(array, axes)
690        })
691    }
692
693    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
694        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
695            NdArrayOps::flip(array, axes)
696        })
697    }
698
699    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
700        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
701            NdArrayMathOps::sign_op(array)
702        })
703    }
704
705    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
706        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
707            NdArrayOps::expand(array, shape)
708        })
709    }
710
711    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
712        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
713            cast_to_dtype(array, dtype.into())
714        })
715    }
716
717    fn float_grid_sample_2d(
718        tensor: FloatTensor<Self>,
719        grid: FloatTensor<Self>,
720        options: GridSampleOptions,
721    ) -> FloatTensor<Self> {
722        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
723            tensor, grid, options
724        ))
725    }
726
727    fn float_unfold(
728        tensor: FloatTensor<Self>,
729        dim: usize,
730        size: usize,
731        step: usize,
732    ) -> FloatTensor<Self> {
733        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
734            NdArrayOps::unfold(array, dim, size, step)
735        })
736    }
737}