burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_tensor::cast::ToElement;
4use burn_tensor::ops::FloatTensor;
5use core::ops::Range;
6
7// Current crate
8use super::{NdArrayMathOps, NdArrayOps, matmul::matmul};
9use crate::element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement};
10use crate::{NdArray, tensor::NdArrayTensor};
11use crate::{NdArrayDevice, NdArrayTensorFloat, SEED, execute_with_float_dtype};
12
13// Workspace crates
14use burn_common::rand::get_seeded_rng;
15use burn_tensor::{DType, Distribution, FloatDType};
16use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float;
21
22use libm::erf;
23
24#[cfg(feature = "std")]
25#[allow(dead_code)]
26fn round_ties_even_wrapper(x: f64) -> f64 {
27    x.round_ties_even()
28}
29
30#[cfg(not(feature = "std"))]
31#[allow(dead_code)]
32fn round_ties_even_wrapper(x: f64) -> f64 {
33    if (x - x.floor()) == 0.5 {
34        (x * 0.5).round() * 2.0
35    } else {
36        x.round()
37    }
38}
39
40impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
41    for NdArray<E, I, Q>
42{
43    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
44        match data.dtype {
45            DType::F64 => NdArrayTensorFloat::F64(NdArrayTensor::from_data(data)),
46            DType::F32 => NdArrayTensorFloat::F32(NdArrayTensor::from_data(data)),
47            _ => unimplemented!("Unsupported dtype for `float_from_data`"),
48        }
49    }
50
51    fn float_random(
52        shape: Shape,
53        distribution: Distribution,
54        device: &NdArrayDevice,
55    ) -> FloatTensor<Self> {
56        let mut seed = SEED.lock().unwrap();
57        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
58            rng_seeded.clone()
59        } else {
60            get_seeded_rng()
61        };
62        let tensor = Self::float_from_data(
63            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
64            device,
65        );
66        *seed = Some(rng);
67        tensor
68    }
69
70    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
71        match tensor {
72            NdArrayTensorFloat::F32(tensor) => NdArrayOps::into_data(tensor),
73            NdArrayTensorFloat::F64(tensor) => NdArrayOps::into_data(tensor),
74        }
75    }
76
77    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
78        NdArrayDevice::Cpu
79    }
80
81    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
82        tensor
83    }
84
85    fn float_empty(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> FloatTensor<Self> {
86        NdArray::<E>::float_zeros(shape, device)
87    }
88
89    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
90        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
91    }
92
93    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
94        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
95    }
96
97    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
98        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
99    }
100
101    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
102        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
103    }
104
105    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
106        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
107    }
108
109    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
110        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
111    }
112
113    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
114        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
115    }
116
117    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
118        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
119    }
120
121    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
122        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
123    }
124
125    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
126        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
127    }
128
129    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
130        execute_with_float_dtype!((lhs, rhs), matmul)
131    }
132
133    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
134        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
135    }
136
137    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
138        execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
139    }
140
141    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
142        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
143    }
144
145    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
146        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
147    }
148
149    fn float_gather(
150        dim: usize,
151        tensor: FloatTensor<Self>,
152        indices: NdArrayTensor<I>,
153    ) -> FloatTensor<Self> {
154        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
155            dim, tensor, indices
156        ))
157    }
158
159    fn float_scatter(
160        dim: usize,
161        tensor: FloatTensor<Self>,
162        indices: NdArrayTensor<I>,
163        value: FloatTensor<Self>,
164    ) -> FloatTensor<Self> {
165        execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
166            dim, tensor, indices, value
167        ))
168    }
169
170    fn float_select(
171        tensor: FloatTensor<Self>,
172        dim: usize,
173        indices: NdArrayTensor<I>,
174    ) -> FloatTensor<Self> {
175        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
176            tensor, dim, indices
177        ))
178    }
179
180    fn float_select_assign(
181        tensor: FloatTensor<Self>,
182        dim: usize,
183        indices: NdArrayTensor<I>,
184        value: FloatTensor<Self>,
185    ) -> FloatTensor<Self> {
186        execute_with_float_dtype!((tensor, value), |tensor, value| {
187            NdArrayMathOps::select_assign(tensor, dim, indices, value)
188        })
189    }
190
191    fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
192        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, ranges))
193    }
194
195    fn float_slice_assign(
196        tensor: FloatTensor<Self>,
197        ranges: &[Range<usize>],
198        value: FloatTensor<Self>,
199    ) -> FloatTensor<Self> {
200        execute_with_float_dtype!((tensor, value), |tensor, value| {
201            NdArrayOps::slice_assign(tensor, ranges, value)
202        })
203    }
204
205    fn float_mask_where(
206        tensor: FloatTensor<Self>,
207        mask: NdArrayTensor<bool>,
208        value: FloatTensor<Self>,
209    ) -> FloatTensor<Self> {
210        execute_with_float_dtype!((tensor, value), |tensor, value| {
211            NdArrayMathOps::mask_where(tensor, mask, value)
212        })
213    }
214
215    fn float_mask_fill(
216        tensor: FloatTensor<Self>,
217        mask: NdArrayTensor<bool>,
218        value: E,
219    ) -> FloatTensor<Self> {
220        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
221            tensor,
222            mask,
223            value.elem()
224        ))
225    }
226
227    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
228        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
229            NdArrayMathOps::equal(lhs, rhs)
230        })
231    }
232
233    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
234        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
235            NdArrayMathOps::equal_elem(tensor, rhs.elem())
236        })
237    }
238
239    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
240        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
241            NdArrayMathOps::greater(lhs, rhs)
242        })
243    }
244
245    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
246        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
247            NdArrayMathOps::greater_elem(tensor, rhs.elem())
248        })
249    }
250
251    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
252        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
253            NdArrayMathOps::greater_equal(lhs, rhs)
254        })
255    }
256
257    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
258        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
259            NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
260        })
261    }
262
263    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
264        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
265            NdArrayMathOps::lower(lhs, rhs)
266        })
267    }
268
269    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
270        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
271            NdArrayMathOps::lower_elem(tensor, rhs.elem())
272        })
273    }
274
275    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
276        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
277            NdArrayMathOps::lower_equal(lhs, rhs)
278        })
279    }
280
281    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
282        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
283            NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
284        })
285    }
286
287    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
288        tensor
289    }
290
291    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
292        execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
293    }
294
295    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
296        execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
297    }
298
299    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
300        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
301    }
302
303    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
304        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
305    }
306
307    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
308        execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmax(tensor, dim))
309    }
310
311    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
312        execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmin(tensor, dim))
313    }
314
315    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
316        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
317            let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared();
318
319            NdArrayTensor::new(array)
320        })
321    }
322
323    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
324        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
325            let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared();
326
327            NdArrayTensor::new(array)
328        })
329    }
330
331    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
332        execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
333    }
334
335    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
336        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
337    }
338
339    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
340        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
341            let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();
342
343            NdArrayTensor::new(array)
344        })
345    }
346
347    fn float_powf_scalar(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
348        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
349            let array = if value == 2.0 {
350                // Happens often and is faster.
351                tensor.array.mapv_into(|a| a * a).into_shared()
352            } else if value.floor() == value {
353                // Is faster then powf
354                tensor
355                    .array
356                    .mapv_into(|a| a.powi_elem(value as i32))
357                    .into_shared()
358            } else {
359                // Default
360                tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared()
361            };
362
363            NdArrayTensor::new(array)
364        })
365    }
366
367    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
368        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
369            let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared();
370
371            NdArrayTensor::new(array)
372        })
373    }
374
375    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
376        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
377            let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
378
379            NdArrayTensor::new(array)
380        })
381    }
382
383    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
385            let array = tensor
386                .array
387                .mapv_into(|a| (a.to_f64()).cos().elem())
388                .into_shared();
389
390            NdArrayTensor::new(array)
391        })
392    }
393
394    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
396            let array = tensor
397                .array
398                .mapv_into(|a| (a.to_f64()).sin().elem())
399                .into_shared();
400
401            NdArrayTensor::new(array)
402        })
403    }
404
405    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
407            let array = tensor
408                .array
409                .mapv_into(|a| (a.to_f64()).tanh().elem())
410                .into_shared();
411
412            NdArrayTensor::new(array)
413        })
414    }
415
416    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
418            let array = tensor
419                .array
420                .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
421                .into_shared();
422
423            NdArrayTensor::new(array)
424        })
425    }
426
427    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
428        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
429            let array = tensor
430                .array
431                .mapv_into(|a| (a.to_f64()).floor().elem())
432                .into_shared();
433
434            NdArrayTensor::new(array)
435        })
436    }
437
438    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
440            let array = tensor
441                .array
442                .mapv_into(|a| (a.to_f64()).ceil().elem())
443                .into_shared();
444
445            NdArrayTensor::new(array)
446        })
447    }
448
449    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
451            let array = tensor
452                .array
453                .mapv_into(|a| erf(a.to_f64()).elem())
454                .into_shared();
455
456            NdArrayTensor::new(array)
457        })
458    }
459
460    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
461        match &tensors[0] {
462            NdArrayTensorFloat::F32(_) => {
463                let tensors = tensors
464                    .iter()
465                    .map(|t| {
466                        if let NdArrayTensorFloat::F32(tensor) = t {
467                            tensor.array.view()
468                        } else {
469                            panic!("Concatenate data type mismatch (expected f32, got f64)")
470                        }
471                    })
472                    .collect::<Vec<_>>();
473                NdArrayTensorFloat::F32(NdArrayOps::concatenate(&tensors, dim))
474            }
475            NdArrayTensorFloat::F64(_) => {
476                let tensors = tensors
477                    .iter()
478                    .map(|t| {
479                        if let NdArrayTensorFloat::F64(tensor) = t {
480                            tensor.array.view()
481                        } else {
482                            panic!("Concatenate data type mismatch (expected f64, got f32)")
483                        }
484                    })
485                    .collect::<Vec<_>>();
486                NdArrayTensorFloat::F64(NdArrayOps::concatenate(&tensors, dim))
487            }
488        }
489    }
490
491    fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
492        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
493            tensor,
494            min.elem()
495        ))
496    }
497
498    fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
499        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
500            tensor,
501            max.elem()
502        ))
503    }
504
505    fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
506        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
507            tensor,
508            min.elem(),
509            max.elem()
510        ))
511    }
512
513    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor<I> {
514        execute_with_float_dtype!(tensor, E => |tensor: NdArrayTensor<E>| {
515            let array = tensor.array.mapv(|a| a.elem()).into_shared();
516            NdArrayTensor { array }
517        })
518    }
519
520    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
521        execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
522            lhs,
523            rhs,
524            |a: &E, b: &E| a.powf(*b)
525        ))
526    }
527
528    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
529        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
530    }
531
532    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
533        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
534    }
535
536    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
537        execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
538    }
539
540    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
541        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
542    }
543
544    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
545        fn cast<E1: FloatNdArrayElement, E2: FloatNdArrayElement>(
546            tensor: &NdArrayTensor<E1>,
547        ) -> NdArrayTensor<E2> {
548            let array = tensor.array.mapv(|a| a.elem()).into_shared();
549            NdArrayTensor { array }
550        }
551
552        match (&tensor, dtype) {
553            // No cast
554            (NdArrayTensorFloat::F32(_), FloatDType::F32)
555            | (NdArrayTensorFloat::F64(_), FloatDType::F64) => tensor,
556            // F32 to F64
557            (NdArrayTensorFloat::F32(tensor), FloatDType::F64) => {
558                NdArrayTensorFloat::F64(cast(tensor))
559            }
560            // F64 to F32
561            (NdArrayTensorFloat::F64(tensor), FloatDType::F32) => {
562                NdArrayTensorFloat::F32(cast(tensor))
563            }
564            _ => panic!("Invalid cast types"),
565        }
566    }
567}