1use alloc::vec::Vec;
3use burn_backend::backend::ExecutionError;
4use burn_backend::ops::GridSampleOptions;
5use burn_backend::tensor::FloatTensor;
6use burn_backend::{TensorMetadata, element::cast::ToElement};
7
8use super::{
10 NdArrayMathOps, NdArrayOps,
11 matmul::{cross, matmul},
12};
13use crate::{
14 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED, slice};
17use crate::{
18 SharedArray,
19 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23use crate::rand::get_seeded_rng;
25use burn_backend::{Distribution, FloatDType, Scalar};
26use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37 x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43 if (x - x.floor()) == 0.5 {
44 (x * 0.5).round() * 2.0
45 } else {
46 x.round()
47 }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51 for NdArray<E, I, Q>
52where
53 NdArrayTensor: From<SharedArray<E>>,
54 NdArrayTensor: From<SharedArray<I>>,
55{
56 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57 NdArrayTensor::from_data(data)
58 }
59
60 fn float_random(
61 shape: Shape,
62 distribution: Distribution,
63 device: &NdArrayDevice,
64 ) -> FloatTensor<Self> {
65 let mut seed = SEED.lock().unwrap();
66 let mut rng = seed.take().unwrap_or_else(get_seeded_rng);
67 let tensor = Self::float_from_data(
68 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
69 device,
70 );
71 *seed = Some(rng);
72 tensor
73 }
74
75 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
76 Ok(tensor.into_data())
77 }
78
79 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
80 NdArrayDevice::Cpu
81 }
82
83 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
84 tensor
85 }
86
87 fn float_empty(
88 shape: Shape,
89 device: &<NdArray<E> as Backend>::Device,
90 dtype: FloatDType,
91 ) -> FloatTensor<Self> {
92 Self::float_zeros(shape, device, dtype)
93 }
94
95 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
96 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
97 }
98
99 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
100 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
101 NdArrayMathOps::add_scalar(array, rhs.elem())
102 })
103 }
104
105 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
106 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
107 }
108
109 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
110 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
111 NdArrayMathOps::sub_scalar(array, rhs.elem())
112 })
113 }
114
115 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
116 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
117 }
118
119 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
120 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
121 NdArrayMathOps::mul_scalar(array, rhs.elem())
122 })
123 }
124
125 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
126 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
127 }
128
129 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
130 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
131 NdArrayMathOps::div_scalar(array, rhs.elem())
132 })
133 }
134
135 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
136 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
137 }
138
139 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {
140 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
141 NdArrayMathOps::remainder_scalar(array, rhs.elem())
142 })
143 }
144
145 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
146 execute_with_float_dtype!((lhs, rhs), matmul)
147 }
148
149 fn float_cross(
150 lhs: FloatTensor<Self>,
151 rhs: FloatTensor<Self>,
152 dim: usize,
153 ) -> FloatTensor<Self> {
154 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
155 }
156
157 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
158 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
159 NdArrayMathOps::recip(array)
160 })
161 }
162
163 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
164 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
165 NdArrayOps::swap_dims(array, dim1, dim2)
166 })
167 }
168
169 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
170 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
171 NdArrayOps::reshape(array, shape)
172 })
173 }
174
175 fn float_gather(
176 dim: usize,
177 tensor: FloatTensor<Self>,
178 indices: NdArrayTensor,
179 ) -> FloatTensor<Self> {
180 execute_with_int_dtype!(
181 indices,
182 IntElem,
183 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
184 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
185 NdArrayOps::gather(dim, array, idx_array)
186 })
187 }
188 )
189 }
190
191 fn float_scatter_add(
192 dim: usize,
193 tensor: FloatTensor<Self>,
194 indices: NdArrayTensor,
195 value: FloatTensor<Self>,
196 ) -> FloatTensor<Self> {
197 execute_with_int_dtype!(
198 indices,
199 IntElem,
200 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
201 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(
202 dim, tensor, idx_array, value
203 ))
204 }
205 )
206 }
207
208 fn float_select(
209 tensor: FloatTensor<Self>,
210 dim: usize,
211 indices: NdArrayTensor,
212 ) -> FloatTensor<Self> {
213 execute_with_int_dtype!(
214 indices,
215 IntElem,
216 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
217 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
218 NdArrayMathOps::select(array, dim, idx_array)
219 })
220 }
221 )
222 }
223
224 fn float_select_add(
225 tensor: FloatTensor<Self>,
226 dim: usize,
227 indices: NdArrayTensor,
228 value: FloatTensor<Self>,
229 ) -> FloatTensor<Self> {
230 execute_with_int_dtype!(
231 indices,
232 IntElem,
233 |idx_array: SharedArray<IntElem>| -> NdArrayTensor {
234 execute_with_float_dtype!((tensor, value), |tensor, value| {
235 NdArrayMathOps::select_assign(tensor, dim, idx_array, value)
236 })
237 }
238 )
239 }
240
241 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {
242 slice!(tensor, slices)
243 }
244
245 fn float_slice_assign(
246 tensor: FloatTensor<Self>,
247 slices: &[burn_backend::Slice],
248 value: FloatTensor<Self>,
249 ) -> FloatTensor<Self> {
250 execute_with_float_dtype!((tensor, value), |tensor, value| {
251 NdArrayOps::slice_assign(tensor, slices, value)
252 })
253 }
254
255 fn float_mask_where(
256 tensor: FloatTensor<Self>,
257 mask: NdArrayTensor,
258 value: FloatTensor<Self>,
259 ) -> FloatTensor<Self> {
260 execute_with_float_dtype!((tensor, value), |tensor, value| {
261 NdArrayOps::mask_where(tensor, mask.bool(), value)
262 })
263 }
264
265 fn float_mask_fill(
266 tensor: FloatTensor<Self>,
267 mask: NdArrayTensor,
268 value: Scalar,
269 ) -> FloatTensor<Self> {
270 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
271 NdArrayOps::mask_fill(array, mask.bool(), value.elem())
272 })
273 }
274
275 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
276 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
277 }
278
279 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
280 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
281 NdArrayMathOps::equal_elem(array, rhs.elem())
282 })
283 }
284
285 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
286 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
287 }
288
289 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
290 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
291 NdArrayMathOps::greater_elem(array, rhs.elem())
292 })
293 }
294
295 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
296 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
297 NdArrayMathOps::greater_equal(lhs, rhs)
298 })
299 }
300
301 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
302 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
303 NdArrayMathOps::greater_equal_elem(array, rhs.elem())
304 })
305 }
306
307 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
308 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
309 }
310
311 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
312 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
313 NdArrayMathOps::lower_elem(array, rhs.elem())
314 })
315 }
316
317 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
318 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
319 NdArrayMathOps::lower_equal(lhs, rhs)
320 })
321 }
322
323 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {
324 execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {
325 NdArrayMathOps::lower_equal_elem(array, rhs.elem())
326 })
327 }
328
329 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
330 tensor
331 }
332
333 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
334 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
336 NdArrayMathOps::mean_view(array.view())
337 })
338 }
339
340 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
341 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
343 NdArrayMathOps::sum_view(array.view())
344 })
345 }
346
347 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
348 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
349 NdArrayMathOps::mean_dim(array, dim)
350 })
351 }
352
353 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
354 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
355 NdArrayMathOps::cumsum(array, dim)
356 })
357 }
358
359 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
360 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
361 NdArrayMathOps::cumprod(array, dim)
362 })
363 }
364
365 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
366 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
367 NdArrayMathOps::cummin(array, dim)
368 })
369 }
370
371 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
372 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
373 NdArrayMathOps::cummax(array, dim)
374 })
375 }
376
377 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
378 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
379 NdArrayMathOps::sum_dim(array, dim)
380 })
381 }
382
383 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
384 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
386 NdArrayMathOps::argmax_view::<I>(array.view(), dim)
387 })
388 }
389
390 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
391 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
393 NdArrayMathOps::argmin_view::<I>(array.view(), dim)
394 })
395 }
396
397 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
398 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
399 array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()
400 })
401 }
402
403 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
404 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
405 array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()
406 })
407 }
408
409 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
410 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
412 NdArrayMathOps::prod_view(array.view())
413 })
414 }
415
416 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
417 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
418 NdArrayMathOps::prod_dim(array, dim)
419 })
420 }
421
422 fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
423 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
425 NdArrayMathOps::max_view(array.view())
426 })
427 }
428
429 fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
430 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
432 NdArrayMathOps::min_view(array.view())
433 })
434 }
435
436 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
437 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
438 array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()
439 })
440 }
441
442 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {
443 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
444 array
445 .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))
446 .into_shared()
447 })
448 }
449
450 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
451 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
452 array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()
453 })
454 }
455
456 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
457 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
458 NdArrayMathOps::abs(array)
459 })
460 }
461
462 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
463 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
464 array
465 .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())
466 .into_shared()
467 })
468 }
469
470 fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
471 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
472 array
473 .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())
474 .into_shared()
475 })
476 }
477
478 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
479 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
480 array
481 .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())
482 .into_shared()
483 })
484 }
485
486 fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
487 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
488 array
489 .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())
490 .into_shared()
491 })
492 }
493
494 fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
495 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
496 array
497 .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())
498 .into_shared()
499 })
500 }
501
502 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
504 array
505 .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())
506 .into_shared()
507 })
508 }
509
510 fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
511 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
512 array
513 .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())
514 .into_shared()
515 })
516 }
517
518 fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
519 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
520 array
521 .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())
522 .into_shared()
523 })
524 }
525
526 fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
527 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
528 array
529 .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())
530 .into_shared()
531 })
532 }
533
534 fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
535 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
536 array
537 .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())
538 .into_shared()
539 })
540 }
541
542 fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
543 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
544 array
545 .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())
546 .into_shared()
547 })
548 }
549
550 fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
551 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
552 array
553 .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())
554 .into_shared()
555 })
556 }
557
558 fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
559 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
560 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))
561 })
562 }
563
564 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
565 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
566 array
567 .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())
568 .into_shared()
569 })
570 }
571
572 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
573 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
574 array
575 .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())
576 .into_shared()
577 })
578 }
579
580 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
581 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
582 array
583 .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())
584 .into_shared()
585 })
586 }
587
588 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
589 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
590 array
591 .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())
592 .into_shared()
593 })
594 }
595
596 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
597 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
598 array
599 .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())
600 .into_shared()
601 })
602 }
603
604 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
605 cat_with_dtype!(tensors, dim, [F64, F32])
606 }
607
608 fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {
609 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
610 NdArrayMathOps::clamp_min(array, min.elem())
611 })
612 }
613
614 fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {
615 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
616 NdArrayMathOps::clamp_max(array, max.elem())
617 })
618 }
619
620 fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {
621 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
622 NdArrayMathOps::clamp(array, min.elem(), max.elem())
623 })
624 }
625
626 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
627 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
628 array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()
629 })
630 }
631
632 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
633 execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {
634 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))
635 })
636 }
637
638 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
639 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
640 NdArrayOps::permute(array, axes)
641 })
642 }
643
644 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
645 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
646 NdArrayOps::flip(array, axes)
647 })
648 }
649
650 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
651 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
652 NdArrayMathOps::sign_op(array)
653 })
654 }
655
656 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
657 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
658 NdArrayOps::expand(array, shape)
659 })
660 }
661
662 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
663 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
664 cast_to_dtype(array, dtype.into())
665 })
666 }
667
668 fn float_grid_sample_2d(
669 tensor: FloatTensor<Self>,
670 grid: FloatTensor<Self>,
671 options: GridSampleOptions,
672 ) -> FloatTensor<Self> {
673 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
674 tensor, grid, options
675 ))
676 }
677
678 fn float_unfold(
679 tensor: FloatTensor<Self>,
680 dim: usize,
681 size: usize,
682 step: usize,
683 ) -> FloatTensor<Self> {
684 execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {
685 NdArrayOps::unfold(array, dim, size, step)
686 })
687 }
688}