1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9use super::{
11 NdArrayMathOps, NdArrayOps,
12 matmul::{cross, matmul},
13};
14use crate::{
15 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19 SharedArray,
20 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38 x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44 if (x - x.floor()) == 0.5 {
45 (x * 0.5).round() * 2.0
46 } else {
47 x.round()
48 }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52 for NdArray<E, I, Q>
53where
54 NdArrayTensor: From<SharedArray<E>>,
55 NdArrayTensor: From<SharedArray<I>>,
56{
57 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58 NdArrayTensor::from_data(data)
59 }
60
61 fn float_random(
62 shape: Shape,
63 distribution: Distribution,
64 device: &NdArrayDevice,
65 dtype: FloatDType,
66 ) -> FloatTensor<Self> {
67 let mut seed = SEED.lock().unwrap();
68 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69 let tensor = execute_with_float_out_dtype!(
70 dtype,
71 E,
72 Self::float_from_data(
73 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74 device,
75 )
76 );
77
78 *seed = Some(rng);
79 tensor
80 }
81
82 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83 Ok(tensor.into_data())
84 }
85
86 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87 NdArrayDevice::Cpu
88 }
89
90 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91 tensor
92 }
93
94 fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95 Self::float_zeros(shape, device, dtype)
96 }
97
98 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100 }
101
102 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104 NdArrayMathOps::add_scalar(array, rhs.elem())
105 })
106 }
107
108 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110 }
111
112 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114 NdArrayMathOps::sub_scalar(array, rhs.elem())
115 })
116 }
117
118 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120 }
121
122 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124 NdArrayMathOps::mul_scalar(array, rhs.elem())
125 })
126 }
127
128 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130 }
131
132 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134 NdArrayMathOps::div_scalar(array, rhs.elem())
135 })
136 }
137
138 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140 }
141
142 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144 NdArrayMathOps::remainder_scalar(array, rhs.elem())
145 })
146 }
147
148 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149 execute_with_float_dtype!((lhs, rhs), matmul)
150 }
151
152 fn float_cross(
153 lhs: FloatTensor<Self>,
154 rhs: FloatTensor<Self>,
155 dim: usize,
156 ) -> FloatTensor<Self> {
157 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158 }
159
160 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162 NdArrayMathOps::recip(array)
163 })
164 }
165
166 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168 NdArrayOps::swap_dims(array, dim1, dim2)
169 })
170 }
171
172 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174 NdArrayOps::reshape(array, shape)
175 })
176 }
177
178 fn float_gather(
179 dim: usize,
180 tensor: FloatTensor<Self>,
181 indices: NdArrayTensor,
182 ) -> FloatTensor<Self> {
183 execute_with_int_dtype!(
184 indices,
185 IntElem,
186 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188 NdArrayOps::gather(dim, array, idx_array)
189 })
190 }
191 )
192 }
193
194 fn float_scatter_add(
195 dim: usize,
196 tensor: FloatTensor<Self>,
197 indices: NdArrayTensor,
198 value: FloatTensor<Self>,
199 ) -> FloatTensor<Self> {
200 execute_with_int_dtype!(
201 indices,
202 IntElem,
203 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205 dim, tensor, idx_array, value
206 ))
207 }
208 )
209 }
210
211 fn float_scatter_nd(
212 data: FloatTensor<Self>,
213 indices: NdArrayTensor,
214 values: FloatTensor<Self>,
215 reduction: burn_backend::tensor::IndexingUpdateOp,
216 ) -> FloatTensor<Self> {
217 execute_with_int_dtype!(
218 indices,
219 IntElem,
220 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
221 execute_with_float_dtype!((data, values), |data, values| NdArrayOps::scatter_nd(
222 data, idx_array, values, reduction
223 ))
224 }
225 )
226 }
227
228 fn float_gather_nd(data: FloatTensor<Self>, indices: NdArrayTensor) -> FloatTensor<Self> {
229 execute_with_int_dtype!(
230 indices,
231 IntElem,
232 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
233 execute_with_float_dtype!(data, FloatElem, |array: SharedArray<FloatElem>| {
234 NdArrayOps::gather_nd(array, idx_array)
235 })
236 }
237 )
238 }
239
240 fn float_select(
241 tensor: FloatTensor<Self>,
242 dim: usize,
243 indices: NdArrayTensor,
244 ) -> FloatTensor<Self> {
245 execute_with_int_dtype!(
246 indices,
247 IntElem,
248 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
249 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
250 NdArrayMathOps::select(array, dim, idx_array)
251 })
252 }
253 )
254 }
255
256 fn float_select_add(
257 tensor: FloatTensor<Self>,
258 dim: usize,
259 indices: NdArrayTensor,
260 value: FloatTensor<Self>,
261 ) -> FloatTensor<Self> {
262 execute_with_int_dtype!(
263 indices,
264 IntElem,
265 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
266 execute_with_float_dtype!((tensor, value), |tensor, value| {
267 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
268 })
269 }
270 )
271 }
272
273 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
274 slice!(tensor, slices)
275 }
276
277 fn float_slice_assign(
278 tensor: FloatTensor<Self>,
279 slices: &[burn_backend::Slice],
280 value: FloatTensor<Self>,
281 ) -> FloatTensor<Self> {
282 execute_with_float_dtype!((tensor, value), |tensor, value| {
283 NdArrayOps::slice_assign(tensor, slices, value)
284 })
285 }
286
287 fn float_mask_where(
288 tensor: FloatTensor<Self>,
289 mask: NdArrayTensor,
290 value: FloatTensor<Self>,
291 ) -> FloatTensor<Self> {
292 execute_with_float_dtype!((tensor, value), |tensor, value| {
293 NdArrayOps::mask_where(tensor, mask.bool(), value)
294 })
295 }
296
297 fn float_mask_fill(
298 tensor: FloatTensor<Self>,
299 mask: NdArrayTensor,
300 value: Scalar,
301 ) -> FloatTensor<Self> {
302 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
303 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
304 })
305 }
306
307 fn float_equal(
308 lhs: FloatTensor<Self>,
309 rhs: FloatTensor<Self>,
310 _out_dtype: BoolDType,
311 ) -> NdArrayTensor {
312 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
313 }
314
315 fn float_equal_elem(
316 lhs: FloatTensor<Self>,
317 rhs: Scalar,
318 _out_dtype: BoolDType,
319 ) -> NdArrayTensor {
320 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
321 NdArrayMathOps::equal_elem(array, rhs.elem())
322 })
323 }
324
325 fn float_greater(
326 lhs: FloatTensor<Self>,
327 rhs: FloatTensor<Self>,
328 _out_dtype: BoolDType,
329 ) -> NdArrayTensor {
330 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
331 }
332
333 fn float_greater_elem(
334 lhs: FloatTensor<Self>,
335 rhs: Scalar,
336 _out_dtype: BoolDType,
337 ) -> NdArrayTensor {
338 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
339 NdArrayMathOps::greater_elem(array, rhs.elem())
340 })
341 }
342
343 fn float_greater_equal(
344 lhs: FloatTensor<Self>,
345 rhs: FloatTensor<Self>,
346 _out_dtype: BoolDType,
347 ) -> NdArrayTensor {
348 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
349 NdArrayMathOps::greater_equal(lhs, rhs)
350 })
351 }
352
353 fn float_greater_equal_elem(
354 lhs: FloatTensor<Self>,
355 rhs: Scalar,
356 _out_dtype: BoolDType,
357 ) -> NdArrayTensor {
358 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
359 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
360 })
361 }
362
363 fn float_lower(
364 lhs: FloatTensor<Self>,
365 rhs: FloatTensor<Self>,
366 _out_dtype: BoolDType,
367 ) -> NdArrayTensor {
368 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
369 }
370
371 fn float_lower_elem(
372 lhs: FloatTensor<Self>,
373 rhs: Scalar,
374 _out_dtype: BoolDType,
375 ) -> NdArrayTensor {
376 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
377 NdArrayMathOps::lower_elem(array, rhs.elem())
378 })
379 }
380
381 fn float_lower_equal(
382 lhs: FloatTensor<Self>,
383 rhs: FloatTensor<Self>,
384 _out_dtype: BoolDType,
385 ) -> NdArrayTensor {
386 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
387 NdArrayMathOps::lower_equal(lhs, rhs)
388 })
389 }
390
391 fn float_lower_equal_elem(
392 lhs: FloatTensor<Self>,
393 rhs: Scalar,
394 _out_dtype: BoolDType,
395 ) -> NdArrayTensor {
396 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
397 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
398 })
399 }
400
401 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402 tensor
403 }
404
405 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408 NdArrayMathOps::mean_view(array.view())
409 })
410 }
411
412 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
413 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
415 NdArrayMathOps::sum_view(array.view())
416 })
417 }
418
419 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
420 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
421 NdArrayMathOps::mean_dim(array, dim)
422 })
423 }
424
425 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
426 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
427 NdArrayMathOps::cumsum(array, dim)
428 })
429 }
430
431 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
432 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
433 NdArrayMathOps::cumprod(array, dim)
434 })
435 }
436
437 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
438 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439 NdArrayMathOps::cummin(array, dim)
440 })
441 }
442
443 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
444 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
445 NdArrayMathOps::cummax(array, dim)
446 })
447 }
448
449 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
450 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
451 NdArrayMathOps::sum_dim(array, dim)
452 })
453 }
454
455 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
456 execute_with_int_out_dtype!(out_dtype, I, {
458 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
460 })
461 })
462 }
463
464 fn float_argtopk(
465 _tensor: FloatTensor<Self>,
466 _dim: usize,
467 _k: usize,
468 _out_dtype: IntDType,
469 ) -> NdArrayTensor {
470 unimplemented!("float_argtopk not implemented for ndarray")
471 }
472
473 fn float_topk(_tensor: FloatTensor<Self>, _dim: usize, _k: usize) -> NdArrayTensor {
474 unimplemented!("float_topk not implemented for ndarray")
475 }
476
477 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
478 execute_with_int_out_dtype!(out_dtype, I, {
480 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
481 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
482 })
483 })
484 }
485
486 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
487 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
488 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
489 })
490 }
491
492 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
493 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
494 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
495 })
496 }
497
498 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
499 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
501 NdArrayMathOps::prod_view(array.view())
502 })
503 }
504
505 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
506 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
507 NdArrayMathOps::prod_dim(array, dim)
508 })
509 }
510
511 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
512 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
514 NdArrayMathOps::max_view(array.view())
515 })
516 }
517
518 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
521 NdArrayMathOps::min_view(array.view())
522 })
523 }
524
525 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
526 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
527 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
528 })
529 }
530
531 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
532 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
533 array
534 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
535 .into_shared()
536 })
537 }
538
539 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
540 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
541 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
542 })
543 }
544
545 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
546 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
547 NdArrayMathOps::abs(array)
548 })
549 }
550
551 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
552 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
553 array
554 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
555 .into_shared()
556 })
557 }
558
559 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
560 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
561 array
562 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
563 .into_shared()
564 })
565 }
566
567 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
568 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
569 array
570 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
571 .into_shared()
572 })
573 }
574
575 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
576 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
577 array
578 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
579 .into_shared()
580 })
581 }
582
583 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
584 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
585 array
586 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
587 .into_shared()
588 })
589 }
590
591 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
592 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
593 array
594 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
595 .into_shared()
596 })
597 }
598
599 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
600 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
601 array
602 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
603 .into_shared()
604 })
605 }
606
607 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
608 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
609 array
610 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
611 .into_shared()
612 })
613 }
614
615 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
616 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617 array
618 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
619 .into_shared()
620 })
621 }
622
623 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
624 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
625 array
626 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
627 .into_shared()
628 })
629 }
630
631 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
632 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
633 array
634 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
635 .into_shared()
636 })
637 }
638
639 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
640 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
641 array
642 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
643 .into_shared()
644 })
645 }
646
647 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
648 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
649 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
650 })
651 }
652
653 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
654 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
655 array
656 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
657 .into_shared()
658 })
659 }
660
661 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
662 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
663 array
664 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
665 .into_shared()
666 })
667 }
668
669 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
670 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
671 array
672 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
673 .into_shared()
674 })
675 }
676
677 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
678 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
679 array
680 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
681 .into_shared()
682 })
683 }
684
685 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
686 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
687 array
688 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
689 .into_shared()
690 })
691 }
692
693 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
694 cat_with_dtype!(tensors, dim, [F64, F32])
695 }
696
697 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
698 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
699 NdArrayMathOps::clamp_min(array, min.elem())
700 })
701 }
702
703 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
704 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
705 NdArrayMathOps::clamp_max(array, max.elem())
706 })
707 }
708
709 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
710 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
711 NdArrayMathOps::clamp(array, min.elem(), max.elem())
712 })
713 }
714
715 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
716 execute_with_int_out_dtype!(out_dtype, I, {
717 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
718 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
719 })
720 })
721 }
722
723 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
724 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
725 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
726 })
727 }
728
729 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
730 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
731 NdArrayOps::permute(array, axes)
732 })
733 }
734
735 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
736 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
737 NdArrayOps::flip(array, axes)
738 })
739 }
740
741 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
742 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
743 NdArrayMathOps::sign_op(array)
744 })
745 }
746
747 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
748 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
749 NdArrayOps::expand(array, shape)
750 })
751 }
752
753 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
754 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
755 cast_to_dtype(array, dtype.into())
756 })
757 }
758
759 fn float_grid_sample_2d(
760 tensor: FloatTensor<Self>,
761 grid: FloatTensor<Self>,
762 options: GridSampleOptions,
763 ) -> FloatTensor<Self> {
764 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
765 tensor, grid, options
766 ))
767 }
768
769 fn float_unfold(
770 tensor: FloatTensor<Self>,
771 dim: usize,
772 size: usize,
773 step: usize,
774 ) -> FloatTensor<Self> {
775 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
776 NdArrayOps::unfold(array, dim, size, step)
777 })
778 }
779}