Skip to main content

burn_tensor/tensor/api/
numeric.rs

1use burn_backend::Scalar;
2pub use burn_backend::tensor::Numeric;
3
4use crate::alloc::borrow::ToOwned;
5use alloc::vec::Vec;
6
7use crate::IndexingUpdateOp;
8use crate::{
9    AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
10    check, check::TensorCheck,
11};
12
13impl<B, const D: usize, K> Tensor<B, D, K>
14where
15    B: Backend,
16    K: Numeric<B>,
17    K::Elem: Element,
18{
19    /// Applies element wise addition operation.
20    ///
21    /// `y = x2 + x1`
22    ///
23    /// # Arguments
24    ///
25    /// * `other` - The tensor to add.
26    ///
27    /// # Example
28    ///
29    /// ```rust
30    /// use burn_tensor::backend::Backend;
31    /// use burn_tensor::{Tensor, Shape};
32    ///
33    /// fn example<B: Backend>() {
34    ///    let device = B::Device::default();
35    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
36    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
37    ///    let tensor = tensor1 + tensor2;
38    ///    println!("{tensor}");
39    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
40    /// }
41    /// ```
42    #[allow(clippy::should_implement_trait)]
43    pub fn add(self, other: Self) -> Self {
44        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
45        Self::new(K::add(self.primitive, other.primitive))
46    }
47
48    /// Applies element wise addition operation with a scalar.
49    ///
50    /// `y = x + s`
51    ///
52    /// # Arguments
53    ///
54    /// * `other` - The scalar to add, element wise.
55    ///
56    /// # Example
57    ///
58    /// ```rust
59    /// use burn_tensor::backend::Backend;
60    /// use burn_tensor::{Tensor, Shape};
61    ///
62    /// fn example<B: Backend>() {
63    ///   let device = B::Device::default();
64    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
65    ///   let scalar = 2.0;
66    ///   let tensor = tensor + scalar;
67    ///   println!("{tensor}");
68    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
69    /// }
70    /// ```
71    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
72        let other = Scalar::new(other, &self.dtype());
73        Self::new(K::add_scalar(self.primitive, other))
74    }
75
76    /// Applies element wise subtraction operation.
77    ///
78    /// `y = x2 - x1`
79    ///
80    /// # Arguments
81    ///
82    /// * `other` - The tensor to subtract.
83    ///
84    /// # Example
85    ///
86    /// ```rust
87    /// use burn_tensor::backend::Backend;
88    /// use burn_tensor::{Tensor, Shape};
89    ///
90    /// fn example<B: Backend>() {
91    ///   let device = B::Device::default();
92    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
93    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
94    ///   let tensor = tensor1 - tensor2;
95    ///   println!("{tensor}");
96    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
97    /// }
98    /// ```
99    #[allow(clippy::should_implement_trait)]
100    pub fn sub(self, other: Self) -> Self {
101        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
102        Self::new(K::sub(self.primitive, other.primitive))
103    }
104
105    /// Applies element wise subtraction operation with a scalar.
106    ///
107    /// `y = x - s`
108    ///
109    /// # Arguments
110    ///
111    /// * `other` - The scalar to subtract, element wise.
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use burn_tensor::backend::Backend;
117    /// use burn_tensor::{Tensor, Shape};
118    ///
119    /// fn example<B: Backend>() {
120    ///    let device = B::Device::default();
121    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
122    ///    let scalar = 2.0;
123    ///    let tensor = tensor - scalar;
124    ///    println!("{tensor}");
125    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
126    /// }
127    /// ```
128    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
129        let other = Scalar::new(other, &self.dtype());
130        Self::new(K::sub_scalar(self.primitive, other))
131    }
132
133    /// Applies element wise division operation.
134    ///
135    /// `y = x2 / x1`
136    ///
137    /// # Arguments
138    ///
139    /// * `other` - The tensor to divide.
140    ///
141    /// # Example
142    ///
143    /// ```rust
144    /// use burn_tensor::backend::Backend;
145    /// use burn_tensor::{Tensor, Shape};
146    ///
147    /// fn example<B: Backend>() {
148    ///    let device = B::Device::default();
149    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151    ///    let tensor = tensor1 / tensor2;
152    ///    println!("{tensor}");
153    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154    /// }
155    /// ```
156    #[allow(clippy::should_implement_trait)]
157    pub fn div(self, other: Self) -> Self {
158        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159        Self::new(K::div(self.primitive, other.primitive))
160    }
161
162    /// Applies element wise division operation with a scalar.
163    ///
164    /// `y = x / s`
165    ///
166    /// # Arguments
167    ///
168    /// * `other` - The scalar to divide, element wise.
169    ///
170    /// # Example
171    ///
172    /// ```rust
173    /// use burn_tensor::backend::Backend;
174    /// use burn_tensor::{Tensor, Shape};
175    ///
176    /// fn example<B: Backend>() {
177    ///    let device = B::Device::default();
178    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179    ///    let scalar = 2.0;
180    ///    let tensor = tensor / scalar;
181    ///    println!("{tensor}");
182    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183    /// }
184    /// ```
185    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186        let other = Scalar::new(other, &self.dtype());
187        Self::new(K::div_scalar(self.primitive, other))
188    }
189
190    /// Applies element wise the remainder operation with a scalar.
191    ///
192    /// `y = x2 % x1`
193    pub fn remainder(self, other: Self) -> Self {
194        Self::new(K::remainder(self.primitive, other.primitive))
195    }
196
197    /// Applies element wise the remainder operation with a scalar.
198    ///
199    /// `y = x % s`
200    ///
201    /// # Arguments
202    ///
203    /// * `other` - The scalar to divide, element wise.
204    ///
205    /// # Example
206    ///
207    /// ```rust
208    /// use burn_tensor::backend::Backend;
209    /// use burn_tensor::{Tensor, Shape};
210    ///
211    /// fn example<B: Backend>() {
212    ///    let device = B::Device::default();
213    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
214    ///    let scalar = 2.0;
215    ///    let tensor = tensor1 % scalar;
216    ///    println!("{tensor}");
217    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
218    /// }
219    /// ```
220    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
221        let other = Scalar::new(other, &self.dtype());
222        Self::new(K::remainder_scalar(self.primitive, other))
223    }
224
225    /// Applies element wise multiplication operation.
226    ///
227    /// `y = x2 * x1`
228    ///
229    /// # Arguments
230    ///
231    /// * `other` - The tensor to multiply.
232    ///
233    /// # Example
234    ///
235    /// ```rust
236    /// use burn_tensor::backend::Backend;
237    /// use burn_tensor::{Tensor, Shape};
238    ///
239    /// fn example<B: Backend>() {
240    ///    let device = B::Device::default();
241    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
243    ///    let tensor = tensor1 * tensor2;
244    ///    println!("{tensor}");
245    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
246    /// }
247    /// ```
248    #[allow(clippy::should_implement_trait)]
249    pub fn mul(self, other: Self) -> Self {
250        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
251        Self::new(K::mul(self.primitive, other.primitive))
252    }
253
254    /// Applies element wise multiplication operation with a scalar.
255    ///
256    /// `y = x * s`
257    ///
258    /// # Arguments
259    ///
260    /// * `other` - The scalar to multiply, element wise.
261    ///
262    /// # Example
263    ///
264    /// ```rust
265    /// use burn_tensor::backend::Backend;
266    /// use burn_tensor::{Tensor, Shape};
267    ///
268    /// fn example<B: Backend>() {
269    ///    let device = B::Device::default();
270    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271    ///    let scalar = 2.0;
272    ///    let tensor = tensor * scalar;
273    ///    println!("{tensor}");
274    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
275    /// }
276    /// ```
277    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
278        let other = Scalar::new(other, &self.dtype());
279        Self::new(K::mul_scalar(self.primitive, other))
280    }
281
282    /// Switch sign of each element in the tensor.
283    ///
284    /// `y = -x`
285    ///
286    /// # Example
287    ///
288    /// ```rust
289    /// use burn_tensor::backend::Backend;
290    /// use burn_tensor::{Tensor, Shape};
291    ///
292    /// fn example<B: Backend>() {
293    ///    let device = B::Device::default();
294    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
295    ///    let tensor = -tensor;
296    ///    println!("{tensor}");
297    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
298    /// }
299    /// ```
300    #[allow(clippy::should_implement_trait)]
301    pub fn neg(self) -> Self {
302        Self::new(K::neg(self.primitive))
303    }
304
305    /// Returns the signs of the elements of the input tensor.
306    ///
307    /// # Example
308    ///
309    /// ```rust
310    /// use burn_tensor::backend::Backend;
311    /// use burn_tensor::{Tensor, Shape};
312    ///
313    /// fn example<B: Backend>() {
314    ///    let device = B::Device::default();
315    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
316    ///    let tensor = tensor.sign();
317    ///    println!("{tensor}");
318    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
319    /// }
320    /// ```
321    pub fn sign(self) -> Self {
322        Self::new(K::sign(self.primitive))
323    }
324
325    /// Aggregate all elements in the tensor with the mean operation.
326    ///
327    /// # Example
328    ///
329    /// ```rust
330    /// use burn_tensor::backend::Backend;
331    /// use burn_tensor::{Tensor, Shape};
332    ///
333    /// fn example<B: Backend>() {
334    ///    let device = B::Device::default();
335    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
336    ///    let tensor = tensor.mean();
337    ///    println!("{tensor}");
338    ///    // [3.6666667]
339    /// }
340    /// ```
341    pub fn mean(self) -> Tensor<B, 1, K> {
342        Tensor::new(K::mean(self.primitive))
343    }
344
345    /// Aggregate all elements in the tensor with the sum operation.
346    ///
347    /// # Example
348    ///
349    /// ```rust
350    /// use burn_tensor::backend::Backend;
351    /// use burn_tensor::{Tensor, Shape};
352    ///
353    /// fn example<B: Backend>() {
354    ///   let device = B::Device::default();
355    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
356    ///   let tensor = tensor.sum();
357    ///   println!("{tensor}");
358    ///   // [22.0]
359    /// }
360    /// ```
361    pub fn sum(self) -> Tensor<B, 1, K> {
362        Tensor::new(K::sum(self.primitive))
363    }
364
365    /// Aggregate all elements along the given *dimension* or *axis*
366    /// in the tensor with the mean operation.
367    ///
368    /// # Arguments
369    ///
370    /// * `dim` - The dimension or axis along which to aggregate the elements;
371    ///   supports negative indexing.
372    ///
373    /// # Example
374    ///
375    /// ```rust
376    /// use burn_tensor::backend::Backend;
377    /// use burn_tensor::{Tensor, Shape};
378    ///
379    /// fn example<B: Backend>() {
380    ///   let device = B::Device::default();
381    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
382    ///   let tensor = tensor.clone().mean_dim(0);
383    ///   println!("{tensor}");
384    ///   // [[3.0, 3.5, 4.5]]
385    ///   let tensor = tensor.clone().mean_dim(1);
386    ///   println!("{tensor}");
387    ///   // [[0.6666667], [6.6666665]]
388    /// }
389    /// ```
390    pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
391        let dim = dim.expect_dim_index(D);
392        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
393        Self::new(K::mean_dim(self.primitive, dim))
394    }
395
396    /// Aggregate all elements along the given *axes*
397    /// in the tensor with the mean operation.
398    ///
399    /// # Arguments
400    ///
401    /// * `dims` - the dimensions to aggregate; supports negative indexing.
402    ///
403    /// # Returns
404    ///
405    /// The returned tensor will have the same rank,
406    /// but the aggregated dimensions will have size 1.
407    ///
408    /// # Example
409    ///
410    /// ```rust
411    /// use burn_tensor::backend::Backend;
412    /// use burn_tensor::{Tensor, Shape};
413    ///
414    /// fn example<B: Backend>() {
415    ///    let device = B::Device::default();
416    ///    let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
417    ///    let tensor = tensor.clone().mean_dims(&[0, 1]);
418    ///    println!("{tensor}");
419    ///    // [[2.0]]
420    /// }
421    /// ```
422    pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
423        dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
424    }
425
426    /// Aggregate all elements along the given *dimension* or *axis*
427    /// in the tensor with the sum operation.
428    ///
429    /// # Arguments
430    ///
431    /// * `dim` - The dimension or axis along which to aggregate the elements;
432    ///   supports negative indexing.
433    ///
434    /// # Example
435    ///
436    /// ```rust
437    /// use burn_tensor::backend::Backend;
438    /// use burn_tensor::{Tensor, Shape};
439    ///
440    /// fn example<B: Backend>() {
441    ///    let device = B::Device::default();
442    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
443    ///    let tensor = tensor.clone().sum_dim(0);
444    ///    println!("{tensor}");
445    ///    // [[6.0, 7.0, 9.0]]
446    ///    let tensor = tensor.clone().sum_dim(1);
447    ///    println!("{tensor}");
448    ///    // [[2.0], [20.0]]
449    /// }
450    /// ```
451    pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
452        let dim = dim.expect_dim_index(D);
453        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
454        Self::new(K::sum_dim(self.primitive, dim))
455    }
456
457    /// Aggregate all elements along the given *axes*
458    /// in the tensor with the sum operation.
459    ///
460    /// # Arguments
461    ///
462    /// * `dims` - the dimensions to aggregate; supports negative indexing.
463    ///
464    /// # Returns
465    ///
466    /// The returned tensor will have the same rank,
467    /// but the aggregated dimensions will have size 1.
468    ///
469    /// # Example
470    ///
471    /// ```rust
472    /// use burn_tensor::backend::Backend;
473    /// use burn_tensor::{Tensor, Shape};
474    ///
475    /// fn example<B: Backend>() {
476    ///    let device = B::Device::default();
477    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
478    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
479    ///    println!("{tensor}");
480    ///    // [[27]]
481    /// }
482    /// ```
483    pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
484        dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
485    }
486
487    /// Aggregate and squeeze along the given dimensions.
488    ///
489    /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
490    ///
491    /// # Arguments
492    ///
493    /// * `dims` - the dimensions to aggregate; supports negative indexing.
494    ///
495    /// # Returns
496    ///
497    /// The returned tensor will have the same rank,
498    /// but the aggregated dimensions will have size 1.
499    ///
500    /// # Example
501    ///
502    /// ```rust
503    /// use burn_tensor::backend::Backend;
504    /// use burn_tensor::{Tensor, Shape};
505    ///
506    /// fn example<B: Backend>() {
507    ///     let device = B::Device::default();
508    ///     let tensor = Tensor::<B, 3>::from_data([
509    ///         [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
510    ///         [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
511    ///     ], &device);
512    ///     let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
513    ///     println!("{tensor}");
514    ///     // [20.0, 16.0, 21.0]
515    /// }
516    /// ```
517    pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
518        // TODO: remove idims when squeeze_dims uses AsIndex.
519        let idims = dims
520            .iter()
521            .map(|&dim| (dim.expect_dim_index(D)) as isize)
522            .collect::<Vec<_>>();
523        self.sum_dims(dims).squeeze_dims::<D2>(&idims)
524    }
525
526    /// Aggregate all elements in the tensor with the product operation.
527    ///
528    /// # Example
529    ///
530    /// ```rust
531    /// use burn_tensor::backend::Backend;
532    /// use burn_tensor::{Tensor, Shape};
533    ///
534    /// fn example<B: Backend>() {
535    ///    let device = B::Device::default();
536    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
537    ///    let tensor = tensor.prod();
538    ///    println!("{tensor}");
539    ///    // [-1620.0]
540    /// }
541    /// ```
542    pub fn prod(self) -> Tensor<B, 1, K> {
543        Tensor::new(K::prod(self.primitive))
544    }
545
546    /// Aggregate all elements along the given *dimension* or *axis*
547    /// in the tensor with the product operation.
548    ///
549    /// # Arguments
550    ///
551    /// * `dim` - The dimension or axis along which to aggregate the elements,
552    ///   supports negative indexing.
553    ///
554    /// # Returns
555    ///
556    /// The returned tensor will have the same rank,
557    /// but the aggregated dimension will have size 1.
558    ///
559    /// # Example
560    ///
561    /// ```rust
562    /// use burn_tensor::backend::Backend;
563    /// use burn_tensor::{Tensor, Shape};
564    ///
565    /// fn example<B: Backend>() {
566    ///    let device = B::Device::default();
567    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
568    ///    let tensor = tensor.clone().prod_dim(0);
569    ///    println!("{tensor}");
570    ///    // [[5.0, -18.0, 18.0]]
571    ///    let tensor = tensor.clone().prod_dim(1);
572    ///    println!("{tensor}");
573    ///    // [[-6.0], [270.0]]
574    /// }
575    /// ```
576    pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
577        let dim = dim.expect_dim_index(D);
578        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
579        Self::new(K::prod_dim(self.primitive, dim))
580    }
581
582    /// Aggregate all elements along the given *axes*
583    /// in the tensor with the prod operation.
584    ///
585    /// # Arguments
586    ///
587    /// * `dims` - the dimensions to aggregate, supports negative indexing.
588    ///
589    /// # Returns
590    ///
591    /// The returned tensor will have the same rank,
592    /// but the aggregated dimensions will have size 1.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// use burn_tensor::backend::Backend;
598    /// use burn_tensor::{Tensor, Shape};
599    ///
600    /// fn example<B: Backend>() {
601    ///    let device = B::Device::default();
602    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
603    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
604    ///    println!("{tensor}");
605    ///    // [[-1620.0]]
606    /// }
607    /// ```
608    pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
609        dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
610    }
611
612    /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
613    ///
614    /// # Arguments
615    ///
616    /// * `dim` - The dimension or axis along which to compute the cumulative sum.
617    ///
618    /// # Example
619    ///
620    /// ```rust
621    /// use burn_tensor::backend::Backend;
622    /// use burn_tensor::{Tensor, Shape};
623    ///
624    /// fn example<B: Backend>() {
625    ///    let device = B::Device::default();
626    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
627    ///    let result = tensor.clone().cumsum(0);
628    ///    println!("{result}");
629    ///    // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
630    ///    let result = tensor.cumsum(1);
631    ///    println!("{result}");
632    ///    // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
633    /// }
634    /// ```
635    pub fn cumsum(self, dim: usize) -> Self {
636        check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
637        Self::new(K::cumsum(self.primitive, dim))
638    }
639
640    /// Computes the cumulative product of elements along the given *dimension* or *axis*.
641    ///
642    /// # Arguments
643    ///
644    /// * `dim` - The dimension or axis along which to compute the cumulative product.
645    ///
646    /// # Example
647    ///
648    /// ```rust
649    /// use burn_tensor::backend::Backend;
650    /// use burn_tensor::{Tensor, Shape};
651    ///
652    /// fn example<B: Backend>() {
653    ///    let device = B::Device::default();
654    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
655    ///    let result = tensor.clone().cumprod(0);
656    ///    println!("{result}");
657    ///    // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
658    ///    let result = tensor.cumprod(1);
659    ///    println!("{result}");
660    ///    // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
661    /// }
662    /// ```
663    pub fn cumprod(self, dim: usize) -> Self {
664        check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
665        Self::new(K::cumprod(self.primitive, dim))
666    }
667
668    /// Apply element wise absolute value operation.
669    ///
670    /// # Example
671    ///
672    /// ```rust
673    /// use burn_tensor::backend::Backend;
674    /// use burn_tensor::{Int, Tensor};
675    ///
676    /// fn example<B: Backend>() {
677    ///   let device = Default::default();
678    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
679    ///   let tensor = tensor.abs();
680    ///   println!("{tensor}");
681    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
682    /// }
683    /// ```
684    pub fn abs(self) -> Self {
685        Self::new(K::abs(self.primitive))
686    }
687
688    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
689    /// the other elements of the result tensor out are set to 0.
690    ///
691    /// See also [`triu_mask`](Tensor::triu_mask).
692    ///
693    /// # Arguments
694    ///
695    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
696    ///   towards the upper triangle.
697    ///
698    /// # Example
699    /// ```rust
700    /// use burn_tensor::backend::Backend;
701    /// use burn_tensor::{Int, Tensor};
702    ///
703    /// fn example<B: Backend>() {
704    ///    let device = Default::default();
705    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
706    ///        [
707    ///          [1, 2, 3],
708    ///          [4, 5, 6],
709    ///          [7, 8, 9]
710    ///        ],
711    ///        &device
712    ///    );
713    ///    let tensor = tensor.triu(1);
714    ///    println!("{tensor}");
715    ///    // [
716    ///    //   [0, 2, 3],
717    ///    //   [0, 0, 6],
718    ///    //   [0, 0, 0]
719    ///    // ]
720    /// }
721    /// ```
722    pub fn triu(self, diagonal: i64) -> Self {
723        check!(TensorCheck::tri::<{ D }>());
724
725        // last two dimensions
726        let shape = &self.shape().dims[D - 2..].to_owned();
727
728        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
729        self.mask_fill(mask, 0)
730    }
731
732    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
733    /// the other elements of the result tensor out are set to 0.
734    ///
735    /// See also [`tril_mask`](Tensor::tril_mask).
736    ///
737    /// # Arguments
738    ///
739    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
740    ///   towards the upper triangle.
741    ///
742    /// # Example
743    /// ```rust
744    /// use burn_tensor::backend::Backend;
745    /// use burn_tensor::{Int, Tensor};
746    ///
747    /// fn example<B: Backend>() {
748    ///    let device = Default::default();
749    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
750    ///        [
751    ///          [1, 2, 3],
752    ///          [4, 5, 6],
753    ///          [7, 8, 9]
754    ///        ],
755    ///        &device
756    ///    );
757    ///
758    ///    let tensor = tensor.tril(-1);
759    ///    println!("{tensor}");
760    ///    // [
761    ///    //   [0, 0, 0],
762    ///    //   [4, 0, 0],
763    ///    //   [7, 8, 0]
764    ///    // ]
765    /// }
766    /// ```
767    pub fn tril(self, diagonal: i64) -> Self {
768        check!(TensorCheck::tri::<{ D }>());
769
770        // last two dimensions
771        let shape = &self.shape().dims[D - 2..].to_owned();
772        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
773
774        self.mask_fill(mask, 0)
775    }
776
777    /// Applies element wise power operation with a float Tensor
778    ///
779    /// # Arguments
780    ///
781    /// * `other` - The tensor to apply the power operation with.
782    ///
783    /// # Example
784    ///
785    /// ```rust
786    /// use burn_tensor::backend::Backend;
787    /// use burn_tensor::{Tensor, Shape};
788    ///
789    /// fn example<B: Backend>() {
790    ///    let device = B::Device::default();
791    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
792    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
793    ///    let tensor = tensor1.powf(tensor2);
794    ///    println!("{tensor}");
795    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
796    /// }
797    /// ```
798    pub fn powf(self, other: Self) -> Self {
799        Self::new(K::powf(self.primitive, other.primitive))
800    }
801
802    /// Applies element wise power operation with a float scalar
803    ///
804    /// # Arguments
805    ///
806    /// * `other` - The scalar to apply the power operation with.
807    ///
808    /// # Example
809    ///
810    /// ```rust
811    /// use burn_tensor::backend::Backend;
812    /// use burn_tensor::{Tensor, Shape};
813    ///
814    /// fn example<B: Backend>() {
815    ///    let device = B::Device::default();
816    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
817    ///    let tensor = tensor.powf_scalar(2.0);
818    ///    println!("{tensor}");
819    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
820    /// }
821    /// ```
822    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
823        let other = Scalar::new(other, &self.dtype());
824        Self::new(K::powf_scalar(self.primitive, other))
825    }
826
827    /// Applies element wise power operation with a integer Tensor
828    ///
829    /// # Arguments
830    ///
831    /// * `other` - The tensor to apply the power operation with.
832    ///
833    /// # Example
834    ///
835    /// ```rust
836    /// use burn_tensor::backend::Backend;
837    /// use burn_tensor::{Tensor, Shape, Int};
838    ///
839    /// fn example<B: Backend>() {
840    ///    let device = B::Device::default();
841    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
842    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
843    ///    let tensor = tensor1.powi(tensor2);
844    ///    println!("{tensor}");
845    ///    // [[1, -8, 81], [5, 81, 216]]
846    /// }
847    /// ```
848    pub fn powi(self, other: Self) -> Self {
849        Self::new(K::powi(self.primitive, other.primitive))
850    }
851
852    /// Applies element wise power operation with a integer scalar
853    ///
854    /// # Arguments
855    ///
856    /// * `other` - The scalar to apply the power operation with.
857    ///
858    /// # Example
859    ///
860    /// ```rust
861    /// use burn_tensor::backend::Backend;
862    /// use burn_tensor::{Tensor, Shape, Int};
863    ///
864    /// fn example<B: Backend>() {
865    ///    let device = B::Device::default();
866    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
867    ///    let tensor = tensor.powi_scalar(2);
868    ///    println!("{tensor}");
869    ///
870    ///    // [[1, 4, 9], [25, 81, 36]]
871    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
872    ///    let tensor = tensor.powi_scalar(2);
873    ///    println!("{tensor}");
874    ///    // [[2.25, 4., 9.], [25., 81., 36.]]
875    /// }
876    /// ```
877    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
878        let other = Scalar::new(other, &self.dtype());
879        Self::new(K::powi_scalar(self.primitive, other))
880    }
881
882    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
883    ///
884    /// # Returns
885    ///
886    /// A boolean tensor with the same shape as the input tensor.
887    ///
888    /// # Example
889    ///
890    /// ```rust
891    /// use burn_tensor::backend::Backend;
892    /// use burn_tensor::{Tensor, Shape};
893    ///
894    /// fn example<B: Backend>() {
895    ///   let device = B::Device::default();
896    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
897    ///   let tensor = tensor.bool();
898    ///   println!("{tensor}");
899    ///   // [
900    ///   //   [true, true, true],
901    ///   //   [false, true, true]
902    ///   // ]
903    /// }
904    pub fn bool(self) -> Tensor<B, D, Bool> {
905        self.not_equal_elem(0)
906    }
907
908    /// Create a random tensor of the given shape on the given device where each element is
909    /// sampled from the given distribution.
910    ///
911    /// See also [`random_like`](Tensor::random_like).
912    ///
913    /// # Arguments
914    ///
915    /// * `shape` - The shape of the tensor.
916    /// * `distribution` - The distribution to sample from.
917    /// * `device` - The device to create the tensor on.
918    ///
919    /// # Returns
920    ///
921    /// A new tensor with the given shape and elements sampled from the given distribution.
922    ///
923    /// # Example
924    ///
925    /// ```rust
926    /// use burn_tensor::backend::Backend;
927    /// use burn_tensor::{Tensor, Shape, Distribution};
928    ///
929    /// fn example<B: Backend>() {
930    ///   let device = B::Device::default();
931    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
932    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
933    ///   println!("{tensor}");
934    ///   // [
935    ///   //   [0.08347523, 0.70498955, 0.60332155],
936    ///   //   [0.08173251, 0.18028641, 0.97942924]
937    ///   // ]
938    /// }
939    /// ```
940    pub fn random<S: Into<Shape>>(
941        shape: S,
942        distribution: Distribution,
943        device: &B::Device,
944    ) -> Self {
945        Self::new(K::random(shape.into(), distribution, device))
946    }
947
948    /// Applies the matrix multiplication operation.
949    ///
950    /// ```math
951    /// C = AB
952    /// ```
953    ///
954    /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as
955    /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general
956    /// matmul, which is often faster.
957    pub fn matmul(self, other: Self) -> Self {
958        check!(TensorCheck::matmul(&self, &other));
959
960        if D >= 3 {
961            let batch_index = D - 3;
962            let vector_index = D - 2;
963            let lhs_dims = &self.shape()[batch_index..D];
964            let rhs_dims = &other.shape()[batch_index..D];
965
966            if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims)
967                && k1 == k2
968            {
969                return Tensor::new(K::matmul(
970                    self.swap_dims(batch_index, vector_index).primitive,
971                    other.primitive,
972                ))
973                .swap_dims(batch_index, vector_index);
974            }
975        }
976
977        Tensor::new(K::matmul(self.primitive, other.primitive))
978    }
979}
980
981impl<B, K> Tensor<B, 1, K>
982where
983    B: Backend,
984    K: Numeric<B>,
985    K::Elem: Element,
986{
987    /// Calculates the dot product with another tensor.
988    ///
989    /// `y = x2.dot(x1)`
990    ///
991    /// # Arguments
992    ///
993    /// * `other` - The tensor to compute dot product with.
994    ///
995    /// # Notes
996    ///
997    /// Both tensors must have the same number of elements.
998    ///
999    /// # Example
1000    ///
1001    /// ```rust
1002    /// use burn_tensor::backend::Backend;
1003    /// use burn_tensor::{Tensor, Shape};
1004    ///
1005    /// fn example<B: Backend>() {
1006    ///    let device = B::Device::default();
1007    ///    let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
1008    ///    let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
1009    ///    let tensor = tensor1.dot(tensor2);
1010    ///    println!("{tensor}");
1011    ///    // [4]
1012    /// }
1013    /// ```
1014    pub fn dot(self, other: Self) -> Self {
1015        self.mul(other).sum()
1016    }
1017}
1018
1019impl<B, K> Tensor<B, 2, K>
1020where
1021    B: Backend,
1022    K: Numeric<B>,
1023    K::Elem: Element,
1024{
1025    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
1026    ///
1027    /// # Arguments
1028    ///
1029    /// * `size` - The size of the square matrix.
1030    pub fn eye(size: usize, device: &B::Device) -> Self {
1031        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
1032        let ones = Self::ones([1, size], device);
1033        let zeros = Self::zeros([size, size], device);
1034
1035        zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
1036    }
1037}
1038
1039// Tensor + tensor
1040impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
1041where
1042    K::Elem: Element,
1043{
1044    type Output = Self;
1045
1046    fn add(self, rhs: Self) -> Self::Output {
1047        Self::add(self, rhs)
1048    }
1049}
1050
1051// Tensor + scalar
1052impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
1053    for Tensor<B, D, K>
1054where
1055    K::Elem: Element,
1056{
1057    type Output = Self;
1058
1059    fn add(self, other: E) -> Self::Output {
1060        Tensor::add_scalar(self, other)
1061    }
1062}
1063
1064// Scalar + tensor
1065macro_rules! impl_tensor_scalar_add {
1066    ($($t:ty),*) => {
1067        $(
1068            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
1069            where
1070                K::Elem: Element,
1071            {
1072                type Output = Tensor<B, D, K>;
1073
1074                fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
1075                    Tensor::add_scalar(tensor, self)
1076                }
1077            }
1078        )*
1079    }
1080}
1081impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
1082
1083// Tensor - tensor
1084impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
1085where
1086    K::Elem: Element,
1087{
1088    type Output = Self;
1089
1090    fn sub(self, rhs: Self) -> Self::Output {
1091        Tensor::sub(self, rhs)
1092    }
1093}
1094
1095// Tensor - scalar
1096impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
1097    for Tensor<B, D, K>
1098where
1099    K::Elem: Element,
1100{
1101    type Output = Self;
1102
1103    fn sub(self, other: E) -> Self::Output {
1104        Tensor::sub_scalar(self, other)
1105    }
1106}
1107
1108// Scalar - tensor
1109macro_rules! impl_tensor_scalar_sub {
1110    ($($t:ty),*) => {
1111        $(
1112            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
1113            where
1114                K::Elem: Element,
1115            {
1116                type Output = Tensor<B, D, K>;
1117
1118                fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
1119                    Tensor::add_scalar(Tensor::neg(tensor), self)
1120                }
1121            }
1122        )*
1123    }
1124}
1125impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
1126
1127// Tensor / tensor
1128impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
1129where
1130    K::Elem: Element,
1131{
1132    type Output = Self;
1133
1134    fn div(self, rhs: Self) -> Self::Output {
1135        Tensor::div(self, rhs)
1136    }
1137}
1138
1139// Tensor / scalar
1140impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
1141    for Tensor<B, D, K>
1142where
1143    K::Elem: Element,
1144{
1145    type Output = Self;
1146
1147    fn div(self, other: E) -> Self::Output {
1148        Tensor::div_scalar(self, other)
1149    }
1150}
1151
1152// Scalar / tensor (float only)
1153macro_rules! impl_tensor_scalar_div {
1154    ($($t:ty),*) => {
1155        $(
1156            impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
1157            {
1158                type Output = Tensor<B, D>;
1159
1160                fn div(self, tensor: Tensor<B, D>) -> Self::Output {
1161                    tensor.recip().mul_scalar(self)
1162                }
1163            }
1164        )*
1165    }
1166}
1167
1168impl_tensor_scalar_div!(f32, f64);
1169
1170// Tensor % tensor.
1171impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
1172where
1173    K::Elem: Element,
1174{
1175    type Output = Self;
1176
1177    fn rem(self, rhs: Self) -> Self::Output {
1178        Tensor::remainder(self, rhs)
1179    }
1180}
1181
1182// Tensor % scalar.
1183impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
1184    for Tensor<B, D, K>
1185where
1186    K::Elem: Element,
1187{
1188    type Output = Self;
1189
1190    fn rem(self, other: E) -> Self::Output {
1191        Tensor::remainder_scalar(self, other)
1192    }
1193}
1194
1195// Tensor * tensor.
1196impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
1197where
1198    K::Elem: Element,
1199{
1200    type Output = Self;
1201
1202    fn mul(self, rhs: Self) -> Self::Output {
1203        Tensor::mul(self, rhs)
1204    }
1205}
1206
1207// Tensor * scalar.
1208impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
1209    for Tensor<B, D, K>
1210where
1211    K::Elem: Element,
1212{
1213    type Output = Self;
1214
1215    fn mul(self, other: E) -> Self::Output {
1216        Tensor::mul_scalar(self, other)
1217    }
1218}
1219
1220macro_rules! impl_tensor_scalar_mul {
1221    ($($t:ty),*) => {
1222        $(
1223            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
1224            where
1225                K::Elem: Element,
1226            {
1227                type Output = Tensor<B, D, K>;
1228
1229                fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
1230                    Tensor::mul_scalar(other, self)
1231                }
1232            }
1233        )*
1234    }
1235}
1236
1237impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
1238
1239impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
1240where
1241    B: Backend,
1242    K: Numeric<B>,
1243    K::Elem: Element,
1244{
1245    type Output = Self;
1246
1247    fn neg(self) -> Self::Output {
1248        Tensor::neg(self)
1249    }
1250}