burn_tensor/tensor/api/orderable.rs
1use burn_backend::{
2 Backend, ElementConversion, Scalar,
3 tensor::{Bool, IndexingUpdateOp, Int, Ordered},
4};
5use burn_std::AsIndex;
6
7use crate::check;
8use crate::{Tensor, check::TensorCheck};
9
10impl<B, const D: usize, K> Tensor<B, D, K>
11where
12 B: Backend,
13 K: Ordered<B>,
14{
15 /// Sort the elements by value in ascending order along a given dimension.
16 ///
17 /// This sort is unstable (i.e., may reorder equal elements).
18 ///
19 /// # Arguments
20 ///
21 /// * `dim` - The dimension to sort along.
22 ///
23 /// # Returns
24 ///
25 /// A new tensor with the elements sorted in ascending order along the given dimension.
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::{Tensor, Shape};
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
36 /// let tensor = tensor.sort(0);
37 /// println!("{tensor}");
38 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
39 /// let tensor = tensor.sort(1);
40 /// println!("{tensor}");
41 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
42 /// }
43 /// ```
44 pub fn sort(self, dim: usize) -> Self {
45 check!(TensorCheck::sort_dim::<D>("Sort", dim));
46 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
47 }
48
49 /// Sort the elements by value in descending order along a given dimension.
50 ///
51 /// This sort is unstable (i.e., may reorder equal elements).
52 ///
53 /// # Arguments
54 ///
55 /// * `dim` - The dimension to sort along.
56 ///
57 /// # Returns
58 ///
59 /// A new tensor with the elements sorted in descending order along the given dimension.
60 ///
61 /// # Example
62 ///
63 /// ```rust
64 /// use burn_tensor::backend::Backend;
65 /// use burn_tensor::{Tensor, Shape};
66 ///
67 /// fn example<B: Backend>() {
68 /// let device = B::Device::default();
69 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
70 /// let tensor = tensor.sort_descending(0);
71 /// println!("{tensor}");
72 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
73 /// let tensor = tensor.sort_descending(1);
74 /// println!("{tensor}");
75 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
76 /// }
77 /// ```
78 pub fn sort_descending(self, dim: usize) -> Self {
79 check!(TensorCheck::sort_dim::<D>("Sort", dim));
80 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
81 }
82
83 /// Sort the elements by value in ascending order along a given dimension.
84 /// Also returns the indices.
85 ///
86 /// This sort is unstable (i.e., may reorder equal elements).
87 ///
88 /// # Arguments
89 ///
90 /// * `dim` - The dimension to sort along.
91 ///
92 /// # Returns
93 ///
94 /// A tuple containing the sorted tensor and the indices tensor.
95 ///
96 /// # Example
97 ///
98 /// ```rust
99 /// use burn_tensor::backend::Backend;
100 /// use burn_tensor::{Tensor, Shape};
101 ///
102 /// fn example<B: Backend>() {
103 /// let device = B::Device::default();
104 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
105 /// let (tensor, indices) = tensor.sort_with_indices(0);
106 /// println!("{tensor}");
107 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
108 /// println!("{}", indices);
109 /// // [[1, 0, 0], [0, 1, 1]]
110 /// }
111 /// ```
112 pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
113 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
114 let (values, indices) =
115 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
116 (Tensor::new(values), Tensor::new(indices))
117 }
118
119 /// Sort the elements by value in descending order along a given dimension.
120 /// Also returns the indices.
121 ///
122 /// This sort is unstable (i.e., may reorder equal elements).
123 ///
124 /// # Arguments
125 ///
126 /// * `dim` - The dimension to sort along.
127 ///
128 /// # Example
129 ///
130 /// ```rust
131 /// use burn_tensor::backend::Backend;
132 /// use burn_tensor::{Tensor, Shape};
133 ///
134 /// fn example<B: Backend>() {
135 /// let device = B::Device::default();
136 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
137 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
138 /// println!("{tensor}");
139 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
140 /// println!("{}", indices);
141 /// // [[0, 1, 1], [1, 0, 0]]
142 /// }
143 /// ```
144 pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
145 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
146 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
147 (Tensor::new(values), Tensor::new(indices))
148 }
149
150 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
151 ///
152 /// This sort is unstable (i.e., may reorder equal elements).
153 ///
154 /// # Arguments
155 ///
156 /// * `dim` - The dimension to sort along.
157 ///
158 /// # Example
159 ///
160 /// ```rust
161 /// use burn_tensor::backend::Backend;
162 /// use burn_tensor::{Tensor, Shape};
163 ///
164 /// fn example<B: Backend>() {
165 /// let device = B::Device::default();
166 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
167 /// let tensor = tensor.argsort(0);
168 /// println!("{tensor}");
169 /// // [[1, 0, 0], [0, 1, 1]]
170 /// }
171 /// ```
172 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
173 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
174 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
175 }
176
177 /// Returns the indices that sort the elements by value in descending order along a given dimension.
178 ///
179 /// This sort is unstable (i.e., may reorder equal elements).
180 ///
181 /// # Arguments
182 ///
183 /// * `dim` - The dimension to sort along.
184 ///
185 /// # Example
186 ///
187 /// ```rust
188 /// use burn_tensor::backend::Backend;
189 /// use burn_tensor::{Tensor, Shape};
190 ///
191 /// fn example<B: Backend>() {
192 /// let device = B::Device::default();
193 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
194 /// let tensor = tensor.argsort_descending(0);
195 /// println!("{tensor}");
196 /// // [[0, 1, 1], [1, 0, 0]]
197 /// let tensor = tensor.argsort_descending(1);
198 /// println!("{tensor}");
199 /// // [[0, 2, 1], [2, 0, 1]]
200 /// }
201 /// ```
202 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
203 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
204 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
205 }
206
207 /// Returns the `k` largest elements of the given input tensor along a given dimension.
208 ///
209 /// # Arguments
210 ///
211 /// * `k` - The number of elements to return.
212 ///
213 /// # Returns
214 ///
215 /// A new tensor with the `k` largest elements along the given dimension.
216 ///
217 /// # Example
218 ///
219 /// ```rust
220 /// use burn_tensor::backend::Backend;
221 /// use burn_tensor::{Tensor, Shape};
222 ///
223 /// fn example<B: Backend>() {
224 /// let device = B::Device::default();
225 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
226 /// let tensor = tensor.topk(2, 0);
227 /// println!("{tensor}");
228 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
229 /// let tensor = tensor.topk(1, 1);
230 /// println!("{tensor}");
231 /// // [[12.0], [6.0]]
232 /// }
233 /// ```
234 pub fn topk(self, k: usize, dim: usize) -> Self {
235 let k_indices = Tensor::arange(0..k as i64, &self.device());
236 self.sort_descending(dim).select(dim, k_indices)
237 }
238
239 /// Returns the `k` largest elements of the given input tensor along a given dimension.
240 /// Also returns the indices.
241 ///
242 /// # Arguments
243 ///
244 /// * `k` - The number of elements to return.
245 /// * `dim` - The dimension to sort along.
246 ///
247 /// # Example
248 ///
249 /// ```rust
250 /// use burn_tensor::backend::Backend;
251 /// use burn_tensor::{Tensor, Shape};
252 ///
253 /// fn example<B: Backend>() {
254 /// let device = B::Device::default();
255 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
256 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
257 /// println!("{tensor}");
258 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
259 /// println!("{}", indices);
260 /// // [[0, 1, 1], [1, 0, 0]]
261 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
262 /// println!("{tensor}");
263 /// // [[12.0], [6.0]]
264 /// println!("{indices}");
265 /// // [[0], [2]]
266 /// }
267 /// ```
268 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
269 let k_indices = Tensor::arange(0..k as i64, &self.device());
270 let (values, indices) = self.sort_descending_with_indices(dim);
271 (
272 values.select(dim, k_indices.clone()),
273 indices.select(dim, k_indices),
274 )
275 }
276
277 /// Create a one hot tensor.
278 ///
279 /// # Example
280 ///
281 /// ```rust
282 /// use burn_tensor::backend::Backend;
283 /// use burn_tensor::Tensor;
284 ///
285 /// fn example<B: Backend>(){
286 /// let device = Default::default();
287 /// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
288 /// let one_hot: Tensor<B, 2> = indices.one_hot(4);
289 /// println!("{}", one_hot.to_data());
290 /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
291 /// }
292 /// ```
293 pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
294 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
295 self.one_hot_fill(num_classes, 1.0, 0.0, -1)
296 }
297
298 /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
299 ///
300 /// # Arguments
301 ///
302 /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
303 /// * `on_value`: The value to assign for active positions (corresponding to indices).
304 /// * `off_value`: The value to assign for inactive positions.
305 /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
306 ///
307 /// # Returns
308 ///
309 /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
310 ///
311 /// # Example
312 /// ```rust
313 /// use burn_tensor::backend::Backend;
314 /// use burn_tensor::{Tensor, Float};
315 /// fn example<B: Backend<FloatElem: From<f32>>>() {
316 /// let device = B::Device::default();
317 /// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
318 /// // One-hot encoding
319 /// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
320 /// println!("{tensor}");
321 /// // [[[5.0, 0.0, 0.0],
322 /// // [0.0, 0.0, 5.0]],
323 /// // [[0.0, 5.0, 0.0],
324 /// // [0.0, 0.0, 5.0]]]
325 /// }
326 /// ```
327 pub fn one_hot_fill<const D2: usize>(
328 self,
329 num_classes: usize,
330 on_value: f32,
331 off_value: f32,
332 axis: i64,
333 ) -> Tensor<B, D2, K> {
334 check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
335 // Initialize shape from the current tensor dimensions and prepare for modification
336 let mut shape = self.shape();
337 let device = self.device();
338 let rank = self.dims().len();
339
340 // Adjust negative axis to a positive index
341 let axis = if axis < 0 {
342 axis + rank as i64 + 1
343 } else {
344 axis
345 };
346
347 // Ensure axis is within valid range
348 if axis < 0 || axis > rank as i64 {
349 panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
350 }
351 // Convert the input tensor to integer indices
352 let indices: Tensor<B, D, Int> =
353 Tensor::from_data(self.to_data().convert::<i64>(), &device);
354 // Insert the new dimension for the one-hot representation
355 shape.insert(axis as usize, num_classes);
356 // Adjust indices to valid range and handle invalid indices
357 let adjusted_indices = indices
358 .clone()
359 .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
360 .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
361 // Unsqueeze the indices tensor along the specified axis
362 let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
363
364 // Initialize the output tensor with the off_value
365 let output = Tensor::full(shape.clone(), off_value, &device);
366
367 // Prepare scatter tensor for on_value and off_value adjustments
368 let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
369 - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
370
371 // Scatter on_value at the appropriate indices to create the one-hot representation
372 output.scatter(
373 axis as usize,
374 indices_unsqueezed,
375 scatter_on_values,
376 IndexingUpdateOp::Add,
377 )
378 }
379
380 /// Applies element wise greater comparison and returns a boolean tensor.
381 ///
382 /// # Panics
383 ///
384 /// If the two tensors don't have the same shape.
385 ///
386 /// # Example
387 ///
388 /// ```rust
389 /// use burn_tensor::backend::Backend;
390 /// use burn_tensor::{Tensor, Shape};
391 ///
392 /// fn example<B: Backend>() {
393 /// let device = B::Device::default();
394 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
396 /// let tensor = tensor1.greater(tensor2);
397 /// println!("{tensor}");
398 /// // [[false, false, false], [true, true, true]]
399 /// }
400 /// ```
401 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
402 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
403 Tensor::new(K::greater(self.primitive, other.primitive))
404 }
405
406 /// Applies element wise greater-equal comparison and returns a boolean tensor.
407 ///
408 /// # Panics
409 ///
410 /// If the two tensors don't have the same shape.
411 ///
412 /// # Example
413 ///
414 /// ```rust
415 /// use burn_tensor::backend::Backend;
416 /// use burn_tensor::{Tensor, Shape};
417 ///
418 /// fn example<B: Backend>() {
419 /// let device = B::Device::default();
420 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
421 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
422 /// let tensor = tensor1.greater_equal(tensor2);
423 /// println!("{tensor}");
424 /// // [[true, false, false], [true, true, true]]
425 /// }
426 /// ```
427 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
428 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
429 Tensor::new(K::greater_equal(self.primitive, other.primitive))
430 }
431
432 /// Applies element wise lower comparison and returns a boolean tensor.
433 ///
434 /// # Panics
435 ///
436 /// If the two tensors don't have the same shape.
437 ///
438 /// # Example
439 ///
440 /// ```rust
441 /// use burn_tensor::backend::Backend;
442 /// use burn_tensor::{Tensor, Shape};
443 ///
444 /// fn example<B: Backend>() {
445 /// let device = B::Device::default();
446 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
447 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
448 /// let tensor = tensor1.lower(tensor2);
449 /// println!("{tensor}");
450 /// // [[false, true, true], [false, false, false]]
451 /// }
452 /// ```
453 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
454 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
455 Tensor::new(K::lower(self.primitive, other.primitive))
456 }
457
458 /// Applies element wise lower-equal comparison and returns a boolean tensor.
459 ///
460 /// # Panics
461 ///
462 /// If the two tensors don't have the same shape.
463 ///
464 /// # Example
465 ///
466 /// ```rust
467 /// use burn_tensor::backend::Backend;
468 /// use burn_tensor::{Tensor, Shape};
469 ///
470 /// fn example<B: Backend>() {
471 /// let device = B::Device::default();
472 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
473 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
474 /// let tensor = tensor1.lower_equal(tensor2);
475 /// println!("{tensor}");
476 /// // [[true, true, true], [false, false, false]]
477 /// }
478 /// ```
479 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
480 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
481 Tensor::new(K::lower_equal(self.primitive, other.primitive))
482 }
483
484 /// Applies greater than `other` comparison and returns a boolean tensor.
485 ///
486 /// # Arguments
487 ///
488 /// * `other` - The element to compare.
489 ///
490 /// # Example
491 ///
492 /// ```rust
493 /// use burn_tensor::backend::Backend;
494 /// use burn_tensor::{Tensor, Shape};
495 ///
496 /// fn example<B: Backend>() {
497 /// let device = B::Device::default();
498 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
499 /// let tensor = tensor.greater_elem(3.0);
500 /// println!("{tensor}");
501 /// // [[false, false, true], [true, true, true]]
502 /// }
503 /// ```
504 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
505 let other = Scalar::new(other, &self.dtype());
506 Tensor::new(K::greater_elem(self.primitive, other))
507 }
508
509 /// Applies greater-equal than `other` comparison and returns a boolean tensor.
510 ///
511 /// # Arguments
512 ///
513 /// * `other` - The element to compare.
514 ///
515 /// # Example
516 ///
517 /// ```rust
518 /// use burn_tensor::backend::Backend;
519 /// use burn_tensor::{Tensor, Shape};
520 ///
521 /// fn example<B: Backend>() {
522 /// let device = B::Device::default();
523 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
524 /// let tensor = tensor.greater_equal_elem(3.0);
525 /// println!("{tensor}");
526 /// // [[false, false, true], [true, true, true]]
527 /// }
528 /// ```
529 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
530 let other = Scalar::new(other, &self.dtype());
531 Tensor::new(K::greater_equal_elem(self.primitive, other))
532 }
533
534 /// Applies lower than `other` comparison and returns a boolean tensor.
535 ///
536 /// # Arguments
537 ///
538 /// * `other` - The element to compare.
539 ///
540 /// # Example
541 ///
542 /// ```rust
543 /// use burn_tensor::backend::Backend;
544 /// use burn_tensor::{Tensor, Shape};
545 ///
546 /// fn example<B: Backend>() {
547 /// let device = B::Device::default();
548 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
549 /// let tensor = tensor.lower_elem(3.0);
550 /// println!("{tensor}");
551 /// // [[true, true, false], [false, false, false]]
552 /// }
553 /// ```
554 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
555 let other = Scalar::new(other, &self.dtype());
556 Tensor::new(K::lower_elem(self.primitive, other))
557 }
558
559 /// Applies lower-equal than `other` comparison and returns a boolean tensor.
560 ///
561 /// # Arguments
562 ///
563 /// * `other` - The element to compare.
564 ///
565 /// # Example
566 ///
567 /// ```rust
568 /// use burn_tensor::backend::Backend;
569 /// use burn_tensor::{Tensor, Shape};
570 ///
571 /// fn example<B: Backend>() {
572 /// let device = B::Device::default();
573 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
574 /// let tensor = tensor.lower_equal_elem(3.0);
575 /// println!("{tensor}");
576 /// // [[true, true, true], [false, false, false]]
577 /// }
578 /// ```
579 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
580 let other = Scalar::new(other, &self.dtype());
581 Tensor::new(K::lower_equal_elem(self.primitive, other))
582 }
583
584 /// Applies the argmax function along the given dimension and returns an integer tensor.
585 ///
586 /// # Example
587 ///
588 /// ```rust
589 /// use burn_tensor::backend::Backend;
590 /// use burn_tensor::{Tensor, Shape};
591 ///
592 /// fn example<B: Backend>() {
593 /// let device = B::Device::default();
594 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
595 /// let tensor = tensor.argmax(1);
596 /// println!("{:?}", tensor.shape());
597 /// // Shape { dims: [2, 1, 3] }
598 /// }
599 /// ```
600 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
601 Tensor::new(K::argmax(self.primitive, dim))
602 }
603
604 /// Find the maximum value.
605 ///
606 /// # Example
607 ///
608 /// ```rust
609 /// use burn_tensor::backend::Backend;
610 /// use burn_tensor::{Tensor, Shape};
611 ///
612 /// fn example<B: Backend>() {
613 /// let device = B::Device::default();
614 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
615 /// let tensor = tensor.max();
616 /// println!("{tensor}");
617 /// // [9.0]
618 /// }
619 /// ```
620 pub fn max(self) -> Tensor<B, 1, K> {
621 Tensor::new(K::max(self.primitive))
622 }
623
624 /// Find the maximum value along the given dimension.
625 ///
626 /// Also returns the indices.
627 ///
628 /// # Example
629 ///
630 /// ```rust
631 /// use burn_tensor::backend::Backend;
632 /// use burn_tensor::{Tensor, Shape};
633 ///
634 /// fn example<B: Backend>() {
635 /// let device = B::Device::default();
636 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
637 /// let (tensor, index) = tensor.max_dim_with_indices(0);
638 /// // [[5.0, 9.0, 6.0]]
639 /// println!("{tensor}");
640 /// // [[1, 1, 1]]
641 /// println!("{index}");
642 /// }
643 /// ```
644 pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
645 let dim = dim.expect_dim_index(D);
646 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
647
648 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
649
650 let tensor = Tensor::new(tensor);
651 let index = Tensor::new(index);
652
653 (tensor, index)
654 }
655
656 /// Find the maximum absolute value.
657 ///
658 /// # Example
659 ///
660 /// ```rust
661 /// use burn_tensor::backend::Backend;
662 /// use burn_tensor::{Tensor, Shape};
663 ///
664 /// fn example<B: Backend>() {
665 /// let device = B::Device::default();
666 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
667 /// let tensor = tensor.max_abs();
668 /// println!("{tensor}");
669 /// // [7.0]
670 /// }
671 /// ```
672 pub fn max_abs(self) -> Tensor<B, 1, K> {
673 Tensor::new(K::max_abs(self.primitive))
674 }
675
676 /// Finds the maximum pair wise values with another tensor.
677 ///
678 /// # Arguments
679 ///
680 /// * `other` - Other tensor to find maximum elements with
681 ///
682 /// # Returns
683 ///
684 /// A tensor with the same shape as the input tensors containing the maximum value found
685 /// in the input tensors.
686 ///
687 /// # Example
688 ///
689 /// ```rust
690 /// use burn_tensor::backend::Backend;
691 /// use burn_tensor::{Tensor, Shape};
692 ///
693 /// fn example<B: Backend>() {
694 /// let device = B::Device::default();
695 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
696 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
697 /// let tensor = tensor1.max_pair(tensor2);
698 /// println!("{tensor}");
699 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
700 /// }
701 /// ```
702 pub fn max_pair(self, other: Self) -> Self {
703 let mask = self.clone().lower(other.clone());
704 self.mask_where(mask, other)
705 }
706
707 /// Find the maximum absolute value along the given dimension.
708 ///
709 /// # Arguments
710 ///
711 /// * `dim` - The dimension or axis along which to aggregate the elements,
712 /// supports negative indexing.
713 ///
714 /// # Returns
715 ///
716 /// The returned tensor will have the same rank,
717 /// but the aggregated dimension will have size 1.
718 ///
719 /// # Example
720 ///
721 /// ```rust
722 /// use burn_tensor::backend::Backend;
723 /// use burn_tensor::{Tensor, Shape};
724 ///
725 /// fn example<B: Backend>() {
726 /// let device = B::Device::default();
727 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
728 /// let tensor = tensor.max_dim(0);
729 /// println!("{tensor}");
730 /// // [[5.0, 9.0, 6.0]]
731 /// }
732 /// ```
733 pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
734 let dim = dim.expect_dim_index(D);
735 check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
736
737 Tensor::new(K::max_abs_dim(self.primitive, dim))
738 }
739
740 /// Find the maximum absolute value along the given dimensions.
741 ///
742 /// # Arguments
743 ///
744 /// * `dims` - The dimensions or axes along which to aggregate the elements,
745 /// supports negative indexing.
746 ///
747 /// # Returns
748 ///
749 /// The returned tensor will have the same rank,
750 /// but the aggregated dimensions will have size 1.
751 ///
752 /// # Example
753 ///
754 /// ```rust
755 /// use burn_tensor::backend::Backend;
756 /// use burn_tensor::{Tensor, Shape};
757 ///
758 /// fn example<B: Backend>() {
759 /// let device = B::Device::default();
760 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
761 /// let tensor = tensor.max_abs_dims(&[0, 1]);
762 /// println!("{tensor}");
763 /// // [[9.0]]
764 /// }
765 /// ```
766 pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
767 dims.iter()
768 .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
769 }
770
771 /// Applies the argmin function along the given dimension and returns an integer tensor.
772 ///
773 /// # Example
774 ///
775 /// ```rust
776 /// use burn_tensor::backend::Backend;
777 /// use burn_tensor::{Tensor, Shape};
778 ///
779 /// fn example<B: Backend>() {
780 /// let device = Default::default();
781 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
782 /// let tensor = tensor.argmin(1);
783 /// println!("{:?}", tensor.shape());
784 /// // Shape { dims: [2, 1, 3] }
785 /// }
786 /// ```
787 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
788 Tensor::new(K::argmin(self.primitive, dim))
789 }
790
791 /// Find the minimum value.
792 ///
793 /// # Example
794 ///
795 /// ```rust
796 /// use burn_tensor::backend::Backend;
797 /// use burn_tensor::{Tensor, Shape};
798 ///
799 /// fn example<B: Backend>() {
800 /// let device = B::Device::default();
801 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
802 /// let tensor = tensor.min();
803 /// println!("{tensor}");
804 /// // [-2.0]
805 /// }
806 /// ```
807 pub fn min(self) -> Tensor<B, 1, K> {
808 Tensor::new(K::min(self.primitive))
809 }
810
811 /// Find the minimum value along the given dimension.
812 ///
813 /// # Arguments
814 ///
815 /// * `dim` - The dimension or axis along which to aggregate the elements;
816 /// supports negative indexing.
817 ///
818 /// # Returns
819 ///
820 /// The returned tensor will have the same rank,
821 /// but the aggregated dimension will have size 1.
822 ///
823 /// # Example
824 ///
825 /// ```rust
826 /// use burn_tensor::backend::Backend;
827 /// use burn_tensor::{Tensor, Shape};
828 ///
829 /// fn example<B: Backend>() {
830 /// let device = B::Device::default();
831 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
832 /// let tensor = tensor.min_dim(0);
833 /// println!("{tensor}");
834 /// // [[1.0, -2.0, 3.0]]
835 /// }
836 /// ```
837 pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
838 let dim = dim.expect_dim_index(D);
839 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
840 Tensor::new(K::min_dim(self.primitive, dim))
841 }
842
843 /// Find the minimum value along the given dimensions.
844 ///
845 /// # Arguments
846 ///
847 /// * `dims` - The dimensions or axes along which to aggregate the elements;
848 /// supports negative indexing.
849 ///
850 /// # Returns
851 ///
852 /// The returned tensor will have the same rank,
853 /// but the aggregated dimensions will have size 1.
854 ///
855 /// # Example
856 ///
857 /// ```rust
858 /// use burn_tensor::backend::Backend;
859 /// use burn_tensor::{Tensor, Shape};
860 ///
861 /// fn example<B: Backend>() {
862 /// let device = B::Device::default();
863 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
864 /// let tensor = tensor.min_dims(&[0, 1]);
865 /// println!("{tensor}");
866 /// // [[-2.0]]
867 /// }
868 /// ```
869 pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
870 dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
871 }
872
873 /// Find the minimum value along the given dimension.
874 ///
875 /// Also returns the indices.
876 ///
877 /// # Example
878 ///
879 /// ```rust
880 /// use burn_tensor::backend::Backend;
881 /// use burn_tensor::{Tensor, Shape};
882 ///
883 /// fn example<B: Backend>() {
884 /// let device = B::Device::default();
885 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
886 /// let (tensor, index) = tensor.min_dim_with_indices(0);
887 /// println!("{tensor}");
888 /// // [[5.0, -2.0, 3.0]]
889 /// println!("{}", index);
890 /// // [[1, 0, 0]]
891 /// }
892 /// ```
893 pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
894 let dim = dim.expect_dim_index(D);
895 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
896
897 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
898
899 let tensor = Tensor::new(tensor);
900 let index = Tensor::new(index);
901
902 (tensor, index)
903 }
904
905 /// Finds the minimum pair wise values with another tensor.
906 ///
907 /// # Arguments
908 ///
909 /// * `other` - Other tensor to find minimum elements with
910 ///
911 /// # Returns
912 ///
913 /// A tensor with the same shape as the input tensors containing the minimum value found
914 /// between each element of the two source tensors.
915 ///
916 /// # Example
917 ///
918 /// ```rust
919 /// use burn_tensor::backend::Backend;
920 /// use burn_tensor::{Tensor, Shape};
921 ///
922 /// fn example<B: Backend>() {
923 /// let device = B::Device::default();
924 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
925 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
926 /// let tensor = tensor1.min_pair(tensor2);
927 /// println!("{tensor}");
928 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
929 /// }
930 pub fn min_pair(self, other: Self) -> Self {
931 let mask = other.clone().lower(self.clone());
932 self.mask_where(mask, other)
933 }
934
935 /// Clamp element wise between the given min and max values.
936 ///
937 /// # Arguments
938 ///
939 /// * `min` - The minimum value.
940 /// * `max` - The maximum value.
941 ///
942 /// # Returns
943 ///
944 /// A new tensor with the values clamped between the given min and max values.
945 ///
946 /// # Example
947 ///
948 /// ```rust
949 /// use burn_tensor::backend::Backend;
950 /// use burn_tensor::{Int, Tensor};
951 ///
952 /// fn example<B: Backend>() {
953 /// let device = Default::default();
954 /// let tensor = Tensor::<B, 2, Int>::from_ints(
955 /// [
956 /// [1, 2, 3],
957 /// [4, 5, 6],
958 /// [7, 8, 9]
959 /// ],
960 /// &device);
961 /// let tensor = tensor.clamp(2, 6);
962 /// println!("{tensor}");
963 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
964 /// }
965 /// ```
966 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
967 let dtype = self.dtype();
968 Self::new(K::clamp(
969 self.primitive,
970 Scalar::new(min, &dtype),
971 Scalar::new(max, &dtype),
972 ))
973 }
974
975 /// Clamp element wise under a minimum value.
976 ///
977 /// # Arguments
978 ///
979 /// * `tensor` - The tensor to clamp.
980 /// * `min` - The minimum value.
981 ///
982 /// # Returns
983 ///
984 /// A new tensor with the values clamped under the given min value.
985 ///
986 /// # Example
987 ///
988 /// ```rust
989 /// use burn_tensor::backend::Backend;
990 /// use burn_tensor::{Int, Tensor};
991 ///
992 /// fn example<B: Backend>() {
993 /// let device = Default::default();
994 /// let tensor = Tensor::<B, 2, Int>::from_ints(
995 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
996 /// &device);
997 /// let tensor = tensor.clamp_min(4);
998 /// println!("{tensor}");
999 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1000 /// }
1001 /// ```
1002 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1003 let min = Scalar::new(min, &self.dtype());
1004 Self::new(K::clamp_min(self.primitive, min))
1005 }
1006
1007 /// Clamp element wise over a maximum value.
1008 ///
1009 /// # Arguments
1010 ///
1011 /// * `tensor` - The tensor to clamp.
1012 /// * `max` - The maximum value.
1013 ///
1014 /// # Returns
1015 ///
1016 /// A new tensor with the values clamped over the given max value.
1017 ///
1018 /// # Example
1019 ///
1020 /// ```rust
1021 /// use burn_tensor::backend::Backend;
1022 /// use burn_tensor::{Int, Tensor};
1023 ///
1024 /// fn example<B: Backend>() {
1025 /// let device = Default::default();
1026 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1027 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1028 /// &device);
1029 /// let tensor = tensor.clamp_max(5);
1030 /// println!("{tensor}");
1031 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1032 /// }
1033 /// ```
1034 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1035 let max = Scalar::new(max, &self.dtype());
1036 Self::new(K::clamp_max(self.primitive, max))
1037 }
1038
1039 /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
1040 ///
1041 /// # Arguments
1042 ///
1043 /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
1044 ///
1045 /// # Example
1046 ///
1047 /// ```rust
1048 /// use burn_tensor::backend::Backend;
1049 /// use burn_tensor::{Tensor, Shape};
1050 ///
1051 /// fn example<B: Backend>() {
1052 /// let device = B::Device::default();
1053 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
1054 /// let result = tensor.clone().cummin(0);
1055 /// println!("{result}");
1056 /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
1057 /// let result = tensor.cummin(1);
1058 /// println!("{result}");
1059 /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
1060 /// }
1061 /// ```
1062 pub fn cummin(self, dim: usize) -> Self {
1063 check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
1064 Self::new(K::cummin(self.primitive, dim))
1065 }
1066
1067 /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
1068 ///
1069 /// # Arguments
1070 ///
1071 /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
1072 ///
1073 /// # Example
1074 ///
1075 /// ```rust
1076 /// use burn_tensor::backend::Backend;
1077 /// use burn_tensor::{Tensor, Shape};
1078 ///
1079 /// fn example<B: Backend>() {
1080 /// let device = B::Device::default();
1081 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
1082 /// let result = tensor.clone().cummax(0);
1083 /// println!("{result}");
1084 /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
1085 /// let result = tensor.cummax(1);
1086 /// println!("{result}");
1087 /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
1088 /// }
1089 /// ```
1090 pub fn cummax(self, dim: usize) -> Self {
1091 check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
1092 Self::new(K::cummax(self.primitive, dim))
1093 }
1094 /// Find the maximum value along the given dimension.
1095 ///
1096 /// # Arguments
1097 ///
1098 /// * `dim` - The dimension or axis along which to aggregate the elements;
1099 /// supports negative indexing.
1100 ///
1101 /// # Returns
1102 ///
1103 /// The returned tensor will have the same rank,
1104 /// but the aggregated dimension will have size 1.
1105 ///
1106 /// # Example
1107 ///
1108 /// ```rust
1109 /// use burn_tensor::backend::Backend;
1110 /// use burn_tensor::{Tensor, Shape};
1111 ///
1112 /// fn example<B: Backend>() {
1113 /// let device = B::Device::default();
1114 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1115 /// let tensor = tensor.max_dim(0);
1116 /// println!("{tensor}");
1117 /// // [[5.0, 9.0, 6.0]]
1118 /// }
1119 /// ```
1120 pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1121 let dim = dim.expect_dim_index(D);
1122 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1123 Tensor::new(K::max_dim(self.primitive, dim))
1124 }
1125
1126 /// Find the maximum value along the given dimensions.
1127 ///
1128 /// # Arguments
1129 ///
1130 /// * `dims` - The dimensions or axis along which to aggregate the elements;
1131 /// supports negative indexing.
1132 ///
1133 /// # Returns
1134 ///
1135 /// The returned tensor will have the same rank,
1136 /// but the aggregated dimensions will have size 1.
1137 ///
1138 /// # Example
1139 ///
1140 /// ```rust
1141 /// use burn_tensor::backend::Backend;
1142 /// use burn_tensor::{Tensor, Shape};
1143 ///
1144 /// fn example<B: Backend>() {
1145 /// let device = B::Device::default();
1146 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1147 /// let tensor = tensor.max_dims(&[0, 1]);
1148 /// println!("{tensor}");
1149 /// // [[9.0]]
1150 /// }
1151 /// ```
1152 pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1153 dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1154 }
1155}