burn_tensor/tensor/api/orderable.rs
1use burn_backend::{
2 Backend, ElementConversion, Scalar,
3 tensor::{Bool, IndexingUpdateOp, Int, Ordered},
4};
5use burn_std::AsIndex;
6
7use crate::check;
8use crate::{Tensor, check::TensorCheck};
9
10impl<B, const D: usize, K> Tensor<B, D, K>
11where
12 B: Backend,
13 K: Ordered<B>,
14{
15 /// Sort the elements by value in ascending order along a given dimension.
16 ///
17 /// This sort is unstable (i.e., may reorder equal elements).
18 ///
19 /// # Arguments
20 ///
21 /// * `dim` - The dimension to sort along.
22 ///
23 /// # Returns
24 ///
25 /// A new tensor with the elements sorted in ascending order along the given dimension.
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::{Tensor, Shape};
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
36 /// let tensor = tensor.sort(0);
37 /// println!("{tensor}");
38 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
39 /// let tensor = tensor.sort(1);
40 /// println!("{tensor}");
41 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
42 /// }
43 /// ```
44 pub fn sort(self, dim: usize) -> Self {
45 check!(TensorCheck::sort_dim::<D>("Sort", dim));
46 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
47 }
48
49 /// Sort the elements by value in descending order along a given dimension.
50 ///
51 /// This sort is unstable (i.e., may reorder equal elements).
52 ///
53 /// # Arguments
54 ///
55 /// * `dim` - The dimension to sort along.
56 ///
57 /// # Returns
58 ///
59 /// A new tensor with the elements sorted in descending order along the given dimension.
60 ///
61 /// # Example
62 ///
63 /// ```rust
64 /// use burn_tensor::backend::Backend;
65 /// use burn_tensor::{Tensor, Shape};
66 ///
67 /// fn example<B: Backend>() {
68 /// let device = B::Device::default();
69 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
70 /// let tensor = tensor.sort_descending(0);
71 /// println!("{tensor}");
72 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
73 /// let tensor = tensor.sort_descending(1);
74 /// println!("{tensor}");
75 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
76 /// }
77 /// ```
78 pub fn sort_descending(self, dim: usize) -> Self {
79 check!(TensorCheck::sort_dim::<D>("Sort", dim));
80 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
81 }
82
83 /// Sort the elements by value in ascending order along a given dimension.
84 /// Also returns the indices.
85 ///
86 /// This sort is unstable (i.e., may reorder equal elements).
87 ///
88 /// # Arguments
89 ///
90 /// * `dim` - The dimension to sort along.
91 ///
92 /// # Returns
93 ///
94 /// A tuple containing the sorted tensor and the indices tensor.
95 ///
96 /// # Example
97 ///
98 /// ```rust
99 /// use burn_tensor::backend::Backend;
100 /// use burn_tensor::{Tensor, Shape};
101 ///
102 /// fn example<B: Backend>() {
103 /// let device = B::Device::default();
104 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
105 /// let (tensor, indices) = tensor.sort_with_indices(0);
106 /// println!("{tensor}");
107 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
108 /// println!("{}", indices);
109 /// // [[1, 0, 0], [0, 1, 1]]
110 /// }
111 /// ```
112 pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
113 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
114 let (values, indices) =
115 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
116 (Tensor::new(values), Tensor::new(indices))
117 }
118
119 /// Sort the elements by value in descending order along a given dimension.
120 /// Also returns the indices.
121 ///
122 /// This sort is unstable (i.e., may reorder equal elements).
123 ///
124 /// # Arguments
125 ///
126 /// * `dim` - The dimension to sort along.
127 ///
128 /// # Example
129 ///
130 /// ```rust
131 /// use burn_tensor::backend::Backend;
132 /// use burn_tensor::{Tensor, Shape};
133 ///
134 /// fn example<B: Backend>() {
135 /// let device = B::Device::default();
136 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
137 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
138 /// println!("{tensor}");
139 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
140 /// println!("{}", indices);
141 /// // [[0, 1, 1], [1, 0, 0]]
142 /// }
143 /// ```
144 pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
145 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
146 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
147 (Tensor::new(values), Tensor::new(indices))
148 }
149
150 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
151 ///
152 /// This sort is unstable (i.e., may reorder equal elements).
153 ///
154 /// # Arguments
155 ///
156 /// * `dim` - The dimension to sort along.
157 ///
158 /// # Example
159 ///
160 /// ```rust
161 /// use burn_tensor::backend::Backend;
162 /// use burn_tensor::{Tensor, Shape};
163 ///
164 /// fn example<B: Backend>() {
165 /// let device = B::Device::default();
166 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
167 /// let tensor = tensor.argsort(0);
168 /// println!("{tensor}");
169 /// // [[1, 0, 0], [0, 1, 1]]
170 /// }
171 /// ```
172 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
173 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
174 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
175 }
176
177 /// Returns the indices that sort the elements by value in descending order along a given dimension.
178 ///
179 /// This sort is unstable (i.e., may reorder equal elements).
180 ///
181 /// # Arguments
182 ///
183 /// * `dim` - The dimension to sort along.
184 ///
185 /// # Example
186 ///
187 /// ```rust
188 /// use burn_tensor::backend::Backend;
189 /// use burn_tensor::{Tensor, Shape};
190 ///
191 /// fn example<B: Backend>() {
192 /// let device = B::Device::default();
193 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
194 /// let tensor = tensor.argsort_descending(0);
195 /// println!("{tensor}");
196 /// // [[0, 1, 1], [1, 0, 0]]
197 /// let tensor = tensor.argsort_descending(1);
198 /// println!("{tensor}");
199 /// // [[0, 2, 1], [2, 0, 1]]
200 /// }
201 /// ```
202 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
203 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
204 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
205 }
206
207 /// Returns the `k` largest elements of the given input tensor along a given dimension.
208 ///
209 /// # Arguments
210 ///
211 /// * `k` - The number of elements to return.
212 ///
213 /// # Returns
214 ///
215 /// A new tensor with the `k` largest elements along the given dimension.
216 ///
217 /// # Example
218 ///
219 /// ```rust
220 /// use burn_tensor::backend::Backend;
221 /// use burn_tensor::{Tensor, Shape};
222 ///
223 /// fn example<B: Backend>() {
224 /// let device = B::Device::default();
225 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
226 /// let tensor = tensor.topk(2, 0);
227 /// println!("{tensor}");
228 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
229 /// let tensor = tensor.topk(1, 1);
230 /// println!("{tensor}");
231 /// // [[12.0], [6.0]]
232 /// }
233 /// ```
234 pub fn topk(self, k: usize, dim: usize) -> Self {
235 assert!(self.shape()[dim] > k);
236 Tensor::new(K::topk(self.primitive, dim, k))
237 }
238
239 /// Returns the `k` largest elements of the given input tensor along a given dimension.
240 /// Also returns the indices.
241 ///
242 /// # Arguments
243 ///
244 /// * `k` - The number of elements to return.
245 /// * `dim` - The dimension to sort along.
246 ///
247 /// # Example
248 ///
249 /// ```rust
250 /// use burn_tensor::backend::Backend;
251 /// use burn_tensor::{Tensor, Shape};
252 ///
253 /// fn example<B: Backend>() {
254 /// let device = B::Device::default();
255 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
256 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
257 /// println!("{tensor}");
258 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
259 /// println!("{}", indices);
260 /// // [[0, 1, 1], [1, 0, 0]]
261 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
262 /// println!("{tensor}");
263 /// // [[12.0], [6.0]]
264 /// println!("{indices}");
265 /// // [[0], [2]]
266 /// }
267 /// ```
268 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
269 let k_indices = Tensor::arange(0..k as i64, &self.device());
270 let (values, indices) = self.sort_descending_with_indices(dim);
271 (
272 values.select(dim, k_indices.clone()),
273 indices.select(dim, k_indices),
274 )
275 }
276
277 /// Create a one hot tensor.
278 ///
279 /// # Example
280 ///
281 /// ```rust
282 /// use burn_tensor::backend::Backend;
283 /// use burn_tensor::Tensor;
284 ///
285 /// fn example<B: Backend>(){
286 /// let device = Default::default();
287 /// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
288 /// let one_hot: Tensor<B, 2> = indices.one_hot(4);
289 /// println!("{}", one_hot.to_data());
290 /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
291 /// }
292 /// ```
293 pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
294 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
295 self.one_hot_fill(num_classes, 1.0, 0.0, -1)
296 }
297
298 /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
299 ///
300 /// # Arguments
301 ///
302 /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
303 /// * `on_value`: The value to assign for active positions (corresponding to indices).
304 /// * `off_value`: The value to assign for inactive positions.
305 /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
306 ///
307 /// # Returns
308 ///
309 /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
310 ///
311 /// # Example
312 /// ```rust
313 /// use burn_tensor::backend::Backend;
314 /// use burn_tensor::{Tensor, Float};
315 /// fn example<B: Backend<FloatElem: From<f32>>>() {
316 /// let device = B::Device::default();
317 /// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
318 /// // One-hot encoding
319 /// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
320 /// println!("{tensor}");
321 /// // [[[5.0, 0.0, 0.0],
322 /// // [0.0, 0.0, 5.0]],
323 /// // [[0.0, 5.0, 0.0],
324 /// // [0.0, 0.0, 5.0]]]
325 /// }
326 /// ```
327 pub fn one_hot_fill<const D2: usize>(
328 self,
329 num_classes: usize,
330 on_value: f32,
331 off_value: f32,
332 axis: i64,
333 ) -> Tensor<B, D2, K> {
334 check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
335 // Initialize shape from the current tensor dimensions and prepare for modification
336 let mut shape = self.shape();
337 let device = self.device();
338 let rank = self.dims().len();
339
340 // Adjust negative axis to a positive index
341 let axis = if axis < 0 {
342 axis + rank as i64 + 1
343 } else {
344 axis
345 };
346
347 // Ensure axis is within valid range
348 if axis < 0 || axis > rank as i64 {
349 panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
350 }
351 // Convert the input tensor to integer indices
352 let indices: Tensor<B, D, Int> =
353 Tensor::from_data(self.to_data().convert::<i64>(), &device);
354 // Insert the new dimension for the one-hot representation
355 shape.insert(axis as usize, num_classes);
356 // Adjust indices to valid range and handle invalid indices
357 let adjusted_indices = indices
358 .clone()
359 .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
360 .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
361 // Unsqueeze the indices tensor along the specified axis
362 let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
363
364 // Initialize the output tensor with the off_value
365 let output = Tensor::full(shape.clone(), off_value, &device);
366
367 // Prepare scatter tensor for on_value and off_value adjustments
368 let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
369 - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
370
371 // Scatter on_value at the appropriate indices to create the one-hot representation
372 output.scatter(
373 axis as usize,
374 indices_unsqueezed,
375 scatter_on_values,
376 IndexingUpdateOp::Add,
377 )
378 }
379
380 /// Applies element wise greater comparison and returns a boolean tensor.
381 ///
382 /// # Panics
383 ///
384 /// If the two tensors don't have the same shape.
385 ///
386 /// # Example
387 ///
388 /// ```rust
389 /// use burn_tensor::backend::Backend;
390 /// use burn_tensor::{Tensor, Shape};
391 ///
392 /// fn example<B: Backend>() {
393 /// let device = B::Device::default();
394 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
396 /// let tensor = tensor1.greater(tensor2);
397 /// println!("{tensor}");
398 /// // [[false, false, false], [true, true, true]]
399 /// }
400 /// ```
401 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
402 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
403 Tensor::new(K::greater(self.primitive, other.primitive))
404 }
405
406 /// Applies element wise greater-equal comparison and returns a boolean tensor.
407 ///
408 /// # Panics
409 ///
410 /// If the two tensors don't have the same shape.
411 ///
412 /// # Example
413 ///
414 /// ```rust
415 /// use burn_tensor::backend::Backend;
416 /// use burn_tensor::{Tensor, Shape};
417 ///
418 /// fn example<B: Backend>() {
419 /// let device = B::Device::default();
420 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
421 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
422 /// let tensor = tensor1.greater_equal(tensor2);
423 /// println!("{tensor}");
424 /// // [[true, false, false], [true, true, true]]
425 /// }
426 /// ```
427 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
428 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
429 Tensor::new(K::greater_equal(self.primitive, other.primitive))
430 }
431
432 /// Applies element wise lower comparison and returns a boolean tensor.
433 ///
434 /// # Panics
435 ///
436 /// If the two tensors don't have the same shape.
437 ///
438 /// # Example
439 ///
440 /// ```rust
441 /// use burn_tensor::backend::Backend;
442 /// use burn_tensor::{Tensor, Shape};
443 ///
444 /// fn example<B: Backend>() {
445 /// let device = B::Device::default();
446 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
447 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
448 /// let tensor = tensor1.lower(tensor2);
449 /// println!("{tensor}");
450 /// // [[false, true, true], [false, false, false]]
451 /// }
452 /// ```
453 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
454 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
455 Tensor::new(K::lower(self.primitive, other.primitive))
456 }
457
458 /// Applies element wise lower-equal comparison and returns a boolean tensor.
459 ///
460 /// # Panics
461 ///
462 /// If the two tensors don't have the same shape.
463 ///
464 /// # Example
465 ///
466 /// ```rust
467 /// use burn_tensor::backend::Backend;
468 /// use burn_tensor::{Tensor, Shape};
469 ///
470 /// fn example<B: Backend>() {
471 /// let device = B::Device::default();
472 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
473 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
474 /// let tensor = tensor1.lower_equal(tensor2);
475 /// println!("{tensor}");
476 /// // [[true, true, true], [false, false, false]]
477 /// }
478 /// ```
479 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
480 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
481 Tensor::new(K::lower_equal(self.primitive, other.primitive))
482 }
483
484 /// Applies greater than `other` comparison and returns a boolean tensor.
485 ///
486 /// # Arguments
487 ///
488 /// * `other` - The element to compare.
489 ///
490 /// # Example
491 ///
492 /// ```rust
493 /// use burn_tensor::backend::Backend;
494 /// use burn_tensor::{Tensor, Shape};
495 ///
496 /// fn example<B: Backend>() {
497 /// let device = B::Device::default();
498 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
499 /// let tensor = tensor.greater_elem(3.0);
500 /// println!("{tensor}");
501 /// // [[false, false, true], [true, true, true]]
502 /// }
503 /// ```
504 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
505 let other = Scalar::new(other, &self.dtype());
506 Tensor::new(K::greater_elem(self.primitive, other))
507 }
508
509 /// Applies greater-equal than `other` comparison and returns a boolean tensor.
510 ///
511 /// # Arguments
512 ///
513 /// * `other` - The element to compare.
514 ///
515 /// # Example
516 ///
517 /// ```rust
518 /// use burn_tensor::backend::Backend;
519 /// use burn_tensor::{Tensor, Shape};
520 ///
521 /// fn example<B: Backend>() {
522 /// let device = B::Device::default();
523 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
524 /// let tensor = tensor.greater_equal_elem(3.0);
525 /// println!("{tensor}");
526 /// // [[false, false, true], [true, true, true]]
527 /// }
528 /// ```
529 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
530 let other = Scalar::new(other, &self.dtype());
531 Tensor::new(K::greater_equal_elem(self.primitive, other))
532 }
533
534 /// Applies lower than `other` comparison and returns a boolean tensor.
535 ///
536 /// # Arguments
537 ///
538 /// * `other` - The element to compare.
539 ///
540 /// # Example
541 ///
542 /// ```rust
543 /// use burn_tensor::backend::Backend;
544 /// use burn_tensor::{Tensor, Shape};
545 ///
546 /// fn example<B: Backend>() {
547 /// let device = B::Device::default();
548 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
549 /// let tensor = tensor.lower_elem(3.0);
550 /// println!("{tensor}");
551 /// // [[true, true, false], [false, false, false]]
552 /// }
553 /// ```
554 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
555 let other = Scalar::new(other, &self.dtype());
556 Tensor::new(K::lower_elem(self.primitive, other))
557 }
558
559 /// Applies lower-equal than `other` comparison and returns a boolean tensor.
560 ///
561 /// # Arguments
562 ///
563 /// * `other` - The element to compare.
564 ///
565 /// # Example
566 ///
567 /// ```rust
568 /// use burn_tensor::backend::Backend;
569 /// use burn_tensor::{Tensor, Shape};
570 ///
571 /// fn example<B: Backend>() {
572 /// let device = B::Device::default();
573 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
574 /// let tensor = tensor.lower_equal_elem(3.0);
575 /// println!("{tensor}");
576 /// // [[true, true, true], [false, false, false]]
577 /// }
578 /// ```
579 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
580 let other = Scalar::new(other, &self.dtype());
581 Tensor::new(K::lower_equal_elem(self.primitive, other))
582 }
583
584 /// Applies the argmax function along the given dimension and returns an integer tensor.
585 ///
586 /// # Example
587 ///
588 /// ```rust
589 /// use burn_tensor::backend::Backend;
590 /// use burn_tensor::{Tensor, Shape};
591 ///
592 /// fn example<B: Backend>() {
593 /// let device = B::Device::default();
594 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
595 /// let tensor = tensor.argmax(1);
596 /// println!("{:?}", tensor.shape());
597 /// // Shape { dims: [2, 1, 3] }
598 /// }
599 /// ```
600 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
601 Tensor::new(K::argmax(self.primitive, dim))
602 }
603
604 /// Applies the argtopk function along the given dimension and returns an integer tensor.
605 ///
606 /// # Example
607 ///
608 /// ```rust
609 /// use burn_tensor::backend::Backend;
610 /// use burn_tensor::{Tensor, Shape};
611 ///
612 /// fn example<B: Backend>() {
613 /// let device = B::Device::default();
614 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
615 /// let tensor = tensor.argtopk(1, 2);
616 /// println!("{:?}", tensor.shape());
617 /// }
618 /// ```
619 pub fn argtopk(self, k: usize, dim: usize) -> Tensor<B, D, Int> {
620 assert!(self.shape()[dim] > k);
621 Tensor::new(K::argtopk(self.primitive, dim, k))
622 }
623
624 /// Find the maximum value.
625 ///
626 /// # Example
627 ///
628 /// ```rust
629 /// use burn_tensor::backend::Backend;
630 /// use burn_tensor::{Tensor, Shape};
631 ///
632 /// fn example<B: Backend>() {
633 /// let device = B::Device::default();
634 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
635 /// let tensor = tensor.max();
636 /// println!("{tensor}");
637 /// // [9.0]
638 /// }
639 /// ```
640 pub fn max(self) -> Tensor<B, 1, K> {
641 Tensor::new(K::max(self.primitive))
642 }
643
644 /// Find the maximum value along the given dimension.
645 ///
646 /// Also returns the indices.
647 ///
648 /// # Example
649 ///
650 /// ```rust
651 /// use burn_tensor::backend::Backend;
652 /// use burn_tensor::{Tensor, Shape};
653 ///
654 /// fn example<B: Backend>() {
655 /// let device = B::Device::default();
656 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
657 /// let (tensor, index) = tensor.max_dim_with_indices(0);
658 /// // [[5.0, 9.0, 6.0]]
659 /// println!("{tensor}");
660 /// // [[1, 1, 1]]
661 /// println!("{index}");
662 /// }
663 /// ```
664 pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
665 let dim = dim.expect_dim_index(D);
666 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
667
668 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
669
670 let tensor = Tensor::new(tensor);
671 let index = Tensor::new(index);
672
673 (tensor, index)
674 }
675
676 /// Find the maximum absolute value.
677 ///
678 /// # Example
679 ///
680 /// ```rust
681 /// use burn_tensor::backend::Backend;
682 /// use burn_tensor::{Tensor, Shape};
683 ///
684 /// fn example<B: Backend>() {
685 /// let device = B::Device::default();
686 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
687 /// let tensor = tensor.max_abs();
688 /// println!("{tensor}");
689 /// // [7.0]
690 /// }
691 /// ```
692 pub fn max_abs(self) -> Tensor<B, 1, K> {
693 Tensor::new(K::max_abs(self.primitive))
694 }
695
696 /// Finds the maximum pair wise values with another tensor.
697 ///
698 /// # Arguments
699 ///
700 /// * `other` - Other tensor to find maximum elements with
701 ///
702 /// # Returns
703 ///
704 /// A tensor with the same shape as the input tensors containing the maximum value found
705 /// in the input tensors.
706 ///
707 /// # Example
708 ///
709 /// ```rust
710 /// use burn_tensor::backend::Backend;
711 /// use burn_tensor::{Tensor, Shape};
712 ///
713 /// fn example<B: Backend>() {
714 /// let device = B::Device::default();
715 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
716 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
717 /// let tensor = tensor1.max_pair(tensor2);
718 /// println!("{tensor}");
719 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
720 /// }
721 /// ```
722 pub fn max_pair(self, other: Self) -> Self {
723 let mask = self.clone().lower(other.clone());
724 self.mask_where(mask, other)
725 }
726
727 /// Find the maximum absolute value along the given dimension.
728 ///
729 /// # Arguments
730 ///
731 /// * `dim` - The dimension or axis along which to aggregate the elements,
732 /// supports negative indexing.
733 ///
734 /// # Returns
735 ///
736 /// The returned tensor will have the same rank,
737 /// but the aggregated dimension will have size 1.
738 ///
739 /// # Example
740 ///
741 /// ```rust
742 /// use burn_tensor::backend::Backend;
743 /// use burn_tensor::{Tensor, Shape};
744 ///
745 /// fn example<B: Backend>() {
746 /// let device = B::Device::default();
747 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
748 /// let tensor = tensor.max_dim(0);
749 /// println!("{tensor}");
750 /// // [[5.0, 9.0, 6.0]]
751 /// }
752 /// ```
753 pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
754 let dim = dim.expect_dim_index(D);
755 check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
756
757 Tensor::new(K::max_abs_dim(self.primitive, dim))
758 }
759
760 /// Find the maximum absolute value along the given dimensions.
761 ///
762 /// # Arguments
763 ///
764 /// * `dims` - The dimensions or axes along which to aggregate the elements,
765 /// supports negative indexing.
766 ///
767 /// # Returns
768 ///
769 /// The returned tensor will have the same rank,
770 /// but the aggregated dimensions will have size 1.
771 ///
772 /// # Example
773 ///
774 /// ```rust
775 /// use burn_tensor::backend::Backend;
776 /// use burn_tensor::{Tensor, Shape};
777 ///
778 /// fn example<B: Backend>() {
779 /// let device = B::Device::default();
780 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
781 /// let tensor = tensor.max_abs_dims(&[0, 1]);
782 /// println!("{tensor}");
783 /// // [[9.0]]
784 /// }
785 /// ```
786 pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
787 dims.iter()
788 .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
789 }
790
791 /// Applies the argmin function along the given dimension and returns an integer tensor.
792 ///
793 /// # Example
794 ///
795 /// ```rust
796 /// use burn_tensor::backend::Backend;
797 /// use burn_tensor::{Tensor, Shape};
798 ///
799 /// fn example<B: Backend>() {
800 /// let device = Default::default();
801 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
802 /// let tensor = tensor.argmin(1);
803 /// println!("{:?}", tensor.shape());
804 /// // Shape { dims: [2, 1, 3] }
805 /// }
806 /// ```
807 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
808 Tensor::new(K::argmin(self.primitive, dim))
809 }
810
811 /// Find the minimum value.
812 ///
813 /// # Example
814 ///
815 /// ```rust
816 /// use burn_tensor::backend::Backend;
817 /// use burn_tensor::{Tensor, Shape};
818 ///
819 /// fn example<B: Backend>() {
820 /// let device = B::Device::default();
821 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
822 /// let tensor = tensor.min();
823 /// println!("{tensor}");
824 /// // [-2.0]
825 /// }
826 /// ```
827 pub fn min(self) -> Tensor<B, 1, K> {
828 Tensor::new(K::min(self.primitive))
829 }
830
831 /// Find the minimum value along the given dimension.
832 ///
833 /// # Arguments
834 ///
835 /// * `dim` - The dimension or axis along which to aggregate the elements;
836 /// supports negative indexing.
837 ///
838 /// # Returns
839 ///
840 /// The returned tensor will have the same rank,
841 /// but the aggregated dimension will have size 1.
842 ///
843 /// # Example
844 ///
845 /// ```rust
846 /// use burn_tensor::backend::Backend;
847 /// use burn_tensor::{Tensor, Shape};
848 ///
849 /// fn example<B: Backend>() {
850 /// let device = B::Device::default();
851 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
852 /// let tensor = tensor.min_dim(0);
853 /// println!("{tensor}");
854 /// // [[1.0, -2.0, 3.0]]
855 /// }
856 /// ```
857 pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
858 let dim = dim.expect_dim_index(D);
859 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
860 Tensor::new(K::min_dim(self.primitive, dim))
861 }
862
863 /// Find the minimum value along the given dimensions.
864 ///
865 /// # Arguments
866 ///
867 /// * `dims` - The dimensions or axes along which to aggregate the elements;
868 /// supports negative indexing.
869 ///
870 /// # Returns
871 ///
872 /// The returned tensor will have the same rank,
873 /// but the aggregated dimensions will have size 1.
874 ///
875 /// # Example
876 ///
877 /// ```rust
878 /// use burn_tensor::backend::Backend;
879 /// use burn_tensor::{Tensor, Shape};
880 ///
881 /// fn example<B: Backend>() {
882 /// let device = B::Device::default();
883 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884 /// let tensor = tensor.min_dims(&[0, 1]);
885 /// println!("{tensor}");
886 /// // [[-2.0]]
887 /// }
888 /// ```
889 pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
890 dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
891 }
892
893 /// Find the minimum value along the given dimension.
894 ///
895 /// Also returns the indices.
896 ///
897 /// # Example
898 ///
899 /// ```rust
900 /// use burn_tensor::backend::Backend;
901 /// use burn_tensor::{Tensor, Shape};
902 ///
903 /// fn example<B: Backend>() {
904 /// let device = B::Device::default();
905 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
906 /// let (tensor, index) = tensor.min_dim_with_indices(0);
907 /// println!("{tensor}");
908 /// // [[5.0, -2.0, 3.0]]
909 /// println!("{}", index);
910 /// // [[1, 0, 0]]
911 /// }
912 /// ```
913 pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
914 let dim = dim.expect_dim_index(D);
915 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
916
917 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
918
919 let tensor = Tensor::new(tensor);
920 let index = Tensor::new(index);
921
922 (tensor, index)
923 }
924
925 /// Finds the minimum pair wise values with another tensor.
926 ///
927 /// # Arguments
928 ///
929 /// * `other` - Other tensor to find minimum elements with
930 ///
931 /// # Returns
932 ///
933 /// A tensor with the same shape as the input tensors containing the minimum value found
934 /// between each element of the two source tensors.
935 ///
936 /// # Example
937 ///
938 /// ```rust
939 /// use burn_tensor::backend::Backend;
940 /// use burn_tensor::{Tensor, Shape};
941 ///
942 /// fn example<B: Backend>() {
943 /// let device = B::Device::default();
944 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
945 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
946 /// let tensor = tensor1.min_pair(tensor2);
947 /// println!("{tensor}");
948 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
949 /// }
950 pub fn min_pair(self, other: Self) -> Self {
951 let mask = other.clone().lower(self.clone());
952 self.mask_where(mask, other)
953 }
954
955 /// Clamp element wise between the given min and max values.
956 ///
957 /// # Arguments
958 ///
959 /// * `min` - The minimum value.
960 /// * `max` - The maximum value.
961 ///
962 /// # Returns
963 ///
964 /// A new tensor with the values clamped between the given min and max values.
965 ///
966 /// # Example
967 ///
968 /// ```rust
969 /// use burn_tensor::backend::Backend;
970 /// use burn_tensor::{Int, Tensor};
971 ///
972 /// fn example<B: Backend>() {
973 /// let device = Default::default();
974 /// let tensor = Tensor::<B, 2, Int>::from_ints(
975 /// [
976 /// [1, 2, 3],
977 /// [4, 5, 6],
978 /// [7, 8, 9]
979 /// ],
980 /// &device);
981 /// let tensor = tensor.clamp(2, 6);
982 /// println!("{tensor}");
983 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
984 /// }
985 /// ```
986 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
987 let dtype = self.dtype();
988 Self::new(K::clamp(
989 self.primitive,
990 Scalar::new(min, &dtype),
991 Scalar::new(max, &dtype),
992 ))
993 }
994
995 /// Clamp element wise under a minimum value.
996 ///
997 /// # Arguments
998 ///
999 /// * `tensor` - The tensor to clamp.
1000 /// * `min` - The minimum value.
1001 ///
1002 /// # Returns
1003 ///
1004 /// A new tensor with the values clamped under the given min value.
1005 ///
1006 /// # Example
1007 ///
1008 /// ```rust
1009 /// use burn_tensor::backend::Backend;
1010 /// use burn_tensor::{Int, Tensor};
1011 ///
1012 /// fn example<B: Backend>() {
1013 /// let device = Default::default();
1014 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1015 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1016 /// &device);
1017 /// let tensor = tensor.clamp_min(4);
1018 /// println!("{tensor}");
1019 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1020 /// }
1021 /// ```
1022 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1023 let min = Scalar::new(min, &self.dtype());
1024 Self::new(K::clamp_min(self.primitive, min))
1025 }
1026
1027 /// Clamp element wise over a maximum value.
1028 ///
1029 /// # Arguments
1030 ///
1031 /// * `tensor` - The tensor to clamp.
1032 /// * `max` - The maximum value.
1033 ///
1034 /// # Returns
1035 ///
1036 /// A new tensor with the values clamped over the given max value.
1037 ///
1038 /// # Example
1039 ///
1040 /// ```rust
1041 /// use burn_tensor::backend::Backend;
1042 /// use burn_tensor::{Int, Tensor};
1043 ///
1044 /// fn example<B: Backend>() {
1045 /// let device = Default::default();
1046 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1047 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1048 /// &device);
1049 /// let tensor = tensor.clamp_max(5);
1050 /// println!("{tensor}");
1051 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1052 /// }
1053 /// ```
1054 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1055 let max = Scalar::new(max, &self.dtype());
1056 Self::new(K::clamp_max(self.primitive, max))
1057 }
1058
1059 /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
1060 ///
1061 /// # Arguments
1062 ///
1063 /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
1064 ///
1065 /// # Example
1066 ///
1067 /// ```rust
1068 /// use burn_tensor::backend::Backend;
1069 /// use burn_tensor::{Tensor, Shape};
1070 ///
1071 /// fn example<B: Backend>() {
1072 /// let device = B::Device::default();
1073 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
1074 /// let result = tensor.clone().cummin(0);
1075 /// println!("{result}");
1076 /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
1077 /// let result = tensor.cummin(1);
1078 /// println!("{result}");
1079 /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
1080 /// }
1081 /// ```
1082 pub fn cummin(self, dim: usize) -> Self {
1083 check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
1084 Self::new(K::cummin(self.primitive, dim))
1085 }
1086
1087 /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
1088 ///
1089 /// # Arguments
1090 ///
1091 /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
1092 ///
1093 /// # Example
1094 ///
1095 /// ```rust
1096 /// use burn_tensor::backend::Backend;
1097 /// use burn_tensor::{Tensor, Shape};
1098 ///
1099 /// fn example<B: Backend>() {
1100 /// let device = B::Device::default();
1101 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
1102 /// let result = tensor.clone().cummax(0);
1103 /// println!("{result}");
1104 /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
1105 /// let result = tensor.cummax(1);
1106 /// println!("{result}");
1107 /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
1108 /// }
1109 /// ```
1110 pub fn cummax(self, dim: usize) -> Self {
1111 check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
1112 Self::new(K::cummax(self.primitive, dim))
1113 }
1114 /// Find the maximum value along the given dimension.
1115 ///
1116 /// # Arguments
1117 ///
1118 /// * `dim` - The dimension or axis along which to aggregate the elements;
1119 /// supports negative indexing.
1120 ///
1121 /// # Returns
1122 ///
1123 /// The returned tensor will have the same rank,
1124 /// but the aggregated dimension will have size 1.
1125 ///
1126 /// # Example
1127 ///
1128 /// ```rust
1129 /// use burn_tensor::backend::Backend;
1130 /// use burn_tensor::{Tensor, Shape};
1131 ///
1132 /// fn example<B: Backend>() {
1133 /// let device = B::Device::default();
1134 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1135 /// let tensor = tensor.max_dim(0);
1136 /// println!("{tensor}");
1137 /// // [[5.0, 9.0, 6.0]]
1138 /// }
1139 /// ```
1140 pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1141 let dim = dim.expect_dim_index(D);
1142 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1143 Tensor::new(K::max_dim(self.primitive, dim))
1144 }
1145
1146 /// Find the maximum value along the given dimensions.
1147 ///
1148 /// # Arguments
1149 ///
1150 /// * `dims` - The dimensions or axis along which to aggregate the elements;
1151 /// supports negative indexing.
1152 ///
1153 /// # Returns
1154 ///
1155 /// The returned tensor will have the same rank,
1156 /// but the aggregated dimensions will have size 1.
1157 ///
1158 /// # Example
1159 ///
1160 /// ```rust
1161 /// use burn_tensor::backend::Backend;
1162 /// use burn_tensor::{Tensor, Shape};
1163 ///
1164 /// fn example<B: Backend>() {
1165 /// let device = B::Device::default();
1166 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1167 /// let tensor = tensor.max_dims(&[0, 1]);
1168 /// println!("{tensor}");
1169 /// // [[9.0]]
1170 /// }
1171 /// ```
1172 pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1173 dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1174 }
1175}