burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7
8// Current crate
9use super::{
10    NdArrayMathOps, NdArrayOps,
11    matmul::{cross, matmul},
12};
13use crate::{
14    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED};
17use crate::{
18    SharedArray,
19    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23// Workspace crates
24use crate::rand::get_seeded_rng;
25use burn_backend::{Distribution, FloatDType};
26use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37    x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43    if (x - x.floor()) == 0.5 {
44        (x * 0.5).round() * 2.0
45    } else {
46        x.round()
47    }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51    for NdArray<E, I, Q>
52where
53    NdArrayTensor: From<SharedArray<E>>,
54    NdArrayTensor: From<SharedArray<I>>,
55{
56    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57        NdArrayTensor::from_data(data)
58    }
59
60    fn float_random(
61        shape: Shape,
62        distribution: Distribution,
63        device: &NdArrayDevice,
64    ) -> FloatTensor<Self> {
65        let mut seed = SEED.lock().unwrap();
66        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67            rng_seeded.clone()
68        } else {
69            get_seeded_rng()
70        };
71        let tensor = Self::float_from_data(
72            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73            device,
74        );
75        *seed = Some(rng);
76        tensor
77    }
78
79    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80        Ok(tensor.into_data())
81    }
82
83    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84        NdArrayDevice::Cpu
85    }
86
87    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88        tensor
89    }
90
91    fn float_empty(
92        shape: Shape,
93        device: &<NdArray<E> as Backend>::Device,
94        dtype: FloatDType,
95    ) -> FloatTensor<Self> {
96        Self::float_zeros(shape, device, dtype)
97    }
98
99    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101    }
102
103    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
104        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
105            NdArrayMathOps::add_scalar(array, rhs.elem())
106        })
107    }
108
109    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
110        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
111    }
112
113    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
114        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
115            NdArrayMathOps::sub_scalar(array, rhs.elem())
116        })
117    }
118
119    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
120        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
121    }
122
123    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
124        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
125            NdArrayMathOps::mul_scalar(array, rhs.elem())
126        })
127    }
128
129    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
131    }
132
133    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
134        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
135            NdArrayMathOps::div_scalar(array, rhs.elem())
136        })
137    }
138
139    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
141    }
142
143    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
144        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
145            NdArrayMathOps::remainder_scalar(array, rhs.elem())
146        })
147    }
148
149    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150        execute_with_float_dtype!((lhs, rhs), matmul)
151    }
152
153    fn float_cross(
154        lhs: FloatTensor<Self>,
155        rhs: FloatTensor<Self>,
156        dim: usize,
157    ) -> FloatTensor<Self> {
158        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
159    }
160
161    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
162        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
163    }
164
165    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
166        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
167            NdArrayMathOps::recip(array)
168        })
169    }
170
171    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
172        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
173            NdArrayOps::swap_dims(array, dim1, dim2)
174        })
175    }
176
177    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
178        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
179            NdArrayOps::reshape(array, shape)
180        })
181    }
182
183    fn float_gather(
184        dim: usize,
185        tensor: FloatTensor<Self>,
186        indices: NdArrayTensor,
187    ) -> FloatTensor<Self> {
188        execute_with_int_dtype!(
189            indices,
190            IntElem,
191            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
192                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
193                    NdArrayOps::gather(dim, array, idx_array)
194                })
195            }
196        )
197    }
198
199    fn float_scatter_add(
200        dim: usize,
201        tensor: FloatTensor<Self>,
202        indices: NdArrayTensor,
203        value: FloatTensor<Self>,
204    ) -> FloatTensor<Self> {
205        execute_with_int_dtype!(
206            indices,
207            IntElem,
208            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
209                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
210                    dim, tensor, idx_array, value
211                ))
212            }
213        )
214    }
215
216    fn float_select(
217        tensor: FloatTensor<Self>,
218        dim: usize,
219        indices: NdArrayTensor,
220    ) -> FloatTensor<Self> {
221        execute_with_int_dtype!(
222            indices,
223            IntElem,
224            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
225                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
226                    NdArrayMathOps::select(array, dim, idx_array)
227                })
228            }
229        )
230    }
231
232    fn float_select_add(
233        tensor: FloatTensor<Self>,
234        dim: usize,
235        indices: NdArrayTensor,
236        value: FloatTensor<Self>,
237    ) -> FloatTensor<Self> {
238        execute_with_int_dtype!(
239            indices,
240            IntElem,
241            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
242                execute_with_float_dtype!((tensor, value), |tensor, value| {
243                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
244                })
245            }
246        )
247    }
248
249    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
250        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
251            NdArrayOps::slice(array, slices)
252        })
253    }
254
255    fn float_slice_assign(
256        tensor: FloatTensor<Self>,
257        slices: &[burn_backend::Slice],
258        value: FloatTensor<Self>,
259    ) -> FloatTensor<Self> {
260        execute_with_float_dtype!((tensor, value), |tensor, value| {
261            NdArrayOps::slice_assign(tensor, slices, value)
262        })
263    }
264
265    fn float_mask_where(
266        tensor: FloatTensor<Self>,
267        mask: NdArrayTensor,
268        value: FloatTensor<Self>,
269    ) -> FloatTensor<Self> {
270        execute_with_float_dtype!((tensor, value), |tensor, value| {
271            NdArrayOps::mask_where(tensor, mask.bool(), value)
272        })
273    }
274
275    fn float_mask_fill(
276        tensor: FloatTensor<Self>,
277        mask: NdArrayTensor,
278        value: E,
279    ) -> FloatTensor<Self> {
280        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
281            NdArrayOps::mask_fill(array, mask.bool(), value.elem())
282        })
283    }
284
285    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
286        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
287    }
288
289    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
290        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
291            NdArrayMathOps::equal_elem(array, rhs.elem())
292        })
293    }
294
295    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
296        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
297    }
298
299    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
300        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
301            NdArrayMathOps::greater_elem(array, rhs.elem())
302        })
303    }
304
305    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
306        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
307            NdArrayMathOps::greater_equal(lhs, rhs)
308        })
309    }
310
311    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
312        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
313            NdArrayMathOps::greater_equal_elem(array, rhs.elem())
314        })
315    }
316
317    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
318        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
319    }
320
321    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
322        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
323            NdArrayMathOps::lower_elem(array, rhs.elem())
324        })
325    }
326
327    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
328        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
329            NdArrayMathOps::lower_equal(lhs, rhs)
330        })
331    }
332
333    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
334        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
335            NdArrayMathOps::lower_equal_elem(array, rhs.elem())
336        })
337    }
338
339    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
340        tensor
341    }
342
343    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
344        // Use view() for zero-copy on borrowed storage
345        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
346            NdArrayMathOps::mean_view(array.view())
347        })
348    }
349
350    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351        // Use view() for zero-copy on borrowed storage
352        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
353            NdArrayMathOps::sum_view(array.view())
354        })
355    }
356
357    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
358        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
359            NdArrayMathOps::mean_dim(array, dim)
360        })
361    }
362
363    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
364        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
365            NdArrayMathOps::cumsum(array, dim)
366        })
367    }
368
369    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
371            NdArrayMathOps::cumprod(array, dim)
372        })
373    }
374
375    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
376        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
377            NdArrayMathOps::cummin(array, dim)
378        })
379    }
380
381    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
382        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383            NdArrayMathOps::cummax(array, dim)
384        })
385    }
386
387    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
388        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
389            NdArrayMathOps::sum_dim(array, dim)
390        })
391    }
392
393    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
394        // Use view() for zero-copy on borrowed storage
395        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
396            NdArrayMathOps::argmax_view::<I>(array.view(), dim)
397        })
398    }
399
400    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
401        // Use view() for zero-copy on borrowed storage
402        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
403            NdArrayMathOps::argmin_view::<I>(array.view(), dim)
404        })
405    }
406
407    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
408        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
409            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
410        })
411    }
412
413    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
414        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
416        })
417    }
418
419    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
420        // Use view() for zero-copy on borrowed storage
421        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422            NdArrayMathOps::prod_view(array.view())
423        })
424    }
425
426    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
427        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
428            NdArrayMathOps::prod_dim(array, dim)
429        })
430    }
431
432    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
433        // Use view() for zero-copy on borrowed storage
434        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
435            NdArrayMathOps::max_view(array.view())
436        })
437    }
438
439    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440        // Use view() for zero-copy on borrowed storage
441        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
442            NdArrayMathOps::min_view(array.view())
443        })
444    }
445
446    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
447        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
448            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
449        })
450    }
451
452    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
453        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
454            array
455                .mapv_into(|a: FloatElem| a.powf_elem(value))
456                .into_shared()
457        })
458    }
459
460    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
462            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
463        })
464    }
465
466    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
468            NdArrayMathOps::abs(array)
469        })
470    }
471
472    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
473        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
474            array
475                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
476                .into_shared()
477        })
478    }
479
480    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
481        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
482            array
483                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
484                .into_shared()
485        })
486    }
487
488    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
489        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
490            array
491                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
492                .into_shared()
493        })
494    }
495
496    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
497        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
498            array
499                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
500                .into_shared()
501        })
502    }
503
504    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
505        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
506            array
507                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
508                .into_shared()
509        })
510    }
511
512    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
513        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
514            array
515                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
516                .into_shared()
517        })
518    }
519
520    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
522            array
523                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
524                .into_shared()
525        })
526    }
527
528    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
529        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
530            array
531                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
532                .into_shared()
533        })
534    }
535
536    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
537        cat_with_dtype!(tensors, dim, [F64, F32])
538    }
539
540    fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
541        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
542            NdArrayMathOps::clamp_min(array, min.elem())
543        })
544    }
545
546    fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
547        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
548            NdArrayMathOps::clamp_max(array, max.elem())
549        })
550    }
551
552    fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
553        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
554            NdArrayMathOps::clamp(array, min.elem(), max.elem())
555        })
556    }
557
558    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
559        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
560            array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
561        })
562    }
563
564    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
565        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
566            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
567        })
568    }
569
570    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
571        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
572            NdArrayOps::permute(array, axes)
573        })
574    }
575
576    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
577        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
578            NdArrayOps::flip(array, axes)
579        })
580    }
581
582    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
583        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
584            NdArrayMathOps::sign_op(array)
585        })
586    }
587
588    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
589        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
590            NdArrayOps::expand(array, shape)
591        })
592    }
593
594    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
595        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
596            cast_to_dtype(array, dtype.into())
597        })
598    }
599
600    fn float_grid_sample_2d(
601        tensor: FloatTensor<Self>,
602        grid: FloatTensor<Self>,
603        options: GridSampleOptions,
604    ) -> FloatTensor<Self> {
605        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
606            tensor, grid, options
607        ))
608    }
609
610    fn float_unfold(
611        tensor: FloatTensor<Self>,
612        dim: usize,
613        size: usize,
614        step: usize,
615    ) -> FloatTensor<Self> {
616        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617            NdArrayOps::unfold(array, dim, size, step)
618        })
619    }
620}