1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7use burn_std::{BoolDType, IntDType};
8
9use super::{
11 NdArrayMathOps, NdArrayOps,
12 matmul::{cross, matmul},
13};
14use crate::{
15 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
16};
17use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice};
18use crate::{
19 SharedArray,
20 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
21};
22use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
23
24use crate::rand::get_seeded_rng;
26use burn_backend::{Distribution, FloatDType, Scalar};
27use burn_backend::{ElementConversion, Shape, TensorData, ops::FloatTensorOps};
28
29#[cfg(not(feature = "std"))]
30#[allow(unused_imports)]
31use num_traits::Float;
32
33use libm::erf;
34
35#[cfg(feature = "std")]
36#[allow(dead_code)]
37fn round_ties_even_wrapper(x: f64) -> f64 {
38 x.round_ties_even()
39}
40
41#[cfg(not(feature = "std"))]
42#[allow(dead_code)]
43fn round_ties_even_wrapper(x: f64) -> f64 {
44 if (x - x.floor()) == 0.5 {
45 (x * 0.5).round() * 2.0
46 } else {
47 x.round()
48 }
49}
50
51impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
52 for NdArray<E, I, Q>
53where
54 NdArrayTensor: From<SharedArray<E>>,
55 NdArrayTensor: From<SharedArray<I>>,
56{
57 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
58 NdArrayTensor::from_data(data)
59 }
60
61 fn float_random(
62 shape: Shape,
63 distribution: Distribution,
64 device: &NdArrayDevice,
65 dtype: FloatDType,
66 ) -> FloatTensor<Self> {
67 let mut seed = SEED.lock().unwrap();
68 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
69 let tensor = execute_with_float_out_dtype!(
70 dtype,
71 E,
72 Self::float_from_data(
73 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
74 device,
75 )
76 );
77
78 *seed = Some(rng);
79 tensor
80 }
81
82 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
83 Ok(tensor.into_data())
84 }
85
86 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
87 NdArrayDevice::Cpu
88 }
89
90 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
91 tensor
92 }
93
94 fn float_empty(shape: Shape, device: &NdArrayDevice, dtype: FloatDType) -> FloatTensor<Self> {
95 Self::float_zeros(shape, device, dtype)
96 }
97
98 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100 }
101
102 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
103 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
104 NdArrayMathOps::add_scalar(array, rhs.elem())
105 })
106 }
107
108 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
109 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
110 }
111
112 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
113 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
114 NdArrayMathOps::sub_scalar(array, rhs.elem())
115 })
116 }
117
118 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
120 }
121
122 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
123 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
124 NdArrayMathOps::mul_scalar(array, rhs.elem())
125 })
126 }
127
128 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
129 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
130 }
131
132 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
133 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
134 NdArrayMathOps::div_scalar(array, rhs.elem())
135 })
136 }
137
138 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
140 }
141
142 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
143 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
144 NdArrayMathOps::remainder_scalar(array, rhs.elem())
145 })
146 }
147
148 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
149 execute_with_float_dtype!((lhs, rhs), matmul)
150 }
151
152 fn float_cross(
153 lhs: FloatTensor<Self>,
154 rhs: FloatTensor<Self>,
155 dim: usize,
156 ) -> FloatTensor<Self> {
157 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
158 }
159
160 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
161 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
162 NdArrayMathOps::recip(array)
163 })
164 }
165
166 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
167 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
168 NdArrayOps::swap_dims(array, dim1, dim2)
169 })
170 }
171
172 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
173 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
174 NdArrayOps::reshape(array, shape)
175 })
176 }
177
178 fn float_gather(
179 dim: usize,
180 tensor: FloatTensor<Self>,
181 indices: NdArrayTensor,
182 ) -> FloatTensor<Self> {
183 execute_with_int_dtype!(
184 indices,
185 IntElem,
186 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
187 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
188 NdArrayOps::gather(dim, array, idx_array)
189 })
190 }
191 )
192 }
193
194 fn float_scatter_add(
195 dim: usize,
196 tensor: FloatTensor<Self>,
197 indices: NdArrayTensor,
198 value: FloatTensor<Self>,
199 ) -> FloatTensor<Self> {
200 execute_with_int_dtype!(
201 indices,
202 IntElem,
203 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
204 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
205 dim, tensor, idx_array, value
206 ))
207 }
208 )
209 }
210
211 fn float_select(
212 tensor: FloatTensor<Self>,
213 dim: usize,
214 indices: NdArrayTensor,
215 ) -> FloatTensor<Self> {
216 execute_with_int_dtype!(
217 indices,
218 IntElem,
219 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
220 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
221 NdArrayMathOps::select(array, dim, idx_array)
222 })
223 }
224 )
225 }
226
227 fn float_select_add(
228 tensor: FloatTensor<Self>,
229 dim: usize,
230 indices: NdArrayTensor,
231 value: FloatTensor<Self>,
232 ) -> FloatTensor<Self> {
233 execute_with_int_dtype!(
234 indices,
235 IntElem,
236 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
237 execute_with_float_dtype!((tensor, value), |tensor, value| {
238 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
239 })
240 }
241 )
242 }
243
244 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
245 slice!(tensor, slices)
246 }
247
248 fn float_slice_assign(
249 tensor: FloatTensor<Self>,
250 slices: &[burn_backend::Slice],
251 value: FloatTensor<Self>,
252 ) -> FloatTensor<Self> {
253 execute_with_float_dtype!((tensor, value), |tensor, value| {
254 NdArrayOps::slice_assign(tensor, slices, value)
255 })
256 }
257
258 fn float_mask_where(
259 tensor: FloatTensor<Self>,
260 mask: NdArrayTensor,
261 value: FloatTensor<Self>,
262 ) -> FloatTensor<Self> {
263 execute_with_float_dtype!((tensor, value), |tensor, value| {
264 NdArrayOps::mask_where(tensor, mask.bool(), value)
265 })
266 }
267
268 fn float_mask_fill(
269 tensor: FloatTensor<Self>,
270 mask: NdArrayTensor,
271 value: Scalar,
272 ) -> FloatTensor<Self> {
273 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
274 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
275 })
276 }
277
278 fn float_equal(
279 lhs: FloatTensor<Self>,
280 rhs: FloatTensor<Self>,
281 _out_dtype: BoolDType,
282 ) -> NdArrayTensor {
283 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
284 }
285
286 fn float_equal_elem(
287 lhs: FloatTensor<Self>,
288 rhs: Scalar,
289 _out_dtype: BoolDType,
290 ) -> NdArrayTensor {
291 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
292 NdArrayMathOps::equal_elem(array, rhs.elem())
293 })
294 }
295
296 fn float_greater(
297 lhs: FloatTensor<Self>,
298 rhs: FloatTensor<Self>,
299 _out_dtype: BoolDType,
300 ) -> NdArrayTensor {
301 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
302 }
303
304 fn float_greater_elem(
305 lhs: FloatTensor<Self>,
306 rhs: Scalar,
307 _out_dtype: BoolDType,
308 ) -> NdArrayTensor {
309 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
310 NdArrayMathOps::greater_elem(array, rhs.elem())
311 })
312 }
313
314 fn float_greater_equal(
315 lhs: FloatTensor<Self>,
316 rhs: FloatTensor<Self>,
317 _out_dtype: BoolDType,
318 ) -> NdArrayTensor {
319 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
320 NdArrayMathOps::greater_equal(lhs, rhs)
321 })
322 }
323
324 fn float_greater_equal_elem(
325 lhs: FloatTensor<Self>,
326 rhs: Scalar,
327 _out_dtype: BoolDType,
328 ) -> NdArrayTensor {
329 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
330 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
331 })
332 }
333
334 fn float_lower(
335 lhs: FloatTensor<Self>,
336 rhs: FloatTensor<Self>,
337 _out_dtype: BoolDType,
338 ) -> NdArrayTensor {
339 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
340 }
341
342 fn float_lower_elem(
343 lhs: FloatTensor<Self>,
344 rhs: Scalar,
345 _out_dtype: BoolDType,
346 ) -> NdArrayTensor {
347 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
348 NdArrayMathOps::lower_elem(array, rhs.elem())
349 })
350 }
351
352 fn float_lower_equal(
353 lhs: FloatTensor<Self>,
354 rhs: FloatTensor<Self>,
355 _out_dtype: BoolDType,
356 ) -> NdArrayTensor {
357 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
358 NdArrayMathOps::lower_equal(lhs, rhs)
359 })
360 }
361
362 fn float_lower_equal_elem(
363 lhs: FloatTensor<Self>,
364 rhs: Scalar,
365 _out_dtype: BoolDType,
366 ) -> NdArrayTensor {
367 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
368 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
369 })
370 }
371
372 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
373 tensor
374 }
375
376 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
377 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
379 NdArrayMathOps::mean_view(array.view())
380 })
381 }
382
383 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
386 NdArrayMathOps::sum_view(array.view())
387 })
388 }
389
390 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
391 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
392 NdArrayMathOps::mean_dim(array, dim)
393 })
394 }
395
396 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
397 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
398 NdArrayMathOps::cumsum(array, dim)
399 })
400 }
401
402 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
403 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
404 NdArrayMathOps::cumprod(array, dim)
405 })
406 }
407
408 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
409 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
410 NdArrayMathOps::cummin(array, dim)
411 })
412 }
413
414 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
415 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
416 NdArrayMathOps::cummax(array, dim)
417 })
418 }
419
420 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
421 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
422 NdArrayMathOps::sum_dim(array, dim)
423 })
424 }
425
426 fn float_argmax(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
427 execute_with_int_out_dtype!(out_dtype, I, {
429 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
430 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
431 })
432 })
433 }
434
435 fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
436 execute_with_int_out_dtype!(out_dtype, I, {
438 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
439 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
440 })
441 })
442 }
443
444 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
445 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
446 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
447 })
448 }
449
450 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
452 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
453 })
454 }
455
456 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
457 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
459 NdArrayMathOps::prod_view(array.view())
460 })
461 }
462
463 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
464 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
465 NdArrayMathOps::prod_dim(array, dim)
466 })
467 }
468
469 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
470 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
472 NdArrayMathOps::max_view(array.view())
473 })
474 }
475
476 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
477 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
479 NdArrayMathOps::min_view(array.view())
480 })
481 }
482
483 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
484 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
485 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
486 })
487 }
488
489 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
490 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
491 array
492 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
493 .into_shared()
494 })
495 }
496
497 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
498 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
499 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
500 })
501 }
502
503 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
504 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
505 NdArrayMathOps::abs(array)
506 })
507 }
508
509 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
510 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
511 array
512 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
513 .into_shared()
514 })
515 }
516
517 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
518 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
519 array
520 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
521 .into_shared()
522 })
523 }
524
525 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
526 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
527 array
528 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
529 .into_shared()
530 })
531 }
532
533 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
534 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
535 array
536 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
537 .into_shared()
538 })
539 }
540
541 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
542 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
543 array
544 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
545 .into_shared()
546 })
547 }
548
549 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
550 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
551 array
552 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
553 .into_shared()
554 })
555 }
556
557 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
558 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
559 array
560 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
561 .into_shared()
562 })
563 }
564
565 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
566 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
567 array
568 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
569 .into_shared()
570 })
571 }
572
573 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
574 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
575 array
576 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
577 .into_shared()
578 })
579 }
580
581 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
582 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
583 array
584 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
585 .into_shared()
586 })
587 }
588
589 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
590 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
591 array
592 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
593 .into_shared()
594 })
595 }
596
597 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
598 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
599 array
600 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
601 .into_shared()
602 })
603 }
604
605 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
606 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
607 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
608 })
609 }
610
611 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
612 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
613 array
614 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
615 .into_shared()
616 })
617 }
618
619 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
620 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
621 array
622 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
623 .into_shared()
624 })
625 }
626
627 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
628 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
629 array
630 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
631 .into_shared()
632 })
633 }
634
635 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
636 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
637 array
638 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
639 .into_shared()
640 })
641 }
642
643 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
644 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
645 array
646 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
647 .into_shared()
648 })
649 }
650
651 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
652 cat_with_dtype!(tensors, dim, [F64, F32])
653 }
654
655 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
656 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
657 NdArrayMathOps::clamp_min(array, min.elem())
658 })
659 }
660
661 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
662 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
663 NdArrayMathOps::clamp_max(array, max.elem())
664 })
665 }
666
667 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
668 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
669 NdArrayMathOps::clamp(array, min.elem(), max.elem())
670 })
671 }
672
673 fn float_into_int(tensor: FloatTensor<Self>, out_dtype: IntDType) -> NdArrayTensor {
674 execute_with_int_out_dtype!(out_dtype, I, {
675 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
676 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
677 })
678 })
679 }
680
681 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
682 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
683 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
684 })
685 }
686
687 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
688 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
689 NdArrayOps::permute(array, axes)
690 })
691 }
692
693 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
694 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
695 NdArrayOps::flip(array, axes)
696 })
697 }
698
699 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
700 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
701 NdArrayMathOps::sign_op(array)
702 })
703 }
704
705 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
706 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
707 NdArrayOps::expand(array, shape)
708 })
709 }
710
711 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
712 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
713 cast_to_dtype(array, dtype.into())
714 })
715 }
716
717 fn float_grid_sample_2d(
718 tensor: FloatTensor<Self>,
719 grid: FloatTensor<Self>,
720 options: GridSampleOptions,
721 ) -> FloatTensor<Self> {
722 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
723 tensor, grid, options
724 ))
725 }
726
727 fn float_unfold(
728 tensor: FloatTensor<Self>,
729 dim: usize,
730 size: usize,
731 step: usize,
732 ) -> FloatTensor<Self> {
733 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
734 NdArrayOps::unfold(array, dim, size, step)
735 })
736 }
737}