1use alloc::vec::Vec;
3use burn_tensor::ops::FloatTensor;
4use burn_tensor::ops::InterpolateMode;
5use burn_tensor::{TensorMetadata, cast::ToElement};
6
7use super::{
9 NdArrayMathOps, NdArrayOps,
10 matmul::{cross, matmul},
11};
12use crate::{
13 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
14};
15use crate::{NdArrayDevice, SEED};
16use crate::{
17 SharedArray,
18 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
19};
20use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
21
22use crate::rand::get_seeded_rng;
24use burn_tensor::{Distribution, FloatDType};
25use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
26
27#[cfg(not(feature = "std"))]
28#[allow(unused_imports)]
29use num_traits::Float;
30
31use libm::erf;
32
33#[cfg(feature = "std")]
34#[allow(dead_code)]
35fn round_ties_even_wrapper(x: f64) -> f64 {
36 x.round_ties_even()
37}
38
39#[cfg(not(feature = "std"))]
40#[allow(dead_code)]
41fn round_ties_even_wrapper(x: f64) -> f64 {
42 if (x - x.floor()) == 0.5 {
43 (x * 0.5).round() * 2.0
44 } else {
45 x.round()
46 }
47}
48
49impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
50 for NdArray<E, I, Q>
51where
52 NdArrayTensor: From<SharedArray<E>>,
53 NdArrayTensor: From<SharedArray<I>>,
54{
55 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
56 NdArrayTensor::from_data(data)
57 }
58
59 fn float_random(
60 shape: Shape,
61 distribution: Distribution,
62 device: &NdArrayDevice,
63 ) -> FloatTensor<Self> {
64 let mut seed = SEED.lock().unwrap();
65 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
66 rng_seeded.clone()
67 } else {
68 get_seeded_rng()
69 };
70 let tensor = Self::float_from_data(
71 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
72 device,
73 );
74 *seed = Some(rng);
75 tensor
76 }
77
78 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
79 tensor.into_data()
80 }
81
82 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
83 NdArrayDevice::Cpu
84 }
85
86 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
87 tensor
88 }
89
90 fn float_empty(
91 shape: Shape,
92 device: &<NdArray<E> as Backend>::Device,
93 dtype: FloatDType,
94 ) -> FloatTensor<Self> {
95 Self::float_zeros(shape, device, dtype)
96 }
97
98 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100 }
101
102 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
103 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
104 }
105
106 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
107 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
108 }
109
110 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
111 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
112 }
113
114 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
115 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
116 }
117
118 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
119 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
120 }
121
122 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
124 }
125
126 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
127 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
128 }
129
130 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
131 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
132 }
133
134 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
135 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
136 }
137
138 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139 execute_with_float_dtype!((lhs, rhs), matmul)
140 }
141
142 fn float_cross(
143 lhs: FloatTensor<Self>,
144 rhs: FloatTensor<Self>,
145 dim: usize,
146 ) -> FloatTensor<Self> {
147 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
148 }
149
150 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
151 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
152 }
153
154 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
155 execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
156 }
157
158 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
159 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
160 }
161
162 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
163 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
164 }
165
166 fn float_gather(
167 dim: usize,
168 tensor: FloatTensor<Self>,
169 indices: NdArrayTensor,
170 ) -> FloatTensor<Self> {
171 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
172 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
173 dim, tensor, indices
174 ))
175 })
176 }
177
178 fn float_scatter(
179 dim: usize,
180 tensor: FloatTensor<Self>,
181 indices: NdArrayTensor,
182 value: FloatTensor<Self>,
183 ) -> FloatTensor<Self> {
184 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
185 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
186 dim, tensor, indices, value
187 ))
188 })
189 }
190
191 fn float_select(
192 tensor: FloatTensor<Self>,
193 dim: usize,
194 indices: NdArrayTensor,
195 ) -> FloatTensor<Self> {
196 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
197 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
198 tensor, dim, indices
199 ))
200 })
201 }
202
203 fn float_select_assign(
204 tensor: FloatTensor<Self>,
205 dim: usize,
206 indices: NdArrayTensor,
207 value: FloatTensor<Self>,
208 ) -> FloatTensor<Self> {
209 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
210 execute_with_float_dtype!((tensor, value), |tensor, value| {
211 NdArrayMathOps::select_assign(tensor, dim, indices, value)
212 })
213 })
214 }
215
216 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
217 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
218 }
219
220 fn float_slice_assign(
221 tensor: FloatTensor<Self>,
222 slices: &[burn_tensor::Slice],
223 value: FloatTensor<Self>,
224 ) -> FloatTensor<Self> {
225 execute_with_float_dtype!((tensor, value), |tensor, value| {
226 NdArrayOps::slice_assign(tensor, slices, value)
227 })
228 }
229
230 fn float_mask_where(
231 tensor: FloatTensor<Self>,
232 mask: NdArrayTensor,
233 value: FloatTensor<Self>,
234 ) -> FloatTensor<Self> {
235 execute_with_float_dtype!((tensor, value), |tensor, value| {
236 NdArrayMathOps::mask_where(tensor, mask.bool(), value)
237 })
238 }
239
240 fn float_mask_fill(
241 tensor: FloatTensor<Self>,
242 mask: NdArrayTensor,
243 value: E,
244 ) -> FloatTensor<Self> {
245 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
246 tensor,
247 mask.bool(),
248 value.elem()
249 ))
250 }
251
252 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
253 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
254 }
255
256 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
257 execute_with_float_dtype!(lhs, |tensor| {
258 NdArrayMathOps::equal_elem(tensor, rhs.elem())
259 })
260 }
261
262 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
263 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
264 }
265
266 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
267 execute_with_float_dtype!(lhs, |tensor| {
268 NdArrayMathOps::greater_elem(tensor, rhs.elem())
269 })
270 }
271
272 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
273 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
274 NdArrayMathOps::greater_equal(lhs, rhs)
275 })
276 }
277
278 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
279 execute_with_float_dtype!(lhs, |tensor| {
280 NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
281 })
282 }
283
284 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
285 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
286 }
287
288 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
289 execute_with_float_dtype!(lhs, |tensor| {
290 NdArrayMathOps::lower_elem(tensor, rhs.elem())
291 })
292 }
293
294 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
295 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
296 NdArrayMathOps::lower_equal(lhs, rhs)
297 })
298 }
299
300 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
301 execute_with_float_dtype!(lhs, |tensor| {
302 NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
303 })
304 }
305
306 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307 tensor
308 }
309
310 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
311 execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
312 }
313
314 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
315 execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
316 }
317
318 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
319 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
320 }
321
322 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
323 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
324 }
325
326 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
327 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
328 }
329
330 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
331 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
332 }
333
334 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
335 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
336 }
337
338 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
339 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
340 }
341
342 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
343 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
344 }
345
346 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
347 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
348 }
349
350 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
352 tensor.mapv_into(|a| a.exp_elem()).into_shared()
353 })
354 }
355
356 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
358 tensor.mapv_into(|a| a.log_elem()).into_shared()
359 })
360 }
361
362 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
363 execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
364 }
365
366 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
367 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
368 }
369
370 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
371 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
372 tensor.mapv_into(|a| a.log1p_elem()).into_shared()
373 })
374 }
375
376 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
377 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
378 tensor.mapv_into(|a| a.powf_elem(value)).into_shared()
379 })
380 }
381
382 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
383 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
384 tensor.mapv_into(|a| a.sqrt_elem()).into_shared()
385 })
386 }
387
388 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389 execute_with_float_dtype!(tensor, E, NdArrayMathOps::abs)
390 }
391
392 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
393 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
394 tensor
395 .mapv_into(|a| (a.to_f64()).cos().elem())
396 .into_shared()
397 })
398 }
399
400 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
402 tensor
403 .mapv_into(|a| (a.to_f64()).sin().elem())
404 .into_shared()
405 })
406 }
407
408 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
409 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
410 tensor
411 .mapv_into(|a| (a.to_f64()).tanh().elem())
412 .into_shared()
413 })
414 }
415
416 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
418 tensor
419 .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
420 .into_shared()
421 })
422 }
423
424 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
425 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
426 tensor
427 .mapv_into(|a| (a.to_f64()).floor().elem())
428 .into_shared()
429 })
430 }
431
432 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
433 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
434 tensor
435 .mapv_into(|a| (a.to_f64()).ceil().elem())
436 .into_shared()
437 })
438 }
439
440 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
441 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
442 tensor
443 .mapv_into(|a| (a.to_f64()).trunc().elem())
444 .into_shared()
445 })
446 }
447
448 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
449 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
450 tensor.mapv_into(|a| erf(a.to_f64()).elem()).into_shared()
451 })
452 }
453
454 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
455 cat_with_dtype!(tensors, dim, [F64, F32])
456 }
457
458 fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
459 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
460 tensor,
461 min.elem()
462 ))
463 }
464
465 fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
466 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
467 tensor,
468 max.elem()
469 ))
470 }
471
472 fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
473 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
474 tensor,
475 min.elem(),
476 max.elem()
477 ))
478 }
479
480 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
481 execute_with_float_dtype!(tensor, |tensor: SharedArray<E>| {
482 tensor.mapv(|a| a.elem::<I>()).into_shared()
483 })
484 }
485
486 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
487 execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
488 lhs,
489 rhs,
490 |a: &E, b: &E| a.powf(*b)
491 ))
492 }
493
494 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
495 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
496 }
497
498 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
499 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
500 }
501
502 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503 execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
504 }
505
506 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
507 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
508 }
509
510 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
511 execute_with_float_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
512 }
513
514 fn float_grid_sample_2d(
515 tensor: FloatTensor<Self>,
516 grid: FloatTensor<Self>,
517 method: InterpolateMode,
518 ) -> FloatTensor<Self> {
519 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
520 tensor, grid, method
521 ))
522 }
523
524 fn float_unfold(
525 tensor: FloatTensor<Self>,
526 dim: usize,
527 size: usize,
528 step: usize,
529 ) -> FloatTensor<Self> {
530 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
531 }
532}