Skip to main content

burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9// Current crate
10use super::{
11    NdArrayMathOps, NdArrayOps,
12    matmul::{cross, matmul},
13};
14use crate::{
15    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19    SharedArray,
20    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24// Workspace crates
25use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38    x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44    if (x - x.floor()) == 0.5 {
45        (x * 0.5).round() * 2.0
46    } else {
47        x.round()
48    }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52    for NdArray<E, I, Q>
53where
54    NdArrayTensor: From<SharedArray<E>>,
55    NdArrayTensor: From<SharedArray<I>>,
56{
57    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58        NdArrayTensor::from_data(data)
59    }
60
61    fn float_random(
62        shape: Shape,
63        distribution: Distribution,
64        device: &NdArrayDevice,
65        dtype: FloatDType,
66    ) -> FloatTensor<Self> {
67        let mut seed = SEED.lock().unwrap();
68        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69        let tensor = execute_with_float_out_dtype!(
70            dtype,
71            E,
72            Self::float_from_data(
73                TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74                device,
75            )
76        );
77
78        *seed = Some(rng);
79        tensor
80    }
81
82    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83        Ok(tensor.into_data())
84    }
85
86    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87        NdArrayDevice::Cpu
88    }
89
90    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91        tensor
92    }
93
94    fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95        Self::float_zeros(shape, device, dtype)
96    }
97
98    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100    }
101
102    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104            NdArrayMathOps::add_scalar(array, rhs.elem())
105        })
106    }
107
108    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110    }
111
112    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114            NdArrayMathOps::sub_scalar(array, rhs.elem())
115        })
116    }
117
118    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120    }
121
122    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124            NdArrayMathOps::mul_scalar(array, rhs.elem())
125        })
126    }
127
128    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130    }
131
132    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134            NdArrayMathOps::div_scalar(array, rhs.elem())
135        })
136    }
137
138    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140    }
141
142    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144            NdArrayMathOps::remainder_scalar(array, rhs.elem())
145        })
146    }
147
148    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149        execute_with_float_dtype!((lhs, rhs), matmul)
150    }
151
152    fn float_cross(
153        lhs: FloatTensor<Self>,
154        rhs: FloatTensor<Self>,
155        dim: usize,
156    ) -> FloatTensor<Self> {
157        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158    }
159
160    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162            NdArrayMathOps::recip(array)
163        })
164    }
165
166    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168            NdArrayOps::swap_dims(array, dim1, dim2)
169        })
170    }
171
172    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174            NdArrayOps::reshape(array, shape)
175        })
176    }
177
178    fn float_gather(
179        dim: usize,
180        tensor: FloatTensor<Self>,
181        indices: NdArrayTensor,
182    ) -> FloatTensor<Self> {
183        execute_with_int_dtype!(
184            indices,
185            IntElem,
186            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188                    NdArrayOps::gather(dim, array, idx_array)
189                })
190            }
191        )
192    }
193
194    fn float_scatter_add(
195        dim: usize,
196        tensor: FloatTensor<Self>,
197        indices: NdArrayTensor,
198        value: FloatTensor<Self>,
199    ) -> FloatTensor<Self> {
200        execute_with_int_dtype!(
201            indices,
202            IntElem,
203            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205                    dim, tensor, idx_array, value
206                ))
207            }
208        )
209    }
210
211    fn float_scatter_nd(
212        data: FloatTensor<Self>,
213        indices: NdArrayTensor,
214        values: FloatTensor<Self>,
215        reduction: burn_backend::tensor::IndexingUpdateOp,
216    ) -> FloatTensor<Self> {
217        execute_with_int_dtype!(
218            indices,
219            IntElem,
220            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221                execute_with_float_dtype!((data, values), |data, values| NdArrayOps::scatter_nd(
222                    data, idx_array, values, reduction
223                ))
224            }
225        )
226    }
227
228    fn float_gather_nd(data: FloatTensor<Self>, indices: NdArrayTensor) -> FloatTensor<Self> {
229        execute_with_int_dtype!(
230            indices,
231            IntElem,
232            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
233                execute_with_float_dtype!(data, FloatElem, |array: SharedArray<FloatElem>| {
234                    NdArrayOps::gather_nd(array, idx_array)
235                })
236            }
237        )
238    }
239
240    fn float_select(
241        tensor: FloatTensor<Self>,
242        dim: usize,
243        indices: NdArrayTensor,
244    ) -> FloatTensor<Self> {
245        execute_with_int_dtype!(
246            indices,
247            IntElem,
248            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
249                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
250                    NdArrayMathOps::select(array, dim, idx_array)
251                })
252            }
253        )
254    }
255
256    fn float_select_add(
257        tensor: FloatTensor<Self>,
258        dim: usize,
259        indices: NdArrayTensor,
260        value: FloatTensor<Self>,
261    ) -> FloatTensor<Self> {
262        execute_with_int_dtype!(
263            indices,
264            IntElem,
265            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
266                execute_with_float_dtype!((tensor, value), |tensor, value| {
267                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
268                })
269            }
270        )
271    }
272
273    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
274        slice!(tensor, slices)
275    }
276
277    fn float_slice_assign(
278        tensor: FloatTensor<Self>,
279        slices: &[burn_backend::Slice],
280        value: FloatTensor<Self>,
281    ) -> FloatTensor<Self> {
282        execute_with_float_dtype!((tensor, value), |tensor, value| {
283            NdArrayOps::slice_assign(tensor, slices, value)
284        })
285    }
286
287    fn float_mask_where(
288        tensor: FloatTensor<Self>,
289        mask: NdArrayTensor,
290        value: FloatTensor<Self>,
291    ) -> FloatTensor<Self> {
292        execute_with_float_dtype!((tensor, value), |tensor, value| {
293            NdArrayOps::mask_where(tensor, mask.bool(), value)
294        })
295    }
296
297    fn float_mask_fill(
298        tensor: FloatTensor<Self>,
299        mask: NdArrayTensor,
300        value: Scalar,
301    ) -> FloatTensor<Self> {
302        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
303            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
304        })
305    }
306
307    fn float_equal(
308        lhs: FloatTensor<Self>,
309        rhs: FloatTensor<Self>,
310        _out_dtype: BoolDType,
311    ) -> NdArrayTensor {
312        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
313    }
314
315    fn float_equal_elem(
316        lhs: FloatTensor<Self>,
317        rhs: Scalar,
318        _out_dtype: BoolDType,
319    ) -> NdArrayTensor {
320        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
321            NdArrayMathOps::equal_elem(array, rhs.elem())
322        })
323    }
324
325    fn float_greater(
326        lhs: FloatTensor<Self>,
327        rhs: FloatTensor<Self>,
328        _out_dtype: BoolDType,
329    ) -> NdArrayTensor {
330        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
331    }
332
333    fn float_greater_elem(
334        lhs: FloatTensor<Self>,
335        rhs: Scalar,
336        _out_dtype: BoolDType,
337    ) -> NdArrayTensor {
338        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
339            NdArrayMathOps::greater_elem(array, rhs.elem())
340        })
341    }
342
343    fn float_greater_equal(
344        lhs: FloatTensor<Self>,
345        rhs: FloatTensor<Self>,
346        _out_dtype: BoolDType,
347    ) -> NdArrayTensor {
348        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
349            NdArrayMathOps::greater_equal(lhs, rhs)
350        })
351    }
352
353    fn float_greater_equal_elem(
354        lhs: FloatTensor<Self>,
355        rhs: Scalar,
356        _out_dtype: BoolDType,
357    ) -> NdArrayTensor {
358        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
359            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
360        })
361    }
362
363    fn float_lower(
364        lhs: FloatTensor<Self>,
365        rhs: FloatTensor<Self>,
366        _out_dtype: BoolDType,
367    ) -> NdArrayTensor {
368        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
369    }
370
371    fn float_lower_elem(
372        lhs: FloatTensor<Self>,
373        rhs: Scalar,
374        _out_dtype: BoolDType,
375    ) -> NdArrayTensor {
376        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
377            NdArrayMathOps::lower_elem(array, rhs.elem())
378        })
379    }
380
381    fn float_lower_equal(
382        lhs: FloatTensor<Self>,
383        rhs: FloatTensor<Self>,
384        _out_dtype: BoolDType,
385    ) -> NdArrayTensor {
386        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
387            NdArrayMathOps::lower_equal(lhs, rhs)
388        })
389    }
390
391    fn float_lower_equal_elem(
392        lhs: FloatTensor<Self>,
393        rhs: Scalar,
394        _out_dtype: BoolDType,
395    ) -> NdArrayTensor {
396        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
397            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
398        })
399    }
400
401    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402        tensor
403    }
404
405    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406        // Use view() for zero-copy on borrowed storage
407        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408            NdArrayMathOps::mean_view(array.view())
409        })
410    }
411
412    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413        // Use view() for zero-copy on borrowed storage
414        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415            NdArrayMathOps::sum_view(array.view())
416        })
417    }
418
419    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
421            NdArrayMathOps::mean_dim(array, dim)
422        })
423    }
424
425    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
427            NdArrayMathOps::cumsum(array, dim)
428        })
429    }
430
431    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
432        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
433            NdArrayMathOps::cumprod(array, dim)
434        })
435    }
436
437    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
438        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439            NdArrayMathOps::cummin(array, dim)
440        })
441    }
442
443    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
444        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
445            NdArrayMathOps::cummax(array, dim)
446        })
447    }
448
449    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
450        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
451            NdArrayMathOps::sum_dim(array, dim)
452        })
453    }
454
455    fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
456        // Use view() for zero-copy on borrowed storage
457        execute_with_int_out_dtype!(out_dtype, I, {
458            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459                NdArrayMathOps::argmax_view::<I>(array.view(), dim)
460            })
461        })
462    }
463
464    fn float_argtopk(
465        _tensor: FloatTensor<Self>,
466        _dim: usize,
467        _k: usize,
468        _out_dtype: IntDType,
469    ) -> NdArrayTensor {
470        unimplemented!("float_argtopk not implemented for ndarray")
471    }
472
473    fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
474        // Use view() for zero-copy on borrowed storage
475        execute_with_int_out_dtype!(out_dtype, I, {
476            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
477                NdArrayMathOps::argmin_view::<I>(array.view(), dim)
478            })
479        })
480    }
481
482    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
484            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
485        })
486    }
487
488    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
489        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
490            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
491        })
492    }
493
494    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495        // Use view() for zero-copy on borrowed storage
496        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
497            NdArrayMathOps::prod_view(array.view())
498        })
499    }
500
501    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
502        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
503            NdArrayMathOps::prod_dim(array, dim)
504        })
505    }
506
507    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
508        // Use view() for zero-copy on borrowed storage
509        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
510            NdArrayMathOps::max_view(array.view())
511        })
512    }
513
514    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515        // Use view() for zero-copy on borrowed storage
516        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
517            NdArrayMathOps::min_view(array.view())
518        })
519    }
520
521    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
522        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
523            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
524        })
525    }
526
527    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
528        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
529            array
530                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
531                .into_shared()
532        })
533    }
534
535    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
536        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
537            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
538        })
539    }
540
541    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
542        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
543            NdArrayMathOps::abs(array)
544        })
545    }
546
547    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
548        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
549            array
550                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
551                .into_shared()
552        })
553    }
554
555    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
556        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
557            array
558                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
559                .into_shared()
560        })
561    }
562
563    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
564        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
565            array
566                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
567                .into_shared()
568        })
569    }
570
571    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
572        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
573            array
574                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
575                .into_shared()
576        })
577    }
578
579    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
580        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
581            array
582                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
583                .into_shared()
584        })
585    }
586
587    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
588        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
589            array
590                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
591                .into_shared()
592        })
593    }
594
595    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
596        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
597            array
598                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
599                .into_shared()
600        })
601    }
602
603    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
604        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
605            array
606                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
607                .into_shared()
608        })
609    }
610
611    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
613            array
614                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
615                .into_shared()
616        })
617    }
618
619    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
620        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
621            array
622                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
623                .into_shared()
624        })
625    }
626
627    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
628        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
629            array
630                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
631                .into_shared()
632        })
633    }
634
635    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
636        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
637            array
638                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
639                .into_shared()
640        })
641    }
642
643    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
644        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
645            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
646        })
647    }
648
649    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
650        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
651            array
652                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
653                .into_shared()
654        })
655    }
656
657    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
658        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
659            array
660                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
661                .into_shared()
662        })
663    }
664
665    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
666        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
667            array
668                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
669                .into_shared()
670        })
671    }
672
673    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
674        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
675            array
676                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
677                .into_shared()
678        })
679    }
680
681    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
682        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
683            array
684                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
685                .into_shared()
686        })
687    }
688
689    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
690        cat_with_dtype!(tensors, dim, [F64, F32])
691    }
692
693    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
694        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
695            NdArrayMathOps::clamp_min(array, min.elem())
696        })
697    }
698
699    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
700        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
701            NdArrayMathOps::clamp_max(array, max.elem())
702        })
703    }
704
705    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
706        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
707            NdArrayMathOps::clamp(array, min.elem(), max.elem())
708        })
709    }
710
711    fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
712        execute_with_int_out_dtype!(out_dtype, I, {
713            execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
714                array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
715            })
716        })
717    }
718
719    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
720        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
721            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
722        })
723    }
724
725    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
726        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
727            NdArrayOps::permute(array, axes)
728        })
729    }
730
731    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
732        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
733            NdArrayOps::flip(array, axes)
734        })
735    }
736
737    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
738        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
739            NdArrayMathOps::sign_op(array)
740        })
741    }
742
743    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
744        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
745            NdArrayOps::expand(array, shape)
746        })
747    }
748
749    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
750        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
751            cast_to_dtype(array, dtype.into())
752        })
753    }
754
755    fn float_grid_sample_2d(
756        tensor: FloatTensor<Self>,
757        grid: FloatTensor<Self>,
758        options: GridSampleOptions,
759    ) -> FloatTensor<Self> {
760        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
761            tensor, grid, options
762        ))
763    }
764
765    fn float_unfold(
766        tensor: FloatTensor<Self>,
767        dim: usize,
768        size: usize,
769        step: usize,
770    ) -> FloatTensor<Self> {
771        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
772            NdArrayOps::unfold(array, dim, size, step)
773        })
774    }
775}