burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_tensor::ops::FloatTensor;
4use burn_tensor::ops::InterpolateMode;
5use burn_tensor::{TensorMetadata, cast::ToElement};
6
7// Current crate
8use super::{
9    NdArrayMathOps, NdArrayOps,
10    matmul::{cross, matmul},
11};
12use crate::{
13    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
14};
15use crate::{NdArrayDevice, SEED};
16use crate::{
17    SharedArray,
18    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
19};
20use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
21
22// Workspace crates
23use crate::rand::get_seeded_rng;
24use burn_tensor::{Distribution, FloatDType};
25use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
26
27#[cfg(not(feature = "std"))]
28#[allow(unused_imports)]
29use num_traits::Float;
30
31use libm::erf;
32
33#[cfg(feature = "std")]
34#[allow(dead_code)]
35fn round_ties_even_wrapper(x: f64) -> f64 {
36    x.round_ties_even()
37}
38
39#[cfg(not(feature = "std"))]
40#[allow(dead_code)]
41fn round_ties_even_wrapper(x: f64) -> f64 {
42    if (x - x.floor()) == 0.5 {
43        (x * 0.5).round() * 2.0
44    } else {
45        x.round()
46    }
47}
48
49impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
50    for NdArray<E, I, Q>
51where
52    NdArrayTensor: From<SharedArray<E>>,
53    NdArrayTensor: From<SharedArray<I>>,
54{
55    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
56        NdArrayTensor::from_data(data)
57    }
58
59    fn float_random(
60        shape: Shape,
61        distribution: Distribution,
62        device: &NdArrayDevice,
63    ) -> FloatTensor<Self> {
64        let mut seed = SEED.lock().unwrap();
65        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
66            rng_seeded.clone()
67        } else {
68            get_seeded_rng()
69        };
70        let tensor = Self::float_from_data(
71            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
72            device,
73        );
74        *seed = Some(rng);
75        tensor
76    }
77
78    async fn float_into_data(tensor: FloatTensor<Self>) -> TensorData {
79        tensor.into_data()
80    }
81
82    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
83        NdArrayDevice::Cpu
84    }
85
86    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
87        tensor
88    }
89
90    fn float_empty(
91        shape: Shape,
92        device: &<NdArray<E> as Backend>::Device,
93        dtype: FloatDType,
94    ) -> FloatTensor<Self> {
95        Self::float_zeros(shape, device, dtype)
96    }
97
98    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
99        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
100    }
101
102    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
103        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
104    }
105
106    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
107        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
108    }
109
110    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
111        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
112    }
113
114    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
115        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
116    }
117
118    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
119        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
120    }
121
122    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
123        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
124    }
125
126    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
127        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
128    }
129
130    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
131        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
132    }
133
134    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
135        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
136    }
137
138    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
139        execute_with_float_dtype!((lhs, rhs), matmul)
140    }
141
142    fn float_cross(
143        lhs: FloatTensor<Self>,
144        rhs: FloatTensor<Self>,
145        dim: usize,
146    ) -> FloatTensor<Self> {
147        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
148    }
149
150    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
151        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
152    }
153
154    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
155        execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
156    }
157
158    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
159        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
160    }
161
162    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
163        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
164    }
165
166    fn float_gather(
167        dim: usize,
168        tensor: FloatTensor<Self>,
169        indices: NdArrayTensor,
170    ) -> FloatTensor<Self> {
171        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
172            execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
173                dim, tensor, indices
174            ))
175        })
176    }
177
178    fn float_scatter(
179        dim: usize,
180        tensor: FloatTensor<Self>,
181        indices: NdArrayTensor,
182        value: FloatTensor<Self>,
183    ) -> FloatTensor<Self> {
184        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
185            execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
186                dim, tensor, indices, value
187            ))
188        })
189    }
190
191    fn float_select(
192        tensor: FloatTensor<Self>,
193        dim: usize,
194        indices: NdArrayTensor,
195    ) -> FloatTensor<Self> {
196        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
197            execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
198                tensor, dim, indices
199            ))
200        })
201    }
202
203    fn float_select_assign(
204        tensor: FloatTensor<Self>,
205        dim: usize,
206        indices: NdArrayTensor,
207        value: FloatTensor<Self>,
208    ) -> FloatTensor<Self> {
209        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
210            execute_with_float_dtype!((tensor, value), |tensor, value| {
211                NdArrayMathOps::select_assign(tensor, dim, indices, value)
212            })
213        })
214    }
215
216    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
217        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
218    }
219
220    fn float_slice_assign(
221        tensor: FloatTensor<Self>,
222        slices: &[burn_tensor::Slice],
223        value: FloatTensor<Self>,
224    ) -> FloatTensor<Self> {
225        execute_with_float_dtype!((tensor, value), |tensor, value| {
226            NdArrayOps::slice_assign(tensor, slices, value)
227        })
228    }
229
230    fn float_mask_where(
231        tensor: FloatTensor<Self>,
232        mask: NdArrayTensor,
233        value: FloatTensor<Self>,
234    ) -> FloatTensor<Self> {
235        execute_with_float_dtype!((tensor, value), |tensor, value| {
236            NdArrayMathOps::mask_where(tensor, mask.bool(), value)
237        })
238    }
239
240    fn float_mask_fill(
241        tensor: FloatTensor<Self>,
242        mask: NdArrayTensor,
243        value: E,
244    ) -> FloatTensor<Self> {
245        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
246            tensor,
247            mask.bool(),
248            value.elem()
249        ))
250    }
251
252    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
253        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
254    }
255
256    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
257        execute_with_float_dtype!(lhs, |tensor| {
258            NdArrayMathOps::equal_elem(tensor, rhs.elem())
259        })
260    }
261
262    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
263        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
264    }
265
266    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
267        execute_with_float_dtype!(lhs, |tensor| {
268            NdArrayMathOps::greater_elem(tensor, rhs.elem())
269        })
270    }
271
272    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
273        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
274            NdArrayMathOps::greater_equal(lhs, rhs)
275        })
276    }
277
278    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
279        execute_with_float_dtype!(lhs, |tensor| {
280            NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
281        })
282    }
283
284    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
285        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
286    }
287
288    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
289        execute_with_float_dtype!(lhs, |tensor| {
290            NdArrayMathOps::lower_elem(tensor, rhs.elem())
291        })
292    }
293
294    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
295        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
296            NdArrayMathOps::lower_equal(lhs, rhs)
297        })
298    }
299
300    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
301        execute_with_float_dtype!(lhs, |tensor| {
302            NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
303        })
304    }
305
306    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
307        tensor
308    }
309
310    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
311        execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
312    }
313
314    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
315        execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
316    }
317
318    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
319        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
320    }
321
322    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
323        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
324    }
325
326    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
327        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
328    }
329
330    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
331        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
332    }
333
334    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
335        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
336    }
337
338    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
339        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
340    }
341
342    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
343        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
344    }
345
346    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
347        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
348    }
349
350    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
351        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
352            tensor.mapv_into(|a| a.exp_elem()).into_shared()
353        })
354    }
355
356    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
357        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
358            tensor.mapv_into(|a| a.log_elem()).into_shared()
359        })
360    }
361
362    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
363        execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
364    }
365
366    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
367        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
368    }
369
370    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
371        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
372            tensor.mapv_into(|a| a.log1p_elem()).into_shared()
373        })
374    }
375
376    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
377        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
378            tensor.mapv_into(|a| a.powf_elem(value)).into_shared()
379        })
380    }
381
382    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
383        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
384            tensor.mapv_into(|a| a.sqrt_elem()).into_shared()
385        })
386    }
387
388    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
389        execute_with_float_dtype!(tensor, E, NdArrayMathOps::abs)
390    }
391
392    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
393        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
394            tensor
395                .mapv_into(|a| (a.to_f64()).cos().elem())
396                .into_shared()
397        })
398    }
399
400    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
401        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
402            tensor
403                .mapv_into(|a| (a.to_f64()).sin().elem())
404                .into_shared()
405        })
406    }
407
408    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
409        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
410            tensor
411                .mapv_into(|a| (a.to_f64()).tanh().elem())
412                .into_shared()
413        })
414    }
415
416    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
417        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
418            tensor
419                .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
420                .into_shared()
421        })
422    }
423
424    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
425        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
426            tensor
427                .mapv_into(|a| (a.to_f64()).floor().elem())
428                .into_shared()
429        })
430    }
431
432    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
433        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
434            tensor
435                .mapv_into(|a| (a.to_f64()).ceil().elem())
436                .into_shared()
437        })
438    }
439
440    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
441        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
442            tensor
443                .mapv_into(|a| (a.to_f64()).trunc().elem())
444                .into_shared()
445        })
446    }
447
448    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
449        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
450            tensor.mapv_into(|a| erf(a.to_f64()).elem()).into_shared()
451        })
452    }
453
454    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
455        cat_with_dtype!(tensors, dim, [F64, F32])
456    }
457
458    fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
459        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
460            tensor,
461            min.elem()
462        ))
463    }
464
465    fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
466        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
467            tensor,
468            max.elem()
469        ))
470    }
471
472    fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
473        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
474            tensor,
475            min.elem(),
476            max.elem()
477        ))
478    }
479
480    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
481        execute_with_float_dtype!(tensor, |tensor: SharedArray<E>| {
482            tensor.mapv(|a| a.elem::<I>()).into_shared()
483        })
484    }
485
486    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
487        execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
488            lhs,
489            rhs,
490            |a: &E, b: &E| a.powf(*b)
491        ))
492    }
493
494    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
495        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
496    }
497
498    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
499        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
500    }
501
502    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
503        execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
504    }
505
506    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
507        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
508    }
509
510    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
511        execute_with_float_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
512    }
513
514    fn float_grid_sample_2d(
515        tensor: FloatTensor<Self>,
516        grid: FloatTensor<Self>,
517        method: InterpolateMode,
518    ) -> FloatTensor<Self> {
519        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
520            tensor, grid, method
521        ))
522    }
523
524    fn float_unfold(
525        tensor: FloatTensor<Self>,
526        dim: usize,
527        size: usize,
528        step: usize,
529    ) -> FloatTensor<Self> {
530        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
531    }
532}