1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9use super::{
11 NdArrayMathOps, NdArrayOps,
12 matmul::{cross, matmul},
13};
14use crate::{
15 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19 SharedArray,
20 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38 x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44 if (x - x.floor()) == 0.5 {
45 (x * 0.5).round() * 2.0
46 } else {
47 x.round()
48 }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52 for NdArray<E, I, Q>
53where
54 NdArrayTensor: From<SharedArray<E>>,
55 NdArrayTensor: From<SharedArray<I>>,
56{
57 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58 NdArrayTensor::from_data(data)
59 }
60
61 fn float_random(
62 shape: Shape,
63 distribution: Distribution,
64 device: &NdArrayDevice,
65 dtype: FloatDType,
66 ) -> FloatTensor<Self> {
67 let mut seed = SEED.lock().unwrap();
68 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69 let tensor = execute_with_float_out_dtype!(
70 dtype,
71 E,
72 Self::float_from_data(
73 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74 device,
75 )
76 );
77
78 *seed = Some(rng);
79 tensor
80 }
81
82 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83 Ok(tensor.into_data())
84 }
85
86 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87 NdArrayDevice::Cpu
88 }
89
90 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91 tensor
92 }
93
94 fn float_empty(
95 shape: Shape,
96 device: &<NdArray<E> as Backend>::Device,
97 dtype: FloatDType,
98 ) -> FloatTensor<Self> {
99 Self::float_zeros(shape, device, dtype)
100 }
101
102 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
103 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
104 }
105
106 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
107 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
108 NdArrayMathOps::add_scalar(array, rhs.elem())
109 })
110 }
111
112 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
113 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
114 }
115
116 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
117 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
118 NdArrayMathOps::sub_scalar(array, rhs.elem())
119 })
120 }
121
122 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
124 }
125
126 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
127 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
128 NdArrayMathOps::mul_scalar(array, rhs.elem())
129 })
130 }
131
132 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
133 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
134 }
135
136 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
137 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
138 NdArrayMathOps::div_scalar(array, rhs.elem())
139 })
140 }
141
142 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
143 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
144 }
145
146 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
147 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
148 NdArrayMathOps::remainder_scalar(array, rhs.elem())
149 })
150 }
151
152 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
153 execute_with_float_dtype!((lhs, rhs), matmul)
154 }
155
156 fn float_cross(
157 lhs: FloatTensor<Self>,
158 rhs: FloatTensor<Self>,
159 dim: usize,
160 ) -> FloatTensor<Self> {
161 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
162 }
163
164 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
165 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
166 NdArrayMathOps::recip(array)
167 })
168 }
169
170 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
171 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
172 NdArrayOps::swap_dims(array, dim1, dim2)
173 })
174 }
175
176 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
177 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
178 NdArrayOps::reshape(array, shape)
179 })
180 }
181
182 fn float_gather(
183 dim: usize,
184 tensor: FloatTensor<Self>,
185 indices: NdArrayTensor,
186 ) -> FloatTensor<Self> {
187 execute_with_int_dtype!(
188 indices,
189 IntElem,
190 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
191 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
192 NdArrayOps::gather(dim, array, idx_array)
193 })
194 }
195 )
196 }
197
198 fn float_scatter_add(
199 dim: usize,
200 tensor: FloatTensor<Self>,
201 indices: NdArrayTensor,
202 value: FloatTensor<Self>,
203 ) -> FloatTensor<Self> {
204 execute_with_int_dtype!(
205 indices,
206 IntElem,
207 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
208 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
209 dim, tensor, idx_array, value
210 ))
211 }
212 )
213 }
214
215 fn float_select(
216 tensor: FloatTensor<Self>,
217 dim: usize,
218 indices: NdArrayTensor,
219 ) -> FloatTensor<Self> {
220 execute_with_int_dtype!(
221 indices,
222 IntElem,
223 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
224 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
225 NdArrayMathOps::select(array, dim, idx_array)
226 })
227 }
228 )
229 }
230
231 fn float_select_add(
232 tensor: FloatTensor<Self>,
233 dim: usize,
234 indices: NdArrayTensor,
235 value: FloatTensor<Self>,
236 ) -> FloatTensor<Self> {
237 execute_with_int_dtype!(
238 indices,
239 IntElem,
240 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
241 execute_with_float_dtype!((tensor, value), |tensor, value| {
242 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
243 })
244 }
245 )
246 }
247
248 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
249 slice!(tensor, slices)
250 }
251
252 fn float_slice_assign(
253 tensor: FloatTensor<Self>,
254 slices: &[burn_backend::Slice],
255 value: FloatTensor<Self>,
256 ) -> FloatTensor<Self> {
257 execute_with_float_dtype!((tensor, value), |tensor, value| {
258 NdArrayOps::slice_assign(tensor, slices, value)
259 })
260 }
261
262 fn float_mask_where(
263 tensor: FloatTensor<Self>,
264 mask: NdArrayTensor,
265 value: FloatTensor<Self>,
266 ) -> FloatTensor<Self> {
267 execute_with_float_dtype!((tensor, value), |tensor, value| {
268 NdArrayOps::mask_where(tensor, mask.bool(), value)
269 })
270 }
271
272 fn float_mask_fill(
273 tensor: FloatTensor<Self>,
274 mask: NdArrayTensor,
275 value: Scalar,
276 ) -> FloatTensor<Self> {
277 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
278 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
279 })
280 }
281
282 fn float_equal(
283 lhs: FloatTensor<Self>,
284 rhs: FloatTensor<Self>,
285 _out_dtype: BoolDType,
286 ) -> NdArrayTensor {
287 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
288 }
289
290 fn float_equal_elem(
291 lhs: FloatTensor<Self>,
292 rhs: Scalar,
293 _out_dtype: BoolDType,
294 ) -> NdArrayTensor {
295 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
296 NdArrayMathOps::equal_elem(array, rhs.elem())
297 })
298 }
299
300 fn float_greater(
301 lhs: FloatTensor<Self>,
302 rhs: FloatTensor<Self>,
303 _out_dtype: BoolDType,
304 ) -> NdArrayTensor {
305 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
306 }
307
308 fn float_greater_elem(
309 lhs: FloatTensor<Self>,
310 rhs: Scalar,
311 _out_dtype: BoolDType,
312 ) -> NdArrayTensor {
313 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
314 NdArrayMathOps::greater_elem(array, rhs.elem())
315 })
316 }
317
318 fn float_greater_equal(
319 lhs: FloatTensor<Self>,
320 rhs: FloatTensor<Self>,
321 _out_dtype: BoolDType,
322 ) -> NdArrayTensor {
323 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
324 NdArrayMathOps::greater_equal(lhs, rhs)
325 })
326 }
327
328 fn float_greater_equal_elem(
329 lhs: FloatTensor<Self>,
330 rhs: Scalar,
331 _out_dtype: BoolDType,
332 ) -> NdArrayTensor {
333 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
334 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
335 })
336 }
337
338 fn float_lower(
339 lhs: FloatTensor<Self>,
340 rhs: FloatTensor<Self>,
341 _out_dtype: BoolDType,
342 ) -> NdArrayTensor {
343 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
344 }
345
346 fn float_lower_elem(
347 lhs: FloatTensor<Self>,
348 rhs: Scalar,
349 _out_dtype: BoolDType,
350 ) -> NdArrayTensor {
351 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
352 NdArrayMathOps::lower_elem(array, rhs.elem())
353 })
354 }
355
356 fn float_lower_equal(
357 lhs: FloatTensor<Self>,
358 rhs: FloatTensor<Self>,
359 _out_dtype: BoolDType,
360 ) -> NdArrayTensor {
361 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
362 NdArrayMathOps::lower_equal(lhs, rhs)
363 })
364 }
365
366 fn float_lower_equal_elem(
367 lhs: FloatTensor<Self>,
368 rhs: Scalar,
369 _out_dtype: BoolDType,
370 ) -> NdArrayTensor {
371 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
372 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
373 })
374 }
375
376 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377 tensor
378 }
379
380 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
381 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
383 NdArrayMathOps::mean_view(array.view())
384 })
385 }
386
387 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
388 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
390 NdArrayMathOps::sum_view(array.view())
391 })
392 }
393
394 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
395 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
396 NdArrayMathOps::mean_dim(array, dim)
397 })
398 }
399
400 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
401 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
402 NdArrayMathOps::cumsum(array, dim)
403 })
404 }
405
406 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
407 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
408 NdArrayMathOps::cumprod(array, dim)
409 })
410 }
411
412 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
413 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
414 NdArrayMathOps::cummin(array, dim)
415 })
416 }
417
418 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
419 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
420 NdArrayMathOps::cummax(array, dim)
421 })
422 }
423
424 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
425 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
426 NdArrayMathOps::sum_dim(array, dim)
427 })
428 }
429
430 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
431 execute_with_int_out_dtype!(out_dtype, I, {
433 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
434 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
435 })
436 })
437 }
438
439 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
440 execute_with_int_out_dtype!(out_dtype, I, {
442 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
443 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
444 })
445 })
446 }
447
448 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
449 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
450 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
451 })
452 }
453
454 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
455 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
456 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
457 })
458 }
459
460 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
463 NdArrayMathOps::prod_view(array.view())
464 })
465 }
466
467 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
468 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
469 NdArrayMathOps::prod_dim(array, dim)
470 })
471 }
472
473 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
474 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
476 NdArrayMathOps::max_view(array.view())
477 })
478 }
479
480 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
481 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
483 NdArrayMathOps::min_view(array.view())
484 })
485 }
486
487 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
488 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
489 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
490 })
491 }
492
493 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
494 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
495 array
496 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
497 .into_shared()
498 })
499 }
500
501 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
502 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
503 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
504 })
505 }
506
507 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
508 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
509 NdArrayMathOps::abs(array)
510 })
511 }
512
513 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
514 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
515 array
516 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
517 .into_shared()
518 })
519 }
520
521 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
522 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
523 array
524 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
525 .into_shared()
526 })
527 }
528
529 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
530 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
531 array
532 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
533 .into_shared()
534 })
535 }
536
537 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
538 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
539 array
540 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
541 .into_shared()
542 })
543 }
544
545 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
546 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
547 array
548 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
549 .into_shared()
550 })
551 }
552
553 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
554 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
555 array
556 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
557 .into_shared()
558 })
559 }
560
561 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
562 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
563 array
564 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
565 .into_shared()
566 })
567 }
568
569 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
570 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
571 array
572 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
573 .into_shared()
574 })
575 }
576
577 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
578 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
579 array
580 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
581 .into_shared()
582 })
583 }
584
585 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
586 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
587 array
588 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
589 .into_shared()
590 })
591 }
592
593 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
594 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
595 array
596 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
597 .into_shared()
598 })
599 }
600
601 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
602 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
603 array
604 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
605 .into_shared()
606 })
607 }
608
609 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
610 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
611 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
612 })
613 }
614
615 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
616 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
617 array
618 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
619 .into_shared()
620 })
621 }
622
623 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
624 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
625 array
626 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
627 .into_shared()
628 })
629 }
630
631 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
632 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
633 array
634 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
635 .into_shared()
636 })
637 }
638
639 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
640 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
641 array
642 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
643 .into_shared()
644 })
645 }
646
647 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
648 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
649 array
650 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
651 .into_shared()
652 })
653 }
654
655 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
656 cat_with_dtype!(tensors, dim, [F64, F32])
657 }
658
659 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
660 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
661 NdArrayMathOps::clamp_min(array, min.elem())
662 })
663 }
664
665 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
666 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
667 NdArrayMathOps::clamp_max(array, max.elem())
668 })
669 }
670
671 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
672 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
673 NdArrayMathOps::clamp(array, min.elem(), max.elem())
674 })
675 }
676
677 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
678 execute_with_int_out_dtype!(out_dtype, I, {
679 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
680 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
681 })
682 })
683 }
684
685 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
686 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
687 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
688 })
689 }
690
691 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
692 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
693 NdArrayOps::permute(array, axes)
694 })
695 }
696
697 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
698 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
699 NdArrayOps::flip(array, axes)
700 })
701 }
702
703 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
704 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
705 NdArrayMathOps::sign_op(array)
706 })
707 }
708
709 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
710 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
711 NdArrayOps::expand(array, shape)
712 })
713 }
714
715 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
716 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
717 cast_to_dtype(array, dtype.into())
718 })
719 }
720
721 fn float_grid_sample_2d(
722 tensor: FloatTensor<Self>,
723 grid: FloatTensor<Self>,
724 options: GridSampleOptions,
725 ) -> FloatTensor<Self> {
726 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
727 tensor, grid, options
728 ))
729 }
730
731 fn float_unfold(
732 tensor: FloatTensor<Self>,
733 dim: usize,
734 size: usize,
735 step: usize,
736 ) -> FloatTensor<Self> {
737 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
738 NdArrayOps::unfold(array, dim, size, step)
739 })
740 }
741}