Skip to main content

burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9// Current crate
10use super::{
11    NdArrayMathOps, NdArrayOps,
12    matmul::{cross, matmul},
13};
14use crate::{
15    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19    SharedArray,
20    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24// Workspace crates
25use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38    x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44    if (x - x.floor()) == 0.5 {
45        (x * 0.5).round() * 2.0
46    } else {
47        x.round()
48    }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52    for NdArray<E, I, Q>
53where
54    NdArrayTensor: From<SharedArray<E>>,
55    NdArrayTensor: From<SharedArray<I>>,
56{
57    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58        NdArrayTensor::from_data(data)
59    }
60
61    fn float_random(
62        shape: Shape,
63        distribution: Distribution,
64        device: &NdArrayDevice,
65        dtype: FloatDType,
66    ) -> FloatTensor<Self> {
67        let mut seed = SEED.lock().unwrap();
68        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69        let tensor = execute_with_float_out_dtype!(
70            dtype,
71            E,
72            Self::float_from_data(
73                TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74                device,
75            )
76        );
77
78        *seed = Some(rng);
79        tensor
80    }
81
82    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83        Ok(tensor.into_data())
84    }
85
86    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87        NdArrayDevice::Cpu
88    }
89
90    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91        tensor
92    }
93
94    fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95        Self::float_zeros(shape, device, dtype)
96    }
97
98    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100    }
101
102    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104            NdArrayMathOps::add_scalar(array, rhs.elem())
105        })
106    }
107
108    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110    }
111
112    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114            NdArrayMathOps::sub_scalar(array, rhs.elem())
115        })
116    }
117
118    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120    }
121
122    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124            NdArrayMathOps::mul_scalar(array, rhs.elem())
125        })
126    }
127
128    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130    }
131
132    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134            NdArrayMathOps::div_scalar(array, rhs.elem())
135        })
136    }
137
138    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140    }
141
142    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144            NdArrayMathOps::remainder_scalar(array, rhs.elem())
145        })
146    }
147
148    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149        execute_with_float_dtype!((lhs, rhs), matmul)
150    }
151
152    fn float_cross(
153        lhs: FloatTensor<Self>,
154        rhs: FloatTensor<Self>,
155        dim: usize,
156    ) -> FloatTensor<Self> {
157        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158    }
159
160    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162            NdArrayMathOps::recip(array)
163        })
164    }
165
166    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168            NdArrayOps::swap_dims(array, dim1, dim2)
169        })
170    }
171
172    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174            NdArrayOps::reshape(array, shape)
175        })
176    }
177
178    fn float_gather(
179        dim: usize,
180        tensor: FloatTensor<Self>,
181        indices: NdArrayTensor,
182    ) -> FloatTensor<Self> {
183        execute_with_int_dtype!(
184            indices,
185            IntElem,
186            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188                    NdArrayOps::gather(dim, array, idx_array)
189                })
190            }
191        )
192    }
193
194    fn float_scatter_add(
195        dim: usize,
196        tensor: FloatTensor<Self>,
197        indices: NdArrayTensor,
198        value: FloatTensor<Self>,
199    ) -> FloatTensor<Self> {
200        execute_with_int_dtype!(
201            indices,
202            IntElem,
203            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205                    dim, tensor, idx_array, value
206                ))
207            }
208        )
209    }
210
211    fn float_scatter_nd(
212        data: FloatTensor<Self>,
213        indices: NdArrayTensor,
214        values: FloatTensor<Self>,
215        reduction: burn_backend::tensor::IndexingUpdateOp,
216    ) -> FloatTensor<Self> {
217        execute_with_int_dtype!(
218            indices,
219            IntElem,
220            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221                execute_with_float_dtype!((data, values), |data, values| NdArrayOps::scatter_nd(
222                    data, idx_array, values, reduction
223                ))
224            }
225        )
226    }
227
228    fn float_gather_nd(data: FloatTensor<Self>, indices: NdArrayTensor) -> FloatTensor<Self> {
229        execute_with_int_dtype!(
230            indices,
231            IntElem,
232            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
233                execute_with_float_dtype!(data, FloatElem, |array: SharedArray<FloatElem>| {
234                    NdArrayOps::gather_nd(array, idx_array)
235                })
236            }
237        )
238    }
239
240    fn float_select(
241        tensor: FloatTensor<Self>,
242        dim: usize,
243        indices: NdArrayTensor,
244    ) -> FloatTensor<Self> {
245        execute_with_int_dtype!(
246            indices,
247            IntElem,
248            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
249                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
250                    NdArrayMathOps::select(array, dim, idx_array)
251                })
252            }
253        )
254    }
255
256    fn float_select_add(
257        tensor: FloatTensor<Self>,
258        dim: usize,
259        indices: NdArrayTensor,
260        value: FloatTensor<Self>,
261    ) -> FloatTensor<Self> {
262        execute_with_int_dtype!(
263            indices,
264            IntElem,
265            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
266                execute_with_float_dtype!((tensor, value), |tensor, value| {
267                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
268                })
269            }
270        )
271    }
272
273    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
274        slice!(tensor, slices)
275    }
276
277    fn float_slice_assign(
278        tensor: FloatTensor<Self>,
279        slices: &[burn_backend::Slice],
280        value: FloatTensor<Self>,
281    ) -> FloatTensor<Self> {
282        execute_with_float_dtype!((tensor, value), |tensor, value| {
283            NdArrayOps::slice_assign(tensor, slices, value)
284        })
285    }
286
287    fn float_mask_where(
288        tensor: FloatTensor<Self>,
289        mask: NdArrayTensor,
290        value: FloatTensor<Self>,
291    ) -> FloatTensor<Self> {
292        execute_with_float_dtype!((tensor, value), |tensor, value| {
293            NdArrayOps::mask_where(tensor, mask.bool(), value)
294        })
295    }
296
297    fn float_mask_fill(
298        tensor: FloatTensor<Self>,
299        mask: NdArrayTensor,
300        value: Scalar,
301    ) -> FloatTensor<Self> {
302        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
303            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
304        })
305    }
306
307    fn float_equal(
308        lhs: FloatTensor<Self>,
309        rhs: FloatTensor<Self>,
310        _out_dtype: BoolDType,
311    ) -> NdArrayTensor {
312        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
313    }
314
315    fn float_equal_elem(
316        lhs: FloatTensor<Self>,
317        rhs: Scalar,
318        _out_dtype: BoolDType,
319    ) -> NdArrayTensor {
320        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
321            NdArrayMathOps::equal_elem(array, rhs.elem())
322        })
323    }
324
325    fn float_greater(
326        lhs: FloatTensor<Self>,
327        rhs: FloatTensor<Self>,
328        _out_dtype: BoolDType,
329    ) -> NdArrayTensor {
330        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
331    }
332
333    fn float_greater_elem(
334        lhs: FloatTensor<Self>,
335        rhs: Scalar,
336        _out_dtype: BoolDType,
337    ) -> NdArrayTensor {
338        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
339            NdArrayMathOps::greater_elem(array, rhs.elem())
340        })
341    }
342
343    fn float_greater_equal(
344        lhs: FloatTensor<Self>,
345        rhs: FloatTensor<Self>,
346        _out_dtype: BoolDType,
347    ) -> NdArrayTensor {
348        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
349            NdArrayMathOps::greater_equal(lhs, rhs)
350        })
351    }
352
353    fn float_greater_equal_elem(
354        lhs: FloatTensor<Self>,
355        rhs: Scalar,
356        _out_dtype: BoolDType,
357    ) -> NdArrayTensor {
358        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
359            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
360        })
361    }
362
363    fn float_lower(
364        lhs: FloatTensor<Self>,
365        rhs: FloatTensor<Self>,
366        _out_dtype: BoolDType,
367    ) -> NdArrayTensor {
368        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
369    }
370
371    fn float_lower_elem(
372        lhs: FloatTensor<Self>,
373        rhs: Scalar,
374        _out_dtype: BoolDType,
375    ) -> NdArrayTensor {
376        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
377            NdArrayMathOps::lower_elem(array, rhs.elem())
378        })
379    }
380
381    fn float_lower_equal(
382        lhs: FloatTensor<Self>,
383        rhs: FloatTensor<Self>,
384        _out_dtype: BoolDType,
385    ) -> NdArrayTensor {
386        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
387            NdArrayMathOps::lower_equal(lhs, rhs)
388        })
389    }
390
391    fn float_lower_equal_elem(
392        lhs: FloatTensor<Self>,
393        rhs: Scalar,
394        _out_dtype: BoolDType,
395    ) -> NdArrayTensor {
396        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
397            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
398        })
399    }
400
401    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402        tensor
403    }
404
405    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406        // Use view() for zero-copy on borrowed storage
407        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408            NdArrayMathOps::mean_view(array.view())
409        })
410    }
411
412    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413        // Use view() for zero-copy on borrowed storage
414        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415            NdArrayMathOps::sum_view(array.view())
416        })
417    }
418
419    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
421            NdArrayMathOps::mean_dim(array, dim)
422        })
423    }
424
425    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
427            NdArrayMathOps::cumsum(array, dim)
428        })
429    }
430
431    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
432        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
433            NdArrayMathOps::cumprod(array, dim)
434        })
435    }
436
437    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
438        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439            NdArrayMathOps::cummin(array, dim)
440        })
441    }
442
443    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
444        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
445            NdArrayMathOps::cummax(array, dim)
446        })
447    }
448
449    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
450        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
451            NdArrayMathOps::sum_dim(array, dim)
452        })
453    }
454
455    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
456        // Use view() for zero-copy on borrowed storage
457        execute_with_int_out_dtype!(out_dtype, I, {
458            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459                NdArrayMathOps::argmax_view::<I>(array.view(), dim)
460            })
461        })
462    }
463
464    fn float_argtopk(
465        _tensor: FloatTensor<Self>,
466        _dim: usize,
467        _k: usize,
468        _out_dtype: IntDType,
469    ) -> NdArrayTensor {
470        unimplemented!("float_argtopk not implemented for ndarray")
471    }
472
473    fn float_topk(_tensor: FloatTensor<Self>, _dim: usize, _k: usize) -> NdArrayTensor {
474        unimplemented!("float_topk not implemented for ndarray")
475    }
476
477    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
478        // Use view() for zero-copy on borrowed storage
479        execute_with_int_out_dtype!(out_dtype, I, {
480            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
481                NdArrayMathOps::argmin_view::<I>(array.view(), dim)
482            })
483        })
484    }
485
486    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
487        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
488            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
489        })
490    }
491
492    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
493        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
494            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
495        })
496    }
497
498    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499        // Use view() for zero-copy on borrowed storage
500        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
501            NdArrayMathOps::prod_view(array.view())
502        })
503    }
504
505    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
506        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
507            NdArrayMathOps::prod_dim(array, dim)
508        })
509    }
510
511    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
512        // Use view() for zero-copy on borrowed storage
513        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
514            NdArrayMathOps::max_view(array.view())
515        })
516    }
517
518    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519        // Use view() for zero-copy on borrowed storage
520        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
521            NdArrayMathOps::min_view(array.view())
522        })
523    }
524
525    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
526        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
527            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
528        })
529    }
530
531    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
532        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
533            array
534                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
535                .into_shared()
536        })
537    }
538
539    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
540        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
541            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
542        })
543    }
544
545    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
546        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
547            NdArrayMathOps::abs(array)
548        })
549    }
550
551    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
552        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
553            array
554                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
555                .into_shared()
556        })
557    }
558
559    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
560        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
561            array
562                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
563                .into_shared()
564        })
565    }
566
567    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
568        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
569            array
570                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
571                .into_shared()
572        })
573    }
574
575    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
576        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
577            array
578                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
579                .into_shared()
580        })
581    }
582
583    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
584        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
585            array
586                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
587                .into_shared()
588        })
589    }
590
591    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
592        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
593            array
594                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
595                .into_shared()
596        })
597    }
598
599    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
600        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
601            array
602                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
603                .into_shared()
604        })
605    }
606
607    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
608        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
609            array
610                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
611                .into_shared()
612        })
613    }
614
615    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
616        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617            array
618                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
619                .into_shared()
620        })
621    }
622
623    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
624        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
625            array
626                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
627                .into_shared()
628        })
629    }
630
631    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
632        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
633            array
634                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
635                .into_shared()
636        })
637    }
638
639    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
640        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
641            array
642                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
643                .into_shared()
644        })
645    }
646
647    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
648        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
649            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
650        })
651    }
652
653    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
654        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
655            array
656                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
657                .into_shared()
658        })
659    }
660
661    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
662        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
663            array
664                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
665                .into_shared()
666        })
667    }
668
669    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
670        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
671            array
672                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
673                .into_shared()
674        })
675    }
676
677    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
678        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
679            array
680                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
681                .into_shared()
682        })
683    }
684
685    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
686        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
687            array
688                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
689                .into_shared()
690        })
691    }
692
693    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
694        cat_with_dtype!(tensors, dim, [F64, F32])
695    }
696
697    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
698        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
699            NdArrayMathOps::clamp_min(array, min.elem())
700        })
701    }
702
703    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
704        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
705            NdArrayMathOps::clamp_max(array, max.elem())
706        })
707    }
708
709    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
710        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
711            NdArrayMathOps::clamp(array, min.elem(), max.elem())
712        })
713    }
714
715    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
716        execute_with_int_out_dtype!(out_dtype, I, {
717            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
718                array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
719            })
720        })
721    }
722
723    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
724        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
725            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
726        })
727    }
728
729    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
730        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
731            NdArrayOps::permute(array, axes)
732        })
733    }
734
735    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
736        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
737            NdArrayOps::flip(array, axes)
738        })
739    }
740
741    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
742        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
743            NdArrayMathOps::sign_op(array)
744        })
745    }
746
747    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
748        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
749            NdArrayOps::expand(array, shape)
750        })
751    }
752
753    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
754        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
755            cast_to_dtype(array, dtype.into())
756        })
757    }
758
759    fn float_grid_sample_2d(
760        tensor: FloatTensor<Self>,
761        grid: FloatTensor<Self>,
762        options: GridSampleOptions,
763    ) -> FloatTensor<Self> {
764        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
765            tensor, grid, options
766        ))
767    }
768
769    fn float_unfold(
770        tensor: FloatTensor<Self>,
771        dim: usize,
772        size: usize,
773        step: usize,
774    ) -> FloatTensor<Self> {
775        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
776            NdArrayOps::unfold(array, dim, size, step)
777        })
778    }
779}