burn_tensor/tensor/api/
numeric.rs

1pub use burn_backend::tensor::Numeric;
2
3use crate::alloc::borrow::ToOwned;
4use alloc::vec::Vec;
5
6use crate::IndexingUpdateOp;
7use crate::{
8    AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
9    check, check::TensorCheck,
10};
11
12impl<B, const D: usize, K> Tensor<B, D, K>
13where
14    B: Backend,
15    K: Numeric<B>,
16    K::Elem: Element,
17{
18    /// Applies element wise addition operation.
19    ///
20    /// `y = x2 + x1`
21    ///
22    /// # Arguments
23    ///
24    /// * `other` - The tensor to add.
25    ///
26    /// # Example
27    ///
28    /// ```rust
29    /// use burn_tensor::backend::Backend;
30    /// use burn_tensor::{Tensor, Shape};
31    ///
32    /// fn example<B: Backend>() {
33    ///    let device = B::Device::default();
34    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
35    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
36    ///    let tensor = tensor1 + tensor2;
37    ///    println!("{tensor}");
38    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
39    /// }
40    /// ```
41    #[allow(clippy::should_implement_trait)]
42    pub fn add(self, other: Self) -> Self {
43        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
44        Self::new(K::add(self.primitive, other.primitive))
45    }
46
47    /// Applies element wise addition operation with a scalar.
48    ///
49    /// `y = x + s`
50    ///
51    /// # Arguments
52    ///
53    /// * `other` - The scalar to add, element wise.
54    ///
55    /// # Example
56    ///
57    /// ```rust
58    /// use burn_tensor::backend::Backend;
59    /// use burn_tensor::{Tensor, Shape};
60    ///
61    /// fn example<B: Backend>() {
62    ///   let device = B::Device::default();
63    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
64    ///   let scalar = 2.0;
65    ///   let tensor = tensor + scalar;
66    ///   println!("{tensor}");
67    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
68    /// }
69    /// ```
70    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
71        Self::new(K::add_scalar::<E>(self.primitive, other))
72    }
73
74    /// Applies element wise subtraction operation.
75    ///
76    /// `y = x2 - x1`
77    ///
78    /// # Arguments
79    ///
80    /// * `other` - The tensor to subtract.
81    ///
82    /// # Example
83    ///
84    /// ```rust
85    /// use burn_tensor::backend::Backend;
86    /// use burn_tensor::{Tensor, Shape};
87    ///
88    /// fn example<B: Backend>() {
89    ///   let device = B::Device::default();
90    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
91    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
92    ///   let tensor = tensor1 - tensor2;
93    ///   println!("{tensor}");
94    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
95    /// }
96    /// ```
97    #[allow(clippy::should_implement_trait)]
98    pub fn sub(self, other: Self) -> Self {
99        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
100        Self::new(K::sub(self.primitive, other.primitive))
101    }
102
103    /// Applies element wise subtraction operation with a scalar.
104    ///
105    /// `y = x - s`
106    ///
107    /// # Arguments
108    ///
109    /// * `other` - The scalar to subtract, element wise.
110    ///
111    /// # Example
112    ///
113    /// ```rust
114    /// use burn_tensor::backend::Backend;
115    /// use burn_tensor::{Tensor, Shape};
116    ///
117    /// fn example<B: Backend>() {
118    ///    let device = B::Device::default();
119    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
120    ///    let scalar = 2.0;
121    ///    let tensor = tensor - scalar;
122    ///    println!("{tensor}");
123    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
124    /// }
125    /// ```
126    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
127        Self::new(K::sub_scalar::<E>(self.primitive, other))
128    }
129
130    /// Applies element wise division operation.
131    ///
132    /// `y = x2 / x1`
133    ///
134    /// # Arguments
135    ///
136    /// * `other` - The tensor to divide.
137    ///
138    /// # Example
139    ///
140    /// ```rust
141    /// use burn_tensor::backend::Backend;
142    /// use burn_tensor::{Tensor, Shape};
143    ///
144    /// fn example<B: Backend>() {
145    ///    let device = B::Device::default();
146    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
147    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
148    ///    let tensor = tensor1 / tensor2;
149    ///    println!("{tensor}");
150    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
151    /// }
152    /// ```
153    #[allow(clippy::should_implement_trait)]
154    pub fn div(self, other: Self) -> Self {
155        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
156        Self::new(K::div(self.primitive, other.primitive))
157    }
158
159    /// Applies element wise division operation with a scalar.
160    ///
161    /// `y = x / s`
162    ///
163    /// # Arguments
164    ///
165    /// * `other` - The scalar to divide, element wise.
166    ///
167    /// # Example
168    ///
169    /// ```rust
170    /// use burn_tensor::backend::Backend;
171    /// use burn_tensor::{Tensor, Shape};
172    ///
173    /// fn example<B: Backend>() {
174    ///    let device = B::Device::default();
175    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
176    ///    let scalar = 2.0;
177    ///    let tensor = tensor / scalar;
178    ///    println!("{tensor}");
179    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
180    /// }
181    /// ```
182    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
183        Self::new(K::div_scalar::<E>(self.primitive, other))
184    }
185
186    /// Applies element wise the remainder operation with a scalar.
187    ///
188    /// `y = x2 % x1`
189    pub fn remainder(self, other: Self) -> Self {
190        Self::new(K::remainder(self.primitive, other.primitive))
191    }
192
193    /// Applies element wise the remainder operation with a scalar.
194    ///
195    /// `y = x % s`
196    ///
197    /// # Arguments
198    ///
199    /// * `other` - The scalar to divide, element wise.
200    ///
201    /// # Example
202    ///
203    /// ```rust
204    /// use burn_tensor::backend::Backend;
205    /// use burn_tensor::{Tensor, Shape};
206    ///
207    /// fn example<B: Backend>() {
208    ///    let device = B::Device::default();
209    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
210    ///    let scalar = 2.0;
211    ///    let tensor = tensor1 % scalar;
212    ///    println!("{tensor}");
213    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
214    /// }
215    /// ```
216    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
217        Self::new(K::remainder_scalar::<E>(self.primitive, other))
218    }
219
220    /// Applies element wise multiplication operation.
221    ///
222    /// `y = x2 * x1`
223    ///
224    /// # Arguments
225    ///
226    /// * `other` - The tensor to multiply.
227    ///
228    /// # Example
229    ///
230    /// ```rust
231    /// use burn_tensor::backend::Backend;
232    /// use burn_tensor::{Tensor, Shape};
233    ///
234    /// fn example<B: Backend>() {
235    ///    let device = B::Device::default();
236    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
237    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
238    ///    let tensor = tensor1 * tensor2;
239    ///    println!("{tensor}");
240    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
241    /// }
242    /// ```
243    #[allow(clippy::should_implement_trait)]
244    pub fn mul(self, other: Self) -> Self {
245        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
246        Self::new(K::mul(self.primitive, other.primitive))
247    }
248
249    /// Applies element wise multiplication operation with a scalar.
250    ///
251    /// `y = x * s`
252    ///
253    /// # Arguments
254    ///
255    /// * `other` - The scalar to multiply, element wise.
256    ///
257    /// # Example
258    ///
259    /// ```rust
260    /// use burn_tensor::backend::Backend;
261    /// use burn_tensor::{Tensor, Shape};
262    ///
263    /// fn example<B: Backend>() {
264    ///    let device = B::Device::default();
265    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
266    ///    let scalar = 2.0;
267    ///    let tensor = tensor * scalar;
268    ///    println!("{tensor}");
269    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
270    /// }
271    /// ```
272    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
273        Self::new(K::mul_scalar::<E>(self.primitive, other))
274    }
275
276    /// Switch sign of each element in the tensor.
277    ///
278    /// `y = -x`
279    ///
280    /// # Example
281    ///
282    /// ```rust
283    /// use burn_tensor::backend::Backend;
284    /// use burn_tensor::{Tensor, Shape};
285    ///
286    /// fn example<B: Backend>() {
287    ///    let device = B::Device::default();
288    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
289    ///    let tensor = -tensor;
290    ///    println!("{tensor}");
291    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
292    /// }
293    /// ```
294    #[allow(clippy::should_implement_trait)]
295    pub fn neg(self) -> Self {
296        Self::new(K::neg(self.primitive))
297    }
298
299    /// Returns the signs of the elements of the input tensor.
300    ///
301    /// # Example
302    ///
303    /// ```rust
304    /// use burn_tensor::backend::Backend;
305    /// use burn_tensor::{Tensor, Shape};
306    ///
307    /// fn example<B: Backend>() {
308    ///    let device = B::Device::default();
309    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
310    ///    let tensor = tensor.sign();
311    ///    println!("{tensor}");
312    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
313    /// }
314    /// ```
315    pub fn sign(self) -> Self {
316        Self::new(K::sign(self.primitive))
317    }
318
319    /// Aggregate all elements in the tensor with the mean operation.
320    ///
321    /// # Example
322    ///
323    /// ```rust
324    /// use burn_tensor::backend::Backend;
325    /// use burn_tensor::{Tensor, Shape};
326    ///
327    /// fn example<B: Backend>() {
328    ///    let device = B::Device::default();
329    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
330    ///    let tensor = tensor.mean();
331    ///    println!("{tensor}");
332    ///    // [3.6666667]
333    /// }
334    /// ```
335    pub fn mean(self) -> Tensor<B, 1, K> {
336        Tensor::new(K::mean(self.primitive))
337    }
338
339    /// Aggregate all elements in the tensor with the sum operation.
340    ///
341    /// # Example
342    ///
343    /// ```rust
344    /// use burn_tensor::backend::Backend;
345    /// use burn_tensor::{Tensor, Shape};
346    ///
347    /// fn example<B: Backend>() {
348    ///   let device = B::Device::default();
349    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
350    ///   let tensor = tensor.sum();
351    ///   println!("{tensor}");
352    ///   // [22.0]
353    /// }
354    /// ```
355    pub fn sum(self) -> Tensor<B, 1, K> {
356        Tensor::new(K::sum(self.primitive))
357    }
358
359    /// Aggregate all elements along the given *dimension* or *axis*
360    /// in the tensor with the mean operation.
361    ///
362    /// # Arguments
363    ///
364    /// * `dim` - The dimension or axis along which to aggregate the elements;
365    ///   supports negative indexing.
366    ///
367    /// # Example
368    ///
369    /// ```rust
370    /// use burn_tensor::backend::Backend;
371    /// use burn_tensor::{Tensor, Shape};
372    ///
373    /// fn example<B: Backend>() {
374    ///   let device = B::Device::default();
375    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
376    ///   let tensor = tensor.clone().mean_dim(0);
377    ///   println!("{tensor}");
378    ///   // [[3.0, 3.5, 4.5]]
379    ///   let tensor = tensor.clone().mean_dim(1);
380    ///   println!("{tensor}");
381    ///   // [[0.6666667], [6.6666665]]
382    /// }
383    /// ```
384    pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
385        let dim = dim.expect_dim_index(D);
386        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
387        Self::new(K::mean_dim(self.primitive, dim))
388    }
389
390    /// Aggregate all elements along the given *axes*
391    /// in the tensor with the mean operation.
392    ///
393    /// # Arguments
394    ///
395    /// * `dims` - the dimensions to aggregate; supports negative indexing.
396    ///
397    /// # Returns
398    ///
399    /// The returned tensor will have the same rank,
400    /// but the aggregated dimensions will have size 1.
401    ///
402    /// # Example
403    ///
404    /// ```rust
405    /// use burn_tensor::backend::Backend;
406    /// use burn_tensor::{Tensor, Shape};
407    ///
408    /// fn example<B: Backend>() {
409    ///    let device = B::Device::default();
410    ///    let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
411    ///    let tensor = tensor.clone().mean_dims(&[0, 1]);
412    ///    println!("{tensor}");
413    ///    // [[2.0]]
414    /// }
415    /// ```
416    pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
417        dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
418    }
419
420    /// Aggregate all elements along the given *dimension* or *axis*
421    /// in the tensor with the sum operation.
422    ///
423    /// # Arguments
424    ///
425    /// * `dim` - The dimension or axis along which to aggregate the elements;
426    ///   supports negative indexing.
427    ///
428    /// # Example
429    ///
430    /// ```rust
431    /// use burn_tensor::backend::Backend;
432    /// use burn_tensor::{Tensor, Shape};
433    ///
434    /// fn example<B: Backend>() {
435    ///    let device = B::Device::default();
436    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
437    ///    let tensor = tensor.clone().sum_dim(0);
438    ///    println!("{tensor}");
439    ///    // [[6.0, 7.0, 9.0]]
440    ///    let tensor = tensor.clone().sum_dim(1);
441    ///    println!("{tensor}");
442    ///    // [[2.0], [20.0]]
443    /// }
444    /// ```
445    pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
446        let dim = dim.expect_dim_index(D);
447        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
448        Self::new(K::sum_dim(self.primitive, dim))
449    }
450
451    /// Aggregate all elements along the given *axes*
452    /// in the tensor with the sum operation.
453    ///
454    /// # Arguments
455    ///
456    /// * `dims` - the dimensions to aggregate; supports negative indexing.
457    ///
458    /// # Returns
459    ///
460    /// The returned tensor will have the same rank,
461    /// but the aggregated dimensions will have size 1.
462    ///
463    /// # Example
464    ///
465    /// ```rust
466    /// use burn_tensor::backend::Backend;
467    /// use burn_tensor::{Tensor, Shape};
468    ///
469    /// fn example<B: Backend>() {
470    ///    let device = B::Device::default();
471    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
472    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
473    ///    println!("{tensor}");
474    ///    // [[27]]
475    /// }
476    /// ```
477    pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
478        dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
479    }
480
481    /// Aggregate and squeeze along the given dimensions.
482    ///
483    /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
484    ///
485    /// # Arguments
486    ///
487    /// * `dims` - the dimensions to aggregate; supports negative indexing.
488    ///
489    /// # Returns
490    ///
491    /// The returned tensor will have the same rank,
492    /// but the aggregated dimensions will have size 1.
493    ///
494    /// # Example
495    ///
496    /// ```rust
497    /// use burn_tensor::backend::Backend;
498    /// use burn_tensor::{Tensor, Shape};
499    ///
500    /// fn example<B: Backend>() {
501    ///     let device = B::Device::default();
502    ///     let tensor = Tensor::<B, 3>::from_data([
503    ///         [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
504    ///         [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
505    ///     ], &device);
506    ///     let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
507    ///     println!("{tensor}");
508    ///     // [20.0, 16.0, 21.0]
509    /// }
510    /// ```
511    pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
512        // TODO: remove idims when squeeze_dims uses AsIndex.
513        let idims = dims
514            .iter()
515            .map(|&dim| (dim.expect_dim_index(D)) as isize)
516            .collect::<Vec<_>>();
517        self.sum_dims(dims).squeeze_dims::<D2>(&idims)
518    }
519
520    /// Aggregate all elements in the tensor with the product operation.
521    ///
522    /// # Example
523    ///
524    /// ```rust
525    /// use burn_tensor::backend::Backend;
526    /// use burn_tensor::{Tensor, Shape};
527    ///
528    /// fn example<B: Backend>() {
529    ///    let device = B::Device::default();
530    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
531    ///    let tensor = tensor.prod();
532    ///    println!("{tensor}");
533    ///    // [-1620.0]
534    /// }
535    /// ```
536    pub fn prod(self) -> Tensor<B, 1, K> {
537        Tensor::new(K::prod(self.primitive))
538    }
539
540    /// Aggregate all elements along the given *dimension* or *axis*
541    /// in the tensor with the product operation.
542    ///
543    /// # Arguments
544    ///
545    /// * `dim` - The dimension or axis along which to aggregate the elements,
546    ///   supports negative indexing.
547    ///
548    /// # Returns
549    ///
550    /// The returned tensor will have the same rank,
551    /// but the aggregated dimension will have size 1.
552    ///
553    /// # Example
554    ///
555    /// ```rust
556    /// use burn_tensor::backend::Backend;
557    /// use burn_tensor::{Tensor, Shape};
558    ///
559    /// fn example<B: Backend>() {
560    ///    let device = B::Device::default();
561    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
562    ///    let tensor = tensor.clone().prod_dim(0);
563    ///    println!("{tensor}");
564    ///    // [[5.0, -18.0, 18.0]]
565    ///    let tensor = tensor.clone().prod_dim(1);
566    ///    println!("{tensor}");
567    ///    // [[-6.0], [270.0]]
568    /// }
569    /// ```
570    pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
571        let dim = dim.expect_dim_index(D);
572        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
573        Self::new(K::prod_dim(self.primitive, dim))
574    }
575
576    /// Aggregate all elements along the given *axes*
577    /// in the tensor with the prod operation.
578    ///
579    /// # Arguments
580    ///
581    /// * `dims` - the dimensions to aggregate, supports negative indexing.
582    ///
583    /// # Returns
584    ///
585    /// The returned tensor will have the same rank,
586    /// but the aggregated dimensions will have size 1.
587    ///
588    /// # Example
589    ///
590    /// ```rust
591    /// use burn_tensor::backend::Backend;
592    /// use burn_tensor::{Tensor, Shape};
593    ///
594    /// fn example<B: Backend>() {
595    ///    let device = B::Device::default();
596    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
597    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);
598    ///    println!("{tensor}");
599    ///    // [[-1620.0]]
600    /// }
601    /// ```
602    pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
603        dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
604    }
605
606    /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
607    ///
608    /// # Arguments
609    ///
610    /// * `dim` - The dimension or axis along which to compute the cumulative sum.
611    ///
612    /// # Example
613    ///
614    /// ```rust
615    /// use burn_tensor::backend::Backend;
616    /// use burn_tensor::{Tensor, Shape};
617    ///
618    /// fn example<B: Backend>() {
619    ///    let device = B::Device::default();
620    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
621    ///    let result = tensor.clone().cumsum(0);
622    ///    println!("{result}");
623    ///    // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
624    ///    let result = tensor.cumsum(1);
625    ///    println!("{result}");
626    ///    // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
627    /// }
628    /// ```
629    pub fn cumsum(self, dim: usize) -> Self {
630        check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
631        Self::new(K::cumsum(self.primitive, dim))
632    }
633
634    /// Computes the cumulative product of elements along the given *dimension* or *axis*.
635    ///
636    /// # Arguments
637    ///
638    /// * `dim` - The dimension or axis along which to compute the cumulative product.
639    ///
640    /// # Example
641    ///
642    /// ```rust
643    /// use burn_tensor::backend::Backend;
644    /// use burn_tensor::{Tensor, Shape};
645    ///
646    /// fn example<B: Backend>() {
647    ///    let device = B::Device::default();
648    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
649    ///    let result = tensor.clone().cumprod(0);
650    ///    println!("{result}");
651    ///    // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
652    ///    let result = tensor.cumprod(1);
653    ///    println!("{result}");
654    ///    // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
655    /// }
656    /// ```
657    pub fn cumprod(self, dim: usize) -> Self {
658        check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
659        Self::new(K::cumprod(self.primitive, dim))
660    }
661
662    /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
663    ///
664    /// # Arguments
665    ///
666    /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
667    ///
668    /// # Example
669    ///
670    /// ```rust
671    /// use burn_tensor::backend::Backend;
672    /// use burn_tensor::{Tensor, Shape};
673    ///
674    /// fn example<B: Backend>() {
675    ///    let device = B::Device::default();
676    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
677    ///    let result = tensor.clone().cummin(0);
678    ///    println!("{result}");
679    ///    // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
680    ///    let result = tensor.cummin(1);
681    ///    println!("{result}");
682    ///    // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
683    /// }
684    /// ```
685    pub fn cummin(self, dim: usize) -> Self {
686        check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
687        Self::new(K::cummin(self.primitive, dim))
688    }
689
690    /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
691    ///
692    /// # Arguments
693    ///
694    /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
695    ///
696    /// # Example
697    ///
698    /// ```rust
699    /// use burn_tensor::backend::Backend;
700    /// use burn_tensor::{Tensor, Shape};
701    ///
702    /// fn example<B: Backend>() {
703    ///    let device = B::Device::default();
704    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
705    ///    let result = tensor.clone().cummax(0);
706    ///    println!("{result}");
707    ///    // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
708    ///    let result = tensor.cummax(1);
709    ///    println!("{result}");
710    ///    // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
711    /// }
712    /// ```
713    pub fn cummax(self, dim: usize) -> Self {
714        check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
715        Self::new(K::cummax(self.primitive, dim))
716    }
717    /// Applies element wise greater comparison and returns a boolean tensor.
718    ///
719    /// # Panics
720    ///
721    /// If the two tensors don't have the same shape.
722    ///
723    /// # Example
724    ///
725    /// ```rust
726    /// use burn_tensor::backend::Backend;
727    /// use burn_tensor::{Tensor, Shape};
728    ///
729    /// fn example<B: Backend>() {
730    ///   let device = B::Device::default();
731    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
732    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
733    ///   let tensor = tensor1.greater(tensor2);
734    ///   println!("{tensor}");
735    ///   // [[false, false, false], [true, true, true]]
736    /// }
737    /// ```
738    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
739        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
740        Tensor::new(K::greater(self.primitive, other.primitive))
741    }
742
743    /// Applies element wise greater-equal comparison and returns a boolean tensor.
744    ///
745    /// # Panics
746    ///
747    /// If the two tensors don't have the same shape.
748    ///
749    /// # Example
750    ///
751    /// ```rust
752    /// use burn_tensor::backend::Backend;
753    /// use burn_tensor::{Tensor, Shape};
754    ///
755    /// fn example<B: Backend>() {
756    ///    let device = B::Device::default();
757    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
758    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
759    ///    let tensor = tensor1.greater_equal(tensor2);
760    ///    println!("{tensor}");
761    ///    // [[true, false, false], [true, true, true]]
762    /// }
763    /// ```
764    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
765        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
766        Tensor::new(K::greater_equal(self.primitive, other.primitive))
767    }
768
769    /// Applies element wise lower comparison and returns a boolean tensor.
770    ///
771    /// # Panics
772    ///
773    /// If the two tensors don't have the same shape.
774    ///
775    /// # Example
776    ///
777    /// ```rust
778    /// use burn_tensor::backend::Backend;
779    /// use burn_tensor::{Tensor, Shape};
780    ///
781    /// fn example<B: Backend>() {
782    ///    let device = B::Device::default();
783    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
784    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
785    ///    let tensor = tensor1.lower(tensor2);
786    ///    println!("{tensor}");
787    ///    // [[false, true, true], [false, false, false]]
788    /// }
789    /// ```
790    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
791        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
792        Tensor::new(K::lower(self.primitive, other.primitive))
793    }
794
795    /// Applies element wise lower-equal comparison and returns a boolean tensor.
796    ///
797    /// # Panics
798    ///
799    /// If the two tensors don't have the same shape.
800    ///
801    /// # Example
802    ///
803    /// ```rust
804    /// use burn_tensor::backend::Backend;
805    /// use burn_tensor::{Tensor, Shape};
806    ///
807    /// fn example<B: Backend>() {
808    ///    let device = B::Device::default();
809    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
810    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
811    ///    let tensor = tensor1.lower_equal(tensor2);
812    ///    println!("{tensor}");
813    ///    // [[true, true, true], [false, false, false]]
814    /// }
815    /// ```
816    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
817        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
818        Tensor::new(K::lower_equal(self.primitive, other.primitive))
819    }
820
821    /// Applies greater than `other` comparison and returns a boolean tensor.
822    ///
823    /// # Arguments
824    ///
825    /// * `other` - The element to compare.
826    ///
827    /// # Example
828    ///
829    /// ```rust
830    /// use burn_tensor::backend::Backend;
831    /// use burn_tensor::{Tensor, Shape};
832    ///
833    /// fn example<B: Backend>() {
834    ///    let device = B::Device::default();
835    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
836    ///    let tensor = tensor.greater_elem(3.0);
837    ///    println!("{tensor}");
838    ///    // [[false, false, true], [true, true, true]]
839    /// }
840    /// ```
841    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
842        Tensor::new(K::greater_elem(self.primitive, other.elem()))
843    }
844
845    /// Applies greater-equal than `other` comparison and returns a boolean tensor.
846    ///
847    /// # Arguments
848    ///
849    /// * `other` - The element to compare.
850    ///
851    /// # Example
852    ///
853    /// ```rust
854    /// use burn_tensor::backend::Backend;
855    /// use burn_tensor::{Tensor, Shape};
856    ///
857    /// fn example<B: Backend>() {
858    ///    let device = B::Device::default();
859    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
860    ///    let tensor = tensor.greater_equal_elem(3.0);
861    ///    println!("{tensor}");
862    ///    // [[false, false, true], [true, true, true]]
863    /// }
864    /// ```
865    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
866        Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
867    }
868
869    /// Applies lower than `other` comparison and returns a boolean tensor.
870    ///
871    /// # Arguments
872    ///
873    /// * `other` - The element to compare.
874    ///
875    /// # Example
876    ///
877    /// ```rust
878    /// use burn_tensor::backend::Backend;
879    /// use burn_tensor::{Tensor, Shape};
880    ///
881    /// fn example<B: Backend>() {
882    ///     let device = B::Device::default();
883    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884    ///     let tensor = tensor.lower_elem(3.0);
885    ///     println!("{tensor}");
886    ///     // [[true, true, false], [false, false, false]]
887    /// }
888    /// ```
889    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
890        Tensor::new(K::lower_elem(self.primitive, other.elem()))
891    }
892
893    /// Applies lower-equal than `other` comparison and returns a boolean tensor.
894    ///
895    /// # Arguments
896    ///
897    /// * `other` - The element to compare.
898    ///
899    /// # Example
900    ///
901    /// ```rust
902    /// use burn_tensor::backend::Backend;
903    /// use burn_tensor::{Tensor, Shape};
904    ///
905    /// fn example<B: Backend>() {
906    ///    let device = B::Device::default();
907    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
908    ///    let tensor = tensor.lower_equal_elem(3.0);
909    ///    println!("{tensor}");
910    ///    // [[true, true, true], [false, false, false]]
911    /// }
912    /// ```
913    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
914        Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
915    }
916
917    /// Applies the argmax function along the given dimension and returns an integer tensor.
918    ///
919    /// # Example
920    ///
921    /// ```rust
922    /// use burn_tensor::backend::Backend;
923    /// use burn_tensor::{Tensor, Shape};
924    ///
925    /// fn example<B: Backend>() {
926    ///     let device = B::Device::default();
927    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
928    ///     let tensor = tensor.argmax(1);
929    ///     println!("{:?}", tensor.shape());
930    ///     // Shape { dims: [2, 1, 3] }
931    /// }
932    /// ```
933    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
934        Tensor::new(K::argmax(self.primitive, dim))
935    }
936
937    /// Find the maximum value.
938    ///
939    /// # Example
940    ///
941    /// ```rust
942    /// use burn_tensor::backend::Backend;
943    /// use burn_tensor::{Tensor, Shape};
944    ///
945    /// fn example<B: Backend>() {
946    ///   let device = B::Device::default();
947    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
948    ///   let tensor = tensor.max();
949    ///   println!("{tensor}");
950    ///   // [9.0]
951    /// }
952    /// ```
953    pub fn max(self) -> Tensor<B, 1, K> {
954        Tensor::new(K::max(self.primitive))
955    }
956
957    /// Find the maximum value along the given dimension.
958    ///
959    /// # Arguments
960    ///
961    /// * `dim` - The dimension or axis along which to aggregate the elements;
962    ///   supports negative indexing.
963    ///
964    /// # Returns
965    ///
966    /// The returned tensor will have the same rank,
967    /// but the aggregated dimension will have size 1.
968    ///
969    /// # Example
970    ///
971    /// ```rust
972    /// use burn_tensor::backend::Backend;
973    /// use burn_tensor::{Tensor, Shape};
974    ///
975    /// fn example<B: Backend>() {
976    ///   let device = B::Device::default();
977    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
978    ///   let tensor = tensor.max_dim(0);
979    ///   println!("{tensor}");
980    ///   // [[5.0, 9.0, 6.0]]
981    /// }
982    /// ```
983    pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
984        let dim = dim.expect_dim_index(D);
985        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
986        Tensor::new(K::max_dim(self.primitive, dim))
987    }
988
989    /// Find the maximum value along the given dimensions.
990    ///
991    /// # Arguments
992    ///
993    /// * `dims` - The dimensions or axis along which to aggregate the elements;
994    ///   supports negative indexing.
995    ///
996    /// # Returns
997    ///
998    /// The returned tensor will have the same rank,
999    /// but the aggregated dimensions will have size 1.
1000    ///
1001    /// # Example
1002    ///
1003    /// ```rust
1004    /// use burn_tensor::backend::Backend;
1005    /// use burn_tensor::{Tensor, Shape};
1006    ///
1007    /// fn example<B: Backend>() {
1008    ///   let device = B::Device::default();
1009    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1010    ///   let tensor = tensor.max_dims(&[0, 1]);
1011    ///   println!("{tensor}");
1012    ///   // [[9.0]]
1013    /// }
1014    /// ```
1015    pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1016        dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1017    }
1018
1019    /// Find the maximum value along the given dimension.
1020    ///
1021    /// Also returns the indices.
1022    ///
1023    /// # Example
1024    ///
1025    /// ```rust
1026    /// use burn_tensor::backend::Backend;
1027    /// use burn_tensor::{Tensor, Shape};
1028    ///
1029    /// fn example<B: Backend>() {
1030    ///    let device = B::Device::default();
1031    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1032    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
1033    ///    // [[5.0, 9.0, 6.0]]
1034    ///    println!("{tensor}");
1035    ///    // [[1, 1, 1]]
1036    ///    println!("{index}");
1037    /// }
1038    /// ```
1039    pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1040        let dim = dim.expect_dim_index(D);
1041        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1042
1043        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1044
1045        let tensor = Tensor::new(tensor);
1046        let index = Tensor::new(index);
1047
1048        (tensor, index)
1049    }
1050
1051    /// Finds the maximum pair wise values with another tensor.
1052    ///
1053    /// # Arguments
1054    ///
1055    /// * `other` - Other tensor to find maximum elements with
1056    ///
1057    /// # Returns
1058    ///
1059    /// A tensor with the same shape as the input tensors containing the maximum value found
1060    /// in the input tensors.
1061    ///
1062    /// # Example
1063    ///
1064    /// ```rust
1065    /// use burn_tensor::backend::Backend;
1066    /// use burn_tensor::{Tensor, Shape};
1067    ///
1068    /// fn example<B: Backend>() {
1069    ///    let device = B::Device::default();
1070    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1071    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1072    ///    let tensor = tensor1.max_pair(tensor2);
1073    ///    println!("{tensor}");
1074    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1075    /// }
1076    /// ```
1077    pub fn max_pair(self, other: Self) -> Self {
1078        let mask = self.clone().lower(other.clone());
1079        self.mask_where(mask, other)
1080    }
1081
1082    /// Find the maximum absolute value.
1083    ///
1084    /// # Example
1085    ///
1086    /// ```rust
1087    /// use burn_tensor::backend::Backend;
1088    /// use burn_tensor::{Tensor, Shape};
1089    ///
1090    /// fn example<B: Backend>() {
1091    ///   let device = B::Device::default();
1092    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1093    ///   let tensor = tensor.max_abs();
1094    ///   println!("{tensor}");
1095    ///   // [7.0]
1096    /// }
1097    /// ```
1098    pub fn max_abs(self) -> Tensor<B, 1, K> {
1099        Tensor::new(K::max_abs(self.primitive))
1100    }
1101
1102    /// Find the maximum absolute value along the given dimension.
1103    ///
1104    /// # Arguments
1105    ///
1106    /// * `dim` - The dimension or axis along which to aggregate the elements,
1107    ///   supports negative indexing.
1108    ///
1109    /// # Returns
1110    ///
1111    /// The returned tensor will have the same rank,
1112    /// but the aggregated dimension will have size 1.
1113    ///
1114    /// # Example
1115    ///
1116    /// ```rust
1117    /// use burn_tensor::backend::Backend;
1118    /// use burn_tensor::{Tensor, Shape};
1119    ///
1120    /// fn example<B: Backend>() {
1121    ///   let device = B::Device::default();
1122    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1123    ///   let tensor = tensor.max_dim(0);
1124    ///   println!("{tensor}");
1125    ///   // [[5.0, 9.0, 6.0]]
1126    /// }
1127    /// ```
1128    pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
1129        let dim = dim.expect_dim_index(D);
1130        check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1131
1132        Tensor::new(K::max_abs_dim(self.primitive, dim))
1133    }
1134
1135    /// Find the maximum absolute value along the given dimensions.
1136    ///
1137    /// # Arguments
1138    ///
1139    /// * `dims` - The dimensions or axes along which to aggregate the elements,
1140    ///   supports negative indexing.
1141    ///
1142    /// # Returns
1143    ///
1144    /// The returned tensor will have the same rank,
1145    /// but the aggregated dimensions will have size 1.
1146    ///
1147    /// # Example
1148    ///
1149    /// ```rust
1150    /// use burn_tensor::backend::Backend;
1151    /// use burn_tensor::{Tensor, Shape};
1152    ///
1153    /// fn example<B: Backend>() {
1154    ///   let device = B::Device::default();
1155    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1156    ///   let tensor = tensor.max_abs_dims(&[0, 1]);
1157    ///   println!("{tensor}");
1158    ///   // [[9.0]]
1159    /// }
1160    /// ```
1161    pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1162        dims.iter()
1163            .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
1164    }
1165
1166    /// Applies the argmin function along the given dimension and returns an integer tensor.
1167    ///
1168    /// # Example
1169    ///
1170    /// ```rust
1171    /// use burn_tensor::backend::Backend;
1172    /// use burn_tensor::{Tensor, Shape};
1173    ///
1174    /// fn example<B: Backend>() {
1175    ///     let device = Default::default();
1176    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1177    ///     let tensor = tensor.argmin(1);
1178    ///     println!("{:?}", tensor.shape());
1179    ///     // Shape { dims: [2, 1, 3] }
1180    /// }
1181    /// ```
1182    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1183        Tensor::new(K::argmin(self.primitive, dim))
1184    }
1185
1186    /// Find the minimum value.
1187    ///
1188    /// # Example
1189    ///
1190    /// ```rust
1191    /// use burn_tensor::backend::Backend;
1192    /// use burn_tensor::{Tensor, Shape};
1193    ///
1194    /// fn example<B: Backend>() {
1195    ///    let device = B::Device::default();
1196    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1197    ///    let tensor = tensor.min();
1198    ///    println!("{tensor}");
1199    ///    // [-2.0]
1200    /// }
1201    /// ```
1202    pub fn min(self) -> Tensor<B, 1, K> {
1203        Tensor::new(K::min(self.primitive))
1204    }
1205
1206    /// Find the minimum value along the given dimension.
1207    ///
1208    /// # Arguments
1209    ///
1210    /// * `dim` - The dimension or axis along which to aggregate the elements;
1211    ///   supports negative indexing.
1212    ///
1213    /// # Returns
1214    ///
1215    /// The returned tensor will have the same rank,
1216    /// but the aggregated dimension will have size 1.
1217    ///
1218    /// # Example
1219    ///
1220    /// ```rust
1221    /// use burn_tensor::backend::Backend;
1222    /// use burn_tensor::{Tensor, Shape};
1223    ///
1224    /// fn example<B: Backend>() {
1225    ///    let device = B::Device::default();
1226    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1227    ///    let tensor = tensor.min_dim(0);
1228    ///    println!("{tensor}");
1229    ///    // [[1.0, -2.0, 3.0]]
1230    /// }
1231    /// ```
1232    pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
1233        let dim = dim.expect_dim_index(D);
1234        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1235        Tensor::new(K::min_dim(self.primitive, dim))
1236    }
1237
1238    /// Find the minimum value along the given dimensions.
1239    ///
1240    /// # Arguments
1241    ///
1242    /// * `dims` - The dimensions or axes along which to aggregate the elements;
1243    ///   supports negative indexing.
1244    ///
1245    /// # Returns
1246    ///
1247    /// The returned tensor will have the same rank,
1248    /// but the aggregated dimensions will have size 1.
1249    ///
1250    /// # Example
1251    ///
1252    /// ```rust
1253    /// use burn_tensor::backend::Backend;
1254    /// use burn_tensor::{Tensor, Shape};
1255    ///
1256    /// fn example<B: Backend>() {
1257    ///   let device = B::Device::default();
1258    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1259    ///   let tensor = tensor.min_dims(&[0, 1]);
1260    ///   println!("{tensor}");
1261    ///   // [[-2.0]]
1262    /// }
1263    /// ```
1264    pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1265        dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
1266    }
1267
1268    /// Find the minimum value along the given dimension.
1269    ///
1270    /// Also returns the indices.
1271    ///
1272    /// # Example
1273    ///
1274    /// ```rust
1275    /// use burn_tensor::backend::Backend;
1276    /// use burn_tensor::{Tensor, Shape};
1277    ///
1278    /// fn example<B: Backend>() {
1279    ///    let device = B::Device::default();
1280    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1281    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
1282    ///    println!("{tensor}");
1283    ///    // [[5.0, -2.0, 3.0]]
1284    ///    println!("{}", index);
1285    ///    // [[1, 0, 0]]
1286    /// }
1287    /// ```
1288    pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1289        let dim = dim.expect_dim_index(D);
1290        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1291
1292        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1293
1294        let tensor = Tensor::new(tensor);
1295        let index = Tensor::new(index);
1296
1297        (tensor, index)
1298    }
1299
1300    /// Finds the minimum pair wise values with another tensor.
1301    ///
1302    /// # Arguments
1303    ///
1304    /// * `other` - Other tensor to find minimum elements with
1305    ///
1306    /// # Returns
1307    ///
1308    /// A tensor with the same shape as the input tensors containing the minimum value found
1309    /// between each element of the two source tensors.
1310    ///
1311    /// # Example
1312    ///
1313    /// ```rust
1314    /// use burn_tensor::backend::Backend;
1315    /// use burn_tensor::{Tensor, Shape};
1316    ///
1317    /// fn example<B: Backend>() {
1318    ///    let device = B::Device::default();
1319    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1320    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1321    ///    let tensor = tensor1.min_pair(tensor2);
1322    ///    println!("{tensor}");
1323    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1324    /// }
1325    pub fn min_pair(self, other: Self) -> Self {
1326        let mask = other.clone().lower(self.clone());
1327        self.mask_where(mask, other)
1328    }
1329
1330    /// Clamp element wise between the given min and max values.
1331    ///
1332    /// # Arguments
1333    ///
1334    /// * `min` - The minimum value.
1335    /// * `max` - The maximum value.
1336    ///
1337    /// # Returns
1338    ///
1339    /// A new tensor with the values clamped between the given min and max values.
1340    ///
1341    /// # Example
1342    ///
1343    /// ```rust
1344    /// use burn_tensor::backend::Backend;
1345    /// use burn_tensor::{Int, Tensor};
1346    ///
1347    /// fn example<B: Backend>() {
1348    ///   let device = Default::default();
1349    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
1350    ///    [
1351    ///     [1, 2, 3],
1352    ///     [4, 5, 6],
1353    ///     [7, 8, 9]
1354    ///    ],
1355    ///    &device);
1356    ///    let tensor = tensor.clamp(2, 6);
1357    ///    println!("{tensor}");
1358    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1359    /// }
1360    /// ```
1361    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1362        Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1363    }
1364
1365    /// Clamp element wise under a minimum value.
1366    ///
1367    /// # Arguments
1368    ///
1369    /// * `tensor` - The tensor to clamp.
1370    /// * `min` - The minimum value.
1371    ///
1372    /// # Returns
1373    ///
1374    /// A new tensor with the values clamped under the given min value.
1375    ///
1376    /// # Example
1377    ///
1378    /// ```rust
1379    /// use burn_tensor::backend::Backend;
1380    /// use burn_tensor::{Int, Tensor};
1381    ///
1382    /// fn example<B: Backend>() {
1383    ///    let device = Default::default();
1384    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1385    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1386    ///    &device);
1387    ///    let tensor = tensor.clamp_min(4);
1388    ///    println!("{tensor}");
1389    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1390    /// }
1391    /// ```
1392    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1393        Self::new(K::clamp_min(self.primitive, min.elem()))
1394    }
1395
1396    /// Clamp element wise over a maximum value.
1397    ///
1398    /// # Arguments
1399    ///
1400    /// * `tensor` - The tensor to clamp.
1401    /// * `max` - The maximum value.
1402    ///
1403    /// # Returns
1404    ///
1405    /// A new tensor with the values clamped over the given max value.
1406    ///
1407    /// # Example
1408    ///
1409    /// ```rust
1410    /// use burn_tensor::backend::Backend;
1411    /// use burn_tensor::{Int, Tensor};
1412    ///
1413    /// fn example<B: Backend>() {
1414    ///    let device = Default::default();
1415    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1416    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1417    ///    &device);
1418    ///    let tensor = tensor.clamp_max(5);
1419    ///    println!("{tensor}");
1420    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1421    /// }
1422    /// ```
1423    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1424        Self::new(K::clamp_max(self.primitive, max.elem()))
1425    }
1426
1427    /// Apply element wise absolute value operation.
1428    ///
1429    /// # Example
1430    ///
1431    /// ```rust
1432    /// use burn_tensor::backend::Backend;
1433    /// use burn_tensor::{Int, Tensor};
1434    ///
1435    /// fn example<B: Backend>() {
1436    ///   let device = Default::default();
1437    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1438    ///   let tensor = tensor.abs();
1439    ///   println!("{tensor}");
1440    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1441    /// }
1442    /// ```
1443    pub fn abs(self) -> Self {
1444        Self::new(K::abs(self.primitive))
1445    }
1446
1447    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1448    /// the other elements of the result tensor out are set to 0.
1449    ///
1450    /// See also [`triu_mask`](Tensor::triu_mask).
1451    ///
1452    /// # Arguments
1453    ///
1454    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1455    ///   towards the upper triangle.
1456    ///
1457    /// # Example
1458    /// ```rust
1459    /// use burn_tensor::backend::Backend;
1460    /// use burn_tensor::{Int, Tensor};
1461    ///
1462    /// fn example<B: Backend>() {
1463    ///    let device = Default::default();
1464    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1465    ///        [
1466    ///          [1, 2, 3],
1467    ///          [4, 5, 6],
1468    ///          [7, 8, 9]
1469    ///        ],
1470    ///        &device
1471    ///    );
1472    ///    let tensor = tensor.triu(1);
1473    ///    println!("{tensor}");
1474    ///    // [
1475    ///    //   [0, 2, 3],
1476    ///    //   [0, 0, 6],
1477    ///    //   [0, 0, 0]
1478    ///    // ]
1479    /// }
1480    /// ```
1481    pub fn triu(self, diagonal: i64) -> Self {
1482        check!(TensorCheck::tri::<{ D }>());
1483
1484        // last two dimensions
1485        let shape = &self.shape().dims[D - 2..].to_owned();
1486
1487        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1488        self.mask_fill(mask, 0)
1489    }
1490
1491    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1492    /// the other elements of the result tensor out are set to 0.
1493    ///
1494    /// See also [`tril_mask`](Tensor::tril_mask).
1495    ///
1496    /// # Arguments
1497    ///
1498    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1499    ///   towards the upper triangle.
1500    ///
1501    /// # Example
1502    /// ```rust
1503    /// use burn_tensor::backend::Backend;
1504    /// use burn_tensor::{Int, Tensor};
1505    ///
1506    /// fn example<B: Backend>() {
1507    ///    let device = Default::default();
1508    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1509    ///        [
1510    ///          [1, 2, 3],
1511    ///          [4, 5, 6],
1512    ///          [7, 8, 9]
1513    ///        ],
1514    ///        &device
1515    ///    );
1516    ///
1517    ///    let tensor = tensor.tril(-1);
1518    ///    println!("{tensor}");
1519    ///    // [
1520    ///    //   [0, 0, 0],
1521    ///    //   [4, 0, 0],
1522    ///    //   [7, 8, 0]
1523    ///    // ]
1524    /// }
1525    /// ```
1526    pub fn tril(self, diagonal: i64) -> Self {
1527        check!(TensorCheck::tri::<{ D }>());
1528
1529        // last two dimensions
1530        let shape = &self.shape().dims[D - 2..].to_owned();
1531        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1532
1533        self.mask_fill(mask, 0)
1534    }
1535
1536    /// Applies element wise power operation with a float Tensor
1537    ///
1538    /// # Arguments
1539    ///
1540    /// * `other` - The tensor to apply the power operation with.
1541    ///
1542    /// # Example
1543    ///
1544    /// ```rust
1545    /// use burn_tensor::backend::Backend;
1546    /// use burn_tensor::{Tensor, Shape};
1547    ///
1548    /// fn example<B: Backend>() {
1549    ///    let device = B::Device::default();
1550    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1551    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1552    ///    let tensor = tensor1.powf(tensor2);
1553    ///    println!("{tensor}");
1554    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1555    /// }
1556    /// ```
1557    pub fn powf(self, other: Self) -> Self {
1558        Self::new(K::powf(self.primitive, other.primitive))
1559    }
1560
1561    /// Applies element wise power operation with a float scalar
1562    ///
1563    /// # Arguments
1564    ///
1565    /// * `other` - The scalar to apply the power operation with.
1566    ///
1567    /// # Example
1568    ///
1569    /// ```rust
1570    /// use burn_tensor::backend::Backend;
1571    /// use burn_tensor::{Tensor, Shape};
1572    ///
1573    /// fn example<B: Backend>() {
1574    ///    let device = B::Device::default();
1575    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1576    ///    let tensor = tensor.powf_scalar(2.0);
1577    ///    println!("{tensor}");
1578    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1579    /// }
1580    /// ```
1581    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1582        Self::new(K::powf_scalar::<E>(self.primitive, other))
1583    }
1584
1585    /// Applies element wise power operation with a integer Tensor
1586    ///
1587    /// # Arguments
1588    ///
1589    /// * `other` - The tensor to apply the power operation with.
1590    ///
1591    /// # Example
1592    ///
1593    /// ```rust
1594    /// use burn_tensor::backend::Backend;
1595    /// use burn_tensor::{Tensor, Shape, Int};
1596    ///
1597    /// fn example<B: Backend>() {
1598    ///    let device = B::Device::default();
1599    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1600    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1601    ///    let tensor = tensor1.powi(tensor2);
1602    ///    println!("{tensor}");
1603    ///    // [[1, -8, 81], [5, 81, 216]]
1604    /// }
1605    /// ```
1606    pub fn powi(self, other: Self) -> Self {
1607        Self::new(K::powi(self.primitive, other.primitive))
1608    }
1609
1610    /// Applies element wise power operation with a integer scalar
1611    ///
1612    /// # Arguments
1613    ///
1614    /// * `other` - The scalar to apply the power operation with.
1615    ///
1616    /// # Example
1617    ///
1618    /// ```rust
1619    /// use burn_tensor::backend::Backend;
1620    /// use burn_tensor::{Tensor, Shape, Int};
1621    ///
1622    /// fn example<B: Backend>() {
1623    ///    let device = B::Device::default();
1624    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1625    ///    let tensor = tensor.powi_scalar(2);
1626    ///    println!("{tensor}");
1627    ///
1628    ///    // [[1, 4, 9], [25, 81, 36]]
1629    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1630    ///    let tensor = tensor.powi_scalar(2);
1631    ///    println!("{tensor}");
1632    ///    // [[2.25, 4., 9.], [25., 81., 36.]]
1633    /// }
1634    /// ```
1635    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1636        Self::new(K::powi_scalar::<E>(self.primitive, other))
1637    }
1638
1639    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1640    ///
1641    /// # Returns
1642    ///
1643    /// A boolean tensor with the same shape as the input tensor.
1644    ///
1645    /// # Example
1646    ///
1647    /// ```rust
1648    /// use burn_tensor::backend::Backend;
1649    /// use burn_tensor::{Tensor, Shape};
1650    ///
1651    /// fn example<B: Backend>() {
1652    ///   let device = B::Device::default();
1653    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1654    ///   let tensor = tensor.bool();
1655    ///   println!("{tensor}");
1656    ///   // [
1657    ///   //   [true, true, true],
1658    ///   //   [false, true, true]
1659    ///   // ]
1660    /// }
1661    pub fn bool(self) -> Tensor<B, D, Bool> {
1662        Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1663    }
1664
1665    /// Create a random tensor of the given shape on the given device where each element is
1666    /// sampled from the given distribution.
1667    ///
1668    /// See also [`random_like`](Tensor::random_like).
1669    ///
1670    /// # Arguments
1671    ///
1672    /// * `shape` - The shape of the tensor.
1673    /// * `distribution` - The distribution to sample from.
1674    /// * `device` - The device to create the tensor on.
1675    ///
1676    /// # Returns
1677    ///
1678    /// A new tensor with the given shape and elements sampled from the given distribution.
1679    ///
1680    /// # Example
1681    ///
1682    /// ```rust
1683    /// use burn_tensor::backend::Backend;
1684    /// use burn_tensor::{Tensor, Shape, Distribution};
1685    ///
1686    /// fn example<B: Backend>() {
1687    ///   let device = B::Device::default();
1688    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1689    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1690    ///   println!("{tensor}");
1691    ///   // [
1692    ///   //   [0.08347523, 0.70498955, 0.60332155],
1693    ///   //   [0.08173251, 0.18028641, 0.97942924]
1694    ///   // ]
1695    /// }
1696    /// ```
1697    pub fn random<S: Into<Shape>>(
1698        shape: S,
1699        distribution: Distribution,
1700        device: &B::Device,
1701    ) -> Self {
1702        Self::new(K::random(shape.into(), distribution, device))
1703    }
1704
1705    /// Sort the elements by value in ascending order along a given dimension.
1706    ///
1707    /// This sort is unstable (i.e., may reorder equal elements).
1708    ///
1709    /// # Arguments
1710    ///
1711    /// * `dim` - The dimension to sort along.
1712    ///
1713    /// # Returns
1714    ///
1715    /// A new tensor with the elements sorted in ascending order along the given dimension.
1716    ///
1717    /// # Example
1718    ///
1719    /// ```rust
1720    /// use burn_tensor::backend::Backend;
1721    /// use burn_tensor::{Tensor, Shape};
1722    ///
1723    /// fn example<B: Backend>() {
1724    ///   let device = B::Device::default();
1725    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1726    ///   let tensor = tensor.sort(0);
1727    ///   println!("{tensor}");
1728    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1729    ///   let tensor = tensor.sort(1);
1730    ///   println!("{tensor}");
1731    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1732    /// }
1733    /// ```
1734    pub fn sort(self, dim: usize) -> Self {
1735        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1736        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1737    }
1738
1739    /// Sort the elements by value in descending order along a given dimension.
1740    ///
1741    /// This sort is unstable (i.e., may reorder equal elements).
1742    ///
1743    /// # Arguments
1744    ///
1745    /// * `dim` - The dimension to sort along.
1746    ///
1747    /// # Returns
1748    ///
1749    /// A new tensor with the elements sorted in descending order along the given dimension.
1750    ///
1751    /// # Example
1752    ///
1753    /// ```rust
1754    /// use burn_tensor::backend::Backend;
1755    /// use burn_tensor::{Tensor, Shape};
1756    ///
1757    /// fn example<B: Backend>() {
1758    ///    let device = B::Device::default();
1759    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1760    ///    let tensor = tensor.sort_descending(0);
1761    ///    println!("{tensor}");
1762    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1763    ///    let tensor = tensor.sort_descending(1);
1764    ///    println!("{tensor}");
1765    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1766    /// }
1767    /// ```
1768    pub fn sort_descending(self, dim: usize) -> Self {
1769        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1770        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1771    }
1772
1773    /// Sort the elements by value in ascending order along a given dimension.
1774    /// Also returns the indices.
1775    ///
1776    /// This sort is unstable (i.e., may reorder equal elements).
1777    ///
1778    /// # Arguments
1779    ///
1780    /// * `dim` - The dimension to sort along.
1781    ///
1782    /// # Returns
1783    ///
1784    /// A tuple containing the sorted tensor and the indices tensor.
1785    ///
1786    /// # Example
1787    ///
1788    /// ```rust
1789    /// use burn_tensor::backend::Backend;
1790    /// use burn_tensor::{Tensor, Shape};
1791    ///
1792    /// fn example<B: Backend>() {
1793    ///   let device = B::Device::default();
1794    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1795    ///   let (tensor, indices) = tensor.sort_with_indices(0);
1796    ///   println!("{tensor}");
1797    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1798    ///   println!("{}", indices);
1799    ///   // [[1, 0, 0], [0, 1, 1]]
1800    /// }
1801    /// ```
1802    pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
1803        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1804        let (values, indices) =
1805            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1806        (Tensor::new(values), Tensor::new(indices))
1807    }
1808
1809    /// Sort the elements by value in descending order along a given dimension.
1810    /// Also returns the indices.
1811    ///
1812    /// This sort is unstable (i.e., may reorder equal elements).
1813    ///
1814    /// # Arguments
1815    ///
1816    /// * `dim` - The dimension to sort along.
1817    ///
1818    /// # Example
1819    ///
1820    /// ```rust
1821    /// use burn_tensor::backend::Backend;
1822    /// use burn_tensor::{Tensor, Shape};
1823    ///
1824    /// fn example<B: Backend>() {
1825    ///    let device = B::Device::default();
1826    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1827    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
1828    ///    println!("{tensor}");
1829    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1830    ///    println!("{}", indices);
1831    ///    // [[0, 1, 1], [1, 0, 0]]
1832    /// }
1833    /// ```
1834    pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
1835        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1836        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1837        (Tensor::new(values), Tensor::new(indices))
1838    }
1839
1840    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1841    ///
1842    /// This sort is unstable (i.e., may reorder equal elements).
1843    ///
1844    /// # Arguments
1845    ///
1846    /// * `dim` - The dimension to sort along.
1847    ///
1848    /// # Example
1849    ///
1850    /// ```rust
1851    /// use burn_tensor::backend::Backend;
1852    /// use burn_tensor::{Tensor, Shape};
1853    ///
1854    /// fn example<B: Backend>() {
1855    ///    let device = B::Device::default();
1856    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1857    ///    let tensor = tensor.argsort(0);
1858    ///    println!("{tensor}");
1859    ///    // [[1, 0, 0], [0, 1, 1]]
1860    /// }
1861    /// ```
1862    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1863        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1864        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1865    }
1866
1867    /// Returns the indices that sort the elements by value in descending order along a given dimension.
1868    ///
1869    /// This sort is unstable (i.e., may reorder equal elements).
1870    ///
1871    /// # Arguments
1872    ///
1873    /// * `dim` - The dimension to sort along.
1874    ///
1875    /// # Example
1876    ///
1877    /// ```rust
1878    /// use burn_tensor::backend::Backend;
1879    /// use burn_tensor::{Tensor, Shape};
1880    ///
1881    /// fn example<B: Backend>() {
1882    ///    let device = B::Device::default();
1883    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1884    ///    let tensor = tensor.argsort_descending(0);
1885    ///    println!("{tensor}");
1886    ///    // [[0, 1, 1], [1, 0, 0]]
1887    ///    let tensor = tensor.argsort_descending(1);
1888    ///    println!("{tensor}");
1889    ///    // [[0, 2, 1], [2, 0, 1]]
1890    /// }
1891    /// ```
1892    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1893        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1894        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1895    }
1896
1897    /// Returns the `k` largest elements of the given input tensor along a given dimension.
1898    ///
1899    /// # Arguments
1900    ///
1901    /// * `k` - The number of elements to return.
1902    ///
1903    /// # Returns
1904    ///
1905    /// A new tensor with the `k` largest elements along the given dimension.
1906    ///
1907    /// # Example
1908    ///
1909    /// ```rust
1910    /// use burn_tensor::backend::Backend;
1911    /// use burn_tensor::{Tensor, Shape};
1912    ///
1913    /// fn example<B: Backend>() {
1914    ///   let device = B::Device::default();
1915    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1916    ///   let tensor = tensor.topk(2, 0);
1917    ///   println!("{tensor}");
1918    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1919    ///   let tensor = tensor.topk(1, 1);
1920    ///   println!("{tensor}");
1921    ///   // [[12.0], [6.0]]
1922    /// }
1923    /// ```
1924    pub fn topk(self, k: usize, dim: usize) -> Self {
1925        let k_indices = Tensor::arange(0..k as i64, &self.device());
1926        self.sort_descending(dim).select(dim, k_indices)
1927    }
1928
1929    /// Returns the `k` largest elements of the given input tensor along a given dimension.
1930    /// Also returns the indices.
1931    ///
1932    /// # Arguments
1933    ///
1934    /// * `k` - The number of elements to return.
1935    /// * `dim` - The dimension to sort along.
1936    ///
1937    /// # Example
1938    ///
1939    /// ```rust
1940    /// use burn_tensor::backend::Backend;
1941    /// use burn_tensor::{Tensor, Shape};
1942    ///
1943    /// fn example<B: Backend>() {
1944    ///    let device = B::Device::default();
1945    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1946    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
1947    ///    println!("{tensor}");
1948    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1949    ///    println!("{}", indices);
1950    ///    // [[0, 1, 1], [1, 0, 0]]
1951    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
1952    ///    println!("{tensor}");
1953    ///    // [[12.0], [6.0]]
1954    ///    println!("{indices}");
1955    ///    // [[0], [2]]
1956    /// }
1957    /// ```
1958    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
1959        let k_indices = Tensor::arange(0..k as i64, &self.device());
1960        let (values, indices) = self.sort_descending_with_indices(dim);
1961        (
1962            values.select(dim, k_indices.clone()),
1963            indices.select(dim, k_indices),
1964        )
1965    }
1966
1967    /// Create a one hot tensor.
1968    ///
1969    /// # Example
1970    ///
1971    /// ```rust
1972    /// use burn_tensor::backend::Backend;
1973    /// use burn_tensor::Tensor;
1974    ///
1975    /// fn example<B: Backend>(){
1976    ///     let device = Default::default();
1977    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
1978    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);
1979    ///     println!("{}", one_hot.to_data());
1980    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
1981    /// }
1982    /// ```
1983    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
1984        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
1985        self.one_hot_fill(num_classes, 1.0, 0.0, -1)
1986    }
1987
1988    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
1989    ///
1990    /// # Arguments
1991    ///
1992    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
1993    /// * `on_value`: The value to assign for active positions (corresponding to indices).
1994    /// * `off_value`: The value to assign for inactive positions.
1995    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
1996    ///
1997    /// # Returns
1998    ///
1999    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2000    ///
2001    /// # Example
2002    /// ```rust
2003    /// use burn_tensor::backend::Backend;
2004    /// use burn_tensor::{Tensor, Float};
2005    /// fn example<B: Backend<FloatElem: From<f32>>>() {
2006    ///     let device = B::Device::default();
2007    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2008    ///     // One-hot encoding
2009    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2010    ///     println!("{tensor}");
2011    ///     // [[[5.0, 0.0, 0.0],
2012    ///     // [0.0, 0.0, 5.0]],
2013    ///     // [[0.0, 5.0, 0.0],
2014    ///     // [0.0, 0.0, 5.0]]]
2015    /// }
2016    /// ```
2017    pub fn one_hot_fill<const D2: usize>(
2018        self,
2019        num_classes: usize,
2020        on_value: f32,
2021        off_value: f32,
2022        axis: i64,
2023    ) -> Tensor<B, D2, K> {
2024        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2025        // Initialize shape from the current tensor dimensions and prepare for modification
2026        let mut shape = self.shape();
2027        let device = self.device();
2028        let rank = self.dims().len();
2029
2030        // Adjust negative axis to a positive index
2031        let axis = if axis < 0 {
2032            axis + rank as i64 + 1
2033        } else {
2034            axis
2035        };
2036
2037        // Ensure axis is within valid range
2038        if axis < 0 || axis > rank as i64 {
2039            panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2040        }
2041        // Convert the input tensor to integer indices
2042        let indices: Tensor<B, D, Int> =
2043            Tensor::from_data(self.to_data().convert::<i64>(), &device);
2044        // Insert the new dimension for the one-hot representation
2045        shape.insert(axis as usize, num_classes);
2046        // Adjust indices to valid range and handle invalid indices
2047        let adjusted_indices = indices
2048            .clone()
2049            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2050            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2051        // Unsqueeze the indices tensor along the specified axis
2052        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2053
2054        // Initialize the output tensor with the off_value
2055        let output = Tensor::full(shape.clone(), off_value, &device);
2056
2057        // Prepare scatter tensor for on_value and off_value adjustments
2058        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2059            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2060
2061        // Scatter on_value at the appropriate indices to create the one-hot representation
2062        output.scatter(
2063            axis as usize,
2064            indices_unsqueezed,
2065            scatter_on_values,
2066            IndexingUpdateOp::Add,
2067        )
2068    }
2069
2070    /// Applies the matrix multiplication operation.
2071    ///
2072    /// ```math
2073    /// C = AB
2074    /// ```
2075    pub fn matmul(self, other: Self) -> Self {
2076        check!(TensorCheck::matmul(&self, &other));
2077        Tensor::new(K::matmul(self.primitive, other.primitive))
2078    }
2079}
2080
2081impl<B, K> Tensor<B, 1, K>
2082where
2083    B: Backend,
2084    K: Numeric<B>,
2085    K::Elem: Element,
2086{
2087    /// Calculates the dot product with another tensor.
2088    ///
2089    /// `y = x2.dot(x1)`
2090    ///
2091    /// # Arguments
2092    ///
2093    /// * `other` - The tensor to compute dot product with.
2094    ///
2095    /// # Notes
2096    ///
2097    /// Both tensors must have the same number of elements.
2098    ///
2099    /// # Example
2100    ///
2101    /// ```rust
2102    /// use burn_tensor::backend::Backend;
2103    /// use burn_tensor::{Tensor, Shape};
2104    ///
2105    /// fn example<B: Backend>() {
2106    ///    let device = B::Device::default();
2107    ///    let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
2108    ///    let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
2109    ///    let tensor = tensor1.dot(tensor2);
2110    ///    println!("{tensor}");
2111    ///    // [4]
2112    /// }
2113    /// ```
2114    pub fn dot(self, other: Self) -> Self {
2115        self.mul(other).sum()
2116    }
2117}
2118
2119impl<B, K> Tensor<B, 2, K>
2120where
2121    B: Backend,
2122    K: Numeric<B>,
2123    K::Elem: Element,
2124{
2125    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2126    ///
2127    /// # Arguments
2128    ///
2129    /// * `size` - The size of the square matrix.
2130    pub fn eye(size: usize, device: &B::Device) -> Self {
2131        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2132        let ones = Self::ones([1, size], device);
2133        let zeros = Self::zeros([size, size], device);
2134
2135        zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
2136    }
2137}
2138
2139// Tensor + tensor
2140impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
2141where
2142    K::Elem: Element,
2143{
2144    type Output = Self;
2145
2146    fn add(self, rhs: Self) -> Self::Output {
2147        Self::add(self, rhs)
2148    }
2149}
2150
2151// Tensor + scalar
2152impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
2153    for Tensor<B, D, K>
2154where
2155    K::Elem: Element,
2156{
2157    type Output = Self;
2158
2159    fn add(self, other: E) -> Self::Output {
2160        Tensor::add_scalar(self, other)
2161    }
2162}
2163
2164// Scalar + tensor
2165macro_rules! impl_tensor_scalar_add {
2166    ($($t:ty),*) => {
2167        $(
2168            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
2169            where
2170                K::Elem: Element,
2171            {
2172                type Output = Tensor<B, D, K>;
2173
2174                fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
2175                    Tensor::add_scalar(tensor, self)
2176                }
2177            }
2178        )*
2179    }
2180}
2181impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
2182
2183// Tensor - tensor
2184impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
2185where
2186    K::Elem: Element,
2187{
2188    type Output = Self;
2189
2190    fn sub(self, rhs: Self) -> Self::Output {
2191        Tensor::sub(self, rhs)
2192    }
2193}
2194
2195// Tensor - scalar
2196impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
2197    for Tensor<B, D, K>
2198where
2199    K::Elem: Element,
2200{
2201    type Output = Self;
2202
2203    fn sub(self, other: E) -> Self::Output {
2204        Tensor::sub_scalar(self, other)
2205    }
2206}
2207
2208// Scalar - tensor
2209macro_rules! impl_tensor_scalar_sub {
2210    ($($t:ty),*) => {
2211        $(
2212            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
2213            where
2214                K::Elem: Element,
2215            {
2216                type Output = Tensor<B, D, K>;
2217
2218                fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
2219                    Tensor::add_scalar(Tensor::neg(tensor), self)
2220                }
2221            }
2222        )*
2223    }
2224}
2225impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
2226
2227// Tensor / tensor
2228impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
2229where
2230    K::Elem: Element,
2231{
2232    type Output = Self;
2233
2234    fn div(self, rhs: Self) -> Self::Output {
2235        Tensor::div(self, rhs)
2236    }
2237}
2238
2239// Tensor / scalar
2240impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
2241    for Tensor<B, D, K>
2242where
2243    K::Elem: Element,
2244{
2245    type Output = Self;
2246
2247    fn div(self, other: E) -> Self::Output {
2248        Tensor::div_scalar(self, other)
2249    }
2250}
2251
2252// Scalar / tensor (float only)
2253macro_rules! impl_tensor_scalar_div {
2254    ($($t:ty),*) => {
2255        $(
2256            impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
2257            {
2258                type Output = Tensor<B, D>;
2259
2260                fn div(self, tensor: Tensor<B, D>) -> Self::Output {
2261                    tensor.recip().mul_scalar(self)
2262                }
2263            }
2264        )*
2265    }
2266}
2267
2268impl_tensor_scalar_div!(f32, f64);
2269
2270// Tensor % tensor.
2271impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
2272where
2273    K::Elem: Element,
2274{
2275    type Output = Self;
2276
2277    fn rem(self, rhs: Self) -> Self::Output {
2278        Tensor::remainder(self, rhs)
2279    }
2280}
2281
2282// Tensor % scalar.
2283impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
2284    for Tensor<B, D, K>
2285where
2286    K::Elem: Element,
2287{
2288    type Output = Self;
2289
2290    fn rem(self, other: E) -> Self::Output {
2291        Tensor::remainder_scalar(self, other)
2292    }
2293}
2294
2295// Tensor * tensor.
2296impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
2297where
2298    K::Elem: Element,
2299{
2300    type Output = Self;
2301
2302    fn mul(self, rhs: Self) -> Self::Output {
2303        Tensor::mul(self, rhs)
2304    }
2305}
2306
2307// Tensor * scalar.
2308impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
2309    for Tensor<B, D, K>
2310where
2311    K::Elem: Element,
2312{
2313    type Output = Self;
2314
2315    fn mul(self, other: E) -> Self::Output {
2316        Tensor::mul_scalar(self, other)
2317    }
2318}
2319
2320macro_rules! impl_tensor_scalar_mul {
2321    ($($t:ty),*) => {
2322        $(
2323            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
2324            where
2325                K::Elem: Element,
2326            {
2327                type Output = Tensor<B, D, K>;
2328
2329                fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
2330                    Tensor::mul_scalar(other, self)
2331                }
2332            }
2333        )*
2334    }
2335}
2336
2337impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
2338
2339impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
2340where
2341    B: Backend,
2342    K: Numeric<B>,
2343    K::Elem: Element,
2344{
2345    type Output = Self;
2346
2347    fn neg(self) -> Self::Output {
2348        Tensor::neg(self)
2349    }
2350}