burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_tensor::cast::ToElement;
4use burn_tensor::ops::FloatTensor;
5use core::ops::Range;
6use ndarray::Zip;
7
8// Current crate
9use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
10use crate::element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement};
11use crate::{execute_with_float_dtype, new_tensor_float, NdArrayDevice, NdArrayTensorFloat, SEED};
12use crate::{tensor::NdArrayTensor, NdArray};
13
14// Workspace crates
15use burn_common::rand::get_seeded_rng;
16use burn_tensor::{backend::Backend, ops::FloatTensorOps, ElementConversion, Shape, TensorData};
17use burn_tensor::{Distribution, FloatDType};
18
19#[cfg(not(feature = "std"))]
20#[allow(unused_imports)]
21use num_traits::Float;
22
23use libm::erf;
24
25#[cfg(feature = "std")]
26#[allow(dead_code)]
27fn round_ties_even_wrapper(x: f64) -> f64 {
28    x.round_ties_even()
29}
30
31#[cfg(not(feature = "std"))]
32#[allow(dead_code)]
33fn round_ties_even_wrapper(x: f64) -> f64 {
34    if (x - x.floor()) == 0.5 {
35        (x * 0.5).round() * 2.0
36    } else {
37        x.round()
38    }
39}
40
41impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
42    for NdArray<E, I, Q>
43{
44    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
45        new_tensor_float!(NdArrayTensor::from_data(data))
46    }
47
48    fn float_random(
49        shape: Shape,
50        distribution: Distribution,
51        device: &NdArrayDevice,
52    ) -> FloatTensor<Self> {
53        let mut seed = SEED.lock().unwrap();
54        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
55            rng_seeded.clone()
56        } else {
57            get_seeded_rng()
58        };
59        let tensor = Self::float_from_data(
60            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
61            device,
62        );
63        *seed = Some(rng);
64        tensor
65    }
66
67    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
68        match tensor {
69            NdArrayTensorFloat::F32(tensor) => NdArrayOps::into_data(tensor),
70            NdArrayTensorFloat::F64(tensor) => NdArrayOps::into_data(tensor),
71        }
72    }
73
74    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
75        NdArrayDevice::Cpu
76    }
77
78    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
79        tensor
80    }
81
82    fn float_empty(shape: Shape, device: &<NdArray<E> as Backend>::Device) -> FloatTensor<Self> {
83        NdArray::<E>::float_zeros(shape, device)
84    }
85
86    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
87        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
88    }
89
90    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
91        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
92    }
93
94    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
95        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
96    }
97
98    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
99        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
100    }
101
102    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
103        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
104    }
105
106    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
107        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
108    }
109
110    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
111        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
112    }
113
114    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
115        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
116    }
117
118    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
119        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
120    }
121
122    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
123        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
124    }
125
126    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
127        execute_with_float_dtype!((lhs, rhs), matmul)
128    }
129
130    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
131        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
132    }
133
134    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
135        execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
136    }
137
138    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
139        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
140    }
141
142    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
143        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
144    }
145
146    fn float_gather(
147        dim: usize,
148        tensor: FloatTensor<Self>,
149        indices: NdArrayTensor<I>,
150    ) -> FloatTensor<Self> {
151        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
152            dim, tensor, indices
153        ))
154    }
155
156    fn float_scatter(
157        dim: usize,
158        tensor: FloatTensor<Self>,
159        indices: NdArrayTensor<I>,
160        value: FloatTensor<Self>,
161    ) -> FloatTensor<Self> {
162        execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
163            dim, tensor, indices, value
164        ))
165    }
166
167    fn float_select(
168        tensor: FloatTensor<Self>,
169        dim: usize,
170        indices: NdArrayTensor<I>,
171    ) -> FloatTensor<Self> {
172        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
173            tensor, dim, indices
174        ))
175    }
176
177    fn float_select_assign(
178        tensor: FloatTensor<Self>,
179        dim: usize,
180        indices: NdArrayTensor<I>,
181        value: FloatTensor<Self>,
182    ) -> FloatTensor<Self> {
183        execute_with_float_dtype!((tensor, value), |tensor, value| {
184            NdArrayMathOps::select_assign(tensor, dim, indices, value)
185        })
186    }
187
188    fn float_slice(tensor: FloatTensor<Self>, ranges: &[Range<usize>]) -> FloatTensor<Self> {
189        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, ranges))
190    }
191
192    fn float_slice_assign(
193        tensor: FloatTensor<Self>,
194        ranges: &[Range<usize>],
195        value: FloatTensor<Self>,
196    ) -> FloatTensor<Self> {
197        execute_with_float_dtype!((tensor, value), |tensor, value| {
198            NdArrayOps::slice_assign(tensor, ranges, value)
199        })
200    }
201
202    fn float_mask_where(
203        tensor: FloatTensor<Self>,
204        mask: NdArrayTensor<bool>,
205        value: FloatTensor<Self>,
206    ) -> FloatTensor<Self> {
207        execute_with_float_dtype!((tensor, value), |tensor, value| {
208            NdArrayMathOps::mask_where(tensor, mask, value)
209        })
210    }
211
212    fn float_mask_fill(
213        tensor: FloatTensor<Self>,
214        mask: NdArrayTensor<bool>,
215        value: E,
216    ) -> FloatTensor<Self> {
217        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
218            tensor,
219            mask,
220            value.elem()
221        ))
222    }
223
224    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
225        execute_with_float_dtype!((lhs, rhs) => |lhs: NdArrayTensor<_>, rhs: NdArrayTensor<_>| {
226            let output = Zip::from(&lhs.array)
227                .and(&rhs.array)
228                .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
229                .into_shared();
230            NdArrayTensor::new(output)
231        })
232    }
233
234    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
235        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
236            let array = tensor.array.mapv(|a| a == rhs.elem::<E>()).into_shared();
237
238            NdArrayTensor::new(array)
239        })
240    }
241
242    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
243        let tensor = NdArray::<E>::float_sub(lhs, rhs);
244        let zero = 0.elem();
245        Self::float_greater_elem(tensor, zero)
246    }
247
248    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
249        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
250            let array = tensor.array.mapv(|a| a > rhs.elem::<E>()).into_shared();
251
252            NdArrayTensor::new(array)
253        })
254    }
255
256    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
257        let tensor = NdArray::<E>::float_sub(lhs, rhs);
258        let zero = 0.elem();
259        Self::float_greater_equal_elem(tensor, zero)
260    }
261
262    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
263        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
264            let array = tensor.array.mapv(|a| a >= rhs.elem::<E>()).into_shared();
265
266            NdArrayTensor::new(array)
267        })
268    }
269
270    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
271        let tensor = NdArray::<E>::float_sub(lhs, rhs);
272        let zero = 0.elem();
273        Self::float_lower_elem(tensor, zero)
274    }
275
276    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
277        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
278            let array = tensor.array.mapv(|a| a < rhs.elem::<E>()).into_shared();
279
280            NdArrayTensor::new(array)
281        })
282    }
283
284    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor<bool> {
285        let tensor = NdArray::<E>::float_sub(lhs, rhs);
286        let zero = 0.elem();
287        Self::float_lower_equal_elem(tensor, zero)
288    }
289
290    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor<bool> {
291        execute_with_float_dtype!(lhs, E => |tensor: NdArrayTensor<E>| {
292            let array = tensor.array.mapv(|a| a <= rhs.elem::<E>()).into_shared();
293
294            NdArrayTensor::new(array)
295        })
296    }
297
298    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
299        tensor
300    }
301
302    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
303        execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
304    }
305
306    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307        execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
308    }
309
310    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
311        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
312    }
313
314    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
315        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
316    }
317
318    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
319        execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmax(tensor, dim))
320    }
321
322    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor<I> {
323        execute_with_float_dtype!(tensor => |tensor| NdArrayMathOps::argmin(tensor, dim))
324    }
325
326    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
327        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
328            let array = tensor.array.mapv_into(|a| a.exp_elem()).into_shared();
329
330            NdArrayTensor::new(array)
331        })
332    }
333
334    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
335        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
336            let array = tensor.array.mapv_into(|a| a.log_elem()).into_shared();
337
338            NdArrayTensor::new(array)
339        })
340    }
341
342    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
343        execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
344    }
345
346    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
347        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
348    }
349
350    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
352            let array = tensor.array.mapv_into(|a| a.log1p_elem()).into_shared();
353
354            NdArrayTensor::new(array)
355        })
356    }
357
358    fn float_powf_scalar(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
359        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
360            let array = if value == 2.0 {
361                // Happens often and is faster.
362                tensor.array.mapv_into(|a| a * a).into_shared()
363            } else if value.floor() == value {
364                // Is faster then powf
365                tensor
366                    .array
367                    .mapv_into(|a| a.powi_elem(value as i32))
368                    .into_shared()
369            } else {
370                // Default
371                tensor.array.mapv_into(|a| a.powf_elem(value)).into_shared()
372            };
373
374            NdArrayTensor::new(array)
375        })
376    }
377
378    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
379        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
380            let array = tensor.array.mapv_into(|a| a.sqrt_elem()).into_shared();
381
382            NdArrayTensor::new(array)
383        })
384    }
385
386    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
387        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
388            let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
389
390            NdArrayTensor::new(array)
391        })
392    }
393
394    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
395        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
396            let array = tensor
397                .array
398                .mapv_into(|a| (a.to_f64()).cos().elem())
399                .into_shared();
400
401            NdArrayTensor::new(array)
402        })
403    }
404
405    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
406        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
407            let array = tensor
408                .array
409                .mapv_into(|a| (a.to_f64()).sin().elem())
410                .into_shared();
411
412            NdArrayTensor::new(array)
413        })
414    }
415
416    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
418            let array = tensor
419                .array
420                .mapv_into(|a| (a.to_f64()).tanh().elem())
421                .into_shared();
422
423            NdArrayTensor::new(array)
424        })
425    }
426
427    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
428        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
429            let array = tensor
430                .array
431                .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
432                .into_shared();
433
434            NdArrayTensor::new(array)
435        })
436    }
437
438    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
439        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
440            let array = tensor
441                .array
442                .mapv_into(|a| (a.to_f64()).floor().elem())
443                .into_shared();
444
445            NdArrayTensor::new(array)
446        })
447    }
448
449    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
451            let array = tensor
452                .array
453                .mapv_into(|a| (a.to_f64()).ceil().elem())
454                .into_shared();
455
456            NdArrayTensor::new(array)
457        })
458    }
459
460    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
461        execute_with_float_dtype!(tensor, E, |tensor: NdArrayTensor<E>| {
462            let array = tensor
463                .array
464                .mapv_into(|a| erf(a.to_f64()).elem())
465                .into_shared();
466
467            NdArrayTensor::new(array)
468        })
469    }
470
471    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
472        match &tensors[0] {
473            NdArrayTensorFloat::F32(_) => {
474                let tensors = tensors
475                    .iter()
476                    .map(|t| {
477                        if let NdArrayTensorFloat::F32(tensor) = t {
478                            tensor.array.view()
479                        } else {
480                            panic!("Concatenate data type mismatch (expected f32, got f64)")
481                        }
482                    })
483                    .collect::<Vec<_>>();
484                NdArrayTensorFloat::F32(NdArrayOps::concatenate(&tensors, dim))
485            }
486            NdArrayTensorFloat::F64(_) => {
487                let tensors = tensors
488                    .iter()
489                    .map(|t| {
490                        if let NdArrayTensorFloat::F64(tensor) = t {
491                            tensor.array.view()
492                        } else {
493                            panic!("Concatenate data type mismatch (expected f64, got f32)")
494                        }
495                    })
496                    .collect::<Vec<_>>();
497                NdArrayTensorFloat::F64(NdArrayOps::concatenate(&tensors, dim))
498            }
499        }
500    }
501
502    fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
503        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
504            tensor,
505            min.elem()
506        ))
507    }
508
509    fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
510        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
511            tensor,
512            max.elem()
513        ))
514    }
515
516    fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
517        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
518            tensor,
519            min.elem(),
520            max.elem()
521        ))
522    }
523
524    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor<I> {
525        execute_with_float_dtype!(tensor, E => |tensor: NdArrayTensor<E>| {
526            let array = tensor.array.mapv(|a| a.elem()).into_shared();
527            NdArrayTensor { array }
528        })
529    }
530
531    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
532        execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
533            lhs,
534            rhs,
535            |a: &E, b: &E| a.powf(*b)
536        ))
537    }
538
539    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
540        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
541    }
542
543    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
544        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
545    }
546
547    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
548        execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
549    }
550
551    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
552        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
553    }
554
555    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
556        fn cast<E1: FloatNdArrayElement, E2: FloatNdArrayElement>(
557            tensor: &NdArrayTensor<E1>,
558        ) -> NdArrayTensor<E2> {
559            let array = tensor.array.mapv(|a| a.elem()).into_shared();
560            NdArrayTensor { array }
561        }
562
563        match (&tensor, dtype) {
564            // No cast
565            (NdArrayTensorFloat::F32(_), FloatDType::F32)
566            | (NdArrayTensorFloat::F64(_), FloatDType::F64) => tensor,
567            // F32 to F64
568            (NdArrayTensorFloat::F32(tensor), FloatDType::F64) => {
569                NdArrayTensorFloat::F64(cast(tensor))
570            }
571            // F64 to F32
572            (NdArrayTensorFloat::F64(tensor), FloatDType::F32) => {
573                NdArrayTensorFloat::F32(cast(tensor))
574            }
575            _ => panic!("Invalid cast types"),
576        }
577    }
578}