1use alloc::vec::Vec;
3use burn_tensor::backend::ExecutionError;
4use burn_tensor::ops::FloatTensor;
5use burn_tensor::ops::GridSampleOptions;
6use burn_tensor::{TensorMetadata, cast::ToElement};
7
8use super::{
10 NdArrayMathOps, NdArrayOps,
11 matmul::{cross, matmul},
12};
13use crate::{
14 NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED};
17use crate::{
18 SharedArray,
19 element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23use crate::rand::get_seeded_rng;
25use burn_tensor::{Distribution, FloatDType};
26use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37 x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43 if (x - x.floor()) == 0.5 {
44 (x * 0.5).round() * 2.0
45 } else {
46 x.round()
47 }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51 for NdArray<E, I, Q>
52where
53 NdArrayTensor: From<SharedArray<E>>,
54 NdArrayTensor: From<SharedArray<I>>,
55{
56 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57 NdArrayTensor::from_data(data)
58 }
59
60 fn float_random(
61 shape: Shape,
62 distribution: Distribution,
63 device: &NdArrayDevice,
64 ) -> FloatTensor<Self> {
65 let mut seed = SEED.lock().unwrap();
66 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67 rng_seeded.clone()
68 } else {
69 get_seeded_rng()
70 };
71 let tensor = Self::float_from_data(
72 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73 device,
74 );
75 *seed = Some(rng);
76 tensor
77 }
78
79 async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80 Ok(tensor.into_data())
81 }
82
83 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84 NdArrayDevice::Cpu
85 }
86
87 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88 tensor
89 }
90
91 fn float_empty(
92 shape: Shape,
93 device: &<NdArray<E> as Backend>::Device,
94 dtype: FloatDType,
95 ) -> FloatTensor<Self> {
96 Self::float_zeros(shape, device, dtype)
97 }
98
99 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101 }
102
103 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
104 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
105 }
106
107 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
108 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
109 }
110
111 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
112 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
113 }
114
115 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
116 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
117 }
118
119 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
120 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
121 }
122
123 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
124 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
125 }
126
127 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
128 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
129 }
130
131 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
132 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
133 }
134
135 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
136 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
137 }
138
139 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140 execute_with_float_dtype!((lhs, rhs), matmul)
141 }
142
143 fn float_cross(
144 lhs: FloatTensor<Self>,
145 rhs: FloatTensor<Self>,
146 dim: usize,
147 ) -> FloatTensor<Self> {
148 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
149 }
150
151 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
152 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
153 }
154
155 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
156 execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
157 }
158
159 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
160 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
161 }
162
163 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
164 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
165 }
166
167 fn float_gather(
168 dim: usize,
169 tensor: FloatTensor<Self>,
170 indices: NdArrayTensor,
171 ) -> FloatTensor<Self> {
172 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
173 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
174 dim, tensor, indices
175 ))
176 })
177 }
178
179 fn float_scatter(
180 dim: usize,
181 tensor: FloatTensor<Self>,
182 indices: NdArrayTensor,
183 value: FloatTensor<Self>,
184 ) -> FloatTensor<Self> {
185 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
186 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
187 dim, tensor, indices, value
188 ))
189 })
190 }
191
192 fn float_select(
193 tensor: FloatTensor<Self>,
194 dim: usize,
195 indices: NdArrayTensor,
196 ) -> FloatTensor<Self> {
197 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
198 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
199 tensor, dim, indices
200 ))
201 })
202 }
203
204 fn float_select_assign(
205 tensor: FloatTensor<Self>,
206 dim: usize,
207 indices: NdArrayTensor,
208 value: FloatTensor<Self>,
209 ) -> FloatTensor<Self> {
210 execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
211 execute_with_float_dtype!((tensor, value), |tensor, value| {
212 NdArrayMathOps::select_assign(tensor, dim, indices, value)
213 })
214 })
215 }
216
217 fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
218 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
219 }
220
221 fn float_slice_assign(
222 tensor: FloatTensor<Self>,
223 slices: &[burn_tensor::Slice],
224 value: FloatTensor<Self>,
225 ) -> FloatTensor<Self> {
226 execute_with_float_dtype!((tensor, value), |tensor, value| {
227 NdArrayOps::slice_assign(tensor, slices, value)
228 })
229 }
230
231 fn float_mask_where(
232 tensor: FloatTensor<Self>,
233 mask: NdArrayTensor,
234 value: FloatTensor<Self>,
235 ) -> FloatTensor<Self> {
236 execute_with_float_dtype!((tensor, value), |tensor, value| {
237 NdArrayMathOps::mask_where(tensor, mask.bool(), value)
238 })
239 }
240
241 fn float_mask_fill(
242 tensor: FloatTensor<Self>,
243 mask: NdArrayTensor,
244 value: E,
245 ) -> FloatTensor<Self> {
246 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
247 tensor,
248 mask.bool(),
249 value.elem()
250 ))
251 }
252
253 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
254 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
255 }
256
257 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
258 execute_with_float_dtype!(lhs, |tensor| {
259 NdArrayMathOps::equal_elem(tensor, rhs.elem())
260 })
261 }
262
263 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
264 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
265 }
266
267 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
268 execute_with_float_dtype!(lhs, |tensor| {
269 NdArrayMathOps::greater_elem(tensor, rhs.elem())
270 })
271 }
272
273 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
274 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
275 NdArrayMathOps::greater_equal(lhs, rhs)
276 })
277 }
278
279 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
280 execute_with_float_dtype!(lhs, |tensor| {
281 NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
282 })
283 }
284
285 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
286 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
287 }
288
289 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
290 execute_with_float_dtype!(lhs, |tensor| {
291 NdArrayMathOps::lower_elem(tensor, rhs.elem())
292 })
293 }
294
295 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
296 execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
297 NdArrayMathOps::lower_equal(lhs, rhs)
298 })
299 }
300
301 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
302 execute_with_float_dtype!(lhs, |tensor| {
303 NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
304 })
305 }
306
307 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
308 tensor
309 }
310
311 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
312 execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
313 }
314
315 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
316 execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
317 }
318
319 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
320 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
321 }
322
323 fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
324 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
325 }
326
327 fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
328 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
329 }
330
331 fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
332 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
333 }
334
335 fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
336 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
337 }
338
339 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
340 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
341 }
342
343 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
344 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
345 }
346
347 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
348 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
349 }
350
351 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
352 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
353 tensor.mapv_into(|a| a.exp_elem()).into_shared()
354 })
355 }
356
357 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
358 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
359 tensor.mapv_into(|a| a.log_elem()).into_shared()
360 })
361 }
362
363 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
364 execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
365 }
366
367 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
368 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
369 }
370
371 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
372 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
373 tensor.mapv_into(|a| a.log1p_elem()).into_shared()
374 })
375 }
376
377 fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
378 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
379 tensor.mapv_into(|a| a.powf_elem(value)).into_shared()
380 })
381 }
382
383 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
385 tensor.mapv_into(|a| a.sqrt_elem()).into_shared()
386 })
387 }
388
389 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
390 execute_with_float_dtype!(tensor, E, NdArrayMathOps::abs)
391 }
392
393 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
394 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
395 tensor
396 .mapv_into(|a| (a.to_f64()).cos().elem())
397 .into_shared()
398 })
399 }
400
401 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
403 tensor
404 .mapv_into(|a| (a.to_f64()).sin().elem())
405 .into_shared()
406 })
407 }
408
409 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
410 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
411 tensor
412 .mapv_into(|a| (a.to_f64()).tanh().elem())
413 .into_shared()
414 })
415 }
416
417 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
419 tensor
420 .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
421 .into_shared()
422 })
423 }
424
425 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
426 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
427 tensor
428 .mapv_into(|a| (a.to_f64()).floor().elem())
429 .into_shared()
430 })
431 }
432
433 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
434 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
435 tensor
436 .mapv_into(|a| (a.to_f64()).ceil().elem())
437 .into_shared()
438 })
439 }
440
441 fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
442 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
443 tensor
444 .mapv_into(|a| (a.to_f64()).trunc().elem())
445 .into_shared()
446 })
447 }
448
449 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450 execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
451 tensor.mapv_into(|a| erf(a.to_f64()).elem()).into_shared()
452 })
453 }
454
455 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
456 cat_with_dtype!(tensors, dim, [F64, F32])
457 }
458
459 fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
460 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
461 tensor,
462 min.elem()
463 ))
464 }
465
466 fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
467 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
468 tensor,
469 max.elem()
470 ))
471 }
472
473 fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
474 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
475 tensor,
476 min.elem(),
477 max.elem()
478 ))
479 }
480
481 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
482 execute_with_float_dtype!(tensor, |tensor: SharedArray<E>| {
483 tensor.mapv(|a| a.elem::<I>()).into_shared()
484 })
485 }
486
487 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
488 execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
489 lhs,
490 rhs,
491 |a: &E, b: &E| a.powf(*b)
492 ))
493 }
494
495 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
496 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
497 }
498
499 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
500 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
501 }
502
503 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
504 execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
505 }
506
507 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
508 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
509 }
510
511 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
512 execute_with_float_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
513 }
514
515 fn float_grid_sample_2d(
516 tensor: FloatTensor<Self>,
517 grid: FloatTensor<Self>,
518 options: GridSampleOptions,
519 ) -> FloatTensor<Self> {
520 execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
521 tensor, grid, options
522 ))
523 }
524
525 fn float_unfold(
526 tensor: FloatTensor<Self>,
527 dim: usize,
528 size: usize,
529 step: usize,
530 ) -> FloatTensor<Self> {
531 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
532 }
533}