Skip to main content

burn_tensor/tensor/api/
numeric.rs

1use burn_backend::Scalar;
2pub use burn_backend::tensor::Numeric;
3
4use crate::alloc::borrow::ToOwned;
5use alloc::vec::Vec;
6
7use crate::IndexingUpdateOp;
8use crate::{
9    AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
10    check, check::TensorCheck,
11};
12
13impl<B, const D: usize, K> Tensor<B, D, K>
14where
15    B: Backend,
16    K: Numeric<B>,
17    K::Elem: Element,
18{
19    /// Applies element wise addition operation.
20    ///
21    /// `y = x2 + x1`
22    ///
23    /// # Arguments
24    ///
25    /// * `other` - The tensor to add.
26    ///
27    /// # Example
28    ///
29    /// ```rust
30    /// use burn_tensor::backend::Backend;
31    /// use burn_tensor::{Tensor, Shape};
32    ///
33    /// fn example<B: Backend>() {
34    ///    let device = B::Device::default();
35    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
36    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
37    ///    let tensor = tensor1 + tensor2;
38    ///    println!("{tensor}");
39    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
40    /// }
41    /// ```
42    #[allow(clippy::should_implement_trait)]
43    pub fn add(self, other: Self) -> Self {
44        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
45        Self::new(K::add(self.primitive, other.primitive))
46    }
47
48    /// Applies element wise addition operation with a scalar.
49    ///
50    /// `y = x + s`
51    ///
52    /// # Arguments
53    ///
54    /// * `other` - The scalar to add, element wise.
55    ///
56    /// # Example
57    ///
58    /// ```rust
59    /// use burn_tensor::backend::Backend;
60    /// use burn_tensor::{Tensor, Shape};
61    ///
62    /// fn example<B: Backend>() {
63    ///   let device = B::Device::default();
64    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
65    ///   let scalar = 2.0;
66    ///   let tensor = tensor + scalar;
67    ///   println!("{tensor}");
68    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
69    /// }
70    /// ```
71    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
72        let other = Scalar::new(other, &self.dtype());
73        Self::new(K::add_scalar(self.primitive, other))
74    }
75
76    /// Applies element wise subtraction operation.
77    ///
78    /// `y = x2 - x1`
79    ///
80    /// # Arguments
81    ///
82    /// * `other` - The tensor to subtract.
83    ///
84    /// # Example
85    ///
86    /// ```rust
87    /// use burn_tensor::backend::Backend;
88    /// use burn_tensor::{Tensor, Shape};
89    ///
90    /// fn example<B: Backend>() {
91    ///   let device = B::Device::default();
92    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
93    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
94    ///   let tensor = tensor1 - tensor2;
95    ///   println!("{tensor}");
96    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
97    /// }
98    /// ```
99    #[allow(clippy::should_implement_trait)]
100    pub fn sub(self, other: Self) -> Self {
101        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
102        Self::new(K::sub(self.primitive, other.primitive))
103    }
104
105    /// Applies element wise subtraction operation with a scalar.
106    ///
107    /// `y = x - s`
108    ///
109    /// # Arguments
110    ///
111    /// * `other` - The scalar to subtract, element wise.
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use burn_tensor::backend::Backend;
117    /// use burn_tensor::{Tensor, Shape};
118    ///
119    /// fn example<B: Backend>() {
120    ///    let device = B::Device::default();
121    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
122    ///    let scalar = 2.0;
123    ///    let tensor = tensor - scalar;
124    ///    println!("{tensor}");
125    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
126    /// }
127    /// ```
128    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
129        let other = Scalar::new(other, &self.dtype());
130        Self::new(K::sub_scalar(self.primitive, other))
131    }
132
133    /// Applies element wise division operation.
134    ///
135    /// `y = x2 / x1`
136    ///
137    /// # Arguments
138    ///
139    /// * `other` - The tensor to divide.
140    ///
141    /// # Example
142    ///
143    /// ```rust
144    /// use burn_tensor::backend::Backend;
145    /// use burn_tensor::{Tensor, Shape};
146    ///
147    /// fn example<B: Backend>() {
148    ///    let device = B::Device::default();
149    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151    ///    let tensor = tensor1 / tensor2;
152    ///    println!("{tensor}");
153    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154    /// }
155    /// ```
156    #[allow(clippy::should_implement_trait)]
157    pub fn div(self, other: Self) -> Self {
158        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159        Self::new(K::div(self.primitive, other.primitive))
160    }
161
162    /// Applies element wise division operation with a scalar.
163    ///
164    /// `y = x / s`
165    ///
166    /// # Arguments
167    ///
168    /// * `other` - The scalar to divide, element wise.
169    ///
170    /// # Example
171    ///
172    /// ```rust
173    /// use burn_tensor::backend::Backend;
174    /// use burn_tensor::{Tensor, Shape};
175    ///
176    /// fn example<B: Backend>() {
177    ///    let device = B::Device::default();
178    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179    ///    let scalar = 2.0;
180    ///    let tensor = tensor / scalar;
181    ///    println!("{tensor}");
182    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183    /// }
184    /// ```
185    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186        let other = Scalar::new(other, &self.dtype());
187        Self::new(K::div_scalar(self.primitive, other))
188    }
189
190    /// Applies element wise the remainder operation with a scalar.
191    ///
192    /// `y = x2 % x1`
193    pub fn remainder(self, other: Self) -> Self {
194        Self::new(K::remainder(self.primitive, other.primitive))
195    }
196
197    /// Applies element wise the remainder operation with a scalar.
198    ///
199    /// `y = x % s`
200    ///
201    /// # Arguments
202    ///
203    /// * `other` - The scalar to divide, element wise.
204    ///
205    /// # Example
206    ///
207    /// ```rust
208    /// use burn_tensor::backend::Backend;
209    /// use burn_tensor::{Tensor, Shape};
210    ///
211    /// fn example<B: Backend>() {
212    ///    let device = B::Device::default();
213    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
214    ///    let scalar = 2.0;
215    ///    let tensor = tensor1 % scalar;
216    ///    println!("{tensor}");
217    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
218    /// }
219    /// ```
220    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
221        let other = Scalar::new(other, &self.dtype());
222        Self::new(K::remainder_scalar(self.primitive, other))
223    }
224
225    /// Applies element wise multiplication operation.
226    ///
227    /// `y = x2 * x1`
228    ///
229    /// # Arguments
230    ///
231    /// * `other` - The tensor to multiply.
232    ///
233    /// # Example
234    ///
235    /// ```rust
236    /// use burn_tensor::backend::Backend;
237    /// use burn_tensor::{Tensor, Shape};
238    ///
239    /// fn example<B: Backend>() {
240    ///    let device = B::Device::default();
241    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
243    ///    let tensor = tensor1 * tensor2;
244    ///    println!("{tensor}");
245    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
246    /// }
247    /// ```
248    #[allow(clippy::should_implement_trait)]
249    pub fn mul(self, other: Self) -> Self {
250        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
251        Self::new(K::mul(self.primitive, other.primitive))
252    }
253
254    /// Applies element wise multiplication operation with a scalar.
255    ///
256    /// `y = x * s`
257    ///
258    /// # Arguments
259    ///
260    /// * `other` - The scalar to multiply, element wise.
261    ///
262    /// # Example
263    ///
264    /// ```rust
265    /// use burn_tensor::backend::Backend;
266    /// use burn_tensor::{Tensor, Shape};
267    ///
268    /// fn example<B: Backend>() {
269    ///    let device = B::Device::default();
270    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271    ///    let scalar = 2.0;
272    ///    let tensor = tensor * scalar;
273    ///    println!("{tensor}");
274    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
275    /// }
276    /// ```
277    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
278        let other = Scalar::new(other, &self.dtype());
279        Self::new(K::mul_scalar(self.primitive, other))
280    }
281
282    /// Switch sign of each element in the tensor.
283    ///
284    /// `y = -x`
285    ///
286    /// # Example
287    ///
288    /// ```rust
289    /// use burn_tensor::backend::Backend;
290    /// use burn_tensor::{Tensor, Shape};
291    ///
292    /// fn example<B: Backend>() {
293    ///    let device = B::Device::default();
294    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
295    ///    let tensor = -tensor;
296    ///    println!("{tensor}");
297    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
298    /// }
299    /// ```
300    #[allow(clippy::should_implement_trait)]
301    pub fn neg(self) -> Self {
302        Self::new(K::neg(self.primitive))
303    }
304
305    /// Returns the signs of the elements of the input tensor.
306    ///
307    /// # Example
308    ///
309    /// ```rust
310    /// use burn_tensor::backend::Backend;
311    /// use burn_tensor::{Tensor, Shape};
312    ///
313    /// fn example<B: Backend>() {
314    ///    let device = B::Device::default();
315    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
316    ///    let tensor = tensor.sign();
317    ///    println!("{tensor}");
318    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
319    /// }
320    /// ```
321    pub fn sign(self) -> Self {
322        Self::new(K::sign(self.primitive))
323    }
324
325    /// Aggregate all elements in the tensor with the mean operation.
326    ///
327    /// # Example
328    ///
329    /// ```rust
330    /// use burn_tensor::backend::Backend;
331    /// use burn_tensor::{Tensor, Shape};
332    ///
333    /// fn example<B: Backend>() {
334    ///    let device = B::Device::default();
335    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
336    ///    let tensor = tensor.mean();
337    ///    println!("{tensor}");
338    ///    // [3.6666667]
339    /// }
340    /// ```
341    pub fn mean(self) -> Tensor<B, 1, K> {
342        Tensor::new(K::mean(self.primitive))
343    }
344
345    /// Aggregate all elements in the tensor with the sum operation.
346    ///
347    /// # Example
348    ///
349    /// ```rust
350    /// use burn_tensor::backend::Backend;
351    /// use burn_tensor::{Tensor, Shape};
352    ///
353    /// fn example<B: Backend>() {
354    ///   let device = B::Device::default();
355    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
356    ///   let tensor = tensor.sum();
357    ///   println!("{tensor}");
358    ///   // [22.0]
359    /// }
360    /// ```
361    pub fn sum(self) -> Tensor<B, 1, K> {
362        Tensor::new(K::sum(self.primitive))
363    }
364
365    /// Aggregate all elements along the given *dimension* or *axis*
366    /// in the tensor with the mean operation.
367    ///
368    /// # Arguments
369    ///
370    /// * `dim` - The dimension or axis along which to aggregate the elements;
371    ///   supports negative indexing.
372    ///
373    /// # Example
374    ///
375    /// ```rust
376    /// use burn_tensor::backend::Backend;
377    /// use burn_tensor::{Tensor, Shape};
378    ///
379    /// fn example<B: Backend>() {
380    ///   let device = B::Device::default();
381    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
382    ///   let tensor = tensor.clone().mean_dim(0);
383    ///   println!("{tensor}");
384    ///   // [[3.0, 3.5, 4.5]]
385    ///   let tensor = tensor.clone().mean_dim(1);
386    ///   println!("{tensor}");
387    ///   // [[0.6666667], [6.6666665]]
388    /// }
389    /// ```
390    pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
391        let dim = dim.expect_dim_index(D);
392        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
393        Self::new(K::mean_dim(self.primitive, dim))
394    }
395
396    /// Aggregate all elements along the given *axes*
397    /// in the tensor with the mean operation.
398    ///
399    /// # Arguments
400    ///
401    /// * `dims` - the dimensions to aggregate; supports negative indexing.
402    ///
403    /// # Returns
404    ///
405    /// The returned tensor will have the same rank,
406    /// but the aggregated dimensions will have size 1.
407    ///
408    /// # Example
409    ///
410    /// ```rust
411    /// use burn_tensor::backend::Backend;
412    /// use burn_tensor::{Tensor, Shape};
413    ///
414    /// fn example<B: Backend>() {
415    ///    let device = B::Device::default();
416    ///    let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
417    ///    let tensor = tensor.clone().mean_dims(&[0, 1]);
418    ///    println!("{tensor}");
419    ///    // [[2.0]]
420    /// }
421    /// ```
422    pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
423        dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
424    }
425
426    /// Aggregate all elements along the given *dimension* or *axis*
427    /// in the tensor with the sum operation.
428    ///
429    /// # Arguments
430    ///
431    /// * `dim` - The dimension or axis along which to aggregate the elements;
432    ///   supports negative indexing.
433    ///
434    /// # Example
435    ///
436    /// ```rust
437    /// use burn_tensor::backend::Backend;
438    /// use burn_tensor::{Tensor, Shape};
439    ///
440    /// fn example<B: Backend>() {
441    ///    let device = B::Device::default();
442    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
443    ///    let tensor = tensor.clone().sum_dim(0);
444    ///    println!("{tensor}");
445    ///    // [[6.0, 7.0, 9.0]]
446    ///    let tensor = tensor.clone().sum_dim(1);
447    ///    println!("{tensor}");
448    ///    // [[2.0], [20.0]]
449    /// }
450    /// ```
451    pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
452        let dim = dim.expect_dim_index(D);
453        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
454        Self::new(K::sum_dim(self.primitive, dim))
455    }
456
457    /// Aggregate all elements along the given *axes*
458    /// in the tensor with the sum operation.
459    ///
460    /// # Arguments
461    ///
462    /// * `dims` - the dimensions to aggregate; supports negative indexing.
463    ///
464    /// # Returns
465    ///
466    /// The returned tensor will have the same rank,
467    /// but the aggregated dimensions will have size 1.
468    ///
469    /// # Example
470    ///
471    /// ```rust
472    /// use burn_tensor::backend::Backend;
473    /// use burn_tensor::{Tensor, Shape};
474    ///
475    /// fn example<B: Backend>() {
476    ///    let device = B::Device::default();
477    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
478    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
479    ///    println!("{tensor}");
480    ///    // [[27]]
481    /// }
482    /// ```
483    pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
484        dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
485    }
486
487    /// Aggregate and squeeze along the given dimensions.
488    ///
489    /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
490    ///
491    /// # Arguments
492    ///
493    /// * `dims` - the dimensions to aggregate; supports negative indexing.
494    ///
495    /// # Returns
496    ///
497    /// The returned tensor will have the same rank,
498    /// but the aggregated dimensions will have size 1.
499    ///
500    /// # Example
501    ///
502    /// ```rust
503    /// use burn_tensor::backend::Backend;
504    /// use burn_tensor::{Tensor, Shape};
505    ///
506    /// fn example<B: Backend>() {
507    ///     let device = B::Device::default();
508    ///     let tensor = Tensor::<B, 3>::from_data([
509    ///         [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
510    ///         [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
511    ///     ], &device);
512    ///     let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
513    ///     println!("{tensor}");
514    ///     // [20.0, 16.0, 21.0]
515    /// }
516    /// ```
517    pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
518        // TODO: remove idims when squeeze_dims uses AsIndex.
519        let idims = dims
520            .iter()
521            .map(|&dim| (dim.expect_dim_index(D)) as isize)
522            .collect::<Vec<_>>();
523        self.sum_dims(dims).squeeze_dims::<D2>(&idims)
524    }
525
526    /// Aggregate all elements in the tensor with the product operation.
527    ///
528    /// # Example
529    ///
530    /// ```rust
531    /// use burn_tensor::backend::Backend;
532    /// use burn_tensor::{Tensor, Shape};
533    ///
534    /// fn example<B: Backend>() {
535    ///    let device = B::Device::default();
536    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
537    ///    let tensor = tensor.prod();
538    ///    println!("{tensor}");
539    ///    // [-1620.0]
540    /// }
541    /// ```
542    pub fn prod(self) -> Tensor<B, 1, K> {
543        Tensor::new(K::prod(self.primitive))
544    }
545
546    /// Aggregate all elements along the given *dimension* or *axis*
547    /// in the tensor with the product operation.
548    ///
549    /// # Arguments
550    ///
551    /// * `dim` - The dimension or axis along which to aggregate the elements,
552    ///   supports negative indexing.
553    ///
554    /// # Returns
555    ///
556    /// The returned tensor will have the same rank,
557    /// but the aggregated dimension will have size 1.
558    ///
559    /// # Example
560    ///
561    /// ```rust
562    /// use burn_tensor::backend::Backend;
563    /// use burn_tensor::{Tensor, Shape};
564    ///
565    /// fn example<B: Backend>() {
566    ///    let device = B::Device::default();
567    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
568    ///    let tensor = tensor.clone().prod_dim(0);
569    ///    println!("{tensor}");
570    ///    // [[5.0, -18.0, 18.0]]
571    ///    let tensor = tensor.clone().prod_dim(1);
572    ///    println!("{tensor}");
573    ///    // [[-6.0], [270.0]]
574    /// }
575    /// ```
576    pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
577        let dim = dim.expect_dim_index(D);
578        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
579        Self::new(K::prod_dim(self.primitive, dim))
580    }
581
582    /// Aggregate all elements along the given *axes*
583    /// in the tensor with the prod operation.
584    ///
585    /// # Arguments
586    ///
587    /// * `dims` - the dimensions to aggregate, supports negative indexing.
588    ///
589    /// # Returns
590    ///
591    /// The returned tensor will have the same rank,
592    /// but the aggregated dimensions will have size 1.
593    ///
594    /// # Example
595    ///
596    /// ```rust
597    /// use burn_tensor::backend::Backend;
598    /// use burn_tensor::{Tensor, Shape};
599    ///
600    /// fn example<B: Backend>() {
601    ///    let device = B::Device::default();
602    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
603    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
604    ///    println!("{tensor}");
605    ///    // [[-1620.0]]
606    /// }
607    /// ```
608    pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
609        dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
610    }
611
612    /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
613    ///
614    /// # Arguments
615    ///
616    /// * `dim` - The dimension or axis along which to compute the cumulative sum.
617    ///
618    /// # Example
619    ///
620    /// ```rust
621    /// use burn_tensor::backend::Backend;
622    /// use burn_tensor::{Tensor, Shape};
623    ///
624    /// fn example<B: Backend>() {
625    ///    let device = B::Device::default();
626    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
627    ///    let result = tensor.clone().cumsum(0);
628    ///    println!("{result}");
629    ///    // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
630    ///    let result = tensor.cumsum(1);
631    ///    println!("{result}");
632    ///    // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
633    /// }
634    /// ```
635    pub fn cumsum(self, dim: usize) -> Self {
636        check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
637        Self::new(K::cumsum(self.primitive, dim))
638    }
639
640    /// Computes the cumulative product of elements along the given *dimension* or *axis*.
641    ///
642    /// # Arguments
643    ///
644    /// * `dim` - The dimension or axis along which to compute the cumulative product.
645    ///
646    /// # Example
647    ///
648    /// ```rust
649    /// use burn_tensor::backend::Backend;
650    /// use burn_tensor::{Tensor, Shape};
651    ///
652    /// fn example<B: Backend>() {
653    ///    let device = B::Device::default();
654    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
655    ///    let result = tensor.clone().cumprod(0);
656    ///    println!("{result}");
657    ///    // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
658    ///    let result = tensor.cumprod(1);
659    ///    println!("{result}");
660    ///    // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
661    /// }
662    /// ```
663    pub fn cumprod(self, dim: usize) -> Self {
664        check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
665        Self::new(K::cumprod(self.primitive, dim))
666    }
667
668    /// Apply element wise absolute value operation.
669    ///
670    /// # Example
671    ///
672    /// ```rust
673    /// use burn_tensor::backend::Backend;
674    /// use burn_tensor::{Int, Tensor};
675    ///
676    /// fn example<B: Backend>() {
677    ///   let device = Default::default();
678    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
679    ///   let tensor = tensor.abs();
680    ///   println!("{tensor}");
681    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
682    /// }
683    /// ```
684    ///
685    /// # Notes
686    ///
687    /// For signed integer dtypes, this operation uses two's-complement wraparound semantics, similar to
688    /// `x.wrapping_abs()`. For example, `abs(i64::MIN) == i64::MIN`.
689    pub fn abs(self) -> Self {
690        Self::new(K::abs(self.primitive))
691    }
692
693    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
694    /// the other elements of the result tensor out are set to 0.
695    ///
696    /// See also [`triu_mask`](Tensor::triu_mask).
697    ///
698    /// # Arguments
699    ///
700    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
701    ///   towards the upper triangle.
702    ///
703    /// # Example
704    /// ```rust
705    /// use burn_tensor::backend::Backend;
706    /// use burn_tensor::{Int, Tensor};
707    ///
708    /// fn example<B: Backend>() {
709    ///    let device = Default::default();
710    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
711    ///        [
712    ///          [1, 2, 3],
713    ///          [4, 5, 6],
714    ///          [7, 8, 9]
715    ///        ],
716    ///        &device
717    ///    );
718    ///    let tensor = tensor.triu(1);
719    ///    println!("{tensor}");
720    ///    // [
721    ///    //   [0, 2, 3],
722    ///    //   [0, 0, 6],
723    ///    //   [0, 0, 0]
724    ///    // ]
725    /// }
726    /// ```
727    pub fn triu(self, diagonal: i64) -> Self {
728        check!(TensorCheck::tri::<{ D }>());
729
730        // last two dimensions
731        let shape = &self.shape()[D - 2..].to_owned();
732
733        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
734        self.mask_fill(mask, 0)
735    }
736
737    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
738    /// the other elements of the result tensor out are set to 0.
739    ///
740    /// See also [`tril_mask`](Tensor::tril_mask).
741    ///
742    /// # Arguments
743    ///
744    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
745    ///   towards the upper triangle.
746    ///
747    /// # Example
748    /// ```rust
749    /// use burn_tensor::backend::Backend;
750    /// use burn_tensor::{Int, Tensor};
751    ///
752    /// fn example<B: Backend>() {
753    ///    let device = Default::default();
754    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
755    ///        [
756    ///          [1, 2, 3],
757    ///          [4, 5, 6],
758    ///          [7, 8, 9]
759    ///        ],
760    ///        &device
761    ///    );
762    ///
763    ///    let tensor = tensor.tril(-1);
764    ///    println!("{tensor}");
765    ///    // [
766    ///    //   [0, 0, 0],
767    ///    //   [4, 0, 0],
768    ///    //   [7, 8, 0]
769    ///    // ]
770    /// }
771    /// ```
772    pub fn tril(self, diagonal: i64) -> Self {
773        check!(TensorCheck::tri::<{ D }>());
774
775        // last two dimensions
776        let shape = &self.shape()[D - 2..].to_owned();
777        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
778
779        self.mask_fill(mask, 0)
780    }
781
782    /// Applies element wise power operation with a float Tensor
783    ///
784    /// # Arguments
785    ///
786    /// * `other` - The tensor to apply the power operation with.
787    ///
788    /// # Example
789    ///
790    /// ```rust
791    /// use burn_tensor::backend::Backend;
792    /// use burn_tensor::{Tensor, Shape};
793    ///
794    /// fn example<B: Backend>() {
795    ///    let device = B::Device::default();
796    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
797    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
798    ///    let tensor = tensor1.powf(tensor2);
799    ///    println!("{tensor}");
800    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
801    /// }
802    /// ```
803    pub fn powf(self, other: Self) -> Self {
804        Self::new(K::powf(self.primitive, other.primitive))
805    }
806
807    /// Applies element wise power operation with a float scalar
808    ///
809    /// # Arguments
810    ///
811    /// * `other` - The scalar to apply the power operation with.
812    ///
813    /// # Example
814    ///
815    /// ```rust
816    /// use burn_tensor::backend::Backend;
817    /// use burn_tensor::{Tensor, Shape};
818    ///
819    /// fn example<B: Backend>() {
820    ///    let device = B::Device::default();
821    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
822    ///    let tensor = tensor.powf_scalar(2.0);
823    ///    println!("{tensor}");
824    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
825    /// }
826    /// ```
827    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
828        let other = Scalar::new(other, &self.dtype());
829        Self::new(K::powf_scalar(self.primitive, other))
830    }
831
832    /// Applies element wise power operation with a integer Tensor
833    ///
834    /// # Arguments
835    ///
836    /// * `other` - The tensor to apply the power operation with.
837    ///
838    /// # Example
839    ///
840    /// ```rust
841    /// use burn_tensor::backend::Backend;
842    /// use burn_tensor::{Tensor, Shape, Int};
843    ///
844    /// fn example<B: Backend>() {
845    ///    let device = B::Device::default();
846    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
847    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
848    ///    let tensor = tensor1.powi(tensor2);
849    ///    println!("{tensor}");
850    ///    // [[1, -8, 81], [5, 81, 216]]
851    /// }
852    /// ```
853    pub fn powi(self, other: Self) -> Self {
854        Self::new(K::powi(self.primitive, other.primitive))
855    }
856
857    /// Applies element wise power operation with a integer scalar
858    ///
859    /// # Arguments
860    ///
861    /// * `other` - The scalar to apply the power operation with.
862    ///
863    /// # Example
864    ///
865    /// ```rust
866    /// use burn_tensor::backend::Backend;
867    /// use burn_tensor::{Tensor, Shape, Int};
868    ///
869    /// fn example<B: Backend>() {
870    ///    let device = B::Device::default();
871    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
872    ///    let tensor = tensor.powi_scalar(2);
873    ///    println!("{tensor}");
874    ///
875    ///    // [[1, 4, 9], [25, 81, 36]]
876    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
877    ///    let tensor = tensor.powi_scalar(2);
878    ///    println!("{tensor}");
879    ///    // [[2.25, 4., 9.], [25., 81., 36.]]
880    /// }
881    /// ```
882    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
883        let other = Scalar::new(other, &self.dtype());
884        Self::new(K::powi_scalar(self.primitive, other))
885    }
886
887    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
888    ///
889    /// # Returns
890    ///
891    /// A boolean tensor with the same shape as the input tensor.
892    ///
893    /// # Example
894    ///
895    /// ```rust
896    /// use burn_tensor::backend::Backend;
897    /// use burn_tensor::{Tensor, Shape};
898    ///
899    /// fn example<B: Backend>() {
900    ///   let device = B::Device::default();
901    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
902    ///   let tensor = tensor.bool();
903    ///   println!("{tensor}");
904    ///   // [
905    ///   //   [true, true, true],
906    ///   //   [false, true, true]
907    ///   // ]
908    /// }
909    pub fn bool(self) -> Tensor<B, D, Bool> {
910        self.not_equal_elem(0)
911    }
912
913    /// Create a random tensor of the given shape on the given device where each element is
914    /// sampled from the given distribution.
915    ///
916    /// See also [`random_like`](Tensor::random_like).
917    ///
918    /// # Arguments
919    ///
920    /// * `shape` - The shape of the tensor.
921    /// * `distribution` - The distribution to sample from.
922    /// * `device` - The device to create the tensor on.
923    ///
924    /// # Returns
925    ///
926    /// A new tensor with the given shape and elements sampled from the given distribution.
927    ///
928    /// # Example
929    ///
930    /// ```rust
931    /// use burn_tensor::backend::Backend;
932    /// use burn_tensor::{Tensor, Shape, Distribution};
933    ///
934    /// fn example<B: Backend>() {
935    ///   let device = B::Device::default();
936    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
937    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
938    ///   println!("{tensor}");
939    ///   // [
940    ///   //   [0.08347523, 0.70498955, 0.60332155],
941    ///   //   [0.08173251, 0.18028641, 0.97942924]
942    ///   // ]
943    /// }
944    /// ```
945    pub fn random<S: Into<Shape>>(
946        shape: S,
947        distribution: Distribution,
948        device: &B::Device,
949    ) -> Self {
950        Self::new(K::random(shape.into(), distribution, device))
951    }
952
953    /// Applies the matrix multiplication operation.
954    ///
955    /// ```math
956    /// C = AB
957    /// ```
958    ///
959    /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as
960    /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general
961    /// matmul, which is often faster.
962    pub fn matmul(self, other: Self) -> Self {
963        check!(TensorCheck::matmul(&self, &other));
964
965        if D >= 3 {
966            let batch_index = D - 3;
967            let vector_index = D - 2;
968            let lhs_dims = &self.shape()[batch_index..D];
969            let rhs_dims = &other.shape()[batch_index..D];
970
971            if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims)
972                && k1 == k2
973            {
974                return Tensor::new(K::matmul(
975                    self.swap_dims(batch_index, vector_index).primitive,
976                    other.primitive,
977                ))
978                .swap_dims(batch_index, vector_index);
979            }
980        }
981
982        Tensor::new(K::matmul(self.primitive, other.primitive))
983    }
984}
985
986impl<B, K> Tensor<B, 1, K>
987where
988    B: Backend,
989    K: Numeric<B>,
990    K::Elem: Element,
991{
992    /// Calculates the dot product with another tensor.
993    ///
994    /// `y = x2.dot(x1)`
995    ///
996    /// # Arguments
997    ///
998    /// * `other` - The tensor to compute dot product with.
999    ///
1000    /// # Notes
1001    ///
1002    /// Both tensors must have the same number of elements.
1003    ///
1004    /// # Example
1005    ///
1006    /// ```rust
1007    /// use burn_tensor::backend::Backend;
1008    /// use burn_tensor::{Tensor, Shape};
1009    ///
1010    /// fn example<B: Backend>() {
1011    ///    let device = B::Device::default();
1012    ///    let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
1013    ///    let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
1014    ///    let tensor = tensor1.dot(tensor2);
1015    ///    println!("{tensor}");
1016    ///    // [4]
1017    /// }
1018    /// ```
1019    pub fn dot(self, other: Self) -> Self {
1020        self.mul(other).sum()
1021    }
1022}
1023
1024impl<B, K> Tensor<B, 2, K>
1025where
1026    B: Backend,
1027    K: Numeric<B>,
1028    K::Elem: Element,
1029{
1030    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
1031    ///
1032    /// # Arguments
1033    ///
1034    /// * `size` - The size of the square matrix.
1035    pub fn eye(size: usize, device: &B::Device) -> Self {
1036        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
1037        let ones = Self::ones([1, size], device);
1038        let zeros = Self::zeros([size, size], device);
1039
1040        zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
1041    }
1042}
1043
1044// Tensor + tensor
1045impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
1046where
1047    K::Elem: Element,
1048{
1049    type Output = Self;
1050
1051    fn add(self, rhs: Self) -> Self::Output {
1052        Self::add(self, rhs)
1053    }
1054}
1055
1056// Tensor + scalar
1057impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
1058    for Tensor<B, D, K>
1059where
1060    K::Elem: Element,
1061{
1062    type Output = Self;
1063
1064    fn add(self, other: E) -> Self::Output {
1065        Tensor::add_scalar(self, other)
1066    }
1067}
1068
1069// Scalar + tensor
1070macro_rules! impl_tensor_scalar_add {
1071    ($($t:ty),*) => {
1072        $(
1073            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
1074            where
1075                K::Elem: Element,
1076            {
1077                type Output = Tensor<B, D, K>;
1078
1079                fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
1080                    Tensor::add_scalar(tensor, self)
1081                }
1082            }
1083        )*
1084    }
1085}
1086impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
1087
1088// Tensor - tensor
1089impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
1090where
1091    K::Elem: Element,
1092{
1093    type Output = Self;
1094
1095    fn sub(self, rhs: Self) -> Self::Output {
1096        Tensor::sub(self, rhs)
1097    }
1098}
1099
1100// Tensor - scalar
1101impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
1102    for Tensor<B, D, K>
1103where
1104    K::Elem: Element,
1105{
1106    type Output = Self;
1107
1108    fn sub(self, other: E) -> Self::Output {
1109        Tensor::sub_scalar(self, other)
1110    }
1111}
1112
1113// Scalar - tensor
1114macro_rules! impl_tensor_scalar_sub {
1115    ($($t:ty),*) => {
1116        $(
1117            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
1118            where
1119                K::Elem: Element,
1120            {
1121                type Output = Tensor<B, D, K>;
1122
1123                fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
1124                    Tensor::add_scalar(Tensor::neg(tensor), self)
1125                }
1126            }
1127        )*
1128    }
1129}
1130impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
1131
1132// Tensor / tensor
1133impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
1134where
1135    K::Elem: Element,
1136{
1137    type Output = Self;
1138
1139    fn div(self, rhs: Self) -> Self::Output {
1140        Tensor::div(self, rhs)
1141    }
1142}
1143
1144// Tensor / scalar
1145impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
1146    for Tensor<B, D, K>
1147where
1148    K::Elem: Element,
1149{
1150    type Output = Self;
1151
1152    fn div(self, other: E) -> Self::Output {
1153        Tensor::div_scalar(self, other)
1154    }
1155}
1156
1157// Scalar / tensor (float only)
1158macro_rules! impl_tensor_scalar_div {
1159    ($($t:ty),*) => {
1160        $(
1161            impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
1162            {
1163                type Output = Tensor<B, D>;
1164
1165                fn div(self, tensor: Tensor<B, D>) -> Self::Output {
1166                    tensor.recip().mul_scalar(self)
1167                }
1168            }
1169        )*
1170    }
1171}
1172
1173impl_tensor_scalar_div!(f32, f64);
1174
1175// Tensor % tensor.
1176impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
1177where
1178    K::Elem: Element,
1179{
1180    type Output = Self;
1181
1182    fn rem(self, rhs: Self) -> Self::Output {
1183        Tensor::remainder(self, rhs)
1184    }
1185}
1186
1187// Tensor % scalar.
1188impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
1189    for Tensor<B, D, K>
1190where
1191    K::Elem: Element,
1192{
1193    type Output = Self;
1194
1195    fn rem(self, other: E) -> Self::Output {
1196        Tensor::remainder_scalar(self, other)
1197    }
1198}
1199
1200// Tensor * tensor.
1201impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
1202where
1203    K::Elem: Element,
1204{
1205    type Output = Self;
1206
1207    fn mul(self, rhs: Self) -> Self::Output {
1208        Tensor::mul(self, rhs)
1209    }
1210}
1211
1212// Tensor * scalar.
1213impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
1214    for Tensor<B, D, K>
1215where
1216    K::Elem: Element,
1217{
1218    type Output = Self;
1219
1220    fn mul(self, other: E) -> Self::Output {
1221        Tensor::mul_scalar(self, other)
1222    }
1223}
1224
1225macro_rules! impl_tensor_scalar_mul {
1226    ($($t:ty),*) => {
1227        $(
1228            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
1229            where
1230                K::Elem: Element,
1231            {
1232                type Output = Tensor<B, D, K>;
1233
1234                fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
1235                    Tensor::mul_scalar(other, self)
1236                }
1237            }
1238        )*
1239    }
1240}
1241
1242impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
1243
1244impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
1245where
1246    B: Backend,
1247    K: Numeric<B>,
1248    K::Elem: Element,
1249{
1250    type Output = Self;
1251
1252    fn neg(self) -> Self::Output {
1253        Tensor::neg(self)
1254    }
1255}