burn_tensor/tensor/api/numeric.rs
1pub use burn_backend::tensor::Numeric;
2
3use crate::alloc::borrow::ToOwned;
4use alloc::vec::Vec;
5
6use crate::IndexingUpdateOp;
7use crate::{
8 AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
9 check, check::TensorCheck,
10};
11
12impl<B, const D: usize, K> Tensor<B, D, K>
13where
14 B: Backend,
15 K: Numeric<B>,
16 K::Elem: Element,
17{
18 /// Applies element wise addition operation.
19 ///
20 /// `y = x2 + x1`
21 ///
22 /// # Arguments
23 ///
24 /// * `other` - The tensor to add.
25 ///
26 /// # Example
27 ///
28 /// ```rust
29 /// use burn_tensor::backend::Backend;
30 /// use burn_tensor::{Tensor, Shape};
31 ///
32 /// fn example<B: Backend>() {
33 /// let device = B::Device::default();
34 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
35 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
36 /// let tensor = tensor1 + tensor2;
37 /// println!("{tensor}");
38 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
39 /// }
40 /// ```
41 #[allow(clippy::should_implement_trait)]
42 pub fn add(self, other: Self) -> Self {
43 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
44 Self::new(K::add(self.primitive, other.primitive))
45 }
46
47 /// Applies element wise addition operation with a scalar.
48 ///
49 /// `y = x + s`
50 ///
51 /// # Arguments
52 ///
53 /// * `other` - The scalar to add, element wise.
54 ///
55 /// # Example
56 ///
57 /// ```rust
58 /// use burn_tensor::backend::Backend;
59 /// use burn_tensor::{Tensor, Shape};
60 ///
61 /// fn example<B: Backend>() {
62 /// let device = B::Device::default();
63 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
64 /// let scalar = 2.0;
65 /// let tensor = tensor + scalar;
66 /// println!("{tensor}");
67 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
68 /// }
69 /// ```
70 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
71 Self::new(K::add_scalar::<E>(self.primitive, other))
72 }
73
74 /// Applies element wise subtraction operation.
75 ///
76 /// `y = x2 - x1`
77 ///
78 /// # Arguments
79 ///
80 /// * `other` - The tensor to subtract.
81 ///
82 /// # Example
83 ///
84 /// ```rust
85 /// use burn_tensor::backend::Backend;
86 /// use burn_tensor::{Tensor, Shape};
87 ///
88 /// fn example<B: Backend>() {
89 /// let device = B::Device::default();
90 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
91 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
92 /// let tensor = tensor1 - tensor2;
93 /// println!("{tensor}");
94 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
95 /// }
96 /// ```
97 #[allow(clippy::should_implement_trait)]
98 pub fn sub(self, other: Self) -> Self {
99 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
100 Self::new(K::sub(self.primitive, other.primitive))
101 }
102
103 /// Applies element wise subtraction operation with a scalar.
104 ///
105 /// `y = x - s`
106 ///
107 /// # Arguments
108 ///
109 /// * `other` - The scalar to subtract, element wise.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use burn_tensor::backend::Backend;
115 /// use burn_tensor::{Tensor, Shape};
116 ///
117 /// fn example<B: Backend>() {
118 /// let device = B::Device::default();
119 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
120 /// let scalar = 2.0;
121 /// let tensor = tensor - scalar;
122 /// println!("{tensor}");
123 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
124 /// }
125 /// ```
126 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
127 Self::new(K::sub_scalar::<E>(self.primitive, other))
128 }
129
130 /// Applies element wise division operation.
131 ///
132 /// `y = x2 / x1`
133 ///
134 /// # Arguments
135 ///
136 /// * `other` - The tensor to divide.
137 ///
138 /// # Example
139 ///
140 /// ```rust
141 /// use burn_tensor::backend::Backend;
142 /// use burn_tensor::{Tensor, Shape};
143 ///
144 /// fn example<B: Backend>() {
145 /// let device = B::Device::default();
146 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
147 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
148 /// let tensor = tensor1 / tensor2;
149 /// println!("{tensor}");
150 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
151 /// }
152 /// ```
153 #[allow(clippy::should_implement_trait)]
154 pub fn div(self, other: Self) -> Self {
155 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
156 Self::new(K::div(self.primitive, other.primitive))
157 }
158
159 /// Applies element wise division operation with a scalar.
160 ///
161 /// `y = x / s`
162 ///
163 /// # Arguments
164 ///
165 /// * `other` - The scalar to divide, element wise.
166 ///
167 /// # Example
168 ///
169 /// ```rust
170 /// use burn_tensor::backend::Backend;
171 /// use burn_tensor::{Tensor, Shape};
172 ///
173 /// fn example<B: Backend>() {
174 /// let device = B::Device::default();
175 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
176 /// let scalar = 2.0;
177 /// let tensor = tensor / scalar;
178 /// println!("{tensor}");
179 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
180 /// }
181 /// ```
182 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
183 Self::new(K::div_scalar::<E>(self.primitive, other))
184 }
185
186 /// Applies element wise the remainder operation with a scalar.
187 ///
188 /// `y = x2 % x1`
189 pub fn remainder(self, other: Self) -> Self {
190 Self::new(K::remainder(self.primitive, other.primitive))
191 }
192
193 /// Applies element wise the remainder operation with a scalar.
194 ///
195 /// `y = x % s`
196 ///
197 /// # Arguments
198 ///
199 /// * `other` - The scalar to divide, element wise.
200 ///
201 /// # Example
202 ///
203 /// ```rust
204 /// use burn_tensor::backend::Backend;
205 /// use burn_tensor::{Tensor, Shape};
206 ///
207 /// fn example<B: Backend>() {
208 /// let device = B::Device::default();
209 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
210 /// let scalar = 2.0;
211 /// let tensor = tensor1 % scalar;
212 /// println!("{tensor}");
213 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
214 /// }
215 /// ```
216 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
217 Self::new(K::remainder_scalar::<E>(self.primitive, other))
218 }
219
220 /// Applies element wise multiplication operation.
221 ///
222 /// `y = x2 * x1`
223 ///
224 /// # Arguments
225 ///
226 /// * `other` - The tensor to multiply.
227 ///
228 /// # Example
229 ///
230 /// ```rust
231 /// use burn_tensor::backend::Backend;
232 /// use burn_tensor::{Tensor, Shape};
233 ///
234 /// fn example<B: Backend>() {
235 /// let device = B::Device::default();
236 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
237 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
238 /// let tensor = tensor1 * tensor2;
239 /// println!("{tensor}");
240 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
241 /// }
242 /// ```
243 #[allow(clippy::should_implement_trait)]
244 pub fn mul(self, other: Self) -> Self {
245 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
246 Self::new(K::mul(self.primitive, other.primitive))
247 }
248
249 /// Applies element wise multiplication operation with a scalar.
250 ///
251 /// `y = x * s`
252 ///
253 /// # Arguments
254 ///
255 /// * `other` - The scalar to multiply, element wise.
256 ///
257 /// # Example
258 ///
259 /// ```rust
260 /// use burn_tensor::backend::Backend;
261 /// use burn_tensor::{Tensor, Shape};
262 ///
263 /// fn example<B: Backend>() {
264 /// let device = B::Device::default();
265 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
266 /// let scalar = 2.0;
267 /// let tensor = tensor * scalar;
268 /// println!("{tensor}");
269 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
270 /// }
271 /// ```
272 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
273 Self::new(K::mul_scalar::<E>(self.primitive, other))
274 }
275
276 /// Switch sign of each element in the tensor.
277 ///
278 /// `y = -x`
279 ///
280 /// # Example
281 ///
282 /// ```rust
283 /// use burn_tensor::backend::Backend;
284 /// use burn_tensor::{Tensor, Shape};
285 ///
286 /// fn example<B: Backend>() {
287 /// let device = B::Device::default();
288 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
289 /// let tensor = -tensor;
290 /// println!("{tensor}");
291 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
292 /// }
293 /// ```
294 #[allow(clippy::should_implement_trait)]
295 pub fn neg(self) -> Self {
296 Self::new(K::neg(self.primitive))
297 }
298
299 /// Returns the signs of the elements of the input tensor.
300 ///
301 /// # Example
302 ///
303 /// ```rust
304 /// use burn_tensor::backend::Backend;
305 /// use burn_tensor::{Tensor, Shape};
306 ///
307 /// fn example<B: Backend>() {
308 /// let device = B::Device::default();
309 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
310 /// let tensor = tensor.sign();
311 /// println!("{tensor}");
312 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
313 /// }
314 /// ```
315 pub fn sign(self) -> Self {
316 Self::new(K::sign(self.primitive))
317 }
318
319 /// Aggregate all elements in the tensor with the mean operation.
320 ///
321 /// # Example
322 ///
323 /// ```rust
324 /// use burn_tensor::backend::Backend;
325 /// use burn_tensor::{Tensor, Shape};
326 ///
327 /// fn example<B: Backend>() {
328 /// let device = B::Device::default();
329 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
330 /// let tensor = tensor.mean();
331 /// println!("{tensor}");
332 /// // [3.6666667]
333 /// }
334 /// ```
335 pub fn mean(self) -> Tensor<B, 1, K> {
336 Tensor::new(K::mean(self.primitive))
337 }
338
339 /// Aggregate all elements in the tensor with the sum operation.
340 ///
341 /// # Example
342 ///
343 /// ```rust
344 /// use burn_tensor::backend::Backend;
345 /// use burn_tensor::{Tensor, Shape};
346 ///
347 /// fn example<B: Backend>() {
348 /// let device = B::Device::default();
349 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
350 /// let tensor = tensor.sum();
351 /// println!("{tensor}");
352 /// // [22.0]
353 /// }
354 /// ```
355 pub fn sum(self) -> Tensor<B, 1, K> {
356 Tensor::new(K::sum(self.primitive))
357 }
358
359 /// Aggregate all elements along the given *dimension* or *axis*
360 /// in the tensor with the mean operation.
361 ///
362 /// # Arguments
363 ///
364 /// * `dim` - The dimension or axis along which to aggregate the elements;
365 /// supports negative indexing.
366 ///
367 /// # Example
368 ///
369 /// ```rust
370 /// use burn_tensor::backend::Backend;
371 /// use burn_tensor::{Tensor, Shape};
372 ///
373 /// fn example<B: Backend>() {
374 /// let device = B::Device::default();
375 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
376 /// let tensor = tensor.clone().mean_dim(0);
377 /// println!("{tensor}");
378 /// // [[3.0, 3.5, 4.5]]
379 /// let tensor = tensor.clone().mean_dim(1);
380 /// println!("{tensor}");
381 /// // [[0.6666667], [6.6666665]]
382 /// }
383 /// ```
384 pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
385 let dim = dim.expect_dim_index(D);
386 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
387 Self::new(K::mean_dim(self.primitive, dim))
388 }
389
390 /// Aggregate all elements along the given *axes*
391 /// in the tensor with the mean operation.
392 ///
393 /// # Arguments
394 ///
395 /// * `dims` - the dimensions to aggregate; supports negative indexing.
396 ///
397 /// # Returns
398 ///
399 /// The returned tensor will have the same rank,
400 /// but the aggregated dimensions will have size 1.
401 ///
402 /// # Example
403 ///
404 /// ```rust
405 /// use burn_tensor::backend::Backend;
406 /// use burn_tensor::{Tensor, Shape};
407 ///
408 /// fn example<B: Backend>() {
409 /// let device = B::Device::default();
410 /// let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
411 /// let tensor = tensor.clone().mean_dims(&[0, 1]);
412 /// println!("{tensor}");
413 /// // [[2.0]]
414 /// }
415 /// ```
416 pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
417 dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
418 }
419
420 /// Aggregate all elements along the given *dimension* or *axis*
421 /// in the tensor with the sum operation.
422 ///
423 /// # Arguments
424 ///
425 /// * `dim` - The dimension or axis along which to aggregate the elements;
426 /// supports negative indexing.
427 ///
428 /// # Example
429 ///
430 /// ```rust
431 /// use burn_tensor::backend::Backend;
432 /// use burn_tensor::{Tensor, Shape};
433 ///
434 /// fn example<B: Backend>() {
435 /// let device = B::Device::default();
436 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
437 /// let tensor = tensor.clone().sum_dim(0);
438 /// println!("{tensor}");
439 /// // [[6.0, 7.0, 9.0]]
440 /// let tensor = tensor.clone().sum_dim(1);
441 /// println!("{tensor}");
442 /// // [[2.0], [20.0]]
443 /// }
444 /// ```
445 pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
446 let dim = dim.expect_dim_index(D);
447 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
448 Self::new(K::sum_dim(self.primitive, dim))
449 }
450
451 /// Aggregate all elements along the given *axes*
452 /// in the tensor with the sum operation.
453 ///
454 /// # Arguments
455 ///
456 /// * `dims` - the dimensions to aggregate; supports negative indexing.
457 ///
458 /// # Returns
459 ///
460 /// The returned tensor will have the same rank,
461 /// but the aggregated dimensions will have size 1.
462 ///
463 /// # Example
464 ///
465 /// ```rust
466 /// use burn_tensor::backend::Backend;
467 /// use burn_tensor::{Tensor, Shape};
468 ///
469 /// fn example<B: Backend>() {
470 /// let device = B::Device::default();
471 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
472 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
473 /// println!("{tensor}");
474 /// // [[27]]
475 /// }
476 /// ```
477 pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
478 dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
479 }
480
481 /// Aggregate and squeeze along the given dimensions.
482 ///
483 /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
484 ///
485 /// # Arguments
486 ///
487 /// * `dims` - the dimensions to aggregate; supports negative indexing.
488 ///
489 /// # Returns
490 ///
491 /// The returned tensor will have the same rank,
492 /// but the aggregated dimensions will have size 1.
493 ///
494 /// # Example
495 ///
496 /// ```rust
497 /// use burn_tensor::backend::Backend;
498 /// use burn_tensor::{Tensor, Shape};
499 ///
500 /// fn example<B: Backend>() {
501 /// let device = B::Device::default();
502 /// let tensor = Tensor::<B, 3>::from_data([
503 /// [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
504 /// [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
505 /// ], &device);
506 /// let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
507 /// println!("{tensor}");
508 /// // [20.0, 16.0, 21.0]
509 /// }
510 /// ```
511 pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
512 // TODO: remove idims when squeeze_dims uses AsIndex.
513 let idims = dims
514 .iter()
515 .map(|&dim| (dim.expect_dim_index(D)) as isize)
516 .collect::<Vec<_>>();
517 self.sum_dims(dims).squeeze_dims::<D2>(&idims)
518 }
519
520 /// Aggregate all elements in the tensor with the product operation.
521 ///
522 /// # Example
523 ///
524 /// ```rust
525 /// use burn_tensor::backend::Backend;
526 /// use burn_tensor::{Tensor, Shape};
527 ///
528 /// fn example<B: Backend>() {
529 /// let device = B::Device::default();
530 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
531 /// let tensor = tensor.prod();
532 /// println!("{tensor}");
533 /// // [-1620.0]
534 /// }
535 /// ```
536 pub fn prod(self) -> Tensor<B, 1, K> {
537 Tensor::new(K::prod(self.primitive))
538 }
539
540 /// Aggregate all elements along the given *dimension* or *axis*
541 /// in the tensor with the product operation.
542 ///
543 /// # Arguments
544 ///
545 /// * `dim` - The dimension or axis along which to aggregate the elements,
546 /// supports negative indexing.
547 ///
548 /// # Returns
549 ///
550 /// The returned tensor will have the same rank,
551 /// but the aggregated dimension will have size 1.
552 ///
553 /// # Example
554 ///
555 /// ```rust
556 /// use burn_tensor::backend::Backend;
557 /// use burn_tensor::{Tensor, Shape};
558 ///
559 /// fn example<B: Backend>() {
560 /// let device = B::Device::default();
561 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
562 /// let tensor = tensor.clone().prod_dim(0);
563 /// println!("{tensor}");
564 /// // [[5.0, -18.0, 18.0]]
565 /// let tensor = tensor.clone().prod_dim(1);
566 /// println!("{tensor}");
567 /// // [[-6.0], [270.0]]
568 /// }
569 /// ```
570 pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
571 let dim = dim.expect_dim_index(D);
572 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
573 Self::new(K::prod_dim(self.primitive, dim))
574 }
575
576 /// Aggregate all elements along the given *axes*
577 /// in the tensor with the prod operation.
578 ///
579 /// # Arguments
580 ///
581 /// * `dims` - the dimensions to aggregate, supports negative indexing.
582 ///
583 /// # Returns
584 ///
585 /// The returned tensor will have the same rank,
586 /// but the aggregated dimensions will have size 1.
587 ///
588 /// # Example
589 ///
590 /// ```rust
591 /// use burn_tensor::backend::Backend;
592 /// use burn_tensor::{Tensor, Shape};
593 ///
594 /// fn example<B: Backend>() {
595 /// let device = B::Device::default();
596 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
597 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
598 /// println!("{tensor}");
599 /// // [[-1620.0]]
600 /// }
601 /// ```
602 pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
603 dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
604 }
605
606 /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
607 ///
608 /// # Arguments
609 ///
610 /// * `dim` - The dimension or axis along which to compute the cumulative sum.
611 ///
612 /// # Example
613 ///
614 /// ```rust
615 /// use burn_tensor::backend::Backend;
616 /// use burn_tensor::{Tensor, Shape};
617 ///
618 /// fn example<B: Backend>() {
619 /// let device = B::Device::default();
620 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
621 /// let result = tensor.clone().cumsum(0);
622 /// println!("{result}");
623 /// // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
624 /// let result = tensor.cumsum(1);
625 /// println!("{result}");
626 /// // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
627 /// }
628 /// ```
629 pub fn cumsum(self, dim: usize) -> Self {
630 check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
631 Self::new(K::cumsum(self.primitive, dim))
632 }
633
634 /// Computes the cumulative product of elements along the given *dimension* or *axis*.
635 ///
636 /// # Arguments
637 ///
638 /// * `dim` - The dimension or axis along which to compute the cumulative product.
639 ///
640 /// # Example
641 ///
642 /// ```rust
643 /// use burn_tensor::backend::Backend;
644 /// use burn_tensor::{Tensor, Shape};
645 ///
646 /// fn example<B: Backend>() {
647 /// let device = B::Device::default();
648 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
649 /// let result = tensor.clone().cumprod(0);
650 /// println!("{result}");
651 /// // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
652 /// let result = tensor.cumprod(1);
653 /// println!("{result}");
654 /// // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
655 /// }
656 /// ```
657 pub fn cumprod(self, dim: usize) -> Self {
658 check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
659 Self::new(K::cumprod(self.primitive, dim))
660 }
661
662 /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
663 ///
664 /// # Arguments
665 ///
666 /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
667 ///
668 /// # Example
669 ///
670 /// ```rust
671 /// use burn_tensor::backend::Backend;
672 /// use burn_tensor::{Tensor, Shape};
673 ///
674 /// fn example<B: Backend>() {
675 /// let device = B::Device::default();
676 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
677 /// let result = tensor.clone().cummin(0);
678 /// println!("{result}");
679 /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
680 /// let result = tensor.cummin(1);
681 /// println!("{result}");
682 /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
683 /// }
684 /// ```
685 pub fn cummin(self, dim: usize) -> Self {
686 check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
687 Self::new(K::cummin(self.primitive, dim))
688 }
689
690 /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
691 ///
692 /// # Arguments
693 ///
694 /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
695 ///
696 /// # Example
697 ///
698 /// ```rust
699 /// use burn_tensor::backend::Backend;
700 /// use burn_tensor::{Tensor, Shape};
701 ///
702 /// fn example<B: Backend>() {
703 /// let device = B::Device::default();
704 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
705 /// let result = tensor.clone().cummax(0);
706 /// println!("{result}");
707 /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
708 /// let result = tensor.cummax(1);
709 /// println!("{result}");
710 /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
711 /// }
712 /// ```
713 pub fn cummax(self, dim: usize) -> Self {
714 check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
715 Self::new(K::cummax(self.primitive, dim))
716 }
717 /// Applies element wise greater comparison and returns a boolean tensor.
718 ///
719 /// # Panics
720 ///
721 /// If the two tensors don't have the same shape.
722 ///
723 /// # Example
724 ///
725 /// ```rust
726 /// use burn_tensor::backend::Backend;
727 /// use burn_tensor::{Tensor, Shape};
728 ///
729 /// fn example<B: Backend>() {
730 /// let device = B::Device::default();
731 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
732 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
733 /// let tensor = tensor1.greater(tensor2);
734 /// println!("{tensor}");
735 /// // [[false, false, false], [true, true, true]]
736 /// }
737 /// ```
738 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
739 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
740 Tensor::new(K::greater(self.primitive, other.primitive))
741 }
742
743 /// Applies element wise greater-equal comparison and returns a boolean tensor.
744 ///
745 /// # Panics
746 ///
747 /// If the two tensors don't have the same shape.
748 ///
749 /// # Example
750 ///
751 /// ```rust
752 /// use burn_tensor::backend::Backend;
753 /// use burn_tensor::{Tensor, Shape};
754 ///
755 /// fn example<B: Backend>() {
756 /// let device = B::Device::default();
757 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
758 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
759 /// let tensor = tensor1.greater_equal(tensor2);
760 /// println!("{tensor}");
761 /// // [[true, false, false], [true, true, true]]
762 /// }
763 /// ```
764 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
765 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
766 Tensor::new(K::greater_equal(self.primitive, other.primitive))
767 }
768
769 /// Applies element wise lower comparison and returns a boolean tensor.
770 ///
771 /// # Panics
772 ///
773 /// If the two tensors don't have the same shape.
774 ///
775 /// # Example
776 ///
777 /// ```rust
778 /// use burn_tensor::backend::Backend;
779 /// use burn_tensor::{Tensor, Shape};
780 ///
781 /// fn example<B: Backend>() {
782 /// let device = B::Device::default();
783 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
784 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
785 /// let tensor = tensor1.lower(tensor2);
786 /// println!("{tensor}");
787 /// // [[false, true, true], [false, false, false]]
788 /// }
789 /// ```
790 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
791 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
792 Tensor::new(K::lower(self.primitive, other.primitive))
793 }
794
795 /// Applies element wise lower-equal comparison and returns a boolean tensor.
796 ///
797 /// # Panics
798 ///
799 /// If the two tensors don't have the same shape.
800 ///
801 /// # Example
802 ///
803 /// ```rust
804 /// use burn_tensor::backend::Backend;
805 /// use burn_tensor::{Tensor, Shape};
806 ///
807 /// fn example<B: Backend>() {
808 /// let device = B::Device::default();
809 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
810 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
811 /// let tensor = tensor1.lower_equal(tensor2);
812 /// println!("{tensor}");
813 /// // [[true, true, true], [false, false, false]]
814 /// }
815 /// ```
816 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
817 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
818 Tensor::new(K::lower_equal(self.primitive, other.primitive))
819 }
820
821 /// Applies greater than `other` comparison and returns a boolean tensor.
822 ///
823 /// # Arguments
824 ///
825 /// * `other` - The element to compare.
826 ///
827 /// # Example
828 ///
829 /// ```rust
830 /// use burn_tensor::backend::Backend;
831 /// use burn_tensor::{Tensor, Shape};
832 ///
833 /// fn example<B: Backend>() {
834 /// let device = B::Device::default();
835 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
836 /// let tensor = tensor.greater_elem(3.0);
837 /// println!("{tensor}");
838 /// // [[false, false, true], [true, true, true]]
839 /// }
840 /// ```
841 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
842 Tensor::new(K::greater_elem(self.primitive, other.elem()))
843 }
844
845 /// Applies greater-equal than `other` comparison and returns a boolean tensor.
846 ///
847 /// # Arguments
848 ///
849 /// * `other` - The element to compare.
850 ///
851 /// # Example
852 ///
853 /// ```rust
854 /// use burn_tensor::backend::Backend;
855 /// use burn_tensor::{Tensor, Shape};
856 ///
857 /// fn example<B: Backend>() {
858 /// let device = B::Device::default();
859 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
860 /// let tensor = tensor.greater_equal_elem(3.0);
861 /// println!("{tensor}");
862 /// // [[false, false, true], [true, true, true]]
863 /// }
864 /// ```
865 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
866 Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
867 }
868
869 /// Applies lower than `other` comparison and returns a boolean tensor.
870 ///
871 /// # Arguments
872 ///
873 /// * `other` - The element to compare.
874 ///
875 /// # Example
876 ///
877 /// ```rust
878 /// use burn_tensor::backend::Backend;
879 /// use burn_tensor::{Tensor, Shape};
880 ///
881 /// fn example<B: Backend>() {
882 /// let device = B::Device::default();
883 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
884 /// let tensor = tensor.lower_elem(3.0);
885 /// println!("{tensor}");
886 /// // [[true, true, false], [false, false, false]]
887 /// }
888 /// ```
889 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
890 Tensor::new(K::lower_elem(self.primitive, other.elem()))
891 }
892
893 /// Applies lower-equal than `other` comparison and returns a boolean tensor.
894 ///
895 /// # Arguments
896 ///
897 /// * `other` - The element to compare.
898 ///
899 /// # Example
900 ///
901 /// ```rust
902 /// use burn_tensor::backend::Backend;
903 /// use burn_tensor::{Tensor, Shape};
904 ///
905 /// fn example<B: Backend>() {
906 /// let device = B::Device::default();
907 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
908 /// let tensor = tensor.lower_equal_elem(3.0);
909 /// println!("{tensor}");
910 /// // [[true, true, true], [false, false, false]]
911 /// }
912 /// ```
913 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
914 Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
915 }
916
917 /// Applies the argmax function along the given dimension and returns an integer tensor.
918 ///
919 /// # Example
920 ///
921 /// ```rust
922 /// use burn_tensor::backend::Backend;
923 /// use burn_tensor::{Tensor, Shape};
924 ///
925 /// fn example<B: Backend>() {
926 /// let device = B::Device::default();
927 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
928 /// let tensor = tensor.argmax(1);
929 /// println!("{:?}", tensor.shape());
930 /// // Shape { dims: [2, 1, 3] }
931 /// }
932 /// ```
933 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
934 Tensor::new(K::argmax(self.primitive, dim))
935 }
936
937 /// Find the maximum value.
938 ///
939 /// # Example
940 ///
941 /// ```rust
942 /// use burn_tensor::backend::Backend;
943 /// use burn_tensor::{Tensor, Shape};
944 ///
945 /// fn example<B: Backend>() {
946 /// let device = B::Device::default();
947 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
948 /// let tensor = tensor.max();
949 /// println!("{tensor}");
950 /// // [9.0]
951 /// }
952 /// ```
953 pub fn max(self) -> Tensor<B, 1, K> {
954 Tensor::new(K::max(self.primitive))
955 }
956
957 /// Find the maximum value along the given dimension.
958 ///
959 /// # Arguments
960 ///
961 /// * `dim` - The dimension or axis along which to aggregate the elements;
962 /// supports negative indexing.
963 ///
964 /// # Returns
965 ///
966 /// The returned tensor will have the same rank,
967 /// but the aggregated dimension will have size 1.
968 ///
969 /// # Example
970 ///
971 /// ```rust
972 /// use burn_tensor::backend::Backend;
973 /// use burn_tensor::{Tensor, Shape};
974 ///
975 /// fn example<B: Backend>() {
976 /// let device = B::Device::default();
977 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
978 /// let tensor = tensor.max_dim(0);
979 /// println!("{tensor}");
980 /// // [[5.0, 9.0, 6.0]]
981 /// }
982 /// ```
983 pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
984 let dim = dim.expect_dim_index(D);
985 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
986 Tensor::new(K::max_dim(self.primitive, dim))
987 }
988
989 /// Find the maximum value along the given dimensions.
990 ///
991 /// # Arguments
992 ///
993 /// * `dims` - The dimensions or axis along which to aggregate the elements;
994 /// supports negative indexing.
995 ///
996 /// # Returns
997 ///
998 /// The returned tensor will have the same rank,
999 /// but the aggregated dimensions will have size 1.
1000 ///
1001 /// # Example
1002 ///
1003 /// ```rust
1004 /// use burn_tensor::backend::Backend;
1005 /// use burn_tensor::{Tensor, Shape};
1006 ///
1007 /// fn example<B: Backend>() {
1008 /// let device = B::Device::default();
1009 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1010 /// let tensor = tensor.max_dims(&[0, 1]);
1011 /// println!("{tensor}");
1012 /// // [[9.0]]
1013 /// }
1014 /// ```
1015 pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1016 dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1017 }
1018
1019 /// Find the maximum value along the given dimension.
1020 ///
1021 /// Also returns the indices.
1022 ///
1023 /// # Example
1024 ///
1025 /// ```rust
1026 /// use burn_tensor::backend::Backend;
1027 /// use burn_tensor::{Tensor, Shape};
1028 ///
1029 /// fn example<B: Backend>() {
1030 /// let device = B::Device::default();
1031 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1032 /// let (tensor, index) = tensor.max_dim_with_indices(0);
1033 /// // [[5.0, 9.0, 6.0]]
1034 /// println!("{tensor}");
1035 /// // [[1, 1, 1]]
1036 /// println!("{index}");
1037 /// }
1038 /// ```
1039 pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1040 let dim = dim.expect_dim_index(D);
1041 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1042
1043 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1044
1045 let tensor = Tensor::new(tensor);
1046 let index = Tensor::new(index);
1047
1048 (tensor, index)
1049 }
1050
1051 /// Finds the maximum pair wise values with another tensor.
1052 ///
1053 /// # Arguments
1054 ///
1055 /// * `other` - Other tensor to find maximum elements with
1056 ///
1057 /// # Returns
1058 ///
1059 /// A tensor with the same shape as the input tensors containing the maximum value found
1060 /// in the input tensors.
1061 ///
1062 /// # Example
1063 ///
1064 /// ```rust
1065 /// use burn_tensor::backend::Backend;
1066 /// use burn_tensor::{Tensor, Shape};
1067 ///
1068 /// fn example<B: Backend>() {
1069 /// let device = B::Device::default();
1070 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1071 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1072 /// let tensor = tensor1.max_pair(tensor2);
1073 /// println!("{tensor}");
1074 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1075 /// }
1076 /// ```
1077 pub fn max_pair(self, other: Self) -> Self {
1078 let mask = self.clone().lower(other.clone());
1079 self.mask_where(mask, other)
1080 }
1081
1082 /// Find the maximum absolute value.
1083 ///
1084 /// # Example
1085 ///
1086 /// ```rust
1087 /// use burn_tensor::backend::Backend;
1088 /// use burn_tensor::{Tensor, Shape};
1089 ///
1090 /// fn example<B: Backend>() {
1091 /// let device = B::Device::default();
1092 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1093 /// let tensor = tensor.max_abs();
1094 /// println!("{tensor}");
1095 /// // [7.0]
1096 /// }
1097 /// ```
1098 pub fn max_abs(self) -> Tensor<B, 1, K> {
1099 Tensor::new(K::max_abs(self.primitive))
1100 }
1101
1102 /// Find the maximum absolute value along the given dimension.
1103 ///
1104 /// # Arguments
1105 ///
1106 /// * `dim` - The dimension or axis along which to aggregate the elements,
1107 /// supports negative indexing.
1108 ///
1109 /// # Returns
1110 ///
1111 /// The returned tensor will have the same rank,
1112 /// but the aggregated dimension will have size 1.
1113 ///
1114 /// # Example
1115 ///
1116 /// ```rust
1117 /// use burn_tensor::backend::Backend;
1118 /// use burn_tensor::{Tensor, Shape};
1119 ///
1120 /// fn example<B: Backend>() {
1121 /// let device = B::Device::default();
1122 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1123 /// let tensor = tensor.max_dim(0);
1124 /// println!("{tensor}");
1125 /// // [[5.0, 9.0, 6.0]]
1126 /// }
1127 /// ```
1128 pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
1129 let dim = dim.expect_dim_index(D);
1130 check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1131
1132 Tensor::new(K::max_abs_dim(self.primitive, dim))
1133 }
1134
1135 /// Find the maximum absolute value along the given dimensions.
1136 ///
1137 /// # Arguments
1138 ///
1139 /// * `dims` - The dimensions or axes along which to aggregate the elements,
1140 /// supports negative indexing.
1141 ///
1142 /// # Returns
1143 ///
1144 /// The returned tensor will have the same rank,
1145 /// but the aggregated dimensions will have size 1.
1146 ///
1147 /// # Example
1148 ///
1149 /// ```rust
1150 /// use burn_tensor::backend::Backend;
1151 /// use burn_tensor::{Tensor, Shape};
1152 ///
1153 /// fn example<B: Backend>() {
1154 /// let device = B::Device::default();
1155 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1156 /// let tensor = tensor.max_abs_dims(&[0, 1]);
1157 /// println!("{tensor}");
1158 /// // [[9.0]]
1159 /// }
1160 /// ```
1161 pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1162 dims.iter()
1163 .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
1164 }
1165
1166 /// Applies the argmin function along the given dimension and returns an integer tensor.
1167 ///
1168 /// # Example
1169 ///
1170 /// ```rust
1171 /// use burn_tensor::backend::Backend;
1172 /// use burn_tensor::{Tensor, Shape};
1173 ///
1174 /// fn example<B: Backend>() {
1175 /// let device = Default::default();
1176 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1177 /// let tensor = tensor.argmin(1);
1178 /// println!("{:?}", tensor.shape());
1179 /// // Shape { dims: [2, 1, 3] }
1180 /// }
1181 /// ```
1182 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1183 Tensor::new(K::argmin(self.primitive, dim))
1184 }
1185
1186 /// Find the minimum value.
1187 ///
1188 /// # Example
1189 ///
1190 /// ```rust
1191 /// use burn_tensor::backend::Backend;
1192 /// use burn_tensor::{Tensor, Shape};
1193 ///
1194 /// fn example<B: Backend>() {
1195 /// let device = B::Device::default();
1196 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1197 /// let tensor = tensor.min();
1198 /// println!("{tensor}");
1199 /// // [-2.0]
1200 /// }
1201 /// ```
1202 pub fn min(self) -> Tensor<B, 1, K> {
1203 Tensor::new(K::min(self.primitive))
1204 }
1205
1206 /// Find the minimum value along the given dimension.
1207 ///
1208 /// # Arguments
1209 ///
1210 /// * `dim` - The dimension or axis along which to aggregate the elements;
1211 /// supports negative indexing.
1212 ///
1213 /// # Returns
1214 ///
1215 /// The returned tensor will have the same rank,
1216 /// but the aggregated dimension will have size 1.
1217 ///
1218 /// # Example
1219 ///
1220 /// ```rust
1221 /// use burn_tensor::backend::Backend;
1222 /// use burn_tensor::{Tensor, Shape};
1223 ///
1224 /// fn example<B: Backend>() {
1225 /// let device = B::Device::default();
1226 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1227 /// let tensor = tensor.min_dim(0);
1228 /// println!("{tensor}");
1229 /// // [[1.0, -2.0, 3.0]]
1230 /// }
1231 /// ```
1232 pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
1233 let dim = dim.expect_dim_index(D);
1234 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1235 Tensor::new(K::min_dim(self.primitive, dim))
1236 }
1237
1238 /// Find the minimum value along the given dimensions.
1239 ///
1240 /// # Arguments
1241 ///
1242 /// * `dims` - The dimensions or axes along which to aggregate the elements;
1243 /// supports negative indexing.
1244 ///
1245 /// # Returns
1246 ///
1247 /// The returned tensor will have the same rank,
1248 /// but the aggregated dimensions will have size 1.
1249 ///
1250 /// # Example
1251 ///
1252 /// ```rust
1253 /// use burn_tensor::backend::Backend;
1254 /// use burn_tensor::{Tensor, Shape};
1255 ///
1256 /// fn example<B: Backend>() {
1257 /// let device = B::Device::default();
1258 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1259 /// let tensor = tensor.min_dims(&[0, 1]);
1260 /// println!("{tensor}");
1261 /// // [[-2.0]]
1262 /// }
1263 /// ```
1264 pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1265 dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
1266 }
1267
1268 /// Find the minimum value along the given dimension.
1269 ///
1270 /// Also returns the indices.
1271 ///
1272 /// # Example
1273 ///
1274 /// ```rust
1275 /// use burn_tensor::backend::Backend;
1276 /// use burn_tensor::{Tensor, Shape};
1277 ///
1278 /// fn example<B: Backend>() {
1279 /// let device = B::Device::default();
1280 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1281 /// let (tensor, index) = tensor.min_dim_with_indices(0);
1282 /// println!("{tensor}");
1283 /// // [[5.0, -2.0, 3.0]]
1284 /// println!("{}", index);
1285 /// // [[1, 0, 0]]
1286 /// }
1287 /// ```
1288 pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1289 let dim = dim.expect_dim_index(D);
1290 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1291
1292 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1293
1294 let tensor = Tensor::new(tensor);
1295 let index = Tensor::new(index);
1296
1297 (tensor, index)
1298 }
1299
1300 /// Finds the minimum pair wise values with another tensor.
1301 ///
1302 /// # Arguments
1303 ///
1304 /// * `other` - Other tensor to find minimum elements with
1305 ///
1306 /// # Returns
1307 ///
1308 /// A tensor with the same shape as the input tensors containing the minimum value found
1309 /// between each element of the two source tensors.
1310 ///
1311 /// # Example
1312 ///
1313 /// ```rust
1314 /// use burn_tensor::backend::Backend;
1315 /// use burn_tensor::{Tensor, Shape};
1316 ///
1317 /// fn example<B: Backend>() {
1318 /// let device = B::Device::default();
1319 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1320 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1321 /// let tensor = tensor1.min_pair(tensor2);
1322 /// println!("{tensor}");
1323 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1324 /// }
1325 pub fn min_pair(self, other: Self) -> Self {
1326 let mask = other.clone().lower(self.clone());
1327 self.mask_where(mask, other)
1328 }
1329
1330 /// Clamp element wise between the given min and max values.
1331 ///
1332 /// # Arguments
1333 ///
1334 /// * `min` - The minimum value.
1335 /// * `max` - The maximum value.
1336 ///
1337 /// # Returns
1338 ///
1339 /// A new tensor with the values clamped between the given min and max values.
1340 ///
1341 /// # Example
1342 ///
1343 /// ```rust
1344 /// use burn_tensor::backend::Backend;
1345 /// use burn_tensor::{Int, Tensor};
1346 ///
1347 /// fn example<B: Backend>() {
1348 /// let device = Default::default();
1349 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1350 /// [
1351 /// [1, 2, 3],
1352 /// [4, 5, 6],
1353 /// [7, 8, 9]
1354 /// ],
1355 /// &device);
1356 /// let tensor = tensor.clamp(2, 6);
1357 /// println!("{tensor}");
1358 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1359 /// }
1360 /// ```
1361 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1362 Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1363 }
1364
1365 /// Clamp element wise under a minimum value.
1366 ///
1367 /// # Arguments
1368 ///
1369 /// * `tensor` - The tensor to clamp.
1370 /// * `min` - The minimum value.
1371 ///
1372 /// # Returns
1373 ///
1374 /// A new tensor with the values clamped under the given min value.
1375 ///
1376 /// # Example
1377 ///
1378 /// ```rust
1379 /// use burn_tensor::backend::Backend;
1380 /// use burn_tensor::{Int, Tensor};
1381 ///
1382 /// fn example<B: Backend>() {
1383 /// let device = Default::default();
1384 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1385 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1386 /// &device);
1387 /// let tensor = tensor.clamp_min(4);
1388 /// println!("{tensor}");
1389 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1390 /// }
1391 /// ```
1392 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1393 Self::new(K::clamp_min(self.primitive, min.elem()))
1394 }
1395
1396 /// Clamp element wise over a maximum value.
1397 ///
1398 /// # Arguments
1399 ///
1400 /// * `tensor` - The tensor to clamp.
1401 /// * `max` - The maximum value.
1402 ///
1403 /// # Returns
1404 ///
1405 /// A new tensor with the values clamped over the given max value.
1406 ///
1407 /// # Example
1408 ///
1409 /// ```rust
1410 /// use burn_tensor::backend::Backend;
1411 /// use burn_tensor::{Int, Tensor};
1412 ///
1413 /// fn example<B: Backend>() {
1414 /// let device = Default::default();
1415 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1416 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1417 /// &device);
1418 /// let tensor = tensor.clamp_max(5);
1419 /// println!("{tensor}");
1420 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1421 /// }
1422 /// ```
1423 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1424 Self::new(K::clamp_max(self.primitive, max.elem()))
1425 }
1426
1427 /// Apply element wise absolute value operation.
1428 ///
1429 /// # Example
1430 ///
1431 /// ```rust
1432 /// use burn_tensor::backend::Backend;
1433 /// use burn_tensor::{Int, Tensor};
1434 ///
1435 /// fn example<B: Backend>() {
1436 /// let device = Default::default();
1437 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1438 /// let tensor = tensor.abs();
1439 /// println!("{tensor}");
1440 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1441 /// }
1442 /// ```
1443 pub fn abs(self) -> Self {
1444 Self::new(K::abs(self.primitive))
1445 }
1446
1447 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1448 /// the other elements of the result tensor out are set to 0.
1449 ///
1450 /// See also [`triu_mask`](Tensor::triu_mask).
1451 ///
1452 /// # Arguments
1453 ///
1454 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1455 /// towards the upper triangle.
1456 ///
1457 /// # Example
1458 /// ```rust
1459 /// use burn_tensor::backend::Backend;
1460 /// use burn_tensor::{Int, Tensor};
1461 ///
1462 /// fn example<B: Backend>() {
1463 /// let device = Default::default();
1464 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1465 /// [
1466 /// [1, 2, 3],
1467 /// [4, 5, 6],
1468 /// [7, 8, 9]
1469 /// ],
1470 /// &device
1471 /// );
1472 /// let tensor = tensor.triu(1);
1473 /// println!("{tensor}");
1474 /// // [
1475 /// // [0, 2, 3],
1476 /// // [0, 0, 6],
1477 /// // [0, 0, 0]
1478 /// // ]
1479 /// }
1480 /// ```
1481 pub fn triu(self, diagonal: i64) -> Self {
1482 check!(TensorCheck::tri::<{ D }>());
1483
1484 // last two dimensions
1485 let shape = &self.shape().dims[D - 2..].to_owned();
1486
1487 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1488 self.mask_fill(mask, 0)
1489 }
1490
1491 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1492 /// the other elements of the result tensor out are set to 0.
1493 ///
1494 /// See also [`tril_mask`](Tensor::tril_mask).
1495 ///
1496 /// # Arguments
1497 ///
1498 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1499 /// towards the upper triangle.
1500 ///
1501 /// # Example
1502 /// ```rust
1503 /// use burn_tensor::backend::Backend;
1504 /// use burn_tensor::{Int, Tensor};
1505 ///
1506 /// fn example<B: Backend>() {
1507 /// let device = Default::default();
1508 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1509 /// [
1510 /// [1, 2, 3],
1511 /// [4, 5, 6],
1512 /// [7, 8, 9]
1513 /// ],
1514 /// &device
1515 /// );
1516 ///
1517 /// let tensor = tensor.tril(-1);
1518 /// println!("{tensor}");
1519 /// // [
1520 /// // [0, 0, 0],
1521 /// // [4, 0, 0],
1522 /// // [7, 8, 0]
1523 /// // ]
1524 /// }
1525 /// ```
1526 pub fn tril(self, diagonal: i64) -> Self {
1527 check!(TensorCheck::tri::<{ D }>());
1528
1529 // last two dimensions
1530 let shape = &self.shape().dims[D - 2..].to_owned();
1531 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1532
1533 self.mask_fill(mask, 0)
1534 }
1535
1536 /// Applies element wise power operation with a float Tensor
1537 ///
1538 /// # Arguments
1539 ///
1540 /// * `other` - The tensor to apply the power operation with.
1541 ///
1542 /// # Example
1543 ///
1544 /// ```rust
1545 /// use burn_tensor::backend::Backend;
1546 /// use burn_tensor::{Tensor, Shape};
1547 ///
1548 /// fn example<B: Backend>() {
1549 /// let device = B::Device::default();
1550 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1551 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1552 /// let tensor = tensor1.powf(tensor2);
1553 /// println!("{tensor}");
1554 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1555 /// }
1556 /// ```
1557 pub fn powf(self, other: Self) -> Self {
1558 Self::new(K::powf(self.primitive, other.primitive))
1559 }
1560
1561 /// Applies element wise power operation with a float scalar
1562 ///
1563 /// # Arguments
1564 ///
1565 /// * `other` - The scalar to apply the power operation with.
1566 ///
1567 /// # Example
1568 ///
1569 /// ```rust
1570 /// use burn_tensor::backend::Backend;
1571 /// use burn_tensor::{Tensor, Shape};
1572 ///
1573 /// fn example<B: Backend>() {
1574 /// let device = B::Device::default();
1575 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1576 /// let tensor = tensor.powf_scalar(2.0);
1577 /// println!("{tensor}");
1578 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1579 /// }
1580 /// ```
1581 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1582 Self::new(K::powf_scalar::<E>(self.primitive, other))
1583 }
1584
1585 /// Applies element wise power operation with a integer Tensor
1586 ///
1587 /// # Arguments
1588 ///
1589 /// * `other` - The tensor to apply the power operation with.
1590 ///
1591 /// # Example
1592 ///
1593 /// ```rust
1594 /// use burn_tensor::backend::Backend;
1595 /// use burn_tensor::{Tensor, Shape, Int};
1596 ///
1597 /// fn example<B: Backend>() {
1598 /// let device = B::Device::default();
1599 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1600 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1601 /// let tensor = tensor1.powi(tensor2);
1602 /// println!("{tensor}");
1603 /// // [[1, -8, 81], [5, 81, 216]]
1604 /// }
1605 /// ```
1606 pub fn powi(self, other: Self) -> Self {
1607 Self::new(K::powi(self.primitive, other.primitive))
1608 }
1609
1610 /// Applies element wise power operation with a integer scalar
1611 ///
1612 /// # Arguments
1613 ///
1614 /// * `other` - The scalar to apply the power operation with.
1615 ///
1616 /// # Example
1617 ///
1618 /// ```rust
1619 /// use burn_tensor::backend::Backend;
1620 /// use burn_tensor::{Tensor, Shape, Int};
1621 ///
1622 /// fn example<B: Backend>() {
1623 /// let device = B::Device::default();
1624 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1625 /// let tensor = tensor.powi_scalar(2);
1626 /// println!("{tensor}");
1627 ///
1628 /// // [[1, 4, 9], [25, 81, 36]]
1629 /// let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1630 /// let tensor = tensor.powi_scalar(2);
1631 /// println!("{tensor}");
1632 /// // [[2.25, 4., 9.], [25., 81., 36.]]
1633 /// }
1634 /// ```
1635 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1636 Self::new(K::powi_scalar::<E>(self.primitive, other))
1637 }
1638
1639 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1640 ///
1641 /// # Returns
1642 ///
1643 /// A boolean tensor with the same shape as the input tensor.
1644 ///
1645 /// # Example
1646 ///
1647 /// ```rust
1648 /// use burn_tensor::backend::Backend;
1649 /// use burn_tensor::{Tensor, Shape};
1650 ///
1651 /// fn example<B: Backend>() {
1652 /// let device = B::Device::default();
1653 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1654 /// let tensor = tensor.bool();
1655 /// println!("{tensor}");
1656 /// // [
1657 /// // [true, true, true],
1658 /// // [false, true, true]
1659 /// // ]
1660 /// }
1661 pub fn bool(self) -> Tensor<B, D, Bool> {
1662 Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1663 }
1664
1665 /// Create a random tensor of the given shape on the given device where each element is
1666 /// sampled from the given distribution.
1667 ///
1668 /// See also [`random_like`](Tensor::random_like).
1669 ///
1670 /// # Arguments
1671 ///
1672 /// * `shape` - The shape of the tensor.
1673 /// * `distribution` - The distribution to sample from.
1674 /// * `device` - The device to create the tensor on.
1675 ///
1676 /// # Returns
1677 ///
1678 /// A new tensor with the given shape and elements sampled from the given distribution.
1679 ///
1680 /// # Example
1681 ///
1682 /// ```rust
1683 /// use burn_tensor::backend::Backend;
1684 /// use burn_tensor::{Tensor, Shape, Distribution};
1685 ///
1686 /// fn example<B: Backend>() {
1687 /// let device = B::Device::default();
1688 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1689 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1690 /// println!("{tensor}");
1691 /// // [
1692 /// // [0.08347523, 0.70498955, 0.60332155],
1693 /// // [0.08173251, 0.18028641, 0.97942924]
1694 /// // ]
1695 /// }
1696 /// ```
1697 pub fn random<S: Into<Shape>>(
1698 shape: S,
1699 distribution: Distribution,
1700 device: &B::Device,
1701 ) -> Self {
1702 Self::new(K::random(shape.into(), distribution, device))
1703 }
1704
1705 /// Sort the elements by value in ascending order along a given dimension.
1706 ///
1707 /// This sort is unstable (i.e., may reorder equal elements).
1708 ///
1709 /// # Arguments
1710 ///
1711 /// * `dim` - The dimension to sort along.
1712 ///
1713 /// # Returns
1714 ///
1715 /// A new tensor with the elements sorted in ascending order along the given dimension.
1716 ///
1717 /// # Example
1718 ///
1719 /// ```rust
1720 /// use burn_tensor::backend::Backend;
1721 /// use burn_tensor::{Tensor, Shape};
1722 ///
1723 /// fn example<B: Backend>() {
1724 /// let device = B::Device::default();
1725 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1726 /// let tensor = tensor.sort(0);
1727 /// println!("{tensor}");
1728 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1729 /// let tensor = tensor.sort(1);
1730 /// println!("{tensor}");
1731 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1732 /// }
1733 /// ```
1734 pub fn sort(self, dim: usize) -> Self {
1735 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1736 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1737 }
1738
1739 /// Sort the elements by value in descending order along a given dimension.
1740 ///
1741 /// This sort is unstable (i.e., may reorder equal elements).
1742 ///
1743 /// # Arguments
1744 ///
1745 /// * `dim` - The dimension to sort along.
1746 ///
1747 /// # Returns
1748 ///
1749 /// A new tensor with the elements sorted in descending order along the given dimension.
1750 ///
1751 /// # Example
1752 ///
1753 /// ```rust
1754 /// use burn_tensor::backend::Backend;
1755 /// use burn_tensor::{Tensor, Shape};
1756 ///
1757 /// fn example<B: Backend>() {
1758 /// let device = B::Device::default();
1759 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1760 /// let tensor = tensor.sort_descending(0);
1761 /// println!("{tensor}");
1762 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1763 /// let tensor = tensor.sort_descending(1);
1764 /// println!("{tensor}");
1765 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1766 /// }
1767 /// ```
1768 pub fn sort_descending(self, dim: usize) -> Self {
1769 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1770 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1771 }
1772
1773 /// Sort the elements by value in ascending order along a given dimension.
1774 /// Also returns the indices.
1775 ///
1776 /// This sort is unstable (i.e., may reorder equal elements).
1777 ///
1778 /// # Arguments
1779 ///
1780 /// * `dim` - The dimension to sort along.
1781 ///
1782 /// # Returns
1783 ///
1784 /// A tuple containing the sorted tensor and the indices tensor.
1785 ///
1786 /// # Example
1787 ///
1788 /// ```rust
1789 /// use burn_tensor::backend::Backend;
1790 /// use burn_tensor::{Tensor, Shape};
1791 ///
1792 /// fn example<B: Backend>() {
1793 /// let device = B::Device::default();
1794 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1795 /// let (tensor, indices) = tensor.sort_with_indices(0);
1796 /// println!("{tensor}");
1797 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1798 /// println!("{}", indices);
1799 /// // [[1, 0, 0], [0, 1, 1]]
1800 /// }
1801 /// ```
1802 pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
1803 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1804 let (values, indices) =
1805 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1806 (Tensor::new(values), Tensor::new(indices))
1807 }
1808
1809 /// Sort the elements by value in descending order along a given dimension.
1810 /// Also returns the indices.
1811 ///
1812 /// This sort is unstable (i.e., may reorder equal elements).
1813 ///
1814 /// # Arguments
1815 ///
1816 /// * `dim` - The dimension to sort along.
1817 ///
1818 /// # Example
1819 ///
1820 /// ```rust
1821 /// use burn_tensor::backend::Backend;
1822 /// use burn_tensor::{Tensor, Shape};
1823 ///
1824 /// fn example<B: Backend>() {
1825 /// let device = B::Device::default();
1826 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1827 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
1828 /// println!("{tensor}");
1829 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1830 /// println!("{}", indices);
1831 /// // [[0, 1, 1], [1, 0, 0]]
1832 /// }
1833 /// ```
1834 pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
1835 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1836 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1837 (Tensor::new(values), Tensor::new(indices))
1838 }
1839
1840 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1841 ///
1842 /// This sort is unstable (i.e., may reorder equal elements).
1843 ///
1844 /// # Arguments
1845 ///
1846 /// * `dim` - The dimension to sort along.
1847 ///
1848 /// # Example
1849 ///
1850 /// ```rust
1851 /// use burn_tensor::backend::Backend;
1852 /// use burn_tensor::{Tensor, Shape};
1853 ///
1854 /// fn example<B: Backend>() {
1855 /// let device = B::Device::default();
1856 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1857 /// let tensor = tensor.argsort(0);
1858 /// println!("{tensor}");
1859 /// // [[1, 0, 0], [0, 1, 1]]
1860 /// }
1861 /// ```
1862 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1863 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1864 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1865 }
1866
1867 /// Returns the indices that sort the elements by value in descending order along a given dimension.
1868 ///
1869 /// This sort is unstable (i.e., may reorder equal elements).
1870 ///
1871 /// # Arguments
1872 ///
1873 /// * `dim` - The dimension to sort along.
1874 ///
1875 /// # Example
1876 ///
1877 /// ```rust
1878 /// use burn_tensor::backend::Backend;
1879 /// use burn_tensor::{Tensor, Shape};
1880 ///
1881 /// fn example<B: Backend>() {
1882 /// let device = B::Device::default();
1883 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1884 /// let tensor = tensor.argsort_descending(0);
1885 /// println!("{tensor}");
1886 /// // [[0, 1, 1], [1, 0, 0]]
1887 /// let tensor = tensor.argsort_descending(1);
1888 /// println!("{tensor}");
1889 /// // [[0, 2, 1], [2, 0, 1]]
1890 /// }
1891 /// ```
1892 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1893 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1894 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1895 }
1896
1897 /// Returns the `k` largest elements of the given input tensor along a given dimension.
1898 ///
1899 /// # Arguments
1900 ///
1901 /// * `k` - The number of elements to return.
1902 ///
1903 /// # Returns
1904 ///
1905 /// A new tensor with the `k` largest elements along the given dimension.
1906 ///
1907 /// # Example
1908 ///
1909 /// ```rust
1910 /// use burn_tensor::backend::Backend;
1911 /// use burn_tensor::{Tensor, Shape};
1912 ///
1913 /// fn example<B: Backend>() {
1914 /// let device = B::Device::default();
1915 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1916 /// let tensor = tensor.topk(2, 0);
1917 /// println!("{tensor}");
1918 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1919 /// let tensor = tensor.topk(1, 1);
1920 /// println!("{tensor}");
1921 /// // [[12.0], [6.0]]
1922 /// }
1923 /// ```
1924 pub fn topk(self, k: usize, dim: usize) -> Self {
1925 let k_indices = Tensor::arange(0..k as i64, &self.device());
1926 self.sort_descending(dim).select(dim, k_indices)
1927 }
1928
1929 /// Returns the `k` largest elements of the given input tensor along a given dimension.
1930 /// Also returns the indices.
1931 ///
1932 /// # Arguments
1933 ///
1934 /// * `k` - The number of elements to return.
1935 /// * `dim` - The dimension to sort along.
1936 ///
1937 /// # Example
1938 ///
1939 /// ```rust
1940 /// use burn_tensor::backend::Backend;
1941 /// use burn_tensor::{Tensor, Shape};
1942 ///
1943 /// fn example<B: Backend>() {
1944 /// let device = B::Device::default();
1945 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1946 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
1947 /// println!("{tensor}");
1948 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1949 /// println!("{}", indices);
1950 /// // [[0, 1, 1], [1, 0, 0]]
1951 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
1952 /// println!("{tensor}");
1953 /// // [[12.0], [6.0]]
1954 /// println!("{indices}");
1955 /// // [[0], [2]]
1956 /// }
1957 /// ```
1958 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
1959 let k_indices = Tensor::arange(0..k as i64, &self.device());
1960 let (values, indices) = self.sort_descending_with_indices(dim);
1961 (
1962 values.select(dim, k_indices.clone()),
1963 indices.select(dim, k_indices),
1964 )
1965 }
1966
1967 /// Create a one hot tensor.
1968 ///
1969 /// # Example
1970 ///
1971 /// ```rust
1972 /// use burn_tensor::backend::Backend;
1973 /// use burn_tensor::Tensor;
1974 ///
1975 /// fn example<B: Backend>(){
1976 /// let device = Default::default();
1977 /// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
1978 /// let one_hot: Tensor<B, 2> = indices.one_hot(4);
1979 /// println!("{}", one_hot.to_data());
1980 /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
1981 /// }
1982 /// ```
1983 pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
1984 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
1985 self.one_hot_fill(num_classes, 1.0, 0.0, -1)
1986 }
1987
1988 /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
1989 ///
1990 /// # Arguments
1991 ///
1992 /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
1993 /// * `on_value`: The value to assign for active positions (corresponding to indices).
1994 /// * `off_value`: The value to assign for inactive positions.
1995 /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
1996 ///
1997 /// # Returns
1998 ///
1999 /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2000 ///
2001 /// # Example
2002 /// ```rust
2003 /// use burn_tensor::backend::Backend;
2004 /// use burn_tensor::{Tensor, Float};
2005 /// fn example<B: Backend<FloatElem: From<f32>>>() {
2006 /// let device = B::Device::default();
2007 /// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2008 /// // One-hot encoding
2009 /// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2010 /// println!("{tensor}");
2011 /// // [[[5.0, 0.0, 0.0],
2012 /// // [0.0, 0.0, 5.0]],
2013 /// // [[0.0, 5.0, 0.0],
2014 /// // [0.0, 0.0, 5.0]]]
2015 /// }
2016 /// ```
2017 pub fn one_hot_fill<const D2: usize>(
2018 self,
2019 num_classes: usize,
2020 on_value: f32,
2021 off_value: f32,
2022 axis: i64,
2023 ) -> Tensor<B, D2, K> {
2024 check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2025 // Initialize shape from the current tensor dimensions and prepare for modification
2026 let mut shape = self.shape();
2027 let device = self.device();
2028 let rank = self.dims().len();
2029
2030 // Adjust negative axis to a positive index
2031 let axis = if axis < 0 {
2032 axis + rank as i64 + 1
2033 } else {
2034 axis
2035 };
2036
2037 // Ensure axis is within valid range
2038 if axis < 0 || axis > rank as i64 {
2039 panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2040 }
2041 // Convert the input tensor to integer indices
2042 let indices: Tensor<B, D, Int> =
2043 Tensor::from_data(self.to_data().convert::<i64>(), &device);
2044 // Insert the new dimension for the one-hot representation
2045 shape.insert(axis as usize, num_classes);
2046 // Adjust indices to valid range and handle invalid indices
2047 let adjusted_indices = indices
2048 .clone()
2049 .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2050 .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2051 // Unsqueeze the indices tensor along the specified axis
2052 let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2053
2054 // Initialize the output tensor with the off_value
2055 let output = Tensor::full(shape.clone(), off_value, &device);
2056
2057 // Prepare scatter tensor for on_value and off_value adjustments
2058 let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2059 - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2060
2061 // Scatter on_value at the appropriate indices to create the one-hot representation
2062 output.scatter(
2063 axis as usize,
2064 indices_unsqueezed,
2065 scatter_on_values,
2066 IndexingUpdateOp::Add,
2067 )
2068 }
2069
2070 /// Applies the matrix multiplication operation.
2071 ///
2072 /// ```math
2073 /// C = AB
2074 /// ```
2075 pub fn matmul(self, other: Self) -> Self {
2076 check!(TensorCheck::matmul(&self, &other));
2077 Tensor::new(K::matmul(self.primitive, other.primitive))
2078 }
2079}
2080
2081impl<B, K> Tensor<B, 1, K>
2082where
2083 B: Backend,
2084 K: Numeric<B>,
2085 K::Elem: Element,
2086{
2087 /// Calculates the dot product with another tensor.
2088 ///
2089 /// `y = x2.dot(x1)`
2090 ///
2091 /// # Arguments
2092 ///
2093 /// * `other` - The tensor to compute dot product with.
2094 ///
2095 /// # Notes
2096 ///
2097 /// Both tensors must have the same number of elements.
2098 ///
2099 /// # Example
2100 ///
2101 /// ```rust
2102 /// use burn_tensor::backend::Backend;
2103 /// use burn_tensor::{Tensor, Shape};
2104 ///
2105 /// fn example<B: Backend>() {
2106 /// let device = B::Device::default();
2107 /// let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
2108 /// let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
2109 /// let tensor = tensor1.dot(tensor2);
2110 /// println!("{tensor}");
2111 /// // [4]
2112 /// }
2113 /// ```
2114 pub fn dot(self, other: Self) -> Self {
2115 self.mul(other).sum()
2116 }
2117}
2118
2119impl<B, K> Tensor<B, 2, K>
2120where
2121 B: Backend,
2122 K: Numeric<B>,
2123 K::Elem: Element,
2124{
2125 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2126 ///
2127 /// # Arguments
2128 ///
2129 /// * `size` - The size of the square matrix.
2130 pub fn eye(size: usize, device: &B::Device) -> Self {
2131 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2132 let ones = Self::ones([1, size], device);
2133 let zeros = Self::zeros([size, size], device);
2134
2135 zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
2136 }
2137}
2138
2139// Tensor + tensor
2140impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
2141where
2142 K::Elem: Element,
2143{
2144 type Output = Self;
2145
2146 fn add(self, rhs: Self) -> Self::Output {
2147 Self::add(self, rhs)
2148 }
2149}
2150
2151// Tensor + scalar
2152impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
2153 for Tensor<B, D, K>
2154where
2155 K::Elem: Element,
2156{
2157 type Output = Self;
2158
2159 fn add(self, other: E) -> Self::Output {
2160 Tensor::add_scalar(self, other)
2161 }
2162}
2163
2164// Scalar + tensor
2165macro_rules! impl_tensor_scalar_add {
2166 ($($t:ty),*) => {
2167 $(
2168 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
2169 where
2170 K::Elem: Element,
2171 {
2172 type Output = Tensor<B, D, K>;
2173
2174 fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
2175 Tensor::add_scalar(tensor, self)
2176 }
2177 }
2178 )*
2179 }
2180}
2181impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
2182
2183// Tensor - tensor
2184impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
2185where
2186 K::Elem: Element,
2187{
2188 type Output = Self;
2189
2190 fn sub(self, rhs: Self) -> Self::Output {
2191 Tensor::sub(self, rhs)
2192 }
2193}
2194
2195// Tensor - scalar
2196impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
2197 for Tensor<B, D, K>
2198where
2199 K::Elem: Element,
2200{
2201 type Output = Self;
2202
2203 fn sub(self, other: E) -> Self::Output {
2204 Tensor::sub_scalar(self, other)
2205 }
2206}
2207
2208// Scalar - tensor
2209macro_rules! impl_tensor_scalar_sub {
2210 ($($t:ty),*) => {
2211 $(
2212 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
2213 where
2214 K::Elem: Element,
2215 {
2216 type Output = Tensor<B, D, K>;
2217
2218 fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
2219 Tensor::add_scalar(Tensor::neg(tensor), self)
2220 }
2221 }
2222 )*
2223 }
2224}
2225impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
2226
2227// Tensor / tensor
2228impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
2229where
2230 K::Elem: Element,
2231{
2232 type Output = Self;
2233
2234 fn div(self, rhs: Self) -> Self::Output {
2235 Tensor::div(self, rhs)
2236 }
2237}
2238
2239// Tensor / scalar
2240impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
2241 for Tensor<B, D, K>
2242where
2243 K::Elem: Element,
2244{
2245 type Output = Self;
2246
2247 fn div(self, other: E) -> Self::Output {
2248 Tensor::div_scalar(self, other)
2249 }
2250}
2251
2252// Scalar / tensor (float only)
2253macro_rules! impl_tensor_scalar_div {
2254 ($($t:ty),*) => {
2255 $(
2256 impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
2257 {
2258 type Output = Tensor<B, D>;
2259
2260 fn div(self, tensor: Tensor<B, D>) -> Self::Output {
2261 tensor.recip().mul_scalar(self)
2262 }
2263 }
2264 )*
2265 }
2266}
2267
2268impl_tensor_scalar_div!(f32, f64);
2269
2270// Tensor % tensor.
2271impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
2272where
2273 K::Elem: Element,
2274{
2275 type Output = Self;
2276
2277 fn rem(self, rhs: Self) -> Self::Output {
2278 Tensor::remainder(self, rhs)
2279 }
2280}
2281
2282// Tensor % scalar.
2283impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
2284 for Tensor<B, D, K>
2285where
2286 K::Elem: Element,
2287{
2288 type Output = Self;
2289
2290 fn rem(self, other: E) -> Self::Output {
2291 Tensor::remainder_scalar(self, other)
2292 }
2293}
2294
2295// Tensor * tensor.
2296impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
2297where
2298 K::Elem: Element,
2299{
2300 type Output = Self;
2301
2302 fn mul(self, rhs: Self) -> Self::Output {
2303 Tensor::mul(self, rhs)
2304 }
2305}
2306
2307// Tensor * scalar.
2308impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
2309 for Tensor<B, D, K>
2310where
2311 K::Elem: Element,
2312{
2313 type Output = Self;
2314
2315 fn mul(self, other: E) -> Self::Output {
2316 Tensor::mul_scalar(self, other)
2317 }
2318}
2319
2320macro_rules! impl_tensor_scalar_mul {
2321 ($($t:ty),*) => {
2322 $(
2323 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
2324 where
2325 K::Elem: Element,
2326 {
2327 type Output = Tensor<B, D, K>;
2328
2329 fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
2330 Tensor::mul_scalar(other, self)
2331 }
2332 }
2333 )*
2334 }
2335}
2336
2337impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
2338
2339impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
2340where
2341 B: Backend,
2342 K: Numeric<B>,
2343 K::Elem: Element,
2344{
2345 type Output = Self;
2346
2347 fn neg(self) -> Self::Output {
2348 Tensor::neg(self)
2349 }
2350}