1use alloc::vec::Vec;
3use burn_tensor::cast::ToElement;
4use burn_tensor::ops::FloatTensor;
5use core::ops::Range;
6
7use super::{NdArrayMathOps, NdArrayOps, matmul::matmul};
9use crate::element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement};
10use crate::{NdArray, tensor::NdArrayTensor};
11use crate::{NdArrayDevice, NdArrayTensorFloat, SEED, execute_with_float_dtype};
12
13use burn_common::rand::get_seeded_rng;
15use burn_tensor::{DType, Distribution, FloatDType};
16use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float;
21
22use libm::erf;
23
24#[cfg(feature = "std")]
25#[allow(dead_code)]
26fn round_ties_even_wrapper(x: f64) -> f64 {
27 x.round_ties_even()
28}
29
30#[cfg(not(feature = "std"))]
31#[allow(dead_code)]
32fn round_ties_even_wrapper(x: f64) -> f64 {
33 if (x - x.floor()) == 0.5 {
34 (x * 0.5).round() * 2.0
35 } else {
36 x.round()
37 }
38}
39
40impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
41 for NdArray<E, I, Q>
42{
43 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
44 match data.dtype {
45 DType::F64 => NdArrayTensorFloat::F64(NdArrayTensor::from_data(data)),
46 DType::F32 => NdArrayTensorFloat::F32(NdArrayTensor::from_data(data)),
47 _ => unimplemented!("Unsupported dtype for `float_from_data`"),
48 }
49 }
50
51 fn float_random(
52 shape: Shape,
53 distribution: Distribution,
54 device: &NdArrayDevice,
55 ) -> FloatTensor<Self> {
56 let mut seed = SEED.lock().unwrap();
57 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
58 rng_seeded.clone()
59 } else {
60 get_seeded_rng()
61 };
62 let tensor = Self::float_from_data(
63 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
64 device,
65 );
66 *seed = Some(rng);
67 tensor
68 }
69
70 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
71 match tensor {
72 NdArrayTensorFloat::F32(tensor) => NdArrayOps::into_data(tensor),
73 NdArrayTensorFloat::F64(tensor) => NdArrayOps::into_data(tensor),
74 }
75 }
76
77 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
78 NdArrayDevice::Cpu
79 }
80
81 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
82 tensor
83 }
84
85 fn float_empty(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> FloatTensor<Self> {
86 NdArray::<E>::float_zeros(shape, device)
87 }
88
89 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
90 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
91 }
92
93 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
94 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
95 }
96
97 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
98 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
99 }
100
101 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
102 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
103 }
104
105 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
106 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
107 }
108
109 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
110 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
111 }
112
113 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
114 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
115 }
116
117 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
118 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
119 }
120
121 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
122 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
123 }
124
125 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
126 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
127 }
128
129 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130 execute_with_float_dtype!((lhs, rhs), matmul)
131 }
132
133 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
134 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
135 }
136
137 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
138 execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
139 }
140
141 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
142 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
143 }
144
145 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
146 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
147 }
148
149 fn float_gather(
150 dim: usize,
151 tensor: FloatTensor<Self>,
152 indices: NdArrayTensor<I>,
153 ) -> FloatTensor<Self> {
154 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
155 dim, tensor, indices
156 ))
157 }
158
159 fn float_scatter(
160 dim: usize,
161 tensor: FloatTensor<Self>,
162 indices: NdArrayTensor<I>,
163 value: FloatTensor<Self>,
164 ) -> FloatTensor<Self> {
165 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
166 dim, tensor, indices, value
167 ))
168 }
169
170 fn float_select(
171 tensor: FloatTensor<Self>,
172 dim: usize,
173 indices: NdArrayTensor<I>,
174 ) -> FloatTensor<Self> {
175 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
176 tensor, dim, indices
177 ))
178 }
179
180 fn float_select_assign(
181 tensor: FloatTensor<Self>,
182 dim: usize,
183 indices: NdArrayTensor<I>,
184 value: FloatTensor<Self>,
185 ) -> FloatTensor<Self> {
186 execute_with_float_dtype!((tensor, value), |tensor, value| {
187 NdArrayMathOps::select_assign(tensor, dim, indices, value)
188 })
189 }
190
191 fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
192 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, ranges))
193 }
194
195 fn float_slice_assign(
196 tensor: FloatTensor<Self>,
197 ranges: &[Range<usize>],
198 value: FloatTensor<Self>,
199 ) -> FloatTensor<Self> {
200 execute_with_float_dtype!((tensor, value), |tensor, value| {
201 NdArrayOps::slice_assign(tensor, ranges, value)
202 })
203 }
204
205 fn float_mask_where(
206 tensor: FloatTensor<Self>,
207 mask: NdArrayTensor<bool>,
208 value: FloatTensor<Self>,
209 ) -> FloatTensor<Self> {
210 execute_with_float_dtype!((tensor, value), |tensor, value| {
211 NdArrayMathOps::mask_where(tensor, mask, value)
212 })
213 }
214
215 fn float_mask_fill(
216 tensor: FloatTensor<Self>,
217 mask: NdArrayTensor<bool>,
218 value: E,
219 ) -> FloatTensor<Self> {
220 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
221 tensor,
222 mask,
223 value.elem()
224 ))
225 }
226
227 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
228 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
229 NdArrayMathOps::equal(lhs, rhs)
230 })
231 }
232
233 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
234 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
235 NdArrayMathOps::equal_elem(tensor, rhs.elem())
236 })
237 }
238
239 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
240 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
241 NdArrayMathOps::greater(lhs, rhs)
242 })
243 }
244
245 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
246 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
247 NdArrayMathOps::greater_elem(tensor, rhs.elem())
248 })
249 }
250
251 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
252 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
253 NdArrayMathOps::greater_equal(lhs, rhs)
254 })
255 }
256
257 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
258 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
259 NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
260 })
261 }
262
263 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
264 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
265 NdArrayMathOps::lower(lhs, rhs)
266 })
267 }
268
269 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
270 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
271 NdArrayMathOps::lower_elem(tensor, rhs.elem())
272 })
273 }
274
275 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
276 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
277 NdArrayMathOps::lower_equal(lhs, rhs)
278 })
279 }
280
281 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
282 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
283 NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
284 })
285 }
286
287 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
288 tensor
289 }
290
291 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
292 execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
293 }
294
295 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
296 execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
297 }
298
299 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
300 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
301 }
302
303 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
304 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
305 }
306
307 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
308 execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmax(tensor, dim))
309 }
310
311 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
312 execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmin(tensor, dim))
313 }
314
315 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
316 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
317 let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared();
318
319 NdArrayTensor::new(array)
320 })
321 }
322
323 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
324 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
325 let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared();
326
327 NdArrayTensor::new(array)
328 })
329 }
330
331 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
332 execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
333 }
334
335 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
336 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
337 }
338
339 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
340 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
341 let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();
342
343 NdArrayTensor::new(array)
344 })
345 }
346
347 fn float_powf_scalar(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
348 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
349 let array = if value == 2.0 {
350 tensor.array.mapv_into(|a| a * a).into_shared()
352 } else if value.floor() == value {
353 tensor
355 .array
356 .mapv_into(|a| a.powi_elem(value as i32))
357 .into_shared()
358 } else {
359 tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared()
361 };
362
363 NdArrayTensor::new(array)
364 })
365 }
366
367 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
368 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
369 let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared();
370
371 NdArrayTensor::new(array)
372 })
373 }
374
375 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
376 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
377 let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
378
379 NdArrayTensor::new(array)
380 })
381 }
382
383 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
385 let array = tensor
386 .array
387 .mapv_into(|a| (a.to_f64()).cos().elem())
388 .into_shared();
389
390 NdArrayTensor::new(array)
391 })
392 }
393
394 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
396 let array = tensor
397 .array
398 .mapv_into(|a| (a.to_f64()).sin().elem())
399 .into_shared();
400
401 NdArrayTensor::new(array)
402 })
403 }
404
405 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
407 let array = tensor
408 .array
409 .mapv_into(|a| (a.to_f64()).tanh().elem())
410 .into_shared();
411
412 NdArrayTensor::new(array)
413 })
414 }
415
416 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
418 let array = tensor
419 .array
420 .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
421 .into_shared();
422
423 NdArrayTensor::new(array)
424 })
425 }
426
427 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
428 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
429 let array = tensor
430 .array
431 .mapv_into(|a| (a.to_f64()).floor().elem())
432 .into_shared();
433
434 NdArrayTensor::new(array)
435 })
436 }
437
438 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
440 let array = tensor
441 .array
442 .mapv_into(|a| (a.to_f64()).ceil().elem())
443 .into_shared();
444
445 NdArrayTensor::new(array)
446 })
447 }
448
449 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
451 let array = tensor
452 .array
453 .mapv_into(|a| erf(a.to_f64()).elem())
454 .into_shared();
455
456 NdArrayTensor::new(array)
457 })
458 }
459
460 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
461 match &tensors[0] {
462 NdArrayTensorFloat::F32(_) => {
463 let tensors = tensors
464 .iter()
465 .map(|t| {
466 if let NdArrayTensorFloat::F32(tensor) = t {
467 tensor.array.view()
468 } else {
469 panic!("Concatenate data type mismatch (expected f32, got f64)")
470 }
471 })
472 .collect::<Vec<_>>();
473 NdArrayTensorFloat::F32(NdArrayOps::concatenate(&tensors, dim))
474 }
475 NdArrayTensorFloat::F64(_) => {
476 let tensors = tensors
477 .iter()
478 .map(|t| {
479 if let NdArrayTensorFloat::F64(tensor) = t {
480 tensor.array.view()
481 } else {
482 panic!("Concatenate data type mismatch (expected f64, got f32)")
483 }
484 })
485 .collect::<Vec<_>>();
486 NdArrayTensorFloat::F64(NdArrayOps::concatenate(&tensors, dim))
487 }
488 }
489 }
490
491 fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
492 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
493 tensor,
494 min.elem()
495 ))
496 }
497
498 fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
499 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
500 tensor,
501 max.elem()
502 ))
503 }
504
505 fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
506 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
507 tensor,
508 min.elem(),
509 max.elem()
510 ))
511 }
512
513 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor<I> {
514 execute_with_float_dtype!(tensor, E => |tensor: NdArrayTensor<E>| {
515 let array = tensor.array.mapv(|a| a.elem()).into_shared();
516 NdArrayTensor { array }
517 })
518 }
519
520 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
521 execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
522 lhs,
523 rhs,
524 |a: &E, b: &E| a.powf(*b)
525 ))
526 }
527
528 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
529 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
530 }
531
532 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
533 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
534 }
535
536 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
537 execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
538 }
539
540 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
541 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
542 }
543
544 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
545 fn cast<E1: FloatNdArrayElement, E2: FloatNdArrayElement>(
546 tensor: &NdArrayTensor<E1>,
547 ) -> NdArrayTensor<E2> {
548 let array = tensor.array.mapv(|a| a.elem()).into_shared();
549 NdArrayTensor { array }
550 }
551
552 match (&tensor, dtype) {
553 (NdArrayTensorFloat::F32(_), FloatDType::F32)
555 | (NdArrayTensorFloat::F64(_), FloatDType::F64) => tensor,
556 (NdArrayTensorFloat::F32(tensor), FloatDType::F64) => {
558 NdArrayTensorFloat::F64(cast(tensor))
559 }
560 (NdArrayTensorFloat::F64(tensor), FloatDType::F32) => {
562 NdArrayTensorFloat::F32(cast(tensor))
563 }
564 _ => panic!("Invalid cast types"),
565 }
566 }
567}