1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7
8use super::{
10 NdArrayMathOps, NdArrayOps,
11 matmul::{cross, matmul},
12};
13use crate::{
14 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED};
17use crate::{
18 SharedArray,
19 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23use crate::rand::get_seeded_rng;
25use burn_backend::{Distribution, FloatDType};
26use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37 x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43 if (x - x.floor()) == 0.5 {
44 (x * 0.5).round() * 2.0
45 } else {
46 x.round()
47 }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51 for NdArray<E, I, Q>
52where
53 NdArrayTensor: From<SharedArray<E>>,
54 NdArrayTensor: From<SharedArray<I>>,
55{
56 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57 NdArrayTensor::from_data(data)
58 }
59
60 fn float_random(
61 shape: Shape,
62 distribution: Distribution,
63 device: &NdArrayDevice,
64 ) -> FloatTensor<Self> {
65 let mut seed = SEED.lock().unwrap();
66 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67 rng_seeded.clone()
68 } else {
69 get_seeded_rng()
70 };
71 let tensor = Self::float_from_data(
72 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73 device,
74 );
75 *seed = Some(rng);
76 tensor
77 }
78
79 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80 Ok(tensor.into_data())
81 }
82
83 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84 NdArrayDevice::Cpu
85 }
86
87 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88 tensor
89 }
90
91 fn float_empty(
92 shape: Shape,
93 device: &<NdArray<E> as Backend>::Device,
94 dtype: FloatDType,
95 ) -> FloatTensor<Self> {
96 Self::float_zeros(shape, device, dtype)
97 }
98
99 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101 }
102
103 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
104 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
105 NdArrayMathOps::add_scalar(array, rhs.elem())
106 })
107 }
108
109 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
110 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
111 }
112
113 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
114 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
115 NdArrayMathOps::sub_scalar(array, rhs.elem())
116 })
117 }
118
119 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
120 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
121 }
122
123 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
124 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
125 NdArrayMathOps::mul_scalar(array, rhs.elem())
126 })
127 }
128
129 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
131 }
132
133 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
134 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
135 NdArrayMathOps::div_scalar(array, rhs.elem())
136 })
137 }
138
139 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
141 }
142
143 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
144 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
145 NdArrayMathOps::remainder_scalar(array, rhs.elem())
146 })
147 }
148
149 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150 execute_with_float_dtype!((lhs, rhs), matmul)
151 }
152
153 fn float_cross(
154 lhs: FloatTensor<Self>,
155 rhs: FloatTensor<Self>,
156 dim: usize,
157 ) -> FloatTensor<Self> {
158 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
159 }
160
161 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
162 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
163 }
164
165 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
166 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
167 NdArrayMathOps::recip(array)
168 })
169 }
170
171 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
172 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
173 NdArrayOps::swap_dims(array, dim1, dim2)
174 })
175 }
176
177 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
178 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
179 NdArrayOps::reshape(array, shape)
180 })
181 }
182
183 fn float_gather(
184 dim: usize,
185 tensor: FloatTensor<Self>,
186 indices: NdArrayTensor,
187 ) -> FloatTensor<Self> {
188 execute_with_int_dtype!(
189 indices,
190 IntElem,
191 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
192 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
193 NdArrayOps::gather(dim, array, idx_array)
194 })
195 }
196 )
197 }
198
199 fn float_scatter_add(
200 dim: usize,
201 tensor: FloatTensor<Self>,
202 indices: NdArrayTensor,
203 value: FloatTensor<Self>,
204 ) -> FloatTensor<Self> {
205 execute_with_int_dtype!(
206 indices,
207 IntElem,
208 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
209 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
210 dim, tensor, idx_array, value
211 ))
212 }
213 )
214 }
215
216 fn float_select(
217 tensor: FloatTensor<Self>,
218 dim: usize,
219 indices: NdArrayTensor,
220 ) -> FloatTensor<Self> {
221 execute_with_int_dtype!(
222 indices,
223 IntElem,
224 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
225 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
226 NdArrayMathOps::select(array, dim, idx_array)
227 })
228 }
229 )
230 }
231
232 fn float_select_add(
233 tensor: FloatTensor<Self>,
234 dim: usize,
235 indices: NdArrayTensor,
236 value: FloatTensor<Self>,
237 ) -> FloatTensor<Self> {
238 execute_with_int_dtype!(
239 indices,
240 IntElem,
241 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
242 execute_with_float_dtype!((tensor, value), |tensor, value| {
243 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
244 })
245 }
246 )
247 }
248
249 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
250 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
251 NdArrayOps::slice(array, slices)
252 })
253 }
254
255 fn float_slice_assign(
256 tensor: FloatTensor<Self>,
257 slices: &[burn_backend::Slice],
258 value: FloatTensor<Self>,
259 ) -> FloatTensor<Self> {
260 execute_with_float_dtype!((tensor, value), |tensor, value| {
261 NdArrayOps::slice_assign(tensor, slices, value)
262 })
263 }
264
265 fn float_mask_where(
266 tensor: FloatTensor<Self>,
267 mask: NdArrayTensor,
268 value: FloatTensor<Self>,
269 ) -> FloatTensor<Self> {
270 execute_with_float_dtype!((tensor, value), |tensor, value| {
271 NdArrayOps::mask_where(tensor, mask.bool(), value)
272 })
273 }
274
275 fn float_mask_fill(
276 tensor: FloatTensor<Self>,
277 mask: NdArrayTensor,
278 value: E,
279 ) -> FloatTensor<Self> {
280 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
281 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
282 })
283 }
284
285 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
286 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
287 }
288
289 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
290 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
291 NdArrayMathOps::equal_elem(array, rhs.elem())
292 })
293 }
294
295 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
296 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
297 }
298
299 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
300 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
301 NdArrayMathOps::greater_elem(array, rhs.elem())
302 })
303 }
304
305 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
306 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
307 NdArrayMathOps::greater_equal(lhs, rhs)
308 })
309 }
310
311 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
312 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
313 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
314 })
315 }
316
317 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
318 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
319 }
320
321 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
322 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
323 NdArrayMathOps::lower_elem(array, rhs.elem())
324 })
325 }
326
327 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
328 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
329 NdArrayMathOps::lower_equal(lhs, rhs)
330 })
331 }
332
333 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
334 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
335 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
336 })
337 }
338
339 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
340 tensor
341 }
342
343 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
344 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
346 NdArrayMathOps::mean_view(array.view())
347 })
348 }
349
350 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
353 NdArrayMathOps::sum_view(array.view())
354 })
355 }
356
357 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
358 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
359 NdArrayMathOps::mean_dim(array, dim)
360 })
361 }
362
363 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
364 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
365 NdArrayMathOps::cumsum(array, dim)
366 })
367 }
368
369 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
371 NdArrayMathOps::cumprod(array, dim)
372 })
373 }
374
375 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
376 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
377 NdArrayMathOps::cummin(array, dim)
378 })
379 }
380
381 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
382 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383 NdArrayMathOps::cummax(array, dim)
384 })
385 }
386
387 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
388 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
389 NdArrayMathOps::sum_dim(array, dim)
390 })
391 }
392
393 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
394 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
396 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
397 })
398 }
399
400 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
401 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
403 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
404 })
405 }
406
407 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
408 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
409 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
410 })
411 }
412
413 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
414 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
416 })
417 }
418
419 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
420 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422 NdArrayMathOps::prod_view(array.view())
423 })
424 }
425
426 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
427 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
428 NdArrayMathOps::prod_dim(array, dim)
429 })
430 }
431
432 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
433 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
435 NdArrayMathOps::max_view(array.view())
436 })
437 }
438
439 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
440 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
442 NdArrayMathOps::min_view(array.view())
443 })
444 }
445
446 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
447 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
448 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
449 })
450 }
451
452 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
453 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
454 array
455 .mapv_into(|a: FloatElem| a.powf_elem(value))
456 .into_shared()
457 })
458 }
459
460 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
462 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
463 })
464 }
465
466 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
468 NdArrayMathOps::abs(array)
469 })
470 }
471
472 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
473 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
474 array
475 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
476 .into_shared()
477 })
478 }
479
480 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
481 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
482 array
483 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
484 .into_shared()
485 })
486 }
487
488 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
489 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
490 array
491 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
492 .into_shared()
493 })
494 }
495
496 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
497 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
498 array
499 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
500 .into_shared()
501 })
502 }
503
504 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
505 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
506 array
507 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
508 .into_shared()
509 })
510 }
511
512 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
513 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
514 array
515 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
516 .into_shared()
517 })
518 }
519
520 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
521 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
522 array
523 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
524 .into_shared()
525 })
526 }
527
528 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
529 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
530 array
531 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
532 .into_shared()
533 })
534 }
535
536 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
537 cat_with_dtype!(tensors, dim, [F64, F32])
538 }
539
540 fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
541 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
542 NdArrayMathOps::clamp_min(array, min.elem())
543 })
544 }
545
546 fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
547 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
548 NdArrayMathOps::clamp_max(array, max.elem())
549 })
550 }
551
552 fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
553 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
554 NdArrayMathOps::clamp(array, min.elem(), max.elem())
555 })
556 }
557
558 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
559 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
560 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
561 })
562 }
563
564 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
565 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
566 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
567 })
568 }
569
570 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
571 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
572 NdArrayOps::permute(array, axes)
573 })
574 }
575
576 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
577 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
578 NdArrayOps::flip(array, axes)
579 })
580 }
581
582 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
583 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
584 NdArrayMathOps::sign_op(array)
585 })
586 }
587
588 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
589 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
590 NdArrayOps::expand(array, shape)
591 })
592 }
593
594 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
595 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
596 cast_to_dtype(array, dtype.into())
597 })
598 }
599
600 fn float_grid_sample_2d(
601 tensor: FloatTensor<Self>,
602 grid: FloatTensor<Self>,
603 options: GridSampleOptions,
604 ) -> FloatTensor<Self> {
605 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
606 tensor, grid, options
607 ))
608 }
609
610 fn float_unfold(
611 tensor: FloatTensor<Self>,
612 dim: usize,
613 size: usize,
614 step: usize,
615 ) -> FloatTensor<Self> {
616 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617 NdArrayOps::unfold(array, dim, size, step)
618 })
619 }
620}