Skip to main content

burn_tensor/tensor/api/
orderable.rs

1use burn_backend::{
2    Backend, ElementConversion, Scalar,
3    tensor::{Bool, IndexingUpdateOp, Int, Ordered},
4};
5use burn_std::AsIndex;
6
7use crate::check;
8use crate::{Tensor, check::TensorCheck};
9
10impl<B, const D: usize, K> Tensor<B, D, K>
11where
12    B: Backend,
13    K: Ordered<B>,
14{
15    /// Sort the elements by value in ascending order along a given dimension.
16    ///
17    /// This sort is unstable (i.e., may reorder equal elements).
18    ///
19    /// # Arguments
20    ///
21    /// * `dim` - The dimension to sort along.
22    ///
23    /// # Returns
24    ///
25    /// A new tensor with the elements sorted in ascending order along the given dimension.
26    ///
27    /// # Example
28    ///
29    /// ```rust
30    /// use burn_tensor::backend::Backend;
31    /// use burn_tensor::{Tensor, Shape};
32    ///
33    /// fn example<B: Backend>() {
34    ///   let device = B::Device::default();
35    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
36    ///   let tensor = tensor.sort(0);
37    ///   println!("{tensor}");
38    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
39    ///   let tensor = tensor.sort(1);
40    ///   println!("{tensor}");
41    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
42    /// }
43    /// ```
44    pub fn sort(self, dim: usize) -> Self {
45        check!(TensorCheck::sort_dim::<D>("Sort", dim));
46        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
47    }
48
49    /// Sort the elements by value in descending order along a given dimension.
50    ///
51    /// This sort is unstable (i.e., may reorder equal elements).
52    ///
53    /// # Arguments
54    ///
55    /// * `dim` - The dimension to sort along.
56    ///
57    /// # Returns
58    ///
59    /// A new tensor with the elements sorted in descending order along the given dimension.
60    ///
61    /// # Example
62    ///
63    /// ```rust
64    /// use burn_tensor::backend::Backend;
65    /// use burn_tensor::{Tensor, Shape};
66    ///
67    /// fn example<B: Backend>() {
68    ///    let device = B::Device::default();
69    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
70    ///    let tensor = tensor.sort_descending(0);
71    ///    println!("{tensor}");
72    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
73    ///    let tensor = tensor.sort_descending(1);
74    ///    println!("{tensor}");
75    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
76    /// }
77    /// ```
78    pub fn sort_descending(self, dim: usize) -> Self {
79        check!(TensorCheck::sort_dim::<D>("Sort", dim));
80        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
81    }
82
83    /// Sort the elements by value in ascending order along a given dimension.
84    /// Also returns the indices.
85    ///
86    /// This sort is unstable (i.e., may reorder equal elements).
87    ///
88    /// # Arguments
89    ///
90    /// * `dim` - The dimension to sort along.
91    ///
92    /// # Returns
93    ///
94    /// A tuple containing the sorted tensor and the indices tensor.
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// use burn_tensor::backend::Backend;
100    /// use burn_tensor::{Tensor, Shape};
101    ///
102    /// fn example<B: Backend>() {
103    ///   let device = B::Device::default();
104    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
105    ///   let (tensor, indices) = tensor.sort_with_indices(0);
106    ///   println!("{tensor}");
107    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
108    ///   println!("{}", indices);
109    ///   // [[1, 0, 0], [0, 1, 1]]
110    /// }
111    /// ```
112    pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
113        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
114        let (values, indices) =
115            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
116        (Tensor::new(values), Tensor::new(indices))
117    }
118
119    /// Sort the elements by value in descending order along a given dimension.
120    /// Also returns the indices.
121    ///
122    /// This sort is unstable (i.e., may reorder equal elements).
123    ///
124    /// # Arguments
125    ///
126    /// * `dim` - The dimension to sort along.
127    ///
128    /// # Example
129    ///
130    /// ```rust
131    /// use burn_tensor::backend::Backend;
132    /// use burn_tensor::{Tensor, Shape};
133    ///
134    /// fn example<B: Backend>() {
135    ///    let device = B::Device::default();
136    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
137    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
138    ///    println!("{tensor}");
139    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
140    ///    println!("{}", indices);
141    ///    // [[0, 1, 1], [1, 0, 0]]
142    /// }
143    /// ```
144    pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
145        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
146        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
147        (Tensor::new(values), Tensor::new(indices))
148    }
149
150    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
151    ///
152    /// This sort is unstable (i.e., may reorder equal elements).
153    ///
154    /// # Arguments
155    ///
156    /// * `dim` - The dimension to sort along.
157    ///
158    /// # Example
159    ///
160    /// ```rust
161    /// use burn_tensor::backend::Backend;
162    /// use burn_tensor::{Tensor, Shape};
163    ///
164    /// fn example<B: Backend>() {
165    ///    let device = B::Device::default();
166    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
167    ///    let tensor = tensor.argsort(0);
168    ///    println!("{tensor}");
169    ///    // [[1, 0, 0], [0, 1, 1]]
170    /// }
171    /// ```
172    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
173        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
174        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
175    }
176
177    /// Returns the indices that sort the elements by value in descending order along a given dimension.
178    ///
179    /// This sort is unstable (i.e., may reorder equal elements).
180    ///
181    /// # Arguments
182    ///
183    /// * `dim` - The dimension to sort along.
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// use burn_tensor::backend::Backend;
189    /// use burn_tensor::{Tensor, Shape};
190    ///
191    /// fn example<B: Backend>() {
192    ///    let device = B::Device::default();
193    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
194    ///    let tensor = tensor.argsort_descending(0);
195    ///    println!("{tensor}");
196    ///    // [[0, 1, 1], [1, 0, 0]]
197    ///    let tensor = tensor.argsort_descending(1);
198    ///    println!("{tensor}");
199    ///    // [[0, 2, 1], [2, 0, 1]]
200    /// }
201    /// ```
202    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
203        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
204        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
205    }
206
207    /// Returns the `k` largest elements of the given input tensor along a given dimension.
208    ///
209    /// # Arguments
210    ///
211    /// * `k` - The number of elements to return.
212    ///
213    /// # Returns
214    ///
215    /// A new tensor with the `k` largest elements along the given dimension.
216    ///
217    /// # Example
218    ///
219    /// ```rust
220    /// use burn_tensor::backend::Backend;
221    /// use burn_tensor::{Tensor, Shape};
222    ///
223    /// fn example<B: Backend>() {
224    ///   let device = B::Device::default();
225    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
226    ///   let tensor = tensor.topk(2, 0);
227    ///   println!("{tensor}");
228    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
229    ///   let tensor = tensor.topk(1, 1);
230    ///   println!("{tensor}");
231    ///   // [[12.0], [6.0]]
232    /// }
233    /// ```
234    pub fn topk(self, k: usize, dim: usize) -> Self {
235        let k_indices = Tensor::arange(0..k as i64, &self.device());
236        self.sort_descending(dim).select(dim, k_indices)
237    }
238
239    /// Returns the `k` largest elements of the given input tensor along a given dimension.
240    /// Also returns the indices.
241    ///
242    /// # Arguments
243    ///
244    /// * `k` - The number of elements to return.
245    /// * `dim` - The dimension to sort along.
246    ///
247    /// # Example
248    ///
249    /// ```rust
250    /// use burn_tensor::backend::Backend;
251    /// use burn_tensor::{Tensor, Shape};
252    ///
253    /// fn example<B: Backend>() {
254    ///    let device = B::Device::default();
255    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
256    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
257    ///    println!("{tensor}");
258    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
259    ///    println!("{}", indices);
260    ///    // [[0, 1, 1], [1, 0, 0]]
261    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
262    ///    println!("{tensor}");
263    ///    // [[12.0], [6.0]]
264    ///    println!("{indices}");
265    ///    // [[0], [2]]
266    /// }
267    /// ```
268    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
269        let k_indices = Tensor::arange(0..k as i64, &self.device());
270        let (values, indices) = self.sort_descending_with_indices(dim);
271        (
272            values.select(dim, k_indices.clone()),
273            indices.select(dim, k_indices),
274        )
275    }
276
277    /// Create a one hot tensor.
278    ///
279    /// # Example
280    ///
281    /// ```rust
282    /// use burn_tensor::backend::Backend;
283    /// use burn_tensor::Tensor;
284    ///
285    /// fn example<B: Backend>(){
286    ///     let device = Default::default();
287    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
288    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);
289    ///     println!("{}", one_hot.to_data());
290    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
291    /// }
292    /// ```
293    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
294        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
295        self.one_hot_fill(num_classes, 1.0, 0.0, -1)
296    }
297
298    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
299    ///
300    /// # Arguments
301    ///
302    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
303    /// * `on_value`: The value to assign for active positions (corresponding to indices).
304    /// * `off_value`: The value to assign for inactive positions.
305    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
306    ///
307    /// # Returns
308    ///
309    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
310    ///
311    /// # Example
312    /// ```rust
313    /// use burn_tensor::backend::Backend;
314    /// use burn_tensor::{Tensor, Float};
315    /// fn example<B: Backend<FloatElem: From<f32>>>() {
316    ///     let device = B::Device::default();
317    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
318    ///     // One-hot encoding
319    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
320    ///     println!("{tensor}");
321    ///     // [[[5.0, 0.0, 0.0],
322    ///     // [0.0, 0.0, 5.0]],
323    ///     // [[0.0, 5.0, 0.0],
324    ///     // [0.0, 0.0, 5.0]]]
325    /// }
326    /// ```
327    pub fn one_hot_fill<const D2: usize>(
328        self,
329        num_classes: usize,
330        on_value: f32,
331        off_value: f32,
332        axis: i64,
333    ) -> Tensor<B, D2, K> {
334        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
335        // Initialize shape from the current tensor dimensions and prepare for modification
336        let mut shape = self.shape();
337        let device = self.device();
338        let rank = self.dims().len();
339
340        // Adjust negative axis to a positive index
341        let axis = if axis < 0 {
342            axis + rank as i64 + 1
343        } else {
344            axis
345        };
346
347        // Ensure axis is within valid range
348        if axis < 0 || axis > rank as i64 {
349            panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
350        }
351        // Convert the input tensor to integer indices
352        let indices: Tensor<B, D, Int> =
353            Tensor::from_data(self.to_data().convert::<i64>(), &device);
354        // Insert the new dimension for the one-hot representation
355        shape.insert(axis as usize, num_classes);
356        // Adjust indices to valid range and handle invalid indices
357        let adjusted_indices = indices
358            .clone()
359            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
360            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
361        // Unsqueeze the indices tensor along the specified axis
362        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
363
364        // Initialize the output tensor with the off_value
365        let output = Tensor::full(shape.clone(), off_value, &device);
366
367        // Prepare scatter tensor for on_value and off_value adjustments
368        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
369            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
370
371        // Scatter on_value at the appropriate indices to create the one-hot representation
372        output.scatter(
373            axis as usize,
374            indices_unsqueezed,
375            scatter_on_values,
376            IndexingUpdateOp::Add,
377        )
378    }
379
380    /// Applies element wise greater comparison and returns a boolean tensor.
381    ///
382    /// # Panics
383    ///
384    /// If the two tensors don't have the same shape.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// use burn_tensor::backend::Backend;
390    /// use burn_tensor::{Tensor, Shape};
391    ///
392    /// fn example<B: Backend>() {
393    ///   let device = B::Device::default();
394    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
396    ///   let tensor = tensor1.greater(tensor2);
397    ///   println!("{tensor}");
398    ///   // [[false, false, false], [true, true, true]]
399    /// }
400    /// ```
401    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
402        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
403        Tensor::new(K::greater(self.primitive, other.primitive))
404    }
405
406    /// Applies element wise greater-equal comparison and returns a boolean tensor.
407    ///
408    /// # Panics
409    ///
410    /// If the two tensors don't have the same shape.
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// use burn_tensor::backend::Backend;
416    /// use burn_tensor::{Tensor, Shape};
417    ///
418    /// fn example<B: Backend>() {
419    ///    let device = B::Device::default();
420    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
421    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
422    ///    let tensor = tensor1.greater_equal(tensor2);
423    ///    println!("{tensor}");
424    ///    // [[true, false, false], [true, true, true]]
425    /// }
426    /// ```
427    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
428        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
429        Tensor::new(K::greater_equal(self.primitive, other.primitive))
430    }
431
432    /// Applies element wise lower comparison and returns a boolean tensor.
433    ///
434    /// # Panics
435    ///
436    /// If the two tensors don't have the same shape.
437    ///
438    /// # Example
439    ///
440    /// ```rust
441    /// use burn_tensor::backend::Backend;
442    /// use burn_tensor::{Tensor, Shape};
443    ///
444    /// fn example<B: Backend>() {
445    ///    let device = B::Device::default();
446    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
447    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
448    ///    let tensor = tensor1.lower(tensor2);
449    ///    println!("{tensor}");
450    ///    // [[false, true, true], [false, false, false]]
451    /// }
452    /// ```
453    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
454        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
455        Tensor::new(K::lower(self.primitive, other.primitive))
456    }
457
458    /// Applies element wise lower-equal comparison and returns a boolean tensor.
459    ///
460    /// # Panics
461    ///
462    /// If the two tensors don't have the same shape.
463    ///
464    /// # Example
465    ///
466    /// ```rust
467    /// use burn_tensor::backend::Backend;
468    /// use burn_tensor::{Tensor, Shape};
469    ///
470    /// fn example<B: Backend>() {
471    ///    let device = B::Device::default();
472    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
473    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
474    ///    let tensor = tensor1.lower_equal(tensor2);
475    ///    println!("{tensor}");
476    ///    // [[true, true, true], [false, false, false]]
477    /// }
478    /// ```
479    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
480        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
481        Tensor::new(K::lower_equal(self.primitive, other.primitive))
482    }
483
484    /// Applies greater than `other` comparison and returns a boolean tensor.
485    ///
486    /// # Arguments
487    ///
488    /// * `other` - The element to compare.
489    ///
490    /// # Example
491    ///
492    /// ```rust
493    /// use burn_tensor::backend::Backend;
494    /// use burn_tensor::{Tensor, Shape};
495    ///
496    /// fn example<B: Backend>() {
497    ///    let device = B::Device::default();
498    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
499    ///    let tensor = tensor.greater_elem(3.0);
500    ///    println!("{tensor}");
501    ///    // [[false, false, true], [true, true, true]]
502    /// }
503    /// ```
504    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
505        let other = Scalar::new(other, &self.dtype());
506        Tensor::new(K::greater_elem(self.primitive, other))
507    }
508
509    /// Applies greater-equal than `other` comparison and returns a boolean tensor.
510    ///
511    /// # Arguments
512    ///
513    /// * `other` - The element to compare.
514    ///
515    /// # Example
516    ///
517    /// ```rust
518    /// use burn_tensor::backend::Backend;
519    /// use burn_tensor::{Tensor, Shape};
520    ///
521    /// fn example<B: Backend>() {
522    ///    let device = B::Device::default();
523    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
524    ///    let tensor = tensor.greater_equal_elem(3.0);
525    ///    println!("{tensor}");
526    ///    // [[false, false, true], [true, true, true]]
527    /// }
528    /// ```
529    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
530        let other = Scalar::new(other, &self.dtype());
531        Tensor::new(K::greater_equal_elem(self.primitive, other))
532    }
533
534    /// Applies lower than `other` comparison and returns a boolean tensor.
535    ///
536    /// # Arguments
537    ///
538    /// * `other` - The element to compare.
539    ///
540    /// # Example
541    ///
542    /// ```rust
543    /// use burn_tensor::backend::Backend;
544    /// use burn_tensor::{Tensor, Shape};
545    ///
546    /// fn example<B: Backend>() {
547    ///     let device = B::Device::default();
548    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
549    ///     let tensor = tensor.lower_elem(3.0);
550    ///     println!("{tensor}");
551    ///     // [[true, true, false], [false, false, false]]
552    /// }
553    /// ```
554    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
555        let other = Scalar::new(other, &self.dtype());
556        Tensor::new(K::lower_elem(self.primitive, other))
557    }
558
559    /// Applies lower-equal than `other` comparison and returns a boolean tensor.
560    ///
561    /// # Arguments
562    ///
563    /// * `other` - The element to compare.
564    ///
565    /// # Example
566    ///
567    /// ```rust
568    /// use burn_tensor::backend::Backend;
569    /// use burn_tensor::{Tensor, Shape};
570    ///
571    /// fn example<B: Backend>() {
572    ///    let device = B::Device::default();
573    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
574    ///    let tensor = tensor.lower_equal_elem(3.0);
575    ///    println!("{tensor}");
576    ///    // [[true, true, true], [false, false, false]]
577    /// }
578    /// ```
579    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
580        let other = Scalar::new(other, &self.dtype());
581        Tensor::new(K::lower_equal_elem(self.primitive, other))
582    }
583
584    /// Applies the argmax function along the given dimension and returns an integer tensor.
585    ///
586    /// # Example
587    ///
588    /// ```rust
589    /// use burn_tensor::backend::Backend;
590    /// use burn_tensor::{Tensor, Shape};
591    ///
592    /// fn example<B: Backend>() {
593    ///     let device = B::Device::default();
594    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
595    ///     let tensor = tensor.argmax(1);
596    ///     println!("{:?}", tensor.shape());
597    ///     // Shape { dims: [2, 1, 3] }
598    /// }
599    /// ```
600    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
601        Tensor::new(K::argmax(self.primitive, dim))
602    }
603
604    /// Find the maximum value.
605    ///
606    /// # Example
607    ///
608    /// ```rust
609    /// use burn_tensor::backend::Backend;
610    /// use burn_tensor::{Tensor, Shape};
611    ///
612    /// fn example<B: Backend>() {
613    ///   let device = B::Device::default();
614    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
615    ///   let tensor = tensor.max();
616    ///   println!("{tensor}");
617    ///   // [9.0]
618    /// }
619    /// ```
620    pub fn max(self) -> Tensor<B, 1, K> {
621        Tensor::new(K::max(self.primitive))
622    }
623
624    /// Find the maximum value along the given dimension.
625    ///
626    /// Also returns the indices.
627    ///
628    /// # Example
629    ///
630    /// ```rust
631    /// use burn_tensor::backend::Backend;
632    /// use burn_tensor::{Tensor, Shape};
633    ///
634    /// fn example<B: Backend>() {
635    ///    let device = B::Device::default();
636    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
637    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
638    ///    // [[5.0, 9.0, 6.0]]
639    ///    println!("{tensor}");
640    ///    // [[1, 1, 1]]
641    ///    println!("{index}");
642    /// }
643    /// ```
644    pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
645        let dim = dim.expect_dim_index(D);
646        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
647
648        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
649
650        let tensor = Tensor::new(tensor);
651        let index = Tensor::new(index);
652
653        (tensor, index)
654    }
655
656    /// Find the maximum absolute value.
657    ///
658    /// # Example
659    ///
660    /// ```rust
661    /// use burn_tensor::backend::Backend;
662    /// use burn_tensor::{Tensor, Shape};
663    ///
664    /// fn example<B: Backend>() {
665    ///   let device = B::Device::default();
666    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
667    ///   let tensor = tensor.max_abs();
668    ///   println!("{tensor}");
669    ///   // [7.0]
670    /// }
671    /// ```
672    pub fn max_abs(self) -> Tensor<B, 1, K> {
673        Tensor::new(K::max_abs(self.primitive))
674    }
675
676    /// Finds the maximum pair wise values with another tensor.
677    ///
678    /// # Arguments
679    ///
680    /// * `other` - Other tensor to find maximum elements with
681    ///
682    /// # Returns
683    ///
684    /// A tensor with the same shape as the input tensors containing the maximum value found
685    /// in the input tensors.
686    ///
687    /// # Example
688    ///
689    /// ```rust
690    /// use burn_tensor::backend::Backend;
691    /// use burn_tensor::{Tensor, Shape};
692    ///
693    /// fn example<B: Backend>() {
694    ///    let device = B::Device::default();
695    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
696    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
697    ///    let tensor = tensor1.max_pair(tensor2);
698    ///    println!("{tensor}");
699    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
700    /// }
701    /// ```
702    pub fn max_pair(self, other: Self) -> Self {
703        let mask = self.clone().lower(other.clone());
704        self.mask_where(mask, other)
705    }
706
707    /// Find the maximum absolute value along the given dimension.
708    ///
709    /// # Arguments
710    ///
711    /// * `dim` - The dimension or axis along which to aggregate the elements,
712    ///   supports negative indexing.
713    ///
714    /// # Returns
715    ///
716    /// The returned tensor will have the same rank,
717    /// but the aggregated dimension will have size 1.
718    ///
719    /// # Example
720    ///
721    /// ```rust
722    /// use burn_tensor::backend::Backend;
723    /// use burn_tensor::{Tensor, Shape};
724    ///
725    /// fn example<B: Backend>() {
726    ///   let device = B::Device::default();
727    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
728    ///   let tensor = tensor.max_dim(0);
729    ///   println!("{tensor}");
730    ///   // [[5.0, 9.0, 6.0]]
731    /// }
732    /// ```
733    pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
734        let dim = dim.expect_dim_index(D);
735        check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
736
737        Tensor::new(K::max_abs_dim(self.primitive, dim))
738    }
739
740    /// Find the maximum absolute value along the given dimensions.
741    ///
742    /// # Arguments
743    ///
744    /// * `dims` - The dimensions or axes along which to aggregate the elements,
745    ///   supports negative indexing.
746    ///
747    /// # Returns
748    ///
749    /// The returned tensor will have the same rank,
750    /// but the aggregated dimensions will have size 1.
751    ///
752    /// # Example
753    ///
754    /// ```rust
755    /// use burn_tensor::backend::Backend;
756    /// use burn_tensor::{Tensor, Shape};
757    ///
758    /// fn example<B: Backend>() {
759    ///   let device = B::Device::default();
760    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
761    ///   let tensor = tensor.max_abs_dims(&[0, 1]);
762    ///   println!("{tensor}");
763    ///   // [[9.0]]
764    /// }
765    /// ```
766    pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
767        dims.iter()
768            .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
769    }
770
771    /// Applies the argmin function along the given dimension and returns an integer tensor.
772    ///
773    /// # Example
774    ///
775    /// ```rust
776    /// use burn_tensor::backend::Backend;
777    /// use burn_tensor::{Tensor, Shape};
778    ///
779    /// fn example<B: Backend>() {
780    ///     let device = Default::default();
781    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
782    ///     let tensor = tensor.argmin(1);
783    ///     println!("{:?}", tensor.shape());
784    ///     // Shape { dims: [2, 1, 3] }
785    /// }
786    /// ```
787    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
788        Tensor::new(K::argmin(self.primitive, dim))
789    }
790
791    /// Find the minimum value.
792    ///
793    /// # Example
794    ///
795    /// ```rust
796    /// use burn_tensor::backend::Backend;
797    /// use burn_tensor::{Tensor, Shape};
798    ///
799    /// fn example<B: Backend>() {
800    ///    let device = B::Device::default();
801    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
802    ///    let tensor = tensor.min();
803    ///    println!("{tensor}");
804    ///    // [-2.0]
805    /// }
806    /// ```
807    pub fn min(self) -> Tensor<B, 1, K> {
808        Tensor::new(K::min(self.primitive))
809    }
810
811    /// Find the minimum value along the given dimension.
812    ///
813    /// # Arguments
814    ///
815    /// * `dim` - The dimension or axis along which to aggregate the elements;
816    ///   supports negative indexing.
817    ///
818    /// # Returns
819    ///
820    /// The returned tensor will have the same rank,
821    /// but the aggregated dimension will have size 1.
822    ///
823    /// # Example
824    ///
825    /// ```rust
826    /// use burn_tensor::backend::Backend;
827    /// use burn_tensor::{Tensor, Shape};
828    ///
829    /// fn example<B: Backend>() {
830    ///    let device = B::Device::default();
831    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
832    ///    let tensor = tensor.min_dim(0);
833    ///    println!("{tensor}");
834    ///    // [[1.0, -2.0, 3.0]]
835    /// }
836    /// ```
837    pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
838        let dim = dim.expect_dim_index(D);
839        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
840        Tensor::new(K::min_dim(self.primitive, dim))
841    }
842
843    /// Find the minimum value along the given dimensions.
844    ///
845    /// # Arguments
846    ///
847    /// * `dims` - The dimensions or axes along which to aggregate the elements;
848    ///   supports negative indexing.
849    ///
850    /// # Returns
851    ///
852    /// The returned tensor will have the same rank,
853    /// but the aggregated dimensions will have size 1.
854    ///
855    /// # Example
856    ///
857    /// ```rust
858    /// use burn_tensor::backend::Backend;
859    /// use burn_tensor::{Tensor, Shape};
860    ///
861    /// fn example<B: Backend>() {
862    ///   let device = B::Device::default();
863    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
864    ///   let tensor = tensor.min_dims(&[0, 1]);
865    ///   println!("{tensor}");
866    ///   // [[-2.0]]
867    /// }
868    /// ```
869    pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
870        dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
871    }
872
873    /// Find the minimum value along the given dimension.
874    ///
875    /// Also returns the indices.
876    ///
877    /// # Example
878    ///
879    /// ```rust
880    /// use burn_tensor::backend::Backend;
881    /// use burn_tensor::{Tensor, Shape};
882    ///
883    /// fn example<B: Backend>() {
884    ///    let device = B::Device::default();
885    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
886    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
887    ///    println!("{tensor}");
888    ///    // [[5.0, -2.0, 3.0]]
889    ///    println!("{}", index);
890    ///    // [[1, 0, 0]]
891    /// }
892    /// ```
893    pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
894        let dim = dim.expect_dim_index(D);
895        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
896
897        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
898
899        let tensor = Tensor::new(tensor);
900        let index = Tensor::new(index);
901
902        (tensor, index)
903    }
904
905    /// Finds the minimum pair wise values with another tensor.
906    ///
907    /// # Arguments
908    ///
909    /// * `other` - Other tensor to find minimum elements with
910    ///
911    /// # Returns
912    ///
913    /// A tensor with the same shape as the input tensors containing the minimum value found
914    /// between each element of the two source tensors.
915    ///
916    /// # Example
917    ///
918    /// ```rust
919    /// use burn_tensor::backend::Backend;
920    /// use burn_tensor::{Tensor, Shape};
921    ///
922    /// fn example<B: Backend>() {
923    ///    let device = B::Device::default();
924    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
925    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
926    ///    let tensor = tensor1.min_pair(tensor2);
927    ///    println!("{tensor}");
928    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
929    /// }
930    pub fn min_pair(self, other: Self) -> Self {
931        let mask = other.clone().lower(self.clone());
932        self.mask_where(mask, other)
933    }
934
935    /// Clamp element wise between the given min and max values.
936    ///
937    /// # Arguments
938    ///
939    /// * `min` - The minimum value.
940    /// * `max` - The maximum value.
941    ///
942    /// # Returns
943    ///
944    /// A new tensor with the values clamped between the given min and max values.
945    ///
946    /// # Example
947    ///
948    /// ```rust
949    /// use burn_tensor::backend::Backend;
950    /// use burn_tensor::{Int, Tensor};
951    ///
952    /// fn example<B: Backend>() {
953    ///   let device = Default::default();
954    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
955    ///    [
956    ///     [1, 2, 3],
957    ///     [4, 5, 6],
958    ///     [7, 8, 9]
959    ///    ],
960    ///    &device);
961    ///    let tensor = tensor.clamp(2, 6);
962    ///    println!("{tensor}");
963    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
964    /// }
965    /// ```
966    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
967        let dtype = self.dtype();
968        Self::new(K::clamp(
969            self.primitive,
970            Scalar::new(min, &dtype),
971            Scalar::new(max, &dtype),
972        ))
973    }
974
975    /// Clamp element wise under a minimum value.
976    ///
977    /// # Arguments
978    ///
979    /// * `tensor` - The tensor to clamp.
980    /// * `min` - The minimum value.
981    ///
982    /// # Returns
983    ///
984    /// A new tensor with the values clamped under the given min value.
985    ///
986    /// # Example
987    ///
988    /// ```rust
989    /// use burn_tensor::backend::Backend;
990    /// use burn_tensor::{Int, Tensor};
991    ///
992    /// fn example<B: Backend>() {
993    ///    let device = Default::default();
994    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
995    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
996    ///    &device);
997    ///    let tensor = tensor.clamp_min(4);
998    ///    println!("{tensor}");
999    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1000    /// }
1001    /// ```
1002    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1003        let min = Scalar::new(min, &self.dtype());
1004        Self::new(K::clamp_min(self.primitive, min))
1005    }
1006
1007    /// Clamp element wise over a maximum value.
1008    ///
1009    /// # Arguments
1010    ///
1011    /// * `tensor` - The tensor to clamp.
1012    /// * `max` - The maximum value.
1013    ///
1014    /// # Returns
1015    ///
1016    /// A new tensor with the values clamped over the given max value.
1017    ///
1018    /// # Example
1019    ///
1020    /// ```rust
1021    /// use burn_tensor::backend::Backend;
1022    /// use burn_tensor::{Int, Tensor};
1023    ///
1024    /// fn example<B: Backend>() {
1025    ///    let device = Default::default();
1026    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1027    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1028    ///    &device);
1029    ///    let tensor = tensor.clamp_max(5);
1030    ///    println!("{tensor}");
1031    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1032    /// }
1033    /// ```
1034    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1035        let max = Scalar::new(max, &self.dtype());
1036        Self::new(K::clamp_max(self.primitive, max))
1037    }
1038
1039    /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
1040    ///
1041    /// # Arguments
1042    ///
1043    /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
1044    ///
1045    /// # Example
1046    ///
1047    /// ```rust
1048    /// use burn_tensor::backend::Backend;
1049    /// use burn_tensor::{Tensor, Shape};
1050    ///
1051    /// fn example<B: Backend>() {
1052    ///    let device = B::Device::default();
1053    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
1054    ///    let result = tensor.clone().cummin(0);
1055    ///    println!("{result}");
1056    ///    // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
1057    ///    let result = tensor.cummin(1);
1058    ///    println!("{result}");
1059    ///    // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
1060    /// }
1061    /// ```
1062    pub fn cummin(self, dim: usize) -> Self {
1063        check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
1064        Self::new(K::cummin(self.primitive, dim))
1065    }
1066
1067    /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
1068    ///
1069    /// # Arguments
1070    ///
1071    /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
1072    ///
1073    /// # Example
1074    ///
1075    /// ```rust
1076    /// use burn_tensor::backend::Backend;
1077    /// use burn_tensor::{Tensor, Shape};
1078    ///
1079    /// fn example<B: Backend>() {
1080    ///    let device = B::Device::default();
1081    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
1082    ///    let result = tensor.clone().cummax(0);
1083    ///    println!("{result}");
1084    ///    // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
1085    ///    let result = tensor.cummax(1);
1086    ///    println!("{result}");
1087    ///    // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
1088    /// }
1089    /// ```
1090    pub fn cummax(self, dim: usize) -> Self {
1091        check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
1092        Self::new(K::cummax(self.primitive, dim))
1093    }
1094    /// Find the maximum value along the given dimension.
1095    ///
1096    /// # Arguments
1097    ///
1098    /// * `dim` - The dimension or axis along which to aggregate the elements;
1099    ///   supports negative indexing.
1100    ///
1101    /// # Returns
1102    ///
1103    /// The returned tensor will have the same rank,
1104    /// but the aggregated dimension will have size 1.
1105    ///
1106    /// # Example
1107    ///
1108    /// ```rust
1109    /// use burn_tensor::backend::Backend;
1110    /// use burn_tensor::{Tensor, Shape};
1111    ///
1112    /// fn example<B: Backend>() {
1113    ///   let device = B::Device::default();
1114    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1115    ///   let tensor = tensor.max_dim(0);
1116    ///   println!("{tensor}");
1117    ///   // [[5.0, 9.0, 6.0]]
1118    /// }
1119    /// ```
1120    pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1121        let dim = dim.expect_dim_index(D);
1122        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1123        Tensor::new(K::max_dim(self.primitive, dim))
1124    }
1125
1126    /// Find the maximum value along the given dimensions.
1127    ///
1128    /// # Arguments
1129    ///
1130    /// * `dims` - The dimensions or axis along which to aggregate the elements;
1131    ///   supports negative indexing.
1132    ///
1133    /// # Returns
1134    ///
1135    /// The returned tensor will have the same rank,
1136    /// but the aggregated dimensions will have size 1.
1137    ///
1138    /// # Example
1139    ///
1140    /// ```rust
1141    /// use burn_tensor::backend::Backend;
1142    /// use burn_tensor::{Tensor, Shape};
1143    ///
1144    /// fn example<B: Backend>() {
1145    ///   let device = B::Device::default();
1146    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1147    ///   let tensor = tensor.max_dims(&[0, 1]);
1148    ///   println!("{tensor}");
1149    ///   // [[9.0]]
1150    /// }
1151    /// ```
1152    pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1153        dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1154    }
1155}