1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9use super::{
11 NdArrayMathOps, NdArrayOps,
12 matmul::{cross, matmul},
13};
14use crate::{
15 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19 SharedArray,
20 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38 x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44 if (x - x.floor()) == 0.5 {
45 (x * 0.5).round() * 2.0
46 } else {
47 x.round()
48 }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52 for NdArray<E, I, Q>
53where
54 NdArrayTensor: From<SharedArray<E>>,
55 NdArrayTensor: From<SharedArray<I>>,
56{
57 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58 NdArrayTensor::from_data(data)
59 }
60
61 fn float_random(
62 shape: Shape,
63 distribution: Distribution,
64 device: &NdArrayDevice,
65 dtype: FloatDType,
66 ) -> FloatTensor<Self> {
67 let mut seed = SEED.lock().unwrap();
68 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69 let tensor = execute_with_float_out_dtype!(
70 dtype,
71 E,
72 Self::float_from_data(
73 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74 device,
75 )
76 );
77
78 *seed = Some(rng);
79 tensor
80 }
81
82 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83 Ok(tensor.into_data())
84 }
85
86 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87 NdArrayDevice::Cpu
88 }
89
90 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91 tensor
92 }
93
94 fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95 Self::float_zeros(shape, device, dtype)
96 }
97
98 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100 }
101
102 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104 NdArrayMathOps::add_scalar(array, rhs.elem())
105 })
106 }
107
108 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110 }
111
112 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114 NdArrayMathOps::sub_scalar(array, rhs.elem())
115 })
116 }
117
118 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120 }
121
122 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124 NdArrayMathOps::mul_scalar(array, rhs.elem())
125 })
126 }
127
128 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130 }
131
132 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134 NdArrayMathOps::div_scalar(array, rhs.elem())
135 })
136 }
137
138 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140 }
141
142 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144 NdArrayMathOps::remainder_scalar(array, rhs.elem())
145 })
146 }
147
148 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149 execute_with_float_dtype!((lhs, rhs), matmul)
150 }
151
152 fn float_cross(
153 lhs: FloatTensor<Self>,
154 rhs: FloatTensor<Self>,
155 dim: usize,
156 ) -> FloatTensor<Self> {
157 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158 }
159
160 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162 NdArrayMathOps::recip(array)
163 })
164 }
165
166 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168 NdArrayOps::swap_dims(array, dim1, dim2)
169 })
170 }
171
172 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174 NdArrayOps::reshape(array, shape)
175 })
176 }
177
178 fn float_gather(
179 dim: usize,
180 tensor: FloatTensor<Self>,
181 indices: NdArrayTensor,
182 ) -> FloatTensor<Self> {
183 execute_with_int_dtype!(
184 indices,
185 IntElem,
186 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188 NdArrayOps::gather(dim, array, idx_array)
189 })
190 }
191 )
192 }
193
194 fn float_scatter_add(
195 dim: usize,
196 tensor: FloatTensor<Self>,
197 indices: NdArrayTensor,
198 value: FloatTensor<Self>,
199 ) -> FloatTensor<Self> {
200 execute_with_int_dtype!(
201 indices,
202 IntElem,
203 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205 dim, tensor, idx_array, value
206 ))
207 }
208 )
209 }
210
211 fn float_scatter_nd(
212 data: FloatTensor<Self>,
213 indices: NdArrayTensor,
214 values: FloatTensor<Self>,
215 reduction: burn_backend::tensor::IndexingUpdateOp,
216 ) -> FloatTensor<Self> {
217 execute_with_int_dtype!(
218 indices,
219 IntElem,
220 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221 execute_with_float_dtype!((data, values), |data, values| NdArrayOps::scatter_nd(
222 data, idx_array, values, reduction
223 ))
224 }
225 )
226 }
227
228 fn float_gather_nd(data: FloatTensor<Self>, indices: NdArrayTensor) -> FloatTensor<Self> {
229 execute_with_int_dtype!(
230 indices,
231 IntElem,
232 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
233 execute_with_float_dtype!(data, FloatElem, |array: SharedArray<FloatElem>| {
234 NdArrayOps::gather_nd(array, idx_array)
235 })
236 }
237 )
238 }
239
240 fn float_select(
241 tensor: FloatTensor<Self>,
242 dim: usize,
243 indices: NdArrayTensor,
244 ) -> FloatTensor<Self> {
245 execute_with_int_dtype!(
246 indices,
247 IntElem,
248 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
249 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
250 NdArrayMathOps::select(array, dim, idx_array)
251 })
252 }
253 )
254 }
255
256 fn float_select_add(
257 tensor: FloatTensor<Self>,
258 dim: usize,
259 indices: NdArrayTensor,
260 value: FloatTensor<Self>,
261 ) -> FloatTensor<Self> {
262 execute_with_int_dtype!(
263 indices,
264 IntElem,
265 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
266 execute_with_float_dtype!((tensor, value), |tensor, value| {
267 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
268 })
269 }
270 )
271 }
272
273 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
274 slice!(tensor, slices)
275 }
276
277 fn float_slice_assign(
278 tensor: FloatTensor<Self>,
279 slices: &[burn_backend::Slice],
280 value: FloatTensor<Self>,
281 ) -> FloatTensor<Self> {
282 execute_with_float_dtype!((tensor, value), |tensor, value| {
283 NdArrayOps::slice_assign(tensor, slices, value)
284 })
285 }
286
287 fn float_mask_where(
288 tensor: FloatTensor<Self>,
289 mask: NdArrayTensor,
290 value: FloatTensor<Self>,
291 ) -> FloatTensor<Self> {
292 execute_with_float_dtype!((tensor, value), |tensor, value| {
293 NdArrayOps::mask_where(tensor, mask.bool(), value)
294 })
295 }
296
297 fn float_mask_fill(
298 tensor: FloatTensor<Self>,
299 mask: NdArrayTensor,
300 value: Scalar,
301 ) -> FloatTensor<Self> {
302 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
303 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
304 })
305 }
306
307 fn float_equal(
308 lhs: FloatTensor<Self>,
309 rhs: FloatTensor<Self>,
310 _out_dtype: BoolDType,
311 ) -> NdArrayTensor {
312 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
313 }
314
315 fn float_equal_elem(
316 lhs: FloatTensor<Self>,
317 rhs: Scalar,
318 _out_dtype: BoolDType,
319 ) -> NdArrayTensor {
320 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
321 NdArrayMathOps::equal_elem(array, rhs.elem())
322 })
323 }
324
325 fn float_greater(
326 lhs: FloatTensor<Self>,
327 rhs: FloatTensor<Self>,
328 _out_dtype: BoolDType,
329 ) -> NdArrayTensor {
330 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
331 }
332
333 fn float_greater_elem(
334 lhs: FloatTensor<Self>,
335 rhs: Scalar,
336 _out_dtype: BoolDType,
337 ) -> NdArrayTensor {
338 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
339 NdArrayMathOps::greater_elem(array, rhs.elem())
340 })
341 }
342
343 fn float_greater_equal(
344 lhs: FloatTensor<Self>,
345 rhs: FloatTensor<Self>,
346 _out_dtype: BoolDType,
347 ) -> NdArrayTensor {
348 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
349 NdArrayMathOps::greater_equal(lhs, rhs)
350 })
351 }
352
353 fn float_greater_equal_elem(
354 lhs: FloatTensor<Self>,
355 rhs: Scalar,
356 _out_dtype: BoolDType,
357 ) -> NdArrayTensor {
358 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
359 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
360 })
361 }
362
363 fn float_lower(
364 lhs: FloatTensor<Self>,
365 rhs: FloatTensor<Self>,
366 _out_dtype: BoolDType,
367 ) -> NdArrayTensor {
368 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
369 }
370
371 fn float_lower_elem(
372 lhs: FloatTensor<Self>,
373 rhs: Scalar,
374 _out_dtype: BoolDType,
375 ) -> NdArrayTensor {
376 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
377 NdArrayMathOps::lower_elem(array, rhs.elem())
378 })
379 }
380
381 fn float_lower_equal(
382 lhs: FloatTensor<Self>,
383 rhs: FloatTensor<Self>,
384 _out_dtype: BoolDType,
385 ) -> NdArrayTensor {
386 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
387 NdArrayMathOps::lower_equal(lhs, rhs)
388 })
389 }
390
391 fn float_lower_equal_elem(
392 lhs: FloatTensor<Self>,
393 rhs: Scalar,
394 _out_dtype: BoolDType,
395 ) -> NdArrayTensor {
396 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
397 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
398 })
399 }
400
401 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402 tensor
403 }
404
405 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408 NdArrayMathOps::mean_view(array.view())
409 })
410 }
411
412 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415 NdArrayMathOps::sum_view(array.view())
416 })
417 }
418
419 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
421 NdArrayMathOps::mean_dim(array, dim)
422 })
423 }
424
425 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
427 NdArrayMathOps::cumsum(array, dim)
428 })
429 }
430
431 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
432 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
433 NdArrayMathOps::cumprod(array, dim)
434 })
435 }
436
437 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
438 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439 NdArrayMathOps::cummin(array, dim)
440 })
441 }
442
443 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
444 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
445 NdArrayMathOps::cummax(array, dim)
446 })
447 }
448
449 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
450 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
451 NdArrayMathOps::sum_dim(array, dim)
452 })
453 }
454
455 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
456 execute_with_int_out_dtype!(out_dtype, I, {
458 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
460 })
461 })
462 }
463
464 fn float_argtopk(
465 _tensor: FloatTensor<Self>,
466 _dim: usize,
467 _k: usize,
468 _out_dtype: IntDType,
469 ) -> NdArrayTensor {
470 unimplemented!("float_argtopk not implemented for ndarray")
471 }
472
473 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
474 execute_with_int_out_dtype!(out_dtype, I, {
476 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
477 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
478 })
479 })
480 }
481
482 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
483 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
484 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
485 })
486 }
487
488 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
489 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
490 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
491 })
492 }
493
494 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
497 NdArrayMathOps::prod_view(array.view())
498 })
499 }
500
501 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
502 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
503 NdArrayMathOps::prod_dim(array, dim)
504 })
505 }
506
507 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
508 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
510 NdArrayMathOps::max_view(array.view())
511 })
512 }
513
514 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
515 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
517 NdArrayMathOps::min_view(array.view())
518 })
519 }
520
521 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
522 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
523 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
524 })
525 }
526
527 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
528 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
529 array
530 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
531 .into_shared()
532 })
533 }
534
535 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
536 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
537 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
538 })
539 }
540
541 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
542 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
543 NdArrayMathOps::abs(array)
544 })
545 }
546
547 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
548 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
549 array
550 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
551 .into_shared()
552 })
553 }
554
555 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
556 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
557 array
558 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
559 .into_shared()
560 })
561 }
562
563 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
564 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
565 array
566 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
567 .into_shared()
568 })
569 }
570
571 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
572 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
573 array
574 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
575 .into_shared()
576 })
577 }
578
579 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
580 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
581 array
582 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
583 .into_shared()
584 })
585 }
586
587 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
588 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
589 array
590 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
591 .into_shared()
592 })
593 }
594
595 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
596 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
597 array
598 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
599 .into_shared()
600 })
601 }
602
603 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
604 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
605 array
606 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
607 .into_shared()
608 })
609 }
610
611 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
613 array
614 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
615 .into_shared()
616 })
617 }
618
619 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
620 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
621 array
622 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
623 .into_shared()
624 })
625 }
626
627 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
628 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
629 array
630 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
631 .into_shared()
632 })
633 }
634
635 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
636 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
637 array
638 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
639 .into_shared()
640 })
641 }
642
643 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
644 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
645 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
646 })
647 }
648
649 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
650 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
651 array
652 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
653 .into_shared()
654 })
655 }
656
657 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
658 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
659 array
660 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
661 .into_shared()
662 })
663 }
664
665 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
666 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
667 array
668 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
669 .into_shared()
670 })
671 }
672
673 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
674 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
675 array
676 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
677 .into_shared()
678 })
679 }
680
681 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
682 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
683 array
684 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
685 .into_shared()
686 })
687 }
688
689 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
690 cat_with_dtype!(tensors, dim, [F64, F32])
691 }
692
693 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
694 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
695 NdArrayMathOps::clamp_min(array, min.elem())
696 })
697 }
698
699 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
700 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
701 NdArrayMathOps::clamp_max(array, max.elem())
702 })
703 }
704
705 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
706 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
707 NdArrayMathOps::clamp(array, min.elem(), max.elem())
708 })
709 }
710
711 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
712 execute_with_int_out_dtype!(out_dtype, I, {
713 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
714 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
715 })
716 })
717 }
718
719 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
720 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
721 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
722 })
723 }
724
725 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
726 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
727 NdArrayOps::permute(array, axes)
728 })
729 }
730
731 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
732 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
733 NdArrayOps::flip(array, axes)
734 })
735 }
736
737 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
738 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
739 NdArrayMathOps::sign_op(array)
740 })
741 }
742
743 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
744 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
745 NdArrayOps::expand(array, shape)
746 })
747 }
748
749 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
750 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
751 cast_to_dtype(array, dtype.into())
752 })
753 }
754
755 fn float_grid_sample_2d(
756 tensor: FloatTensor<Self>,
757 grid: FloatTensor<Self>,
758 options: GridSampleOptions,
759 ) -> FloatTensor<Self> {
760 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
761 tensor, grid, options
762 ))
763 }
764
765 fn float_unfold(
766 tensor: FloatTensor<Self>,
767 dim: usize,
768 size: usize,
769 step: usize,
770 ) -> FloatTensor<Self> {
771 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
772 NdArrayOps::unfold(array, dim, size, step)
773 })
774 }
775}