burn_tensor/tensor/api/numeric.rs
1use burn_backend::Scalar;
2pub use burn_backend::tensor::Numeric;
3
4use crate::alloc::borrow::ToOwned;
5use alloc::vec::Vec;
6
7use crate::IndexingUpdateOp;
8use crate::{
9 AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
10 check, check::TensorCheck,
11};
12
13impl<B, const D: usize, K> Tensor<B, D, K>
14where
15 B: Backend,
16 K: Numeric<B>,
17 K::Elem: Element,
18{
19 /// Applies element wise addition operation.
20 ///
21 /// `y = x2 + x1`
22 ///
23 /// # Arguments
24 ///
25 /// * `other` - The tensor to add.
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::{Tensor, Shape};
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
36 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
37 /// let tensor = tensor1 + tensor2;
38 /// println!("{tensor}");
39 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
40 /// }
41 /// ```
42 #[allow(clippy::should_implement_trait)]
43 pub fn add(self, other: Self) -> Self {
44 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
45 Self::new(K::add(self.primitive, other.primitive))
46 }
47
48 /// Applies element wise addition operation with a scalar.
49 ///
50 /// `y = x + s`
51 ///
52 /// # Arguments
53 ///
54 /// * `other` - The scalar to add, element wise.
55 ///
56 /// # Example
57 ///
58 /// ```rust
59 /// use burn_tensor::backend::Backend;
60 /// use burn_tensor::{Tensor, Shape};
61 ///
62 /// fn example<B: Backend>() {
63 /// let device = B::Device::default();
64 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
65 /// let scalar = 2.0;
66 /// let tensor = tensor + scalar;
67 /// println!("{tensor}");
68 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
69 /// }
70 /// ```
71 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
72 let other = Scalar::new(other, &self.dtype());
73 Self::new(K::add_scalar(self.primitive, other))
74 }
75
76 /// Applies element wise subtraction operation.
77 ///
78 /// `y = x2 - x1`
79 ///
80 /// # Arguments
81 ///
82 /// * `other` - The tensor to subtract.
83 ///
84 /// # Example
85 ///
86 /// ```rust
87 /// use burn_tensor::backend::Backend;
88 /// use burn_tensor::{Tensor, Shape};
89 ///
90 /// fn example<B: Backend>() {
91 /// let device = B::Device::default();
92 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
93 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
94 /// let tensor = tensor1 - tensor2;
95 /// println!("{tensor}");
96 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
97 /// }
98 /// ```
99 #[allow(clippy::should_implement_trait)]
100 pub fn sub(self, other: Self) -> Self {
101 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
102 Self::new(K::sub(self.primitive, other.primitive))
103 }
104
105 /// Applies element wise subtraction operation with a scalar.
106 ///
107 /// `y = x - s`
108 ///
109 /// # Arguments
110 ///
111 /// * `other` - The scalar to subtract, element wise.
112 ///
113 /// # Example
114 ///
115 /// ```rust
116 /// use burn_tensor::backend::Backend;
117 /// use burn_tensor::{Tensor, Shape};
118 ///
119 /// fn example<B: Backend>() {
120 /// let device = B::Device::default();
121 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
122 /// let scalar = 2.0;
123 /// let tensor = tensor - scalar;
124 /// println!("{tensor}");
125 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
126 /// }
127 /// ```
128 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
129 let other = Scalar::new(other, &self.dtype());
130 Self::new(K::sub_scalar(self.primitive, other))
131 }
132
133 /// Applies element wise division operation.
134 ///
135 /// `y = x2 / x1`
136 ///
137 /// # Arguments
138 ///
139 /// * `other` - The tensor to divide.
140 ///
141 /// # Example
142 ///
143 /// ```rust
144 /// use burn_tensor::backend::Backend;
145 /// use burn_tensor::{Tensor, Shape};
146 ///
147 /// fn example<B: Backend>() {
148 /// let device = B::Device::default();
149 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151 /// let tensor = tensor1 / tensor2;
152 /// println!("{tensor}");
153 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154 /// }
155 /// ```
156 #[allow(clippy::should_implement_trait)]
157 pub fn div(self, other: Self) -> Self {
158 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159 Self::new(K::div(self.primitive, other.primitive))
160 }
161
162 /// Applies element wise division operation with a scalar.
163 ///
164 /// `y = x / s`
165 ///
166 /// # Arguments
167 ///
168 /// * `other` - The scalar to divide, element wise.
169 ///
170 /// # Example
171 ///
172 /// ```rust
173 /// use burn_tensor::backend::Backend;
174 /// use burn_tensor::{Tensor, Shape};
175 ///
176 /// fn example<B: Backend>() {
177 /// let device = B::Device::default();
178 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179 /// let scalar = 2.0;
180 /// let tensor = tensor / scalar;
181 /// println!("{tensor}");
182 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183 /// }
184 /// ```
185 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186 let other = Scalar::new(other, &self.dtype());
187 Self::new(K::div_scalar(self.primitive, other))
188 }
189
190 /// Applies element wise the remainder operation with a scalar.
191 ///
192 /// `y = x2 % x1`
193 pub fn remainder(self, other: Self) -> Self {
194 Self::new(K::remainder(self.primitive, other.primitive))
195 }
196
197 /// Applies element wise the remainder operation with a scalar.
198 ///
199 /// `y = x % s`
200 ///
201 /// # Arguments
202 ///
203 /// * `other` - The scalar to divide, element wise.
204 ///
205 /// # Example
206 ///
207 /// ```rust
208 /// use burn_tensor::backend::Backend;
209 /// use burn_tensor::{Tensor, Shape};
210 ///
211 /// fn example<B: Backend>() {
212 /// let device = B::Device::default();
213 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
214 /// let scalar = 2.0;
215 /// let tensor = tensor1 % scalar;
216 /// println!("{tensor}");
217 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
218 /// }
219 /// ```
220 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
221 let other = Scalar::new(other, &self.dtype());
222 Self::new(K::remainder_scalar(self.primitive, other))
223 }
224
225 /// Applies element wise multiplication operation.
226 ///
227 /// `y = x2 * x1`
228 ///
229 /// # Arguments
230 ///
231 /// * `other` - The tensor to multiply.
232 ///
233 /// # Example
234 ///
235 /// ```rust
236 /// use burn_tensor::backend::Backend;
237 /// use burn_tensor::{Tensor, Shape};
238 ///
239 /// fn example<B: Backend>() {
240 /// let device = B::Device::default();
241 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
243 /// let tensor = tensor1 * tensor2;
244 /// println!("{tensor}");
245 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
246 /// }
247 /// ```
248 #[allow(clippy::should_implement_trait)]
249 pub fn mul(self, other: Self) -> Self {
250 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
251 Self::new(K::mul(self.primitive, other.primitive))
252 }
253
254 /// Applies element wise multiplication operation with a scalar.
255 ///
256 /// `y = x * s`
257 ///
258 /// # Arguments
259 ///
260 /// * `other` - The scalar to multiply, element wise.
261 ///
262 /// # Example
263 ///
264 /// ```rust
265 /// use burn_tensor::backend::Backend;
266 /// use burn_tensor::{Tensor, Shape};
267 ///
268 /// fn example<B: Backend>() {
269 /// let device = B::Device::default();
270 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271 /// let scalar = 2.0;
272 /// let tensor = tensor * scalar;
273 /// println!("{tensor}");
274 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
275 /// }
276 /// ```
277 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
278 let other = Scalar::new(other, &self.dtype());
279 Self::new(K::mul_scalar(self.primitive, other))
280 }
281
282 /// Switch sign of each element in the tensor.
283 ///
284 /// `y = -x`
285 ///
286 /// # Example
287 ///
288 /// ```rust
289 /// use burn_tensor::backend::Backend;
290 /// use burn_tensor::{Tensor, Shape};
291 ///
292 /// fn example<B: Backend>() {
293 /// let device = B::Device::default();
294 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
295 /// let tensor = -tensor;
296 /// println!("{tensor}");
297 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
298 /// }
299 /// ```
300 #[allow(clippy::should_implement_trait)]
301 pub fn neg(self) -> Self {
302 Self::new(K::neg(self.primitive))
303 }
304
305 /// Returns the signs of the elements of the input tensor.
306 ///
307 /// # Example
308 ///
309 /// ```rust
310 /// use burn_tensor::backend::Backend;
311 /// use burn_tensor::{Tensor, Shape};
312 ///
313 /// fn example<B: Backend>() {
314 /// let device = B::Device::default();
315 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
316 /// let tensor = tensor.sign();
317 /// println!("{tensor}");
318 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
319 /// }
320 /// ```
321 pub fn sign(self) -> Self {
322 Self::new(K::sign(self.primitive))
323 }
324
325 /// Aggregate all elements in the tensor with the mean operation.
326 ///
327 /// # Example
328 ///
329 /// ```rust
330 /// use burn_tensor::backend::Backend;
331 /// use burn_tensor::{Tensor, Shape};
332 ///
333 /// fn example<B: Backend>() {
334 /// let device = B::Device::default();
335 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
336 /// let tensor = tensor.mean();
337 /// println!("{tensor}");
338 /// // [3.6666667]
339 /// }
340 /// ```
341 pub fn mean(self) -> Tensor<B, 1, K> {
342 Tensor::new(K::mean(self.primitive))
343 }
344
345 /// Aggregate all elements in the tensor with the sum operation.
346 ///
347 /// # Example
348 ///
349 /// ```rust
350 /// use burn_tensor::backend::Backend;
351 /// use burn_tensor::{Tensor, Shape};
352 ///
353 /// fn example<B: Backend>() {
354 /// let device = B::Device::default();
355 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
356 /// let tensor = tensor.sum();
357 /// println!("{tensor}");
358 /// // [22.0]
359 /// }
360 /// ```
361 pub fn sum(self) -> Tensor<B, 1, K> {
362 Tensor::new(K::sum(self.primitive))
363 }
364
365 /// Aggregate all elements along the given *dimension* or *axis*
366 /// in the tensor with the mean operation.
367 ///
368 /// # Arguments
369 ///
370 /// * `dim` - The dimension or axis along which to aggregate the elements;
371 /// supports negative indexing.
372 ///
373 /// # Example
374 ///
375 /// ```rust
376 /// use burn_tensor::backend::Backend;
377 /// use burn_tensor::{Tensor, Shape};
378 ///
379 /// fn example<B: Backend>() {
380 /// let device = B::Device::default();
381 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
382 /// let tensor = tensor.clone().mean_dim(0);
383 /// println!("{tensor}");
384 /// // [[3.0, 3.5, 4.5]]
385 /// let tensor = tensor.clone().mean_dim(1);
386 /// println!("{tensor}");
387 /// // [[0.6666667], [6.6666665]]
388 /// }
389 /// ```
390 pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
391 let dim = dim.expect_dim_index(D);
392 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
393 Self::new(K::mean_dim(self.primitive, dim))
394 }
395
396 /// Aggregate all elements along the given *axes*
397 /// in the tensor with the mean operation.
398 ///
399 /// # Arguments
400 ///
401 /// * `dims` - the dimensions to aggregate; supports negative indexing.
402 ///
403 /// # Returns
404 ///
405 /// The returned tensor will have the same rank,
406 /// but the aggregated dimensions will have size 1.
407 ///
408 /// # Example
409 ///
410 /// ```rust
411 /// use burn_tensor::backend::Backend;
412 /// use burn_tensor::{Tensor, Shape};
413 ///
414 /// fn example<B: Backend>() {
415 /// let device = B::Device::default();
416 /// let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
417 /// let tensor = tensor.clone().mean_dims(&[0, 1]);
418 /// println!("{tensor}");
419 /// // [[2.0]]
420 /// }
421 /// ```
422 pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
423 dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
424 }
425
426 /// Aggregate all elements along the given *dimension* or *axis*
427 /// in the tensor with the sum operation.
428 ///
429 /// # Arguments
430 ///
431 /// * `dim` - The dimension or axis along which to aggregate the elements;
432 /// supports negative indexing.
433 ///
434 /// # Example
435 ///
436 /// ```rust
437 /// use burn_tensor::backend::Backend;
438 /// use burn_tensor::{Tensor, Shape};
439 ///
440 /// fn example<B: Backend>() {
441 /// let device = B::Device::default();
442 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
443 /// let tensor = tensor.clone().sum_dim(0);
444 /// println!("{tensor}");
445 /// // [[6.0, 7.0, 9.0]]
446 /// let tensor = tensor.clone().sum_dim(1);
447 /// println!("{tensor}");
448 /// // [[2.0], [20.0]]
449 /// }
450 /// ```
451 pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
452 let dim = dim.expect_dim_index(D);
453 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
454 Self::new(K::sum_dim(self.primitive, dim))
455 }
456
457 /// Aggregate all elements along the given *axes*
458 /// in the tensor with the sum operation.
459 ///
460 /// # Arguments
461 ///
462 /// * `dims` - the dimensions to aggregate; supports negative indexing.
463 ///
464 /// # Returns
465 ///
466 /// The returned tensor will have the same rank,
467 /// but the aggregated dimensions will have size 1.
468 ///
469 /// # Example
470 ///
471 /// ```rust
472 /// use burn_tensor::backend::Backend;
473 /// use burn_tensor::{Tensor, Shape};
474 ///
475 /// fn example<B: Backend>() {
476 /// let device = B::Device::default();
477 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
478 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
479 /// println!("{tensor}");
480 /// // [[27]]
481 /// }
482 /// ```
483 pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
484 dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
485 }
486
487 /// Aggregate and squeeze along the given dimensions.
488 ///
489 /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
490 ///
491 /// # Arguments
492 ///
493 /// * `dims` - the dimensions to aggregate; supports negative indexing.
494 ///
495 /// # Returns
496 ///
497 /// The returned tensor will have the same rank,
498 /// but the aggregated dimensions will have size 1.
499 ///
500 /// # Example
501 ///
502 /// ```rust
503 /// use burn_tensor::backend::Backend;
504 /// use burn_tensor::{Tensor, Shape};
505 ///
506 /// fn example<B: Backend>() {
507 /// let device = B::Device::default();
508 /// let tensor = Tensor::<B, 3>::from_data([
509 /// [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
510 /// [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
511 /// ], &device);
512 /// let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
513 /// println!("{tensor}");
514 /// // [20.0, 16.0, 21.0]
515 /// }
516 /// ```
517 pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
518 // TODO: remove idims when squeeze_dims uses AsIndex.
519 let idims = dims
520 .iter()
521 .map(|&dim| (dim.expect_dim_index(D)) as isize)
522 .collect::<Vec<_>>();
523 self.sum_dims(dims).squeeze_dims::<D2>(&idims)
524 }
525
526 /// Aggregate all elements in the tensor with the product operation.
527 ///
528 /// # Example
529 ///
530 /// ```rust
531 /// use burn_tensor::backend::Backend;
532 /// use burn_tensor::{Tensor, Shape};
533 ///
534 /// fn example<B: Backend>() {
535 /// let device = B::Device::default();
536 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
537 /// let tensor = tensor.prod();
538 /// println!("{tensor}");
539 /// // [-1620.0]
540 /// }
541 /// ```
542 pub fn prod(self) -> Tensor<B, 1, K> {
543 Tensor::new(K::prod(self.primitive))
544 }
545
546 /// Aggregate all elements along the given *dimension* or *axis*
547 /// in the tensor with the product operation.
548 ///
549 /// # Arguments
550 ///
551 /// * `dim` - The dimension or axis along which to aggregate the elements,
552 /// supports negative indexing.
553 ///
554 /// # Returns
555 ///
556 /// The returned tensor will have the same rank,
557 /// but the aggregated dimension will have size 1.
558 ///
559 /// # Example
560 ///
561 /// ```rust
562 /// use burn_tensor::backend::Backend;
563 /// use burn_tensor::{Tensor, Shape};
564 ///
565 /// fn example<B: Backend>() {
566 /// let device = B::Device::default();
567 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
568 /// let tensor = tensor.clone().prod_dim(0);
569 /// println!("{tensor}");
570 /// // [[5.0, -18.0, 18.0]]
571 /// let tensor = tensor.clone().prod_dim(1);
572 /// println!("{tensor}");
573 /// // [[-6.0], [270.0]]
574 /// }
575 /// ```
576 pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
577 let dim = dim.expect_dim_index(D);
578 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
579 Self::new(K::prod_dim(self.primitive, dim))
580 }
581
582 /// Aggregate all elements along the given *axes*
583 /// in the tensor with the prod operation.
584 ///
585 /// # Arguments
586 ///
587 /// * `dims` - the dimensions to aggregate, supports negative indexing.
588 ///
589 /// # Returns
590 ///
591 /// The returned tensor will have the same rank,
592 /// but the aggregated dimensions will have size 1.
593 ///
594 /// # Example
595 ///
596 /// ```rust
597 /// use burn_tensor::backend::Backend;
598 /// use burn_tensor::{Tensor, Shape};
599 ///
600 /// fn example<B: Backend>() {
601 /// let device = B::Device::default();
602 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
603 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
604 /// println!("{tensor}");
605 /// // [[-1620.0]]
606 /// }
607 /// ```
608 pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
609 dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
610 }
611
612 /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
613 ///
614 /// # Arguments
615 ///
616 /// * `dim` - The dimension or axis along which to compute the cumulative sum.
617 ///
618 /// # Example
619 ///
620 /// ```rust
621 /// use burn_tensor::backend::Backend;
622 /// use burn_tensor::{Tensor, Shape};
623 ///
624 /// fn example<B: Backend>() {
625 /// let device = B::Device::default();
626 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
627 /// let result = tensor.clone().cumsum(0);
628 /// println!("{result}");
629 /// // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
630 /// let result = tensor.cumsum(1);
631 /// println!("{result}");
632 /// // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
633 /// }
634 /// ```
635 pub fn cumsum(self, dim: usize) -> Self {
636 check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
637 Self::new(K::cumsum(self.primitive, dim))
638 }
639
640 /// Computes the cumulative product of elements along the given *dimension* or *axis*.
641 ///
642 /// # Arguments
643 ///
644 /// * `dim` - The dimension or axis along which to compute the cumulative product.
645 ///
646 /// # Example
647 ///
648 /// ```rust
649 /// use burn_tensor::backend::Backend;
650 /// use burn_tensor::{Tensor, Shape};
651 ///
652 /// fn example<B: Backend>() {
653 /// let device = B::Device::default();
654 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
655 /// let result = tensor.clone().cumprod(0);
656 /// println!("{result}");
657 /// // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
658 /// let result = tensor.cumprod(1);
659 /// println!("{result}");
660 /// // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
661 /// }
662 /// ```
663 pub fn cumprod(self, dim: usize) -> Self {
664 check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
665 Self::new(K::cumprod(self.primitive, dim))
666 }
667
668 /// Apply element wise absolute value operation.
669 ///
670 /// # Example
671 ///
672 /// ```rust
673 /// use burn_tensor::backend::Backend;
674 /// use burn_tensor::{Int, Tensor};
675 ///
676 /// fn example<B: Backend>() {
677 /// let device = Default::default();
678 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
679 /// let tensor = tensor.abs();
680 /// println!("{tensor}");
681 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
682 /// }
683 /// ```
684 pub fn abs(self) -> Self {
685 Self::new(K::abs(self.primitive))
686 }
687
688 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
689 /// the other elements of the result tensor out are set to 0.
690 ///
691 /// See also [`triu_mask`](Tensor::triu_mask).
692 ///
693 /// # Arguments
694 ///
695 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
696 /// towards the upper triangle.
697 ///
698 /// # Example
699 /// ```rust
700 /// use burn_tensor::backend::Backend;
701 /// use burn_tensor::{Int, Tensor};
702 ///
703 /// fn example<B: Backend>() {
704 /// let device = Default::default();
705 /// let tensor = Tensor::<B, 2, Int>::from_ints(
706 /// [
707 /// [1, 2, 3],
708 /// [4, 5, 6],
709 /// [7, 8, 9]
710 /// ],
711 /// &device
712 /// );
713 /// let tensor = tensor.triu(1);
714 /// println!("{tensor}");
715 /// // [
716 /// // [0, 2, 3],
717 /// // [0, 0, 6],
718 /// // [0, 0, 0]
719 /// // ]
720 /// }
721 /// ```
722 pub fn triu(self, diagonal: i64) -> Self {
723 check!(TensorCheck::tri::<{ D }>());
724
725 // last two dimensions
726 let shape = &self.shape().dims[D - 2..].to_owned();
727
728 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
729 self.mask_fill(mask, 0)
730 }
731
732 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
733 /// the other elements of the result tensor out are set to 0.
734 ///
735 /// See also [`tril_mask`](Tensor::tril_mask).
736 ///
737 /// # Arguments
738 ///
739 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
740 /// towards the upper triangle.
741 ///
742 /// # Example
743 /// ```rust
744 /// use burn_tensor::backend::Backend;
745 /// use burn_tensor::{Int, Tensor};
746 ///
747 /// fn example<B: Backend>() {
748 /// let device = Default::default();
749 /// let tensor = Tensor::<B, 2, Int>::from_ints(
750 /// [
751 /// [1, 2, 3],
752 /// [4, 5, 6],
753 /// [7, 8, 9]
754 /// ],
755 /// &device
756 /// );
757 ///
758 /// let tensor = tensor.tril(-1);
759 /// println!("{tensor}");
760 /// // [
761 /// // [0, 0, 0],
762 /// // [4, 0, 0],
763 /// // [7, 8, 0]
764 /// // ]
765 /// }
766 /// ```
767 pub fn tril(self, diagonal: i64) -> Self {
768 check!(TensorCheck::tri::<{ D }>());
769
770 // last two dimensions
771 let shape = &self.shape().dims[D - 2..].to_owned();
772 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
773
774 self.mask_fill(mask, 0)
775 }
776
777 /// Applies element wise power operation with a float Tensor
778 ///
779 /// # Arguments
780 ///
781 /// * `other` - The tensor to apply the power operation with.
782 ///
783 /// # Example
784 ///
785 /// ```rust
786 /// use burn_tensor::backend::Backend;
787 /// use burn_tensor::{Tensor, Shape};
788 ///
789 /// fn example<B: Backend>() {
790 /// let device = B::Device::default();
791 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
792 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
793 /// let tensor = tensor1.powf(tensor2);
794 /// println!("{tensor}");
795 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
796 /// }
797 /// ```
798 pub fn powf(self, other: Self) -> Self {
799 Self::new(K::powf(self.primitive, other.primitive))
800 }
801
802 /// Applies element wise power operation with a float scalar
803 ///
804 /// # Arguments
805 ///
806 /// * `other` - The scalar to apply the power operation with.
807 ///
808 /// # Example
809 ///
810 /// ```rust
811 /// use burn_tensor::backend::Backend;
812 /// use burn_tensor::{Tensor, Shape};
813 ///
814 /// fn example<B: Backend>() {
815 /// let device = B::Device::default();
816 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
817 /// let tensor = tensor.powf_scalar(2.0);
818 /// println!("{tensor}");
819 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
820 /// }
821 /// ```
822 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
823 let other = Scalar::new(other, &self.dtype());
824 Self::new(K::powf_scalar(self.primitive, other))
825 }
826
827 /// Applies element wise power operation with a integer Tensor
828 ///
829 /// # Arguments
830 ///
831 /// * `other` - The tensor to apply the power operation with.
832 ///
833 /// # Example
834 ///
835 /// ```rust
836 /// use burn_tensor::backend::Backend;
837 /// use burn_tensor::{Tensor, Shape, Int};
838 ///
839 /// fn example<B: Backend>() {
840 /// let device = B::Device::default();
841 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
842 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
843 /// let tensor = tensor1.powi(tensor2);
844 /// println!("{tensor}");
845 /// // [[1, -8, 81], [5, 81, 216]]
846 /// }
847 /// ```
848 pub fn powi(self, other: Self) -> Self {
849 Self::new(K::powi(self.primitive, other.primitive))
850 }
851
852 /// Applies element wise power operation with a integer scalar
853 ///
854 /// # Arguments
855 ///
856 /// * `other` - The scalar to apply the power operation with.
857 ///
858 /// # Example
859 ///
860 /// ```rust
861 /// use burn_tensor::backend::Backend;
862 /// use burn_tensor::{Tensor, Shape, Int};
863 ///
864 /// fn example<B: Backend>() {
865 /// let device = B::Device::default();
866 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
867 /// let tensor = tensor.powi_scalar(2);
868 /// println!("{tensor}");
869 ///
870 /// // [[1, 4, 9], [25, 81, 36]]
871 /// let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
872 /// let tensor = tensor.powi_scalar(2);
873 /// println!("{tensor}");
874 /// // [[2.25, 4., 9.], [25., 81., 36.]]
875 /// }
876 /// ```
877 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
878 let other = Scalar::new(other, &self.dtype());
879 Self::new(K::powi_scalar(self.primitive, other))
880 }
881
882 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
883 ///
884 /// # Returns
885 ///
886 /// A boolean tensor with the same shape as the input tensor.
887 ///
888 /// # Example
889 ///
890 /// ```rust
891 /// use burn_tensor::backend::Backend;
892 /// use burn_tensor::{Tensor, Shape};
893 ///
894 /// fn example<B: Backend>() {
895 /// let device = B::Device::default();
896 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
897 /// let tensor = tensor.bool();
898 /// println!("{tensor}");
899 /// // [
900 /// // [true, true, true],
901 /// // [false, true, true]
902 /// // ]
903 /// }
904 pub fn bool(self) -> Tensor<B, D, Bool> {
905 self.not_equal_elem(0)
906 }
907
908 /// Create a random tensor of the given shape on the given device where each element is
909 /// sampled from the given distribution.
910 ///
911 /// See also [`random_like`](Tensor::random_like).
912 ///
913 /// # Arguments
914 ///
915 /// * `shape` - The shape of the tensor.
916 /// * `distribution` - The distribution to sample from.
917 /// * `device` - The device to create the tensor on.
918 ///
919 /// # Returns
920 ///
921 /// A new tensor with the given shape and elements sampled from the given distribution.
922 ///
923 /// # Example
924 ///
925 /// ```rust
926 /// use burn_tensor::backend::Backend;
927 /// use burn_tensor::{Tensor, Shape, Distribution};
928 ///
929 /// fn example<B: Backend>() {
930 /// let device = B::Device::default();
931 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
932 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
933 /// println!("{tensor}");
934 /// // [
935 /// // [0.08347523, 0.70498955, 0.60332155],
936 /// // [0.08173251, 0.18028641, 0.97942924]
937 /// // ]
938 /// }
939 /// ```
940 pub fn random<S: Into<Shape>>(
941 shape: S,
942 distribution: Distribution,
943 device: &B::Device,
944 ) -> Self {
945 Self::new(K::random(shape.into(), distribution, device))
946 }
947
948 /// Applies the matrix multiplication operation.
949 ///
950 /// ```math
951 /// C = AB
952 /// ```
953 ///
954 /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as
955 /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general
956 /// matmul, which is often faster.
957 pub fn matmul(self, other: Self) -> Self {
958 check!(TensorCheck::matmul(&self, &other));
959
960 if D >= 3 {
961 let batch_index = D - 3;
962 let vector_index = D - 2;
963 let lhs_dims = &self.shape()[batch_index..D];
964 let rhs_dims = &other.shape()[batch_index..D];
965
966 if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims)
967 && k1 == k2
968 {
969 return Tensor::new(K::matmul(
970 self.swap_dims(batch_index, vector_index).primitive,
971 other.primitive,
972 ))
973 .swap_dims(batch_index, vector_index);
974 }
975 }
976
977 Tensor::new(K::matmul(self.primitive, other.primitive))
978 }
979}
980
981impl<B, K> Tensor<B, 1, K>
982where
983 B: Backend,
984 K: Numeric<B>,
985 K::Elem: Element,
986{
987 /// Calculates the dot product with another tensor.
988 ///
989 /// `y = x2.dot(x1)`
990 ///
991 /// # Arguments
992 ///
993 /// * `other` - The tensor to compute dot product with.
994 ///
995 /// # Notes
996 ///
997 /// Both tensors must have the same number of elements.
998 ///
999 /// # Example
1000 ///
1001 /// ```rust
1002 /// use burn_tensor::backend::Backend;
1003 /// use burn_tensor::{Tensor, Shape};
1004 ///
1005 /// fn example<B: Backend>() {
1006 /// let device = B::Device::default();
1007 /// let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
1008 /// let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
1009 /// let tensor = tensor1.dot(tensor2);
1010 /// println!("{tensor}");
1011 /// // [4]
1012 /// }
1013 /// ```
1014 pub fn dot(self, other: Self) -> Self {
1015 self.mul(other).sum()
1016 }
1017}
1018
1019impl<B, K> Tensor<B, 2, K>
1020where
1021 B: Backend,
1022 K: Numeric<B>,
1023 K::Elem: Element,
1024{
1025 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
1026 ///
1027 /// # Arguments
1028 ///
1029 /// * `size` - The size of the square matrix.
1030 pub fn eye(size: usize, device: &B::Device) -> Self {
1031 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
1032 let ones = Self::ones([1, size], device);
1033 let zeros = Self::zeros([size, size], device);
1034
1035 zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
1036 }
1037}
1038
1039// Tensor + tensor
1040impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
1041where
1042 K::Elem: Element,
1043{
1044 type Output = Self;
1045
1046 fn add(self, rhs: Self) -> Self::Output {
1047 Self::add(self, rhs)
1048 }
1049}
1050
1051// Tensor + scalar
1052impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
1053 for Tensor<B, D, K>
1054where
1055 K::Elem: Element,
1056{
1057 type Output = Self;
1058
1059 fn add(self, other: E) -> Self::Output {
1060 Tensor::add_scalar(self, other)
1061 }
1062}
1063
1064// Scalar + tensor
1065macro_rules! impl_tensor_scalar_add {
1066 ($($t:ty),*) => {
1067 $(
1068 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
1069 where
1070 K::Elem: Element,
1071 {
1072 type Output = Tensor<B, D, K>;
1073
1074 fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
1075 Tensor::add_scalar(tensor, self)
1076 }
1077 }
1078 )*
1079 }
1080}
1081impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
1082
1083// Tensor - tensor
1084impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
1085where
1086 K::Elem: Element,
1087{
1088 type Output = Self;
1089
1090 fn sub(self, rhs: Self) -> Self::Output {
1091 Tensor::sub(self, rhs)
1092 }
1093}
1094
1095// Tensor - scalar
1096impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
1097 for Tensor<B, D, K>
1098where
1099 K::Elem: Element,
1100{
1101 type Output = Self;
1102
1103 fn sub(self, other: E) -> Self::Output {
1104 Tensor::sub_scalar(self, other)
1105 }
1106}
1107
1108// Scalar - tensor
1109macro_rules! impl_tensor_scalar_sub {
1110 ($($t:ty),*) => {
1111 $(
1112 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
1113 where
1114 K::Elem: Element,
1115 {
1116 type Output = Tensor<B, D, K>;
1117
1118 fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
1119 Tensor::add_scalar(Tensor::neg(tensor), self)
1120 }
1121 }
1122 )*
1123 }
1124}
1125impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
1126
1127// Tensor / tensor
1128impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
1129where
1130 K::Elem: Element,
1131{
1132 type Output = Self;
1133
1134 fn div(self, rhs: Self) -> Self::Output {
1135 Tensor::div(self, rhs)
1136 }
1137}
1138
1139// Tensor / scalar
1140impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
1141 for Tensor<B, D, K>
1142where
1143 K::Elem: Element,
1144{
1145 type Output = Self;
1146
1147 fn div(self, other: E) -> Self::Output {
1148 Tensor::div_scalar(self, other)
1149 }
1150}
1151
1152// Scalar / tensor (float only)
1153macro_rules! impl_tensor_scalar_div {
1154 ($($t:ty),*) => {
1155 $(
1156 impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
1157 {
1158 type Output = Tensor<B, D>;
1159
1160 fn div(self, tensor: Tensor<B, D>) -> Self::Output {
1161 tensor.recip().mul_scalar(self)
1162 }
1163 }
1164 )*
1165 }
1166}
1167
1168impl_tensor_scalar_div!(f32, f64);
1169
1170// Tensor % tensor.
1171impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
1172where
1173 K::Elem: Element,
1174{
1175 type Output = Self;
1176
1177 fn rem(self, rhs: Self) -> Self::Output {
1178 Tensor::remainder(self, rhs)
1179 }
1180}
1181
1182// Tensor % scalar.
1183impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
1184 for Tensor<B, D, K>
1185where
1186 K::Elem: Element,
1187{
1188 type Output = Self;
1189
1190 fn rem(self, other: E) -> Self::Output {
1191 Tensor::remainder_scalar(self, other)
1192 }
1193}
1194
1195// Tensor * tensor.
1196impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
1197where
1198 K::Elem: Element,
1199{
1200 type Output = Self;
1201
1202 fn mul(self, rhs: Self) -> Self::Output {
1203 Tensor::mul(self, rhs)
1204 }
1205}
1206
1207// Tensor * scalar.
1208impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
1209 for Tensor<B, D, K>
1210where
1211 K::Elem: Element,
1212{
1213 type Output = Self;
1214
1215 fn mul(self, other: E) -> Self::Output {
1216 Tensor::mul_scalar(self, other)
1217 }
1218}
1219
1220macro_rules! impl_tensor_scalar_mul {
1221 ($($t:ty),*) => {
1222 $(
1223 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
1224 where
1225 K::Elem: Element,
1226 {
1227 type Output = Tensor<B, D, K>;
1228
1229 fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
1230 Tensor::mul_scalar(other, self)
1231 }
1232 }
1233 )*
1234 }
1235}
1236
1237impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
1238
1239impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
1240where
1241 B: Backend,
1242 K: Numeric<B>,
1243 K::Elem: Element,
1244{
1245 type Output = Self;
1246
1247 fn neg(self) -> Self::Output {
1248 Tensor::neg(self)
1249 }
1250}