burn_ndarray/ops/
tensor.rs

1// Language
2use alloc::vec::Vec;
3use burn_tensor::backend::ExecutionError;
4use burn_tensor::ops::FloatTensor;
5use burn_tensor::ops::GridSampleOptions;
6use burn_tensor::{TensorMetadata, cast::ToElement};
7
8// Current crate
9use super::{
10    NdArrayMathOps, NdArrayOps,
11    matmul::{cross, matmul},
12};
13use crate::{
14    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,
15};
16use crate::{NdArrayDevice, SEED};
17use crate::{
18    SharedArray,
19    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},
20};
21use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};
22
23// Workspace crates
24use crate::rand::get_seeded_rng;
25use burn_tensor::{Distribution, FloatDType};
26use burn_tensor::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};
27
28#[cfg(not(feature = "std"))]
29#[allow(unused_imports)]
30use num_traits::Float;
31
32use libm::erf;
33
34#[cfg(feature = "std")]
35#[allow(dead_code)]
36fn round_ties_even_wrapper(x: f64) -> f64 {
37    x.round_ties_even()
38}
39
40#[cfg(not(feature = "std"))]
41#[allow(dead_code)]
42fn round_ties_even_wrapper(x: f64) -> f64 {
43    if (x - x.floor()) == 0.5 {
44        (x * 0.5).round() * 2.0
45    } else {
46        x.round()
47    }
48}
49
50impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
51    for NdArray<E, I, Q>
52where
53    NdArrayTensor: From<SharedArray<E>>,
54    NdArrayTensor: From<SharedArray<I>>,
55{
56    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {
57        NdArrayTensor::from_data(data)
58    }
59
60    fn float_random(
61        shape: Shape,
62        distribution: Distribution,
63        device: &NdArrayDevice,
64    ) -> FloatTensor<Self> {
65        let mut seed = SEED.lock().unwrap();
66        let mut rng = if let Some(rng_seeded) = seed.as_ref() {
67            rng_seeded.clone()
68        } else {
69            get_seeded_rng()
70        };
71        let tensor = Self::float_from_data(
72            TensorData::random::<E, _, _>(shape, distribution, &mut rng),
73            device,
74        );
75        *seed = Some(rng);
76        tensor
77    }
78
79    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {
80        Ok(tensor.into_data())
81    }
82
83    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {
84        NdArrayDevice::Cpu
85    }
86
87    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {
88        tensor
89    }
90
91    fn float_empty(
92        shape: Shape,
93        device: &<NdArray<E> as Backend>::Device,
94        dtype: FloatDType,
95    ) -> FloatTensor<Self> {
96        Self::float_zeros(shape, device, dtype)
97    }
98
99    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
100        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)
101    }
102
103    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
104        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::add_scalar(lhs, rhs.elem()))
105    }
106
107    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
108        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)
109    }
110
111    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
112        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::sub_scalar(lhs, rhs.elem()))
113    }
114
115    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
116        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)
117    }
118
119    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
120        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::mul_scalar(lhs, rhs.elem()))
121    }
122
123    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
124        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)
125    }
126
127    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
128        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::div_scalar(lhs, rhs.elem()))
129    }
130
131    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
132        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)
133    }
134
135    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: E) -> FloatTensor<Self> {
136        execute_with_float_dtype!(lhs, |lhs| NdArrayMathOps::remainder_scalar(lhs, rhs.elem()))
137    }
138
139    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
140        execute_with_float_dtype!((lhs, rhs), matmul)
141    }
142
143    fn float_cross(
144        lhs: FloatTensor<Self>,
145        rhs: FloatTensor<Self>,
146        dim: usize,
147    ) -> FloatTensor<Self> {
148        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))
149    }
150
151    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
152        Self::float_mul_scalar(tensor, (-1f32).elem::<E>())
153    }
154
155    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
156        execute_with_float_dtype!(tensor, NdArrayMathOps::recip)
157    }
158
159    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {
160        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::swap_dims(tensor, dim1, dim2))
161    }
162
163    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
164        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::reshape(tensor, shape))
165    }
166
167    fn float_gather(
168        dim: usize,
169        tensor: FloatTensor<Self>,
170        indices: NdArrayTensor,
171    ) -> FloatTensor<Self> {
172        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
173            execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::gather(
174                dim, tensor, indices
175            ))
176        })
177    }
178
179    fn float_scatter_add(
180        dim: usize,
181        tensor: FloatTensor<Self>,
182        indices: NdArrayTensor,
183        value: FloatTensor<Self>,
184    ) -> FloatTensor<Self> {
185        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
186            execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayMathOps::scatter(
187                dim, tensor, indices, value
188            ))
189        })
190    }
191
192    fn float_select(
193        tensor: FloatTensor<Self>,
194        dim: usize,
195        indices: NdArrayTensor,
196    ) -> FloatTensor<Self> {
197        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
198            execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::select(
199                tensor, dim, indices
200            ))
201        })
202    }
203
204    fn float_select_add(
205        tensor: FloatTensor<Self>,
206        dim: usize,
207        indices: NdArrayTensor,
208        value: FloatTensor<Self>,
209    ) -> FloatTensor<Self> {
210        execute_with_int_dtype!(indices, I, |indices| -> NdArrayTensor {
211            execute_with_float_dtype!((tensor, value), |tensor, value| {
212                NdArrayMathOps::select_assign(tensor, dim, indices, value)
213            })
214        })
215    }
216
217    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_tensor::Slice]) -> FloatTensor<Self> {
218        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::slice(tensor, slices))
219    }
220
221    fn float_slice_assign(
222        tensor: FloatTensor<Self>,
223        slices: &[burn_tensor::Slice],
224        value: FloatTensor<Self>,
225    ) -> FloatTensor<Self> {
226        execute_with_float_dtype!((tensor, value), |tensor, value| {
227            NdArrayOps::slice_assign(tensor, slices, value)
228        })
229    }
230
231    fn float_mask_where(
232        tensor: FloatTensor<Self>,
233        mask: NdArrayTensor,
234        value: FloatTensor<Self>,
235    ) -> FloatTensor<Self> {
236        execute_with_float_dtype!((tensor, value), |tensor, value| {
237            NdArrayMathOps::mask_where(tensor, mask.bool(), value)
238        })
239    }
240
241    fn float_mask_fill(
242        tensor: FloatTensor<Self>,
243        mask: NdArrayTensor,
244        value: E,
245    ) -> FloatTensor<Self> {
246        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mask_fill(
247            tensor,
248            mask.bool(),
249            value.elem()
250        ))
251    }
252
253    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
254        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })
255    }
256
257    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
258        execute_with_float_dtype!(lhs, |tensor| {
259            NdArrayMathOps::equal_elem(tensor, rhs.elem())
260        })
261    }
262
263    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
264        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })
265    }
266
267    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
268        execute_with_float_dtype!(lhs, |tensor| {
269            NdArrayMathOps::greater_elem(tensor, rhs.elem())
270        })
271    }
272
273    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
274        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
275            NdArrayMathOps::greater_equal(lhs, rhs)
276        })
277    }
278
279    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
280        execute_with_float_dtype!(lhs, |tensor| {
281            NdArrayMathOps::greater_equal_elem(tensor, rhs.elem())
282        })
283    }
284
285    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
286        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })
287    }
288
289    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
290        execute_with_float_dtype!(lhs, |tensor| {
291            NdArrayMathOps::lower_elem(tensor, rhs.elem())
292        })
293    }
294
295    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {
296        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {
297            NdArrayMathOps::lower_equal(lhs, rhs)
298        })
299    }
300
301    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: E) -> NdArrayTensor {
302        execute_with_float_dtype!(lhs, |tensor| {
303            NdArrayMathOps::lower_equal_elem(tensor, rhs.elem())
304        })
305    }
306
307    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
308        tensor
309    }
310
311    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
312        execute_with_float_dtype!(tensor, NdArrayMathOps::mean)
313    }
314
315    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
316        execute_with_float_dtype!(tensor, NdArrayMathOps::sum)
317    }
318
319    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
320        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::mean_dim(tensor, dim))
321    }
322
323    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
324        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
325    }
326
327    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
328        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumprod(tensor, dim))
329    }
330
331    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
332        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
333    }
334
335    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
336        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummax(tensor, dim))
337    }
338
339    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
340        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim))
341    }
342
343    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
344        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmax::<I>(tensor, dim))
345    }
346
347    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {
348        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::argmin::<I>(tensor, dim))
349    }
350
351    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
352        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
353            tensor.mapv_into(|a| a.exp_elem()).into_shared()
354        })
355    }
356
357    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
358        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
359            tensor.mapv_into(|a| a.log_elem()).into_shared()
360        })
361    }
362
363    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
364        execute_with_float_dtype!(tensor, NdArrayMathOps::prod)
365    }
366
367    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
368        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::prod_dim(tensor, dim))
369    }
370
371    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
372        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
373            tensor.mapv_into(|a| a.log1p_elem()).into_shared()
374        })
375    }
376
377    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: f32) -> FloatTensor<Self> {
378        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
379            tensor.mapv_into(|a| a.powf_elem(value)).into_shared()
380        })
381    }
382
383    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
384        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
385            tensor.mapv_into(|a| a.sqrt_elem()).into_shared()
386        })
387    }
388
389    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
390        execute_with_float_dtype!(tensor, E, NdArrayMathOps::abs)
391    }
392
393    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
394        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
395            tensor
396                .mapv_into(|a| (a.to_f64()).cos().elem())
397                .into_shared()
398        })
399    }
400
401    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
402        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
403            tensor
404                .mapv_into(|a| (a.to_f64()).sin().elem())
405                .into_shared()
406        })
407    }
408
409    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
410        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
411            tensor
412                .mapv_into(|a| (a.to_f64()).tanh().elem())
413                .into_shared()
414        })
415    }
416
417    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
418        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
419            tensor
420                .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
421                .into_shared()
422        })
423    }
424
425    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
426        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
427            tensor
428                .mapv_into(|a| (a.to_f64()).floor().elem())
429                .into_shared()
430        })
431    }
432
433    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
434        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
435            tensor
436                .mapv_into(|a| (a.to_f64()).ceil().elem())
437                .into_shared()
438        })
439    }
440
441    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
442        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
443            tensor
444                .mapv_into(|a| (a.to_f64()).trunc().elem())
445                .into_shared()
446        })
447    }
448
449    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
450        execute_with_float_dtype!(tensor, E, |tensor: SharedArray<E>| {
451            tensor.mapv_into(|a| erf(a.to_f64()).elem()).into_shared()
452        })
453    }
454
455    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {
456        cat_with_dtype!(tensors, dim, [F64, F32])
457    }
458
459    fn float_clamp_min(tensor: FloatTensor<Self>, min: E) -> FloatTensor<Self> {
460        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_min(
461            tensor,
462            min.elem()
463        ))
464    }
465
466    fn float_clamp_max(tensor: FloatTensor<Self>, max: E) -> FloatTensor<Self> {
467        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp_max(
468            tensor,
469            max.elem()
470        ))
471    }
472
473    fn float_clamp(tensor: FloatTensor<Self>, min: E, max: E) -> FloatTensor<Self> {
474        execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::clamp(
475            tensor,
476            min.elem(),
477            max.elem()
478        ))
479    }
480
481    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {
482        execute_with_float_dtype!(tensor, |tensor: SharedArray<E>| {
483            tensor.mapv(|a| a.elem::<I>()).into_shared()
484        })
485    }
486
487    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
488        execute_with_float_dtype!((lhs, rhs), E, |lhs, rhs| NdArrayMathOps::elementwise_op(
489            lhs,
490            rhs,
491            |a: &E, b: &E| a.powf(*b)
492        ))
493    }
494
495    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
496        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::permute(tensor, axes))
497    }
498
499    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {
500        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::flip(tensor, axes))
501    }
502
503    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
504        execute_with_float_dtype!(tensor, NdArrayMathOps::sign_op)
505    }
506
507    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
508        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::expand(tensor, shape))
509    }
510
511    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
512        execute_with_float_dtype!(tensor, |tensor| cast_to_dtype(tensor, dtype.into()))
513    }
514
515    fn float_grid_sample_2d(
516        tensor: FloatTensor<Self>,
517        grid: FloatTensor<Self>,
518        options: GridSampleOptions,
519    ) -> FloatTensor<Self> {
520        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(
521            tensor, grid, options
522        ))
523    }
524
525    fn float_unfold(
526        tensor: FloatTensor<Self>,
527        dim: usize,
528        size: usize,
529        step: usize,
530    ) -> FloatTensor<Self> {
531        execute_with_float_dtype!(tensor, |tensor| NdArrayOps::unfold(tensor, dim, size, step))
532    }
533}