Skip to main content

burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9// Current crate
10use super::{
11    NdArrayMathOps, NdArrayOps,
12    matmul::{cross, matmul},
13};
14use crate::{
15    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19    SharedArray,
20    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24// Workspace crates
25use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38    x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44    if (x - x.floor()) == 0.5 {
45        (x * 0.5).round() * 2.0
46    } else {
47        x.round()
48    }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52    for NdArray<E, I, Q>
53where
54    NdArrayTensor: From<SharedArray<E>>,
55    NdArrayTensor: From<SharedArray<I>>,
56{
57    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58        NdArrayTensor::from_data(data)
59    }
60
61    fn float_random(
62        shape: Shape,
63        distribution: Distribution,
64        device: &NdArrayDevice,
65        dtype: FloatDType,
66    ) -> FloatTensor<Self> {
67        let mut seed = SEED.lock().unwrap();
68        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69        let tensor = execute_with_float_out_dtype!(
70            dtype,
71            E,
72            Self::float_from_data(
73                TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74                device,
75            )
76        );
77
78        *seed = Some(rng);
79        tensor
80    }
81
82    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83        Ok(tensor.into_data())
84    }
85
86    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87        NdArrayDevice::Cpu
88    }
89
90    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91        tensor
92    }
93
94    fn float_empty(
95        shape: Shape,
96        device: &<NdArray<E> as Backend>::Device,
97        dtype: FloatDType,
98    ) -> FloatTensor<Self> {
99        Self::float_zeros(shape, device, dtype)
100    }
101
102    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
103        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
104    }
105
106    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
107        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
108            NdArrayMathOps::add_scalar(array, rhs.elem())
109        })
110    }
111
112    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
113        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
114    }
115
116    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
117        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
118            NdArrayMathOps::sub_scalar(array, rhs.elem())
119        })
120    }
121
122    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
124    }
125
126    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
127        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
128            NdArrayMathOps::mul_scalar(array, rhs.elem())
129        })
130    }
131
132    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
134    }
135
136    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
137        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
138            NdArrayMathOps::div_scalar(array, rhs.elem())
139        })
140    }
141
142    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
143        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
144    }
145
146    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
147        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
148            NdArrayMathOps::remainder_scalar(array, rhs.elem())
149        })
150    }
151
152    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
153        execute_with_float_dtype!((lhs, rhs), matmul)
154    }
155
156    fn float_cross(
157        lhs: FloatTensor<Self>,
158        rhs: FloatTensor<Self>,
159        dim: usize,
160    ) -> FloatTensor<Self> {
161        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
162    }
163
164    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
165        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
166            NdArrayMathOps::recip(array)
167        })
168    }
169
170    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
171        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
172            NdArrayOps::swap_dims(array, dim1, dim2)
173        })
174    }
175
176    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
177        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
178            NdArrayOps::reshape(array, shape)
179        })
180    }
181
182    fn float_gather(
183        dim: usize,
184        tensor: FloatTensor<Self>,
185        indices: NdArrayTensor,
186    ) -> FloatTensor<Self> {
187        execute_with_int_dtype!(
188            indices,
189            IntElem,
190            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
191                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
192                    NdArrayOps::gather(dim, array, idx_array)
193                })
194            }
195        )
196    }
197
198    fn float_scatter_add(
199        dim: usize,
200        tensor: FloatTensor<Self>,
201        indices: NdArrayTensor,
202        value: FloatTensor<Self>,
203    ) -> FloatTensor<Self> {
204        execute_with_int_dtype!(
205            indices,
206            IntElem,
207            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
208                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
209                    dim, tensor, idx_array, value
210                ))
211            }
212        )
213    }
214
215    fn float_select(
216        tensor: FloatTensor<Self>,
217        dim: usize,
218        indices: NdArrayTensor,
219    ) -> FloatTensor<Self> {
220        execute_with_int_dtype!(
221            indices,
222            IntElem,
223            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
224                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
225                    NdArrayMathOps::select(array, dim, idx_array)
226                })
227            }
228        )
229    }
230
231    fn float_select_add(
232        tensor: FloatTensor<Self>,
233        dim: usize,
234        indices: NdArrayTensor,
235        value: FloatTensor<Self>,
236    ) -> FloatTensor<Self> {
237        execute_with_int_dtype!(
238            indices,
239            IntElem,
240            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
241                execute_with_float_dtype!((tensor, value), |tensor, value| {
242                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
243                })
244            }
245        )
246    }
247
248    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
249        slice!(tensor, slices)
250    }
251
252    fn float_slice_assign(
253        tensor: FloatTensor<Self>,
254        slices: &[burn_backend::Slice],
255        value: FloatTensor<Self>,
256    ) -> FloatTensor<Self> {
257        execute_with_float_dtype!((tensor, value), |tensor, value| {
258            NdArrayOps::slice_assign(tensor, slices, value)
259        })
260    }
261
262    fn float_mask_where(
263        tensor: FloatTensor<Self>,
264        mask: NdArrayTensor,
265        value: FloatTensor<Self>,
266    ) -> FloatTensor<Self> {
267        execute_with_float_dtype!((tensor, value), |tensor, value| {
268            NdArrayOps::mask_where(tensor, mask.bool(), value)
269        })
270    }
271
272    fn float_mask_fill(
273        tensor: FloatTensor<Self>,
274        mask: NdArrayTensor,
275        value: Scalar,
276    ) -> FloatTensor<Self> {
277        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
278            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
279        })
280    }
281
282    fn float_equal(
283        lhs: FloatTensor<Self>,
284        rhs: FloatTensor<Self>,
285        _out_dtype: BoolDType,
286    ) -> NdArrayTensor {
287        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
288    }
289
290    fn float_equal_elem(
291        lhs: FloatTensor<Self>,
292        rhs: Scalar,
293        _out_dtype: BoolDType,
294    ) -> NdArrayTensor {
295        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
296            NdArrayMathOps::equal_elem(array, rhs.elem())
297        })
298    }
299
300    fn float_greater(
301        lhs: FloatTensor<Self>,
302        rhs: FloatTensor<Self>,
303        _out_dtype: BoolDType,
304    ) -> NdArrayTensor {
305        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
306    }
307
308    fn float_greater_elem(
309        lhs: FloatTensor<Self>,
310        rhs: Scalar,
311        _out_dtype: BoolDType,
312    ) -> NdArrayTensor {
313        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
314            NdArrayMathOps::greater_elem(array, rhs.elem())
315        })
316    }
317
318    fn float_greater_equal(
319        lhs: FloatTensor<Self>,
320        rhs: FloatTensor<Self>,
321        _out_dtype: BoolDType,
322    ) -> NdArrayTensor {
323        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
324            NdArrayMathOps::greater_equal(lhs, rhs)
325        })
326    }
327
328    fn float_greater_equal_elem(
329        lhs: FloatTensor<Self>,
330        rhs: Scalar,
331        _out_dtype: BoolDType,
332    ) -> NdArrayTensor {
333        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
334            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
335        })
336    }
337
338    fn float_lower(
339        lhs: FloatTensor<Self>,
340        rhs: FloatTensor<Self>,
341        _out_dtype: BoolDType,
342    ) -> NdArrayTensor {
343        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
344    }
345
346    fn float_lower_elem(
347        lhs: FloatTensor<Self>,
348        rhs: Scalar,
349        _out_dtype: BoolDType,
350    ) -> NdArrayTensor {
351        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
352            NdArrayMathOps::lower_elem(array, rhs.elem())
353        })
354    }
355
356    fn float_lower_equal(
357        lhs: FloatTensor<Self>,
358        rhs: FloatTensor<Self>,
359        _out_dtype: BoolDType,
360    ) -> NdArrayTensor {
361        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
362            NdArrayMathOps::lower_equal(lhs, rhs)
363        })
364    }
365
366    fn float_lower_equal_elem(
367        lhs: FloatTensor<Self>,
368        rhs: Scalar,
369        _out_dtype: BoolDType,
370    ) -> NdArrayTensor {
371        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
372            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
373        })
374    }
375
376    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377        tensor
378    }
379
380    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381        // Use view() for zero-copy on borrowed storage
382        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383            NdArrayMathOps::mean_view(array.view())
384        })
385    }
386
387    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
388        // Use view() for zero-copy on borrowed storage
389        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
390            NdArrayMathOps::sum_view(array.view())
391        })
392    }
393
394    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
395        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
396            NdArrayMathOps::mean_dim(array, dim)
397        })
398    }
399
400    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
401        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
402            NdArrayMathOps::cumsum(array, dim)
403        })
404    }
405
406    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
407        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408            NdArrayMathOps::cumprod(array, dim)
409        })
410    }
411
412    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
413        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
414            NdArrayMathOps::cummin(array, dim)
415        })
416    }
417
418    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
419        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
420            NdArrayMathOps::cummax(array, dim)
421        })
422    }
423
424    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
425        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
426            NdArrayMathOps::sum_dim(array, dim)
427        })
428    }
429
430    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
431        // Use view() for zero-copy on borrowed storage
432        execute_with_int_out_dtype!(out_dtype, I, {
433            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
434                NdArrayMathOps::argmax_view::<I>(array.view(), dim)
435            })
436        })
437    }
438
439    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
440        // Use view() for zero-copy on borrowed storage
441        execute_with_int_out_dtype!(out_dtype, I, {
442            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
443                NdArrayMathOps::argmin_view::<I>(array.view(), dim)
444            })
445        })
446    }
447
448    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
449        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
450            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
451        })
452    }
453
454    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
456            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
457        })
458    }
459
460    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461        // Use view() for zero-copy on borrowed storage
462        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
463            NdArrayMathOps::prod_view(array.view())
464        })
465    }
466
467    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
468        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
469            NdArrayMathOps::prod_dim(array, dim)
470        })
471    }
472
473    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
474        // Use view() for zero-copy on borrowed storage
475        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
476            NdArrayMathOps::max_view(array.view())
477        })
478    }
479
480    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
481        // Use view() for zero-copy on borrowed storage
482        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
483            NdArrayMathOps::min_view(array.view())
484        })
485    }
486
487    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
489            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
490        })
491    }
492
493    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
494        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
495            array
496                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
497                .into_shared()
498        })
499    }
500
501    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
502        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
503            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
504        })
505    }
506
507    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
508        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
509            NdArrayMathOps::abs(array)
510        })
511    }
512
513    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
514        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
515            array
516                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
517                .into_shared()
518        })
519    }
520
521    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
522        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
523            array
524                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
525                .into_shared()
526        })
527    }
528
529    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
530        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
531            array
532                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
533                .into_shared()
534        })
535    }
536
537    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
538        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
539            array
540                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
541                .into_shared()
542        })
543    }
544
545    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
546        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
547            array
548                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
549                .into_shared()
550        })
551    }
552
553    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
554        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
555            array
556                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
557                .into_shared()
558        })
559    }
560
561    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
562        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
563            array
564                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
565                .into_shared()
566        })
567    }
568
569    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
570        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
571            array
572                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
573                .into_shared()
574        })
575    }
576
577    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
578        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
579            array
580                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
581                .into_shared()
582        })
583    }
584
585    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
586        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
587            array
588                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
589                .into_shared()
590        })
591    }
592
593    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
594        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
595            array
596                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
597                .into_shared()
598        })
599    }
600
601    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
602        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
603            array
604                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
605                .into_shared()
606        })
607    }
608
609    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
610        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
611            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
612        })
613    }
614
615    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
616        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617            array
618                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
619                .into_shared()
620        })
621    }
622
623    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
624        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
625            array
626                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
627                .into_shared()
628        })
629    }
630
631    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
632        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
633            array
634                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
635                .into_shared()
636        })
637    }
638
639    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
640        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
641            array
642                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
643                .into_shared()
644        })
645    }
646
647    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
648        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
649            array
650                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
651                .into_shared()
652        })
653    }
654
655    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
656        cat_with_dtype!(tensors, dim, [F64, F32])
657    }
658
659    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
660        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
661            NdArrayMathOps::clamp_min(array, min.elem())
662        })
663    }
664
665    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
666        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
667            NdArrayMathOps::clamp_max(array, max.elem())
668        })
669    }
670
671    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
672        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
673            NdArrayMathOps::clamp(array, min.elem(), max.elem())
674        })
675    }
676
677    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
678        execute_with_int_out_dtype!(out_dtype, I, {
679            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
680                array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
681            })
682        })
683    }
684
685    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
686        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
687            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
688        })
689    }
690
691    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
692        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
693            NdArrayOps::permute(array, axes)
694        })
695    }
696
697    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
698        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
699            NdArrayOps::flip(array, axes)
700        })
701    }
702
703    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
704        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
705            NdArrayMathOps::sign_op(array)
706        })
707    }
708
709    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
710        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
711            NdArrayOps::expand(array, shape)
712        })
713    }
714
715    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
716        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
717            cast_to_dtype(array, dtype.into())
718        })
719    }
720
721    fn float_grid_sample_2d(
722        tensor: FloatTensor<Self>,
723        grid: FloatTensor<Self>,
724        options: GridSampleOptions,
725    ) -> FloatTensor<Self> {
726        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
727            tensor, grid, options
728        ))
729    }
730
731    fn float_unfold(
732        tensor: FloatTensor<Self>,
733        dim: usize,
734        size: usize,
735        step: usize,
736    ) -> FloatTensor<Self> {
737        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
738            NdArrayOps::unfold(array, dim, size, step)
739        })
740    }
741}