1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7
8use super::{
10 NdArrayMathOps, NdArrayOps,
11 matmul::{cross, matmul},
12};
13use crate::{
14 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED, slice};
17use crate::{
18 SharedArray,
19 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23use crate::rand::get_seeded_rng;
25use burn_backend::{Distribution, FloatDType, Scalar};
26use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37 x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43 if (x - x.floor()) == 0.5 {
44 (x * 0.5).round() * 2.0
45 } else {
46 x.round()
47 }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51 for NdArray<E, I, Q>
52where
53 NdArrayTensor: From<SharedArray<E>>,
54 NdArrayTensor: From<SharedArray<I>>,
55{
56 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57 NdArrayTensor::from_data(data)
58 }
59
60 fn float_random(
61 shape: Shape,
62 distribution: Distribution,
63 device: &NdArrayDevice,
64 ) -> FloatTensor<Self> {
65 let mut seed = SEED.lock().unwrap();
66 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67 rng_seeded.clone()
68 } else {
69 get_seeded_rng()
70 };
71 let tensor = Self::float_from_data(
72 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73 device,
74 );
75 *seed = Some(rng);
76 tensor
77 }
78
79 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80 Ok(tensor.into_data())
81 }
82
83 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84 NdArrayDevice::Cpu
85 }
86
87 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88 tensor
89 }
90
91 fn float_empty(
92 shape: Shape,
93 device: &<NdArray<E> as Backend>::Device,
94 dtype: FloatDType,
95 ) -> FloatTensor<Self> {
96 Self::float_zeros(shape, device, dtype)
97 }
98
99 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101 }
102
103 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
104 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
105 NdArrayMathOps::add_scalar(array, rhs.elem())
106 })
107 }
108
109 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
110 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
111 }
112
113 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
114 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
115 NdArrayMathOps::sub_scalar(array, rhs.elem())
116 })
117 }
118
119 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
120 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
121 }
122
123 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
124 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
125 NdArrayMathOps::mul_scalar(array, rhs.elem())
126 })
127 }
128
129 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
131 }
132
133 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
134 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
135 NdArrayMathOps::div_scalar(array, rhs.elem())
136 })
137 }
138
139 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
141 }
142
143 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
144 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
145 NdArrayMathOps::remainder_scalar(array, rhs.elem())
146 })
147 }
148
149 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
150 execute_with_float_dtype!((lhs, rhs), matmul)
151 }
152
153 fn float_cross(
154 lhs: FloatTensor<Self>,
155 rhs: FloatTensor<Self>,
156 dim: usize,
157 ) -> FloatTensor<Self> {
158 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
159 }
160
161 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
162 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
163 NdArrayMathOps::recip(array)
164 })
165 }
166
167 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
168 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
169 NdArrayOps::swap_dims(array, dim1, dim2)
170 })
171 }
172
173 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
174 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
175 NdArrayOps::reshape(array, shape)
176 })
177 }
178
179 fn float_gather(
180 dim: usize,
181 tensor: FloatTensor<Self>,
182 indices: NdArrayTensor,
183 ) -> FloatTensor<Self> {
184 execute_with_int_dtype!(
185 indices,
186 IntElem,
187 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
188 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
189 NdArrayOps::gather(dim, array, idx_array)
190 })
191 }
192 )
193 }
194
195 fn float_scatter_add(
196 dim: usize,
197 tensor: FloatTensor<Self>,
198 indices: NdArrayTensor,
199 value: FloatTensor<Self>,
200 ) -> FloatTensor<Self> {
201 execute_with_int_dtype!(
202 indices,
203 IntElem,
204 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
205 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
206 dim, tensor, idx_array, value
207 ))
208 }
209 )
210 }
211
212 fn float_select(
213 tensor: FloatTensor<Self>,
214 dim: usize,
215 indices: NdArrayTensor,
216 ) -> FloatTensor<Self> {
217 execute_with_int_dtype!(
218 indices,
219 IntElem,
220 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
222 NdArrayMathOps::select(array, dim, idx_array)
223 })
224 }
225 )
226 }
227
228 fn float_select_add(
229 tensor: FloatTensor<Self>,
230 dim: usize,
231 indices: NdArrayTensor,
232 value: FloatTensor<Self>,
233 ) -> FloatTensor<Self> {
234 execute_with_int_dtype!(
235 indices,
236 IntElem,
237 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
238 execute_with_float_dtype!((tensor, value), |tensor, value| {
239 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
240 })
241 }
242 )
243 }
244
245 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
246 slice!(tensor, slices)
247 }
248
249 fn float_slice_assign(
250 tensor: FloatTensor<Self>,
251 slices: &[burn_backend::Slice],
252 value: FloatTensor<Self>,
253 ) -> FloatTensor<Self> {
254 execute_with_float_dtype!((tensor, value), |tensor, value| {
255 NdArrayOps::slice_assign(tensor, slices, value)
256 })
257 }
258
259 fn float_mask_where(
260 tensor: FloatTensor<Self>,
261 mask: NdArrayTensor,
262 value: FloatTensor<Self>,
263 ) -> FloatTensor<Self> {
264 execute_with_float_dtype!((tensor, value), |tensor, value| {
265 NdArrayOps::mask_where(tensor, mask.bool(), value)
266 })
267 }
268
269 fn float_mask_fill(
270 tensor: FloatTensor<Self>,
271 mask: NdArrayTensor,
272 value: Scalar,
273 ) -> FloatTensor<Self> {
274 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
275 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
276 })
277 }
278
279 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
280 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
281 }
282
283 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
284 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
285 NdArrayMathOps::equal_elem(array, rhs.elem())
286 })
287 }
288
289 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
290 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
291 }
292
293 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
294 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
295 NdArrayMathOps::greater_elem(array, rhs.elem())
296 })
297 }
298
299 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
300 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
301 NdArrayMathOps::greater_equal(lhs, rhs)
302 })
303 }
304
305 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
306 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
307 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
308 })
309 }
310
311 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
312 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
313 }
314
315 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
316 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
317 NdArrayMathOps::lower_elem(array, rhs.elem())
318 })
319 }
320
321 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
322 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
323 NdArrayMathOps::lower_equal(lhs, rhs)
324 })
325 }
326
327 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
328 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
329 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
330 })
331 }
332
333 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
334 tensor
335 }
336
337 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
338 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
340 NdArrayMathOps::mean_view(array.view())
341 })
342 }
343
344 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
345 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
347 NdArrayMathOps::sum_view(array.view())
348 })
349 }
350
351 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
352 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
353 NdArrayMathOps::mean_dim(array, dim)
354 })
355 }
356
357 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
358 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
359 NdArrayMathOps::cumsum(array, dim)
360 })
361 }
362
363 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
364 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
365 NdArrayMathOps::cumprod(array, dim)
366 })
367 }
368
369 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
370 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
371 NdArrayMathOps::cummin(array, dim)
372 })
373 }
374
375 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
376 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
377 NdArrayMathOps::cummax(array, dim)
378 })
379 }
380
381 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
382 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383 NdArrayMathOps::sum_dim(array, dim)
384 })
385 }
386
387 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
388 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
390 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
391 })
392 }
393
394 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
395 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
397 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
398 })
399 }
400
401 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
403 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
404 })
405 }
406
407 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
408 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
409 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
410 })
411 }
412
413 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
414 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
416 NdArrayMathOps::prod_view(array.view())
417 })
418 }
419
420 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
421 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422 NdArrayMathOps::prod_dim(array, dim)
423 })
424 }
425
426 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
427 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
429 NdArrayMathOps::max_view(array.view())
430 })
431 }
432
433 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
434 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
436 NdArrayMathOps::min_view(array.view())
437 })
438 }
439
440 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
441 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
442 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
443 })
444 }
445
446 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
447 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
448 array
449 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
450 .into_shared()
451 })
452 }
453
454 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
456 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
457 })
458 }
459
460 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
462 NdArrayMathOps::abs(array)
463 })
464 }
465
466 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
467 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
468 array
469 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
470 .into_shared()
471 })
472 }
473
474 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
475 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
476 array
477 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
478 .into_shared()
479 })
480 }
481
482 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
484 array
485 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
486 .into_shared()
487 })
488 }
489
490 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
491 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
492 array
493 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
494 .into_shared()
495 })
496 }
497
498 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
500 array
501 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
502 .into_shared()
503 })
504 }
505
506 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
507 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
508 array
509 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
510 .into_shared()
511 })
512 }
513
514 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
516 array
517 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
518 .into_shared()
519 })
520 }
521
522 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
523 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
524 array
525 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
526 .into_shared()
527 })
528 }
529
530 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
531 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
532 array
533 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
534 .into_shared()
535 })
536 }
537
538 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
539 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
540 array
541 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
542 .into_shared()
543 })
544 }
545
546 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
547 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
548 array
549 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
550 .into_shared()
551 })
552 }
553
554 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
555 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
556 array
557 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
558 .into_shared()
559 })
560 }
561
562 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
563 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
564 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
565 })
566 }
567
568 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
569 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
570 array
571 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
572 .into_shared()
573 })
574 }
575
576 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
577 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
578 array
579 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
580 .into_shared()
581 })
582 }
583
584 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
585 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
586 array
587 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
588 .into_shared()
589 })
590 }
591
592 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
593 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
594 array
595 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
596 .into_shared()
597 })
598 }
599
600 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
601 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
602 array
603 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
604 .into_shared()
605 })
606 }
607
608 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
609 cat_with_dtype!(tensors, dim, [F64, F32])
610 }
611
612 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
613 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
614 NdArrayMathOps::clamp_min(array, min.elem())
615 })
616 }
617
618 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
619 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
620 NdArrayMathOps::clamp_max(array, max.elem())
621 })
622 }
623
624 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
625 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
626 NdArrayMathOps::clamp(array, min.elem(), max.elem())
627 })
628 }
629
630 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
631 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
632 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
633 })
634 }
635
636 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
637 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
638 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
639 })
640 }
641
642 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
643 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
644 NdArrayOps::permute(array, axes)
645 })
646 }
647
648 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
649 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
650 NdArrayOps::flip(array, axes)
651 })
652 }
653
654 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
655 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
656 NdArrayMathOps::sign_op(array)
657 })
658 }
659
660 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
661 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
662 NdArrayOps::expand(array, shape)
663 })
664 }
665
666 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
667 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
668 cast_to_dtype(array, dtype.into())
669 })
670 }
671
672 fn float_grid_sample_2d(
673 tensor: FloatTensor<Self>,
674 grid: FloatTensor<Self>,
675 options: GridSampleOptions,
676 ) -> FloatTensor<Self> {
677 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
678 tensor, grid, options
679 ))
680 }
681
682 fn float_unfold(
683 tensor: FloatTensor<Self>,
684 dim: usize,
685 size: usize,
686 step: usize,
687 ) -> FloatTensor<Self> {
688 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
689 NdArrayOps::unfold(array, dim, size, step)
690 })
691 }
692}