Skip to main content

burn_tensor/tensor/api/
orderable.rs

1use burn_backend::{
2    Backend, ElementConversion, Scalar,
3    tensor::{Bool, IndexingUpdateOp, Int, Ordered},
4};
5use burn_std::AsIndex;
6
7use crate::check;
8use crate::{Tensor, check::TensorCheck};
9
10impl<B, const D: usize, K> Tensor<B, D, K>
11where
12    B: Backend,
13    K: Ordered<B>,
14{
15    /// Sort the elements by value in ascending order along a given dimension.
16    ///
17    /// This sort is unstable (i.e., may reorder equal elements).
18    ///
19    /// # Arguments
20    ///
21    /// * `dim` - The dimension to sort along.
22    ///
23    /// # Returns
24    ///
25    /// A new tensor with the elements sorted in ascending order along the given dimension.
26    ///
27    /// # Example
28    ///
29    /// ```rust
30    /// use burn_tensor::backend::Backend;
31    /// use burn_tensor::{Tensor, Shape};
32    ///
33    /// fn example<B: Backend>() {
34    ///   let device = B::Device::default();
35    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
36    ///   let tensor = tensor.sort(0);
37    ///   println!("{tensor}");
38    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
39    ///   let tensor = tensor.sort(1);
40    ///   println!("{tensor}");
41    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
42    /// }
43    /// ```
44    pub fn sort(self, dim: usize) -> Self {
45        check!(TensorCheck::sort_dim::<D>("Sort", dim));
46        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
47    }
48
49    /// Sort the elements by value in descending order along a given dimension.
50    ///
51    /// This sort is unstable (i.e., may reorder equal elements).
52    ///
53    /// # Arguments
54    ///
55    /// * `dim` - The dimension to sort along.
56    ///
57    /// # Returns
58    ///
59    /// A new tensor with the elements sorted in descending order along the given dimension.
60    ///
61    /// # Example
62    ///
63    /// ```rust
64    /// use burn_tensor::backend::Backend;
65    /// use burn_tensor::{Tensor, Shape};
66    ///
67    /// fn example<B: Backend>() {
68    ///    let device = B::Device::default();
69    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
70    ///    let tensor = tensor.sort_descending(0);
71    ///    println!("{tensor}");
72    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
73    ///    let tensor = tensor.sort_descending(1);
74    ///    println!("{tensor}");
75    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
76    /// }
77    /// ```
78    pub fn sort_descending(self, dim: usize) -> Self {
79        check!(TensorCheck::sort_dim::<D>("Sort", dim));
80        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
81    }
82
83    /// Sort the elements by value in ascending order along a given dimension.
84    /// Also returns the indices.
85    ///
86    /// This sort is unstable (i.e., may reorder equal elements).
87    ///
88    /// # Arguments
89    ///
90    /// * `dim` - The dimension to sort along.
91    ///
92    /// # Returns
93    ///
94    /// A tuple containing the sorted tensor and the indices tensor.
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// use burn_tensor::backend::Backend;
100    /// use burn_tensor::{Tensor, Shape};
101    ///
102    /// fn example<B: Backend>() {
103    ///   let device = B::Device::default();
104    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
105    ///   let (tensor, indices) = tensor.sort_with_indices(0);
106    ///   println!("{tensor}");
107    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
108    ///   println!("{}", indices);
109    ///   // [[1, 0, 0], [0, 1, 1]]
110    /// }
111    /// ```
112    pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
113        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
114        let (values, indices) =
115            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
116        (Tensor::new(values), Tensor::new(indices))
117    }
118
119    /// Sort the elements by value in descending order along a given dimension.
120    /// Also returns the indices.
121    ///
122    /// This sort is unstable (i.e., may reorder equal elements).
123    ///
124    /// # Arguments
125    ///
126    /// * `dim` - The dimension to sort along.
127    ///
128    /// # Example
129    ///
130    /// ```rust
131    /// use burn_tensor::backend::Backend;
132    /// use burn_tensor::{Tensor, Shape};
133    ///
134    /// fn example<B: Backend>() {
135    ///    let device = B::Device::default();
136    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
137    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
138    ///    println!("{tensor}");
139    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
140    ///    println!("{}", indices);
141    ///    // [[0, 1, 1], [1, 0, 0]]
142    /// }
143    /// ```
144    pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
145        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
146        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
147        (Tensor::new(values), Tensor::new(indices))
148    }
149
150    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
151    ///
152    /// This sort is unstable (i.e., may reorder equal elements).
153    ///
154    /// # Arguments
155    ///
156    /// * `dim` - The dimension to sort along.
157    ///
158    /// # Example
159    ///
160    /// ```rust
161    /// use burn_tensor::backend::Backend;
162    /// use burn_tensor::{Tensor, Shape};
163    ///
164    /// fn example<B: Backend>() {
165    ///    let device = B::Device::default();
166    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
167    ///    let tensor = tensor.argsort(0);
168    ///    println!("{tensor}");
169    ///    // [[1, 0, 0], [0, 1, 1]]
170    /// }
171    /// ```
172    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
173        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
174        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
175    }
176
177    /// Returns the indices that sort the elements by value in descending order along a given dimension.
178    ///
179    /// This sort is unstable (i.e., may reorder equal elements).
180    ///
181    /// # Arguments
182    ///
183    /// * `dim` - The dimension to sort along.
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// use burn_tensor::backend::Backend;
189    /// use burn_tensor::{Tensor, Shape};
190    ///
191    /// fn example<B: Backend>() {
192    ///    let device = B::Device::default();
193    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
194    ///    let tensor = tensor.argsort_descending(0);
195    ///    println!("{tensor}");
196    ///    // [[0, 1, 1], [1, 0, 0]]
197    ///    let tensor = tensor.argsort_descending(1);
198    ///    println!("{tensor}");
199    ///    // [[0, 2, 1], [2, 0, 1]]
200    /// }
201    /// ```
202    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
203        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
204        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
205    }
206
207    /// Returns the `k` largest elements of the given input tensor along a given dimension.
208    ///
209    /// # Arguments
210    ///
211    /// * `k` - The number of elements to return.
212    ///
213    /// # Returns
214    ///
215    /// A new tensor with the `k` largest elements along the given dimension.
216    ///
217    /// # Example
218    ///
219    /// ```rust
220    /// use burn_tensor::backend::Backend;
221    /// use burn_tensor::{Tensor, Shape};
222    ///
223    /// fn example<B: Backend>() {
224    ///   let device = B::Device::default();
225    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
226    ///   let tensor = tensor.topk(2, 0);
227    ///   println!("{tensor}");
228    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
229    ///   let tensor = tensor.topk(1, 1);
230    ///   println!("{tensor}");
231    ///   // [[12.0], [6.0]]
232    /// }
233    /// ```
234    pub fn topk(self, k: usize, dim: usize) -> Self {
235        assert!(self.shape()[dim] > k);
236        Tensor::new(K::topk(self.primitive, dim, k))
237    }
238
239    /// Returns the `k` largest elements of the given input tensor along a given dimension.
240    /// Also returns the indices.
241    ///
242    /// # Arguments
243    ///
244    /// * `k` - The number of elements to return.
245    /// * `dim` - The dimension to sort along.
246    ///
247    /// # Example
248    ///
249    /// ```rust
250    /// use burn_tensor::backend::Backend;
251    /// use burn_tensor::{Tensor, Shape};
252    ///
253    /// fn example<B: Backend>() {
254    ///    let device = B::Device::default();
255    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
256    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
257    ///    println!("{tensor}");
258    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
259    ///    println!("{}", indices);
260    ///    // [[0, 1, 1], [1, 0, 0]]
261    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
262    ///    println!("{tensor}");
263    ///    // [[12.0], [6.0]]
264    ///    println!("{indices}");
265    ///    // [[0], [2]]
266    /// }
267    /// ```
268    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
269        let k_indices = Tensor::arange(0..k as i64, &self.device());
270        let (values, indices) = self.sort_descending_with_indices(dim);
271        (
272            values.select(dim, k_indices.clone()),
273            indices.select(dim, k_indices),
274        )
275    }
276
277    /// Create a one hot tensor.
278    ///
279    /// # Example
280    ///
281    /// ```rust
282    /// use burn_tensor::backend::Backend;
283    /// use burn_tensor::Tensor;
284    ///
285    /// fn example<B: Backend>(){
286    ///     let device = Default::default();
287    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
288    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);
289    ///     println!("{}", one_hot.to_data());
290    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
291    /// }
292    /// ```
293    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
294        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
295        self.one_hot_fill(num_classes, 1.0, 0.0, -1)
296    }
297
298    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
299    ///
300    /// # Arguments
301    ///
302    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
303    /// * `on_value`: The value to assign for active positions (corresponding to indices).
304    /// * `off_value`: The value to assign for inactive positions.
305    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
306    ///
307    /// # Returns
308    ///
309    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
310    ///
311    /// # Example
312    /// ```rust
313    /// use burn_tensor::backend::Backend;
314    /// use burn_tensor::{Tensor, Float};
315    /// fn example<B: Backend<FloatElem: From<f32>>>() {
316    ///     let device = B::Device::default();
317    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
318    ///     // One-hot encoding
319    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
320    ///     println!("{tensor}");
321    ///     // [[[5.0, 0.0, 0.0],
322    ///     // [0.0, 0.0, 5.0]],
323    ///     // [[0.0, 5.0, 0.0],
324    ///     // [0.0, 0.0, 5.0]]]
325    /// }
326    /// ```
327    pub fn one_hot_fill<const D2: usize>(
328        self,
329        num_classes: usize,
330        on_value: f32,
331        off_value: f32,
332        axis: i64,
333    ) -> Tensor<B, D2, K> {
334        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
335        // Initialize shape from the current tensor dimensions and prepare for modification
336        let mut shape = self.shape();
337        let device = self.device();
338        let rank = self.dims().len();
339
340        // Adjust negative axis to a positive index
341        let axis = if axis < 0 {
342            axis + rank as i64 + 1
343        } else {
344            axis
345        };
346
347        // Ensure axis is within valid range
348        if axis < 0 || axis > rank as i64 {
349            panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
350        }
351        // Convert the input tensor to integer indices
352        let indices: Tensor<B, D, Int> =
353            Tensor::from_data(self.to_data().convert::<i64>(), &device);
354        // Insert the new dimension for the one-hot representation
355        shape.insert(axis as usize, num_classes);
356        // Adjust indices to valid range and handle invalid indices
357        let adjusted_indices = indices
358            .clone()
359            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
360            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
361        // Unsqueeze the indices tensor along the specified axis
362        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
363
364        // Initialize the output tensor with the off_value
365        let output = Tensor::full(shape.clone(), off_value, &device);
366
367        // Prepare scatter tensor for on_value and off_value adjustments
368        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
369            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
370
371        // Scatter on_value at the appropriate indices to create the one-hot representation
372        output.scatter(
373            axis as usize,
374            indices_unsqueezed,
375            scatter_on_values,
376            IndexingUpdateOp::Add,
377        )
378    }
379
380    /// Applies element wise greater comparison and returns a boolean tensor.
381    ///
382    /// # Panics
383    ///
384    /// If the two tensors don't have the same shape.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// use burn_tensor::backend::Backend;
390    /// use burn_tensor::{Tensor, Shape};
391    ///
392    /// fn example<B: Backend>() {
393    ///   let device = B::Device::default();
394    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
396    ///   let tensor = tensor1.greater(tensor2);
397    ///   println!("{tensor}");
398    ///   // [[false, false, false], [true, true, true]]
399    /// }
400    /// ```
401    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
402        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
403        Tensor::new(K::greater(self.primitive, other.primitive))
404    }
405
406    /// Applies element wise greater-equal comparison and returns a boolean tensor.
407    ///
408    /// # Panics
409    ///
410    /// If the two tensors don't have the same shape.
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// use burn_tensor::backend::Backend;
416    /// use burn_tensor::{Tensor, Shape};
417    ///
418    /// fn example<B: Backend>() {
419    ///    let device = B::Device::default();
420    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
421    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
422    ///    let tensor = tensor1.greater_equal(tensor2);
423    ///    println!("{tensor}");
424    ///    // [[true, false, false], [true, true, true]]
425    /// }
426    /// ```
427    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
428        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
429        Tensor::new(K::greater_equal(self.primitive, other.primitive))
430    }
431
432    /// Applies element wise lower comparison and returns a boolean tensor.
433    ///
434    /// # Panics
435    ///
436    /// If the two tensors don't have the same shape.
437    ///
438    /// # Example
439    ///
440    /// ```rust
441    /// use burn_tensor::backend::Backend;
442    /// use burn_tensor::{Tensor, Shape};
443    ///
444    /// fn example<B: Backend>() {
445    ///    let device = B::Device::default();
446    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
447    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
448    ///    let tensor = tensor1.lower(tensor2);
449    ///    println!("{tensor}");
450    ///    // [[false, true, true], [false, false, false]]
451    /// }
452    /// ```
453    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
454        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
455        Tensor::new(K::lower(self.primitive, other.primitive))
456    }
457
458    /// Applies element wise lower-equal comparison and returns a boolean tensor.
459    ///
460    /// # Panics
461    ///
462    /// If the two tensors don't have the same shape.
463    ///
464    /// # Example
465    ///
466    /// ```rust
467    /// use burn_tensor::backend::Backend;
468    /// use burn_tensor::{Tensor, Shape};
469    ///
470    /// fn example<B: Backend>() {
471    ///    let device = B::Device::default();
472    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
473    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
474    ///    let tensor = tensor1.lower_equal(tensor2);
475    ///    println!("{tensor}");
476    ///    // [[true, true, true], [false, false, false]]
477    /// }
478    /// ```
479    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
480        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
481        Tensor::new(K::lower_equal(self.primitive, other.primitive))
482    }
483
484    /// Applies greater than `other` comparison and returns a boolean tensor.
485    ///
486    /// # Arguments
487    ///
488    /// * `other` - The element to compare.
489    ///
490    /// # Example
491    ///
492    /// ```rust
493    /// use burn_tensor::backend::Backend;
494    /// use burn_tensor::{Tensor, Shape};
495    ///
496    /// fn example<B: Backend>() {
497    ///    let device = B::Device::default();
498    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
499    ///    let tensor = tensor.greater_elem(3.0);
500    ///    println!("{tensor}");
501    ///    // [[false, false, true], [true, true, true]]
502    /// }
503    /// ```
504    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
505        let other = Scalar::new(other, &self.dtype());
506        Tensor::new(K::greater_elem(self.primitive, other))
507    }
508
509    /// Applies greater-equal than `other` comparison and returns a boolean tensor.
510    ///
511    /// # Arguments
512    ///
513    /// * `other` - The element to compare.
514    ///
515    /// # Example
516    ///
517    /// ```rust
518    /// use burn_tensor::backend::Backend;
519    /// use burn_tensor::{Tensor, Shape};
520    ///
521    /// fn example<B: Backend>() {
522    ///    let device = B::Device::default();
523    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
524    ///    let tensor = tensor.greater_equal_elem(3.0);
525    ///    println!("{tensor}");
526    ///    // [[false, false, true], [true, true, true]]
527    /// }
528    /// ```
529    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
530        let other = Scalar::new(other, &self.dtype());
531        Tensor::new(K::greater_equal_elem(self.primitive, other))
532    }
533
534    /// Applies lower than `other` comparison and returns a boolean tensor.
535    ///
536    /// # Arguments
537    ///
538    /// * `other` - The element to compare.
539    ///
540    /// # Example
541    ///
542    /// ```rust
543    /// use burn_tensor::backend::Backend;
544    /// use burn_tensor::{Tensor, Shape};
545    ///
546    /// fn example<B: Backend>() {
547    ///     let device = B::Device::default();
548    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
549    ///     let tensor = tensor.lower_elem(3.0);
550    ///     println!("{tensor}");
551    ///     // [[true, true, false], [false, false, false]]
552    /// }
553    /// ```
554    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
555        let other = Scalar::new(other, &self.dtype());
556        Tensor::new(K::lower_elem(self.primitive, other))
557    }
558
559    /// Applies lower-equal than `other` comparison and returns a boolean tensor.
560    ///
561    /// # Arguments
562    ///
563    /// * `other` - The element to compare.
564    ///
565    /// # Example
566    ///
567    /// ```rust
568    /// use burn_tensor::backend::Backend;
569    /// use burn_tensor::{Tensor, Shape};
570    ///
571    /// fn example<B: Backend>() {
572    ///    let device = B::Device::default();
573    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
574    ///    let tensor = tensor.lower_equal_elem(3.0);
575    ///    println!("{tensor}");
576    ///    // [[true, true, true], [false, false, false]]
577    /// }
578    /// ```
579    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
580        let other = Scalar::new(other, &self.dtype());
581        Tensor::new(K::lower_equal_elem(self.primitive, other))
582    }
583
584    /// Applies the argmax function along the given dimension and returns an integer tensor.
585    ///
586    /// # Example
587    ///
588    /// ```rust
589    /// use burn_tensor::backend::Backend;
590    /// use burn_tensor::{Tensor, Shape};
591    ///
592    /// fn example<B: Backend>() {
593    ///     let device = B::Device::default();
594    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
595    ///     let tensor = tensor.argmax(1);
596    ///     println!("{:?}", tensor.shape());
597    ///     // Shape { dims: [2, 1, 3] }
598    /// }
599    /// ```
600    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
601        Tensor::new(K::argmax(self.primitive, dim))
602    }
603
604    /// Applies the argtopk function along the given dimension and returns an integer tensor.
605    ///
606    /// # Example
607    ///
608    /// ```rust
609    /// use burn_tensor::backend::Backend;
610    /// use burn_tensor::{Tensor, Shape};
611    ///
612    /// fn example<B: Backend>() {
613    ///     let device = B::Device::default();
614    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
615    ///     let tensor = tensor.argtopk(1, 2);
616    ///     println!("{:?}", tensor.shape());
617    /// }
618    /// ```
619    pub fn argtopk(self, k: usize, dim: usize) -> Tensor<B, D, Int> {
620        assert!(self.shape()[dim] > k);
621        Tensor::new(K::argtopk(self.primitive, dim, k))
622    }
623
624    /// Find the maximum value.
625    ///
626    /// # Example
627    ///
628    /// ```rust
629    /// use burn_tensor::backend::Backend;
630    /// use burn_tensor::{Tensor, Shape};
631    ///
632    /// fn example<B: Backend>() {
633    ///   let device = B::Device::default();
634    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
635    ///   let tensor = tensor.max();
636    ///   println!("{tensor}");
637    ///   // [9.0]
638    /// }
639    /// ```
640    pub fn max(self) -> Tensor<B, 1, K> {
641        Tensor::new(K::max(self.primitive))
642    }
643
644    /// Find the maximum value along the given dimension.
645    ///
646    /// Also returns the indices.
647    ///
648    /// # Example
649    ///
650    /// ```rust
651    /// use burn_tensor::backend::Backend;
652    /// use burn_tensor::{Tensor, Shape};
653    ///
654    /// fn example<B: Backend>() {
655    ///    let device = B::Device::default();
656    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
657    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
658    ///    // [[5.0, 9.0, 6.0]]
659    ///    println!("{tensor}");
660    ///    // [[1, 1, 1]]
661    ///    println!("{index}");
662    /// }
663    /// ```
664    pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
665        let dim = dim.expect_dim_index(D);
666        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
667
668        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
669
670        let tensor = Tensor::new(tensor);
671        let index = Tensor::new(index);
672
673        (tensor, index)
674    }
675
676    /// Find the maximum absolute value.
677    ///
678    /// # Example
679    ///
680    /// ```rust
681    /// use burn_tensor::backend::Backend;
682    /// use burn_tensor::{Tensor, Shape};
683    ///
684    /// fn example<B: Backend>() {
685    ///   let device = B::Device::default();
686    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
687    ///   let tensor = tensor.max_abs();
688    ///   println!("{tensor}");
689    ///   // [7.0]
690    /// }
691    /// ```
692    pub fn max_abs(self) -> Tensor<B, 1, K> {
693        Tensor::new(K::max_abs(self.primitive))
694    }
695
696    /// Finds the maximum pair wise values with another tensor.
697    ///
698    /// # Arguments
699    ///
700    /// * `other` - Other tensor to find maximum elements with
701    ///
702    /// # Returns
703    ///
704    /// A tensor with the same shape as the input tensors containing the maximum value found
705    /// in the input tensors.
706    ///
707    /// # Example
708    ///
709    /// ```rust
710    /// use burn_tensor::backend::Backend;
711    /// use burn_tensor::{Tensor, Shape};
712    ///
713    /// fn example<B: Backend>() {
714    ///    let device = B::Device::default();
715    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
716    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
717    ///    let tensor = tensor1.max_pair(tensor2);
718    ///    println!("{tensor}");
719    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
720    /// }
721    /// ```
722    pub fn max_pair(self, other: Self) -> Self {
723        let mask = self.clone().lower(other.clone());
724        self.mask_where(mask, other)
725    }
726
727    /// Find the maximum absolute value along the given dimension.
728    ///
729    /// # Arguments
730    ///
731    /// * `dim` - The dimension or axis along which to aggregate the elements,
732    ///   supports negative indexing.
733    ///
734    /// # Returns
735    ///
736    /// The returned tensor will have the same rank,
737    /// but the aggregated dimension will have size 1.
738    ///
739    /// # Example
740    ///
741    /// ```rust
742    /// use burn_tensor::backend::Backend;
743    /// use burn_tensor::{Tensor, Shape};
744    ///
745    /// fn example<B: Backend>() {
746    ///   let device = B::Device::default();
747    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
748    ///   let tensor = tensor.max_dim(0);
749    ///   println!("{tensor}");
750    ///   // [[5.0, 9.0, 6.0]]
751    /// }
752    /// ```
753    pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
754        let dim = dim.expect_dim_index(D);
755        check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
756
757        Tensor::new(K::max_abs_dim(self.primitive, dim))
758    }
759
760    /// Find the maximum absolute value along the given dimensions.
761    ///
762    /// # Arguments
763    ///
764    /// * `dims` - The dimensions or axes along which to aggregate the elements,
765    ///   supports negative indexing.
766    ///
767    /// # Returns
768    ///
769    /// The returned tensor will have the same rank,
770    /// but the aggregated dimensions will have size 1.
771    ///
772    /// # Example
773    ///
774    /// ```rust
775    /// use burn_tensor::backend::Backend;
776    /// use burn_tensor::{Tensor, Shape};
777    ///
778    /// fn example<B: Backend>() {
779    ///   let device = B::Device::default();
780    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
781    ///   let tensor = tensor.max_abs_dims(&[0, 1]);
782    ///   println!("{tensor}");
783    ///   // [[9.0]]
784    /// }
785    /// ```
786    pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
787        dims.iter()
788            .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
789    }
790
791    /// Applies the argmin function along the given dimension and returns an integer tensor.
792    ///
793    /// # Example
794    ///
795    /// ```rust
796    /// use burn_tensor::backend::Backend;
797    /// use burn_tensor::{Tensor, Shape};
798    ///
799    /// fn example<B: Backend>() {
800    ///     let device = Default::default();
801    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
802    ///     let tensor = tensor.argmin(1);
803    ///     println!("{:?}", tensor.shape());
804    ///     // Shape { dims: [2, 1, 3] }
805    /// }
806    /// ```
807    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
808        Tensor::new(K::argmin(self.primitive, dim))
809    }
810
811    /// Find the minimum value.
812    ///
813    /// # Example
814    ///
815    /// ```rust
816    /// use burn_tensor::backend::Backend;
817    /// use burn_tensor::{Tensor, Shape};
818    ///
819    /// fn example<B: Backend>() {
820    ///    let device = B::Device::default();
821    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
822    ///    let tensor = tensor.min();
823    ///    println!("{tensor}");
824    ///    // [-2.0]
825    /// }
826    /// ```
827    pub fn min(self) -> Tensor<B, 1, K> {
828        Tensor::new(K::min(self.primitive))
829    }
830
831    /// Find the minimum value along the given dimension.
832    ///
833    /// # Arguments
834    ///
835    /// * `dim` - The dimension or axis along which to aggregate the elements;
836    ///   supports negative indexing.
837    ///
838    /// # Returns
839    ///
840    /// The returned tensor will have the same rank,
841    /// but the aggregated dimension will have size 1.
842    ///
843    /// # Example
844    ///
845    /// ```rust
846    /// use burn_tensor::backend::Backend;
847    /// use burn_tensor::{Tensor, Shape};
848    ///
849    /// fn example<B: Backend>() {
850    ///    let device = B::Device::default();
851    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
852    ///    let tensor = tensor.min_dim(0);
853    ///    println!("{tensor}");
854    ///    // [[1.0, -2.0, 3.0]]
855    /// }
856    /// ```
857    pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
858        let dim = dim.expect_dim_index(D);
859        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
860        Tensor::new(K::min_dim(self.primitive, dim))
861    }
862
863    /// Find the minimum value along the given dimensions.
864    ///
865    /// # Arguments
866    ///
867    /// * `dims` - The dimensions or axes along which to aggregate the elements;
868    ///   supports negative indexing.
869    ///
870    /// # Returns
871    ///
872    /// The returned tensor will have the same rank,
873    /// but the aggregated dimensions will have size 1.
874    ///
875    /// # Example
876    ///
877    /// ```rust
878    /// use burn_tensor::backend::Backend;
879    /// use burn_tensor::{Tensor, Shape};
880    ///
881    /// fn example<B: Backend>() {
882    ///   let device = B::Device::default();
883    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884    ///   let tensor = tensor.min_dims(&[0, 1]);
885    ///   println!("{tensor}");
886    ///   // [[-2.0]]
887    /// }
888    /// ```
889    pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
890        dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
891    }
892
893    /// Find the minimum value along the given dimension.
894    ///
895    /// Also returns the indices.
896    ///
897    /// # Example
898    ///
899    /// ```rust
900    /// use burn_tensor::backend::Backend;
901    /// use burn_tensor::{Tensor, Shape};
902    ///
903    /// fn example<B: Backend>() {
904    ///    let device = B::Device::default();
905    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
906    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
907    ///    println!("{tensor}");
908    ///    // [[5.0, -2.0, 3.0]]
909    ///    println!("{}", index);
910    ///    // [[1, 0, 0]]
911    /// }
912    /// ```
913    pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
914        let dim = dim.expect_dim_index(D);
915        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
916
917        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
918
919        let tensor = Tensor::new(tensor);
920        let index = Tensor::new(index);
921
922        (tensor, index)
923    }
924
925    /// Finds the minimum pair wise values with another tensor.
926    ///
927    /// # Arguments
928    ///
929    /// * `other` - Other tensor to find minimum elements with
930    ///
931    /// # Returns
932    ///
933    /// A tensor with the same shape as the input tensors containing the minimum value found
934    /// between each element of the two source tensors.
935    ///
936    /// # Example
937    ///
938    /// ```rust
939    /// use burn_tensor::backend::Backend;
940    /// use burn_tensor::{Tensor, Shape};
941    ///
942    /// fn example<B: Backend>() {
943    ///    let device = B::Device::default();
944    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
945    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
946    ///    let tensor = tensor1.min_pair(tensor2);
947    ///    println!("{tensor}");
948    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
949    /// }
950    pub fn min_pair(self, other: Self) -> Self {
951        let mask = other.clone().lower(self.clone());
952        self.mask_where(mask, other)
953    }
954
955    /// Clamp element wise between the given min and max values.
956    ///
957    /// # Arguments
958    ///
959    /// * `min` - The minimum value.
960    /// * `max` - The maximum value.
961    ///
962    /// # Returns
963    ///
964    /// A new tensor with the values clamped between the given min and max values.
965    ///
966    /// # Example
967    ///
968    /// ```rust
969    /// use burn_tensor::backend::Backend;
970    /// use burn_tensor::{Int, Tensor};
971    ///
972    /// fn example<B: Backend>() {
973    ///   let device = Default::default();
974    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
975    ///    [
976    ///     [1, 2, 3],
977    ///     [4, 5, 6],
978    ///     [7, 8, 9]
979    ///    ],
980    ///    &device);
981    ///    let tensor = tensor.clamp(2, 6);
982    ///    println!("{tensor}");
983    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
984    /// }
985    /// ```
986    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
987        let dtype = self.dtype();
988        Self::new(K::clamp(
989            self.primitive,
990            Scalar::new(min, &dtype),
991            Scalar::new(max, &dtype),
992        ))
993    }
994
995    /// Clamp element wise under a minimum value.
996    ///
997    /// # Arguments
998    ///
999    /// * `tensor` - The tensor to clamp.
1000    /// * `min` - The minimum value.
1001    ///
1002    /// # Returns
1003    ///
1004    /// A new tensor with the values clamped under the given min value.
1005    ///
1006    /// # Example
1007    ///
1008    /// ```rust
1009    /// use burn_tensor::backend::Backend;
1010    /// use burn_tensor::{Int, Tensor};
1011    ///
1012    /// fn example<B: Backend>() {
1013    ///    let device = Default::default();
1014    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1015    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1016    ///    &device);
1017    ///    let tensor = tensor.clamp_min(4);
1018    ///    println!("{tensor}");
1019    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1020    /// }
1021    /// ```
1022    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1023        let min = Scalar::new(min, &self.dtype());
1024        Self::new(K::clamp_min(self.primitive, min))
1025    }
1026
1027    /// Clamp element wise over a maximum value.
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `tensor` - The tensor to clamp.
1032    /// * `max` - The maximum value.
1033    ///
1034    /// # Returns
1035    ///
1036    /// A new tensor with the values clamped over the given max value.
1037    ///
1038    /// # Example
1039    ///
1040    /// ```rust
1041    /// use burn_tensor::backend::Backend;
1042    /// use burn_tensor::{Int, Tensor};
1043    ///
1044    /// fn example<B: Backend>() {
1045    ///    let device = Default::default();
1046    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1047    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1048    ///    &device);
1049    ///    let tensor = tensor.clamp_max(5);
1050    ///    println!("{tensor}");
1051    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1052    /// }
1053    /// ```
1054    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1055        let max = Scalar::new(max, &self.dtype());
1056        Self::new(K::clamp_max(self.primitive, max))
1057    }
1058
1059    /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
1060    ///
1061    /// # Arguments
1062    ///
1063    /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
1064    ///
1065    /// # Example
1066    ///
1067    /// ```rust
1068    /// use burn_tensor::backend::Backend;
1069    /// use burn_tensor::{Tensor, Shape};
1070    ///
1071    /// fn example<B: Backend>() {
1072    ///    let device = B::Device::default();
1073    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
1074    ///    let result = tensor.clone().cummin(0);
1075    ///    println!("{result}");
1076    ///    // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
1077    ///    let result = tensor.cummin(1);
1078    ///    println!("{result}");
1079    ///    // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
1080    /// }
1081    /// ```
1082    pub fn cummin(self, dim: usize) -> Self {
1083        check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
1084        Self::new(K::cummin(self.primitive, dim))
1085    }
1086
1087    /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
1088    ///
1089    /// # Arguments
1090    ///
1091    /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
1092    ///
1093    /// # Example
1094    ///
1095    /// ```rust
1096    /// use burn_tensor::backend::Backend;
1097    /// use burn_tensor::{Tensor, Shape};
1098    ///
1099    /// fn example<B: Backend>() {
1100    ///    let device = B::Device::default();
1101    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
1102    ///    let result = tensor.clone().cummax(0);
1103    ///    println!("{result}");
1104    ///    // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
1105    ///    let result = tensor.cummax(1);
1106    ///    println!("{result}");
1107    ///    // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
1108    /// }
1109    /// ```
1110    pub fn cummax(self, dim: usize) -> Self {
1111        check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
1112        Self::new(K::cummax(self.primitive, dim))
1113    }
1114    /// Find the maximum value along the given dimension.
1115    ///
1116    /// # Arguments
1117    ///
1118    /// * `dim` - The dimension or axis along which to aggregate the elements;
1119    ///   supports negative indexing.
1120    ///
1121    /// # Returns
1122    ///
1123    /// The returned tensor will have the same rank,
1124    /// but the aggregated dimension will have size 1.
1125    ///
1126    /// # Example
1127    ///
1128    /// ```rust
1129    /// use burn_tensor::backend::Backend;
1130    /// use burn_tensor::{Tensor, Shape};
1131    ///
1132    /// fn example<B: Backend>() {
1133    ///   let device = B::Device::default();
1134    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1135    ///   let tensor = tensor.max_dim(0);
1136    ///   println!("{tensor}");
1137    ///   // [[5.0, 9.0, 6.0]]
1138    /// }
1139    /// ```
1140    pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1141        let dim = dim.expect_dim_index(D);
1142        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1143        Tensor::new(K::max_dim(self.primitive, dim))
1144    }
1145
1146    /// Find the maximum value along the given dimensions.
1147    ///
1148    /// # Arguments
1149    ///
1150    /// * `dims` - The dimensions or axis along which to aggregate the elements;
1151    ///   supports negative indexing.
1152    ///
1153    /// # Returns
1154    ///
1155    /// The returned tensor will have the same rank,
1156    /// but the aggregated dimensions will have size 1.
1157    ///
1158    /// # Example
1159    ///
1160    /// ```rust
1161    /// use burn_tensor::backend::Backend;
1162    /// use burn_tensor::{Tensor, Shape};
1163    ///
1164    /// fn example<B: Backend>() {
1165    ///   let device = B::Device::default();
1166    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1167    ///   let tensor = tensor.max_dims(&[0, 1]);
1168    ///   println!("{tensor}");
1169    ///   // [[9.0]]
1170    /// }
1171    /// ```
1172    pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1173        dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1174    }
1175}