burn_tensor/tensor/api/numeric.rs
1use burn_backend::Scalar;
2pub use burn_backend::tensor::Numeric;
3
4use crate::alloc::borrow::ToOwned;
5use alloc::vec::Vec;
6
7use crate::IndexingUpdateOp;
8use crate::{
9 AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,
10 check, check::TensorCheck,
11};
12
13impl<B, const D: usize, K> Tensor<B, D, K>
14where
15 B: Backend,
16 K: Numeric<B>,
17 K::Elem: Element,
18{
19 /// Applies element wise addition operation.
20 ///
21 /// `y = x2 + x1`
22 ///
23 /// # Arguments
24 ///
25 /// * `other` - The tensor to add.
26 ///
27 /// # Example
28 ///
29 /// ```rust
30 /// use burn_tensor::backend::Backend;
31 /// use burn_tensor::{Tensor, Shape};
32 ///
33 /// fn example<B: Backend>() {
34 /// let device = B::Device::default();
35 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
36 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
37 /// let tensor = tensor1 + tensor2;
38 /// println!("{tensor}");
39 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
40 /// }
41 /// ```
42 #[allow(clippy::should_implement_trait)]
43 pub fn add(self, other: Self) -> Self {
44 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
45 Self::new(K::add(self.primitive, other.primitive))
46 }
47
48 /// Applies element wise addition operation with a scalar.
49 ///
50 /// `y = x + s`
51 ///
52 /// # Arguments
53 ///
54 /// * `other` - The scalar to add, element wise.
55 ///
56 /// # Example
57 ///
58 /// ```rust
59 /// use burn_tensor::backend::Backend;
60 /// use burn_tensor::{Tensor, Shape};
61 ///
62 /// fn example<B: Backend>() {
63 /// let device = B::Device::default();
64 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
65 /// let scalar = 2.0;
66 /// let tensor = tensor + scalar;
67 /// println!("{tensor}");
68 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
69 /// }
70 /// ```
71 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
72 let other = Scalar::new(other, &self.dtype());
73 Self::new(K::add_scalar(self.primitive, other))
74 }
75
76 /// Applies element wise subtraction operation.
77 ///
78 /// `y = x2 - x1`
79 ///
80 /// # Arguments
81 ///
82 /// * `other` - The tensor to subtract.
83 ///
84 /// # Example
85 ///
86 /// ```rust
87 /// use burn_tensor::backend::Backend;
88 /// use burn_tensor::{Tensor, Shape};
89 ///
90 /// fn example<B: Backend>() {
91 /// let device = B::Device::default();
92 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
93 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
94 /// let tensor = tensor1 - tensor2;
95 /// println!("{tensor}");
96 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
97 /// }
98 /// ```
99 #[allow(clippy::should_implement_trait)]
100 pub fn sub(self, other: Self) -> Self {
101 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
102 Self::new(K::sub(self.primitive, other.primitive))
103 }
104
105 /// Applies element wise subtraction operation with a scalar.
106 ///
107 /// `y = x - s`
108 ///
109 /// # Arguments
110 ///
111 /// * `other` - The scalar to subtract, element wise.
112 ///
113 /// # Example
114 ///
115 /// ```rust
116 /// use burn_tensor::backend::Backend;
117 /// use burn_tensor::{Tensor, Shape};
118 ///
119 /// fn example<B: Backend>() {
120 /// let device = B::Device::default();
121 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
122 /// let scalar = 2.0;
123 /// let tensor = tensor - scalar;
124 /// println!("{tensor}");
125 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
126 /// }
127 /// ```
128 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
129 let other = Scalar::new(other, &self.dtype());
130 Self::new(K::sub_scalar(self.primitive, other))
131 }
132
133 /// Applies element wise division operation.
134 ///
135 /// `y = x2 / x1`
136 ///
137 /// # Arguments
138 ///
139 /// * `other` - The tensor to divide.
140 ///
141 /// # Example
142 ///
143 /// ```rust
144 /// use burn_tensor::backend::Backend;
145 /// use burn_tensor::{Tensor, Shape};
146 ///
147 /// fn example<B: Backend>() {
148 /// let device = B::Device::default();
149 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151 /// let tensor = tensor1 / tensor2;
152 /// println!("{tensor}");
153 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154 /// }
155 /// ```
156 #[allow(clippy::should_implement_trait)]
157 pub fn div(self, other: Self) -> Self {
158 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159 Self::new(K::div(self.primitive, other.primitive))
160 }
161
162 /// Applies element wise division operation with a scalar.
163 ///
164 /// `y = x / s`
165 ///
166 /// # Arguments
167 ///
168 /// * `other` - The scalar to divide, element wise.
169 ///
170 /// # Example
171 ///
172 /// ```rust
173 /// use burn_tensor::backend::Backend;
174 /// use burn_tensor::{Tensor, Shape};
175 ///
176 /// fn example<B: Backend>() {
177 /// let device = B::Device::default();
178 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179 /// let scalar = 2.0;
180 /// let tensor = tensor / scalar;
181 /// println!("{tensor}");
182 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183 /// }
184 /// ```
185 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186 let other = Scalar::new(other, &self.dtype());
187 Self::new(K::div_scalar(self.primitive, other))
188 }
189
190 /// Applies element wise the remainder operation with a scalar.
191 ///
192 /// `y = x2 % x1`
193 pub fn remainder(self, other: Self) -> Self {
194 Self::new(K::remainder(self.primitive, other.primitive))
195 }
196
197 /// Applies element wise the remainder operation with a scalar.
198 ///
199 /// `y = x % s`
200 ///
201 /// # Arguments
202 ///
203 /// * `other` - The scalar to divide, element wise.
204 ///
205 /// # Example
206 ///
207 /// ```rust
208 /// use burn_tensor::backend::Backend;
209 /// use burn_tensor::{Tensor, Shape};
210 ///
211 /// fn example<B: Backend>() {
212 /// let device = B::Device::default();
213 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
214 /// let scalar = 2.0;
215 /// let tensor = tensor1 % scalar;
216 /// println!("{tensor}");
217 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
218 /// }
219 /// ```
220 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
221 let other = Scalar::new(other, &self.dtype());
222 Self::new(K::remainder_scalar(self.primitive, other))
223 }
224
225 /// Applies element wise multiplication operation.
226 ///
227 /// `y = x2 * x1`
228 ///
229 /// # Arguments
230 ///
231 /// * `other` - The tensor to multiply.
232 ///
233 /// # Example
234 ///
235 /// ```rust
236 /// use burn_tensor::backend::Backend;
237 /// use burn_tensor::{Tensor, Shape};
238 ///
239 /// fn example<B: Backend>() {
240 /// let device = B::Device::default();
241 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
242 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
243 /// let tensor = tensor1 * tensor2;
244 /// println!("{tensor}");
245 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
246 /// }
247 /// ```
248 #[allow(clippy::should_implement_trait)]
249 pub fn mul(self, other: Self) -> Self {
250 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
251 Self::new(K::mul(self.primitive, other.primitive))
252 }
253
254 /// Applies element wise multiplication operation with a scalar.
255 ///
256 /// `y = x * s`
257 ///
258 /// # Arguments
259 ///
260 /// * `other` - The scalar to multiply, element wise.
261 ///
262 /// # Example
263 ///
264 /// ```rust
265 /// use burn_tensor::backend::Backend;
266 /// use burn_tensor::{Tensor, Shape};
267 ///
268 /// fn example<B: Backend>() {
269 /// let device = B::Device::default();
270 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
271 /// let scalar = 2.0;
272 /// let tensor = tensor * scalar;
273 /// println!("{tensor}");
274 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
275 /// }
276 /// ```
277 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
278 let other = Scalar::new(other, &self.dtype());
279 Self::new(K::mul_scalar(self.primitive, other))
280 }
281
282 /// Switch sign of each element in the tensor.
283 ///
284 /// `y = -x`
285 ///
286 /// # Example
287 ///
288 /// ```rust
289 /// use burn_tensor::backend::Backend;
290 /// use burn_tensor::{Tensor, Shape};
291 ///
292 /// fn example<B: Backend>() {
293 /// let device = B::Device::default();
294 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
295 /// let tensor = -tensor;
296 /// println!("{tensor}");
297 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
298 /// }
299 /// ```
300 #[allow(clippy::should_implement_trait)]
301 pub fn neg(self) -> Self {
302 Self::new(K::neg(self.primitive))
303 }
304
305 /// Returns the signs of the elements of the input tensor.
306 ///
307 /// # Example
308 ///
309 /// ```rust
310 /// use burn_tensor::backend::Backend;
311 /// use burn_tensor::{Tensor, Shape};
312 ///
313 /// fn example<B: Backend>() {
314 /// let device = B::Device::default();
315 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
316 /// let tensor = tensor.sign();
317 /// println!("{tensor}");
318 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
319 /// }
320 /// ```
321 pub fn sign(self) -> Self {
322 Self::new(K::sign(self.primitive))
323 }
324
325 /// Aggregate all elements in the tensor with the mean operation.
326 ///
327 /// # Example
328 ///
329 /// ```rust
330 /// use burn_tensor::backend::Backend;
331 /// use burn_tensor::{Tensor, Shape};
332 ///
333 /// fn example<B: Backend>() {
334 /// let device = B::Device::default();
335 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
336 /// let tensor = tensor.mean();
337 /// println!("{tensor}");
338 /// // [3.6666667]
339 /// }
340 /// ```
341 pub fn mean(self) -> Tensor<B, 1, K> {
342 Tensor::new(K::mean(self.primitive))
343 }
344
345 /// Aggregate all elements in the tensor with the sum operation.
346 ///
347 /// # Example
348 ///
349 /// ```rust
350 /// use burn_tensor::backend::Backend;
351 /// use burn_tensor::{Tensor, Shape};
352 ///
353 /// fn example<B: Backend>() {
354 /// let device = B::Device::default();
355 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
356 /// let tensor = tensor.sum();
357 /// println!("{tensor}");
358 /// // [22.0]
359 /// }
360 /// ```
361 pub fn sum(self) -> Tensor<B, 1, K> {
362 Tensor::new(K::sum(self.primitive))
363 }
364
365 /// Aggregate all elements along the given *dimension* or *axis*
366 /// in the tensor with the mean operation.
367 ///
368 /// # Arguments
369 ///
370 /// * `dim` - The dimension or axis along which to aggregate the elements;
371 /// supports negative indexing.
372 ///
373 /// # Example
374 ///
375 /// ```rust
376 /// use burn_tensor::backend::Backend;
377 /// use burn_tensor::{Tensor, Shape};
378 ///
379 /// fn example<B: Backend>() {
380 /// let device = B::Device::default();
381 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
382 /// let tensor = tensor.clone().mean_dim(0);
383 /// println!("{tensor}");
384 /// // [[3.0, 3.5, 4.5]]
385 /// let tensor = tensor.clone().mean_dim(1);
386 /// println!("{tensor}");
387 /// // [[0.6666667], [6.6666665]]
388 /// }
389 /// ```
390 pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
391 let dim = dim.expect_dim_index(D);
392 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
393 Self::new(K::mean_dim(self.primitive, dim))
394 }
395
396 /// Aggregate all elements along the given *axes*
397 /// in the tensor with the mean operation.
398 ///
399 /// # Arguments
400 ///
401 /// * `dims` - the dimensions to aggregate; supports negative indexing.
402 ///
403 /// # Returns
404 ///
405 /// The returned tensor will have the same rank,
406 /// but the aggregated dimensions will have size 1.
407 ///
408 /// # Example
409 ///
410 /// ```rust
411 /// use burn_tensor::backend::Backend;
412 /// use burn_tensor::{Tensor, Shape};
413 ///
414 /// fn example<B: Backend>() {
415 /// let device = B::Device::default();
416 /// let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
417 /// let tensor = tensor.clone().mean_dims(&[0, 1]);
418 /// println!("{tensor}");
419 /// // [[2.0]]
420 /// }
421 /// ```
422 pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
423 dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
424 }
425
426 /// Aggregate all elements along the given *dimension* or *axis*
427 /// in the tensor with the sum operation.
428 ///
429 /// # Arguments
430 ///
431 /// * `dim` - The dimension or axis along which to aggregate the elements;
432 /// supports negative indexing.
433 ///
434 /// # Example
435 ///
436 /// ```rust
437 /// use burn_tensor::backend::Backend;
438 /// use burn_tensor::{Tensor, Shape};
439 ///
440 /// fn example<B: Backend>() {
441 /// let device = B::Device::default();
442 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
443 /// let tensor = tensor.clone().sum_dim(0);
444 /// println!("{tensor}");
445 /// // [[6.0, 7.0, 9.0]]
446 /// let tensor = tensor.clone().sum_dim(1);
447 /// println!("{tensor}");
448 /// // [[2.0], [20.0]]
449 /// }
450 /// ```
451 pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
452 let dim = dim.expect_dim_index(D);
453 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
454 Self::new(K::sum_dim(self.primitive, dim))
455 }
456
457 /// Aggregate all elements along the given *axes*
458 /// in the tensor with the sum operation.
459 ///
460 /// # Arguments
461 ///
462 /// * `dims` - the dimensions to aggregate; supports negative indexing.
463 ///
464 /// # Returns
465 ///
466 /// The returned tensor will have the same rank,
467 /// but the aggregated dimensions will have size 1.
468 ///
469 /// # Example
470 ///
471 /// ```rust
472 /// use burn_tensor::backend::Backend;
473 /// use burn_tensor::{Tensor, Shape};
474 ///
475 /// fn example<B: Backend>() {
476 /// let device = B::Device::default();
477 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
478 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
479 /// println!("{tensor}");
480 /// // [[27]]
481 /// }
482 /// ```
483 pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
484 dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
485 }
486
487 /// Aggregate and squeeze along the given dimensions.
488 ///
489 /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
490 ///
491 /// # Arguments
492 ///
493 /// * `dims` - the dimensions to aggregate; supports negative indexing.
494 ///
495 /// # Returns
496 ///
497 /// The returned tensor will have the same rank,
498 /// but the aggregated dimensions will have size 1.
499 ///
500 /// # Example
501 ///
502 /// ```rust
503 /// use burn_tensor::backend::Backend;
504 /// use burn_tensor::{Tensor, Shape};
505 ///
506 /// fn example<B: Backend>() {
507 /// let device = B::Device::default();
508 /// let tensor = Tensor::<B, 3>::from_data([
509 /// [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
510 /// [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
511 /// ], &device);
512 /// let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
513 /// println!("{tensor}");
514 /// // [20.0, 16.0, 21.0]
515 /// }
516 /// ```
517 pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
518 // TODO: remove idims when squeeze_dims uses AsIndex.
519 let idims = dims
520 .iter()
521 .map(|&dim| (dim.expect_dim_index(D)) as isize)
522 .collect::<Vec<_>>();
523 self.sum_dims(dims).squeeze_dims::<D2>(&idims)
524 }
525
526 /// Aggregate all elements in the tensor with the product operation.
527 ///
528 /// # Example
529 ///
530 /// ```rust
531 /// use burn_tensor::backend::Backend;
532 /// use burn_tensor::{Tensor, Shape};
533 ///
534 /// fn example<B: Backend>() {
535 /// let device = B::Device::default();
536 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
537 /// let tensor = tensor.prod();
538 /// println!("{tensor}");
539 /// // [-1620.0]
540 /// }
541 /// ```
542 pub fn prod(self) -> Tensor<B, 1, K> {
543 Tensor::new(K::prod(self.primitive))
544 }
545
546 /// Aggregate all elements along the given *dimension* or *axis*
547 /// in the tensor with the product operation.
548 ///
549 /// # Arguments
550 ///
551 /// * `dim` - The dimension or axis along which to aggregate the elements,
552 /// supports negative indexing.
553 ///
554 /// # Returns
555 ///
556 /// The returned tensor will have the same rank,
557 /// but the aggregated dimension will have size 1.
558 ///
559 /// # Example
560 ///
561 /// ```rust
562 /// use burn_tensor::backend::Backend;
563 /// use burn_tensor::{Tensor, Shape};
564 ///
565 /// fn example<B: Backend>() {
566 /// let device = B::Device::default();
567 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
568 /// let tensor = tensor.clone().prod_dim(0);
569 /// println!("{tensor}");
570 /// // [[5.0, -18.0, 18.0]]
571 /// let tensor = tensor.clone().prod_dim(1);
572 /// println!("{tensor}");
573 /// // [[-6.0], [270.0]]
574 /// }
575 /// ```
576 pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
577 let dim = dim.expect_dim_index(D);
578 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
579 Self::new(K::prod_dim(self.primitive, dim))
580 }
581
582 /// Aggregate all elements along the given *axes*
583 /// in the tensor with the prod operation.
584 ///
585 /// # Arguments
586 ///
587 /// * `dims` - the dimensions to aggregate, supports negative indexing.
588 ///
589 /// # Returns
590 ///
591 /// The returned tensor will have the same rank,
592 /// but the aggregated dimensions will have size 1.
593 ///
594 /// # Example
595 ///
596 /// ```rust
597 /// use burn_tensor::backend::Backend;
598 /// use burn_tensor::{Tensor, Shape};
599 ///
600 /// fn example<B: Backend>() {
601 /// let device = B::Device::default();
602 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
603 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
604 /// println!("{tensor}");
605 /// // [[-1620.0]]
606 /// }
607 /// ```
608 pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
609 dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
610 }
611
612 /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
613 ///
614 /// # Arguments
615 ///
616 /// * `dim` - The dimension or axis along which to compute the cumulative sum.
617 ///
618 /// # Example
619 ///
620 /// ```rust
621 /// use burn_tensor::backend::Backend;
622 /// use burn_tensor::{Tensor, Shape};
623 ///
624 /// fn example<B: Backend>() {
625 /// let device = B::Device::default();
626 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
627 /// let result = tensor.clone().cumsum(0);
628 /// println!("{result}");
629 /// // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
630 /// let result = tensor.cumsum(1);
631 /// println!("{result}");
632 /// // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
633 /// }
634 /// ```
635 pub fn cumsum(self, dim: usize) -> Self {
636 check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
637 Self::new(K::cumsum(self.primitive, dim))
638 }
639
640 /// Computes the cumulative product of elements along the given *dimension* or *axis*.
641 ///
642 /// # Arguments
643 ///
644 /// * `dim` - The dimension or axis along which to compute the cumulative product.
645 ///
646 /// # Example
647 ///
648 /// ```rust
649 /// use burn_tensor::backend::Backend;
650 /// use burn_tensor::{Tensor, Shape};
651 ///
652 /// fn example<B: Backend>() {
653 /// let device = B::Device::default();
654 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
655 /// let result = tensor.clone().cumprod(0);
656 /// println!("{result}");
657 /// // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
658 /// let result = tensor.cumprod(1);
659 /// println!("{result}");
660 /// // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
661 /// }
662 /// ```
663 pub fn cumprod(self, dim: usize) -> Self {
664 check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
665 Self::new(K::cumprod(self.primitive, dim))
666 }
667
668 /// Apply element wise absolute value operation.
669 ///
670 /// # Example
671 ///
672 /// ```rust
673 /// use burn_tensor::backend::Backend;
674 /// use burn_tensor::{Int, Tensor};
675 ///
676 /// fn example<B: Backend>() {
677 /// let device = Default::default();
678 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
679 /// let tensor = tensor.abs();
680 /// println!("{tensor}");
681 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
682 /// }
683 /// ```
684 ///
685 /// # Notes
686 ///
687 /// For signed integer dtypes, this operation uses two's-complement wraparound semantics, similar to
688 /// `x.wrapping_abs()`. For example, `abs(i64::MIN) == i64::MIN`.
689 pub fn abs(self) -> Self {
690 Self::new(K::abs(self.primitive))
691 }
692
693 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
694 /// the other elements of the result tensor out are set to 0.
695 ///
696 /// See also [`triu_mask`](Tensor::triu_mask).
697 ///
698 /// # Arguments
699 ///
700 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
701 /// towards the upper triangle.
702 ///
703 /// # Example
704 /// ```rust
705 /// use burn_tensor::backend::Backend;
706 /// use burn_tensor::{Int, Tensor};
707 ///
708 /// fn example<B: Backend>() {
709 /// let device = Default::default();
710 /// let tensor = Tensor::<B, 2, Int>::from_ints(
711 /// [
712 /// [1, 2, 3],
713 /// [4, 5, 6],
714 /// [7, 8, 9]
715 /// ],
716 /// &device
717 /// );
718 /// let tensor = tensor.triu(1);
719 /// println!("{tensor}");
720 /// // [
721 /// // [0, 2, 3],
722 /// // [0, 0, 6],
723 /// // [0, 0, 0]
724 /// // ]
725 /// }
726 /// ```
727 pub fn triu(self, diagonal: i64) -> Self {
728 check!(TensorCheck::tri::<{ D }>());
729
730 // last two dimensions
731 let shape = &self.shape()[D - 2..].to_owned();
732
733 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
734 self.mask_fill(mask, 0)
735 }
736
737 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
738 /// the other elements of the result tensor out are set to 0.
739 ///
740 /// See also [`tril_mask`](Tensor::tril_mask).
741 ///
742 /// # Arguments
743 ///
744 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
745 /// towards the upper triangle.
746 ///
747 /// # Example
748 /// ```rust
749 /// use burn_tensor::backend::Backend;
750 /// use burn_tensor::{Int, Tensor};
751 ///
752 /// fn example<B: Backend>() {
753 /// let device = Default::default();
754 /// let tensor = Tensor::<B, 2, Int>::from_ints(
755 /// [
756 /// [1, 2, 3],
757 /// [4, 5, 6],
758 /// [7, 8, 9]
759 /// ],
760 /// &device
761 /// );
762 ///
763 /// let tensor = tensor.tril(-1);
764 /// println!("{tensor}");
765 /// // [
766 /// // [0, 0, 0],
767 /// // [4, 0, 0],
768 /// // [7, 8, 0]
769 /// // ]
770 /// }
771 /// ```
772 pub fn tril(self, diagonal: i64) -> Self {
773 check!(TensorCheck::tri::<{ D }>());
774
775 // last two dimensions
776 let shape = &self.shape()[D - 2..].to_owned();
777 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
778
779 self.mask_fill(mask, 0)
780 }
781
782 /// Applies element wise power operation with a float Tensor
783 ///
784 /// # Arguments
785 ///
786 /// * `other` - The tensor to apply the power operation with.
787 ///
788 /// # Example
789 ///
790 /// ```rust
791 /// use burn_tensor::backend::Backend;
792 /// use burn_tensor::{Tensor, Shape};
793 ///
794 /// fn example<B: Backend>() {
795 /// let device = B::Device::default();
796 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
797 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
798 /// let tensor = tensor1.powf(tensor2);
799 /// println!("{tensor}");
800 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
801 /// }
802 /// ```
803 pub fn powf(self, other: Self) -> Self {
804 Self::new(K::powf(self.primitive, other.primitive))
805 }
806
807 /// Applies element wise power operation with a float scalar
808 ///
809 /// # Arguments
810 ///
811 /// * `other` - The scalar to apply the power operation with.
812 ///
813 /// # Example
814 ///
815 /// ```rust
816 /// use burn_tensor::backend::Backend;
817 /// use burn_tensor::{Tensor, Shape};
818 ///
819 /// fn example<B: Backend>() {
820 /// let device = B::Device::default();
821 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
822 /// let tensor = tensor.powf_scalar(2.0);
823 /// println!("{tensor}");
824 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
825 /// }
826 /// ```
827 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
828 let other = Scalar::new(other, &self.dtype());
829 Self::new(K::powf_scalar(self.primitive, other))
830 }
831
832 /// Applies element wise power operation with a integer Tensor
833 ///
834 /// # Arguments
835 ///
836 /// * `other` - The tensor to apply the power operation with.
837 ///
838 /// # Example
839 ///
840 /// ```rust
841 /// use burn_tensor::backend::Backend;
842 /// use burn_tensor::{Tensor, Shape, Int};
843 ///
844 /// fn example<B: Backend>() {
845 /// let device = B::Device::default();
846 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
847 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
848 /// let tensor = tensor1.powi(tensor2);
849 /// println!("{tensor}");
850 /// // [[1, -8, 81], [5, 81, 216]]
851 /// }
852 /// ```
853 pub fn powi(self, other: Self) -> Self {
854 Self::new(K::powi(self.primitive, other.primitive))
855 }
856
857 /// Applies element wise power operation with a integer scalar
858 ///
859 /// # Arguments
860 ///
861 /// * `other` - The scalar to apply the power operation with.
862 ///
863 /// # Example
864 ///
865 /// ```rust
866 /// use burn_tensor::backend::Backend;
867 /// use burn_tensor::{Tensor, Shape, Int};
868 ///
869 /// fn example<B: Backend>() {
870 /// let device = B::Device::default();
871 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
872 /// let tensor = tensor.powi_scalar(2);
873 /// println!("{tensor}");
874 ///
875 /// // [[1, 4, 9], [25, 81, 36]]
876 /// let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
877 /// let tensor = tensor.powi_scalar(2);
878 /// println!("{tensor}");
879 /// // [[2.25, 4., 9.], [25., 81., 36.]]
880 /// }
881 /// ```
882 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
883 let other = Scalar::new(other, &self.dtype());
884 Self::new(K::powi_scalar(self.primitive, other))
885 }
886
887 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
888 ///
889 /// # Returns
890 ///
891 /// A boolean tensor with the same shape as the input tensor.
892 ///
893 /// # Example
894 ///
895 /// ```rust
896 /// use burn_tensor::backend::Backend;
897 /// use burn_tensor::{Tensor, Shape};
898 ///
899 /// fn example<B: Backend>() {
900 /// let device = B::Device::default();
901 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
902 /// let tensor = tensor.bool();
903 /// println!("{tensor}");
904 /// // [
905 /// // [true, true, true],
906 /// // [false, true, true]
907 /// // ]
908 /// }
909 pub fn bool(self) -> Tensor<B, D, Bool> {
910 self.not_equal_elem(0)
911 }
912
913 /// Create a random tensor of the given shape on the given device where each element is
914 /// sampled from the given distribution.
915 ///
916 /// See also [`random_like`](Tensor::random_like).
917 ///
918 /// # Arguments
919 ///
920 /// * `shape` - The shape of the tensor.
921 /// * `distribution` - The distribution to sample from.
922 /// * `device` - The device to create the tensor on.
923 ///
924 /// # Returns
925 ///
926 /// A new tensor with the given shape and elements sampled from the given distribution.
927 ///
928 /// # Example
929 ///
930 /// ```rust
931 /// use burn_tensor::backend::Backend;
932 /// use burn_tensor::{Tensor, Shape, Distribution};
933 ///
934 /// fn example<B: Backend>() {
935 /// let device = B::Device::default();
936 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
937 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
938 /// println!("{tensor}");
939 /// // [
940 /// // [0.08347523, 0.70498955, 0.60332155],
941 /// // [0.08173251, 0.18028641, 0.97942924]
942 /// // ]
943 /// }
944 /// ```
945 pub fn random<S: Into<Shape>>(
946 shape: S,
947 distribution: Distribution,
948 device: &B::Device,
949 ) -> Self {
950 Self::new(K::random(shape.into(), distribution, device))
951 }
952
953 /// Applies the matrix multiplication operation.
954 ///
955 /// ```math
956 /// C = AB
957 /// ```
958 ///
959 /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as
960 /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general
961 /// matmul, which is often faster.
962 pub fn matmul(self, other: Self) -> Self {
963 check!(TensorCheck::matmul(&self, &other));
964
965 if D >= 3 {
966 let batch_index = D - 3;
967 let vector_index = D - 2;
968 let lhs_dims = &self.shape()[batch_index..D];
969 let rhs_dims = &other.shape()[batch_index..D];
970
971 if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims)
972 && k1 == k2
973 {
974 return Tensor::new(K::matmul(
975 self.swap_dims(batch_index, vector_index).primitive,
976 other.primitive,
977 ))
978 .swap_dims(batch_index, vector_index);
979 }
980 }
981
982 Tensor::new(K::matmul(self.primitive, other.primitive))
983 }
984}
985
986impl<B, K> Tensor<B, 1, K>
987where
988 B: Backend,
989 K: Numeric<B>,
990 K::Elem: Element,
991{
992 /// Calculates the dot product with another tensor.
993 ///
994 /// `y = x2.dot(x1)`
995 ///
996 /// # Arguments
997 ///
998 /// * `other` - The tensor to compute dot product with.
999 ///
1000 /// # Notes
1001 ///
1002 /// Both tensors must have the same number of elements.
1003 ///
1004 /// # Example
1005 ///
1006 /// ```rust
1007 /// use burn_tensor::backend::Backend;
1008 /// use burn_tensor::{Tensor, Shape};
1009 ///
1010 /// fn example<B: Backend>() {
1011 /// let device = B::Device::default();
1012 /// let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
1013 /// let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
1014 /// let tensor = tensor1.dot(tensor2);
1015 /// println!("{tensor}");
1016 /// // [4]
1017 /// }
1018 /// ```
1019 pub fn dot(self, other: Self) -> Self {
1020 self.mul(other).sum()
1021 }
1022}
1023
1024impl<B, K> Tensor<B, 2, K>
1025where
1026 B: Backend,
1027 K: Numeric<B>,
1028 K::Elem: Element,
1029{
1030 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
1031 ///
1032 /// # Arguments
1033 ///
1034 /// * `size` - The size of the square matrix.
1035 pub fn eye(size: usize, device: &B::Device) -> Self {
1036 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
1037 let ones = Self::ones([1, size], device);
1038 let zeros = Self::zeros([size, size], device);
1039
1040 zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)
1041 }
1042}
1043
1044// Tensor + tensor
1045impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
1046where
1047 K::Elem: Element,
1048{
1049 type Output = Self;
1050
1051 fn add(self, rhs: Self) -> Self::Output {
1052 Self::add(self, rhs)
1053 }
1054}
1055
1056// Tensor + scalar
1057impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
1058 for Tensor<B, D, K>
1059where
1060 K::Elem: Element,
1061{
1062 type Output = Self;
1063
1064 fn add(self, other: E) -> Self::Output {
1065 Tensor::add_scalar(self, other)
1066 }
1067}
1068
1069// Scalar + tensor
1070macro_rules! impl_tensor_scalar_add {
1071 ($($t:ty),*) => {
1072 $(
1073 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
1074 where
1075 K::Elem: Element,
1076 {
1077 type Output = Tensor<B, D, K>;
1078
1079 fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
1080 Tensor::add_scalar(tensor, self)
1081 }
1082 }
1083 )*
1084 }
1085}
1086impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
1087
1088// Tensor - tensor
1089impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
1090where
1091 K::Elem: Element,
1092{
1093 type Output = Self;
1094
1095 fn sub(self, rhs: Self) -> Self::Output {
1096 Tensor::sub(self, rhs)
1097 }
1098}
1099
1100// Tensor - scalar
1101impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
1102 for Tensor<B, D, K>
1103where
1104 K::Elem: Element,
1105{
1106 type Output = Self;
1107
1108 fn sub(self, other: E) -> Self::Output {
1109 Tensor::sub_scalar(self, other)
1110 }
1111}
1112
1113// Scalar - tensor
1114macro_rules! impl_tensor_scalar_sub {
1115 ($($t:ty),*) => {
1116 $(
1117 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
1118 where
1119 K::Elem: Element,
1120 {
1121 type Output = Tensor<B, D, K>;
1122
1123 fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
1124 Tensor::add_scalar(Tensor::neg(tensor), self)
1125 }
1126 }
1127 )*
1128 }
1129}
1130impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
1131
1132// Tensor / tensor
1133impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
1134where
1135 K::Elem: Element,
1136{
1137 type Output = Self;
1138
1139 fn div(self, rhs: Self) -> Self::Output {
1140 Tensor::div(self, rhs)
1141 }
1142}
1143
1144// Tensor / scalar
1145impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
1146 for Tensor<B, D, K>
1147where
1148 K::Elem: Element,
1149{
1150 type Output = Self;
1151
1152 fn div(self, other: E) -> Self::Output {
1153 Tensor::div_scalar(self, other)
1154 }
1155}
1156
1157// Scalar / tensor (float only)
1158macro_rules! impl_tensor_scalar_div {
1159 ($($t:ty),*) => {
1160 $(
1161 impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
1162 {
1163 type Output = Tensor<B, D>;
1164
1165 fn div(self, tensor: Tensor<B, D>) -> Self::Output {
1166 tensor.recip().mul_scalar(self)
1167 }
1168 }
1169 )*
1170 }
1171}
1172
1173impl_tensor_scalar_div!(f32, f64);
1174
1175// Tensor % tensor.
1176impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
1177where
1178 K::Elem: Element,
1179{
1180 type Output = Self;
1181
1182 fn rem(self, rhs: Self) -> Self::Output {
1183 Tensor::remainder(self, rhs)
1184 }
1185}
1186
1187// Tensor % scalar.
1188impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
1189 for Tensor<B, D, K>
1190where
1191 K::Elem: Element,
1192{
1193 type Output = Self;
1194
1195 fn rem(self, other: E) -> Self::Output {
1196 Tensor::remainder_scalar(self, other)
1197 }
1198}
1199
1200// Tensor * tensor.
1201impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
1202where
1203 K::Elem: Element,
1204{
1205 type Output = Self;
1206
1207 fn mul(self, rhs: Self) -> Self::Output {
1208 Tensor::mul(self, rhs)
1209 }
1210}
1211
1212// Tensor * scalar.
1213impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
1214 for Tensor<B, D, K>
1215where
1216 K::Elem: Element,
1217{
1218 type Output = Self;
1219
1220 fn mul(self, other: E) -> Self::Output {
1221 Tensor::mul_scalar(self, other)
1222 }
1223}
1224
1225macro_rules! impl_tensor_scalar_mul {
1226 ($($t:ty),*) => {
1227 $(
1228 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
1229 where
1230 K::Elem: Element,
1231 {
1232 type Output = Tensor<B, D, K>;
1233
1234 fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
1235 Tensor::mul_scalar(other, self)
1236 }
1237 }
1238 )*
1239 }
1240}
1241
1242impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
1243
1244impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
1245where
1246 B: Backend,
1247 K: Numeric<B>,
1248 K::Elem: Element,
1249{
1250 type Output = Self;
1251
1252 fn neg(self) -> Self::Output {
1253 Tensor::neg(self)
1254 }
1255}