1use alloc::vec::Vec;
3use burn_tensor::cast::ToElement;
4use burn_tensor::ops::FloatTensor;
5use core::ops::Range;
6use ndarray::Zip;
7
8use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
10use crate::element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{execute_with_float_dtype, new_tensor_float, NdArrayDevice, NdArrayTensorFloat, SEED};
12use crate::{tensor::NdArrayTensor, NdArray};
13
14use burn_common::rand::get_seeded_rng;
16use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData};
17use burn_tensor::{Distribution, FloatDType};
18
19#[cfg(not(feature = "std"))]
20#[allow(unused_imports)]
21use num_traits::Float;
22
23use libm::erf;
24
25#[cfg(feature = "std")]
26#[allow(dead_code)]
27fn round_ties_even_wrapper(x: f64) -> f64 {
28 x.round_ties_even()
29}
30
31#[cfg(not(feature = "std"))]
32#[allow(dead_code)]
33fn round_ties_even_wrapper(x: f64) -> f64 {
34 if (x - x.floor()) == 0.5 {
35 (x * 0.5).round() * 2.0
36 } else {
37 x.round()
38 }
39}
40
41impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
42 for NdArray<E, I, Q>
43{
44 fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
45 new_tensor_float!(NdArrayTensor::from_data(data))
46 }
47
48 fn float_random(
49 shape: Shape,
50 distribution: Distribution,
51 device: &NdArrayDevice,
52 ) -> FloatTensor<Self> {
53 let mut seed = SEED.lock().unwrap();
54 let mut rng = if let Some(rng_seeded) = seed.as_ref() {
55 rng_seeded.clone()
56 } else {
57 get_seeded_rng()
58 };
59 let tensor = Self::float_from_data(
60 TensorData::random::<E, _, _>(shape, distribution, &mut rng),
61 device,
62 );
63 *seed = Some(rng);
64 tensor
65 }
66
67 async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
68 match tensor {
69 NdArrayTensorFloat::F32(tensor) => NdArrayOps::into_data(tensor),
70 NdArrayTensorFloat::F64(tensor) => NdArrayOps::into_data(tensor),
71 }
72 }
73
74 fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
75 NdArrayDevice::Cpu
76 }
77
78 fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
79 tensor
80 }
81
82 fn float_empty(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> FloatTensor<Self> {
83 NdArray::<E>::float_zeros(shape, device)
84 }
85
86 fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
87 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
88 }
89
90 fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
91 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
92 }
93
94 fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
95 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
96 }
97
98 fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
99 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
100 }
101
102 fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
103 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
104 }
105
106 fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
107 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
108 }
109
110 fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
111 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
112 }
113
114 fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
115 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
116 }
117
118 fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119 execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
120 }
121
122 fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
123 execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
124 }
125
126 fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
127 execute_with_float_dtype!((lhs, rhs), matmul)
128 }
129
130 fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
131 Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
132 }
133
134 fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
135 execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
136 }
137
138 fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
139 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
140 }
141
142 fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
143 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
144 }
145
146 fn float_gather(
147 dim: usize,
148 tensor: FloatTensor<Self>,
149 indices: NdArrayTensor<I>,
150 ) -> FloatTensor<Self> {
151 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
152 dim, tensor, indices
153 ))
154 }
155
156 fn float_scatter(
157 dim: usize,
158 tensor: FloatTensor<Self>,
159 indices: NdArrayTensor<I>,
160 value: FloatTensor<Self>,
161 ) -> FloatTensor<Self> {
162 execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
163 dim, tensor, indices, value
164 ))
165 }
166
167 fn float_select(
168 tensor: FloatTensor<Self>,
169 dim: usize,
170 indices: NdArrayTensor<I>,
171 ) -> FloatTensor<Self> {
172 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
173 tensor, dim, indices
174 ))
175 }
176
177 fn float_select_assign(
178 tensor: FloatTensor<Self>,
179 dim: usize,
180 indices: NdArrayTensor<I>,
181 value: FloatTensor<Self>,
182 ) -> FloatTensor<Self> {
183 execute_with_float_dtype!((tensor, value), |tensor, value| {
184 NdArrayMathOps::select_assign(tensor, dim, indices, value)
185 })
186 }
187
188 fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
189 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, ranges))
190 }
191
192 fn float_slice_assign(
193 tensor: FloatTensor<Self>,
194 ranges: &[Range<usize>],
195 value: FloatTensor<Self>,
196 ) -> FloatTensor<Self> {
197 execute_with_float_dtype!((tensor, value), |tensor, value| {
198 NdArrayOps::slice_assign(tensor, ranges, value)
199 })
200 }
201
202 fn float_mask_where(
203 tensor: FloatTensor<Self>,
204 mask: NdArrayTensor<bool>,
205 value: FloatTensor<Self>,
206 ) -> FloatTensor<Self> {
207 execute_with_float_dtype!((tensor, value), |tensor, value| {
208 NdArrayMathOps::mask_where(tensor, mask, value)
209 })
210 }
211
212 fn float_mask_fill(
213 tensor: FloatTensor<Self>,
214 mask: NdArrayTensor<bool>,
215 value: E,
216 ) -> FloatTensor<Self> {
217 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
218 tensor,
219 mask,
220 value.elem()
221 ))
222 }
223
224 fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
225 execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
226 let output = Zip::from(&lhs.array)
227 .and(&rhs.array)
228 .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
229 .into_shared();
230 NdArrayTensor::new(output)
231 })
232 }
233
234 fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
235 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
236 let array = tensor.array.mapv(|a| a == rhs.elem::<E>()).into_shared();
237
238 NdArrayTensor::new(array)
239 })
240 }
241
242 fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
243 let tensor = NdArray::<E>::float_sub(lhs, rhs);
244 let zero = 0.elem();
245 Self::float_greater_elem(tensor, zero)
246 }
247
248 fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
249 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
250 let array = tensor.array.mapv(|a| a > rhs.elem::<E>()).into_shared();
251
252 NdArrayTensor::new(array)
253 })
254 }
255
256 fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
257 let tensor = NdArray::<E>::float_sub(lhs, rhs);
258 let zero = 0.elem();
259 Self::float_greater_equal_elem(tensor, zero)
260 }
261
262 fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
263 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
264 let array = tensor.array.mapv(|a| a >= rhs.elem::<E>()).into_shared();
265
266 NdArrayTensor::new(array)
267 })
268 }
269
270 fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
271 let tensor = NdArray::<E>::float_sub(lhs, rhs);
272 let zero = 0.elem();
273 Self::float_lower_elem(tensor, zero)
274 }
275
276 fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
277 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
278 let array = tensor.array.mapv(|a| a < rhs.elem::<E>()).into_shared();
279
280 NdArrayTensor::new(array)
281 })
282 }
283
284 fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
285 let tensor = NdArray::<E>::float_sub(lhs, rhs);
286 let zero = 0.elem();
287 Self::float_lower_equal_elem(tensor, zero)
288 }
289
290 fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
291 execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
292 let array = tensor.array.mapv(|a| a <= rhs.elem::<E>()).into_shared();
293
294 NdArrayTensor::new(array)
295 })
296 }
297
298 fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
299 tensor
300 }
301
302 fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303 execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
304 }
305
306 fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307 execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
308 }
309
310 fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
311 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
312 }
313
314 fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
315 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
316 }
317
318 fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
319 execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmax(tensor, dim))
320 }
321
322 fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
323 execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmin(tensor, dim))
324 }
325
326 fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
327 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
328 let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared();
329
330 NdArrayTensor::new(array)
331 })
332 }
333
334 fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
335 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
336 let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared();
337
338 NdArrayTensor::new(array)
339 })
340 }
341
342 fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
343 execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
344 }
345
346 fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
347 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
348 }
349
350 fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
352 let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();
353
354 NdArrayTensor::new(array)
355 })
356 }
357
358 fn float_powf_scalar(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
359 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
360 let array = if value == 2.0 {
361 tensor.array.mapv_into(|a| a * a).into_shared()
363 } else if value.floor() == value {
364 tensor
366 .array
367 .mapv_into(|a| a.powi_elem(value as i32))
368 .into_shared()
369 } else {
370 tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared()
372 };
373
374 NdArrayTensor::new(array)
375 })
376 }
377
378 fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
380 let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared();
381
382 NdArrayTensor::new(array)
383 })
384 }
385
386 fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
387 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
388 let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
389
390 NdArrayTensor::new(array)
391 })
392 }
393
394 fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
396 let array = tensor
397 .array
398 .mapv_into(|a| (a.to_f64()).cos().elem())
399 .into_shared();
400
401 NdArrayTensor::new(array)
402 })
403 }
404
405 fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
407 let array = tensor
408 .array
409 .mapv_into(|a| (a.to_f64()).sin().elem())
410 .into_shared();
411
412 NdArrayTensor::new(array)
413 })
414 }
415
416 fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
418 let array = tensor
419 .array
420 .mapv_into(|a| (a.to_f64()).tanh().elem())
421 .into_shared();
422
423 NdArrayTensor::new(array)
424 })
425 }
426
427 fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
428 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
429 let array = tensor
430 .array
431 .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
432 .into_shared();
433
434 NdArrayTensor::new(array)
435 })
436 }
437
438 fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
440 let array = tensor
441 .array
442 .mapv_into(|a| (a.to_f64()).floor().elem())
443 .into_shared();
444
445 NdArrayTensor::new(array)
446 })
447 }
448
449 fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
451 let array = tensor
452 .array
453 .mapv_into(|a| (a.to_f64()).ceil().elem())
454 .into_shared();
455
456 NdArrayTensor::new(array)
457 })
458 }
459
460 fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461 execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
462 let array = tensor
463 .array
464 .mapv_into(|a| erf(a.to_f64()).elem())
465 .into_shared();
466
467 NdArrayTensor::new(array)
468 })
469 }
470
471 fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
472 match &tensors[0] {
473 NdArrayTensorFloat::F32(_) => {
474 let tensors = tensors
475 .iter()
476 .map(|t| {
477 if let NdArrayTensorFloat::F32(tensor) = t {
478 tensor.array.view()
479 } else {
480 panic!("Concatenate data type mismatch (expected f32, got f64)")
481 }
482 })
483 .collect::<Vec<_>>();
484 NdArrayTensorFloat::F32(NdArrayOps::concatenate(&tensors, dim))
485 }
486 NdArrayTensorFloat::F64(_) => {
487 let tensors = tensors
488 .iter()
489 .map(|t| {
490 if let NdArrayTensorFloat::F64(tensor) = t {
491 tensor.array.view()
492 } else {
493 panic!("Concatenate data type mismatch (expected f64, got f32)")
494 }
495 })
496 .collect::<Vec<_>>();
497 NdArrayTensorFloat::F64(NdArrayOps::concatenate(&tensors, dim))
498 }
499 }
500 }
501
502 fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
503 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
504 tensor,
505 min.elem()
506 ))
507 }
508
509 fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
510 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
511 tensor,
512 max.elem()
513 ))
514 }
515
516 fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
517 execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
518 tensor,
519 min.elem(),
520 max.elem()
521 ))
522 }
523
524 fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor<I> {
525 execute_with_float_dtype!(tensor, E => |tensor: NdArrayTensor<E>| {
526 let array = tensor.array.mapv(|a| a.elem()).into_shared();
527 NdArrayTensor { array }
528 })
529 }
530
531 fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
532 execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
533 lhs,
534 rhs,
535 |a: &E, b: &E| a.powf(*b)
536 ))
537 }
538
539 fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
540 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
541 }
542
543 fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
544 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
545 }
546
547 fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
548 execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
549 }
550
551 fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
552 execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
553 }
554
555 fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
556 fn cast<E1: FloatNdArrayElement, E2: FloatNdArrayElement>(
557 tensor: &NdArrayTensor<E1>,
558 ) -> NdArrayTensor<E2> {
559 let array = tensor.array.mapv(|a| a.elem()).into_shared();
560 NdArrayTensor { array }
561 }
562
563 match (&tensor, dtype) {
564 (NdArrayTensorFloat::F32(_), FloatDType::F32)
566 | (NdArrayTensorFloat::F64(_), FloatDType::F64) => tensor,
567 (NdArrayTensorFloat::F32(tensor), FloatDType::F64) => {
569 NdArrayTensorFloat::F64(cast(tensor))
570 }
571 (NdArrayTensorFloat::F64(tensor), FloatDType::F32) => {
573 NdArrayTensorFloat::F32(cast(tensor))
574 }
575 _ => panic!("Invalid cast types"),
576 }
577 }
578}