burn_tensor/tensor/api/numeric.rs
1use alloc::vec::Vec;
2
3use crate::alloc::borrow::ToOwned;
4
5use crate::indexing::canonicalize_dim;
6use crate::{
7 AsIndex, BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
8 TensorKind,
9 backend::Backend,
10 check,
11 check::TensorCheck,
12 ops::{Device, IntTensor},
13};
14use crate::{DType, TensorPrimitive};
15
16macro_rules! q_bin_ops {
17 ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {
18 match ($lhs, $rhs) {
19 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
20 TensorPrimitive::Float(B::$op(lhs, rhs))
21 }
22 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),
23 (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {
24 TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))
25 }
26 (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
27 TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))
28 }
29 }
30 };
31}
32
33impl<B, const D: usize, K> Tensor<B, D, K>
34where
35 B: Backend,
36 K: Numeric<B>,
37 K::Elem: Element,
38{
39 /// Applies element wise addition operation.
40 ///
41 /// `y = x2 + x1`
42 ///
43 /// # Arguments
44 ///
45 /// * `other` - The tensor to add.
46 ///
47 /// # Example
48 ///
49 /// ```rust
50 /// use burn_tensor::backend::Backend;
51 /// use burn_tensor::{Tensor, Shape};
52 ///
53 /// fn example<B: Backend>() {
54 /// let device = B::Device::default();
55 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
56 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
57 /// let tensor = tensor1 + tensor2;
58 /// println!("{tensor}");
59 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
60 /// }
61 /// ```
62 #[allow(clippy::should_implement_trait)]
63 pub fn add(self, other: Self) -> Self {
64 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
65 Self::new(K::add(self.primitive, other.primitive))
66 }
67
68 /// Applies element wise addition operation with a scalar.
69 ///
70 /// `y = x + s`
71 ///
72 /// # Arguments
73 ///
74 /// * `other` - The scalar to add, element wise.
75 ///
76 /// # Example
77 ///
78 /// ```rust
79 /// use burn_tensor::backend::Backend;
80 /// use burn_tensor::{Tensor, Shape};
81 ///
82 /// fn example<B: Backend>() {
83 /// let device = B::Device::default();
84 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
85 /// let scalar = 2.0;
86 /// let tensor = tensor + scalar;
87 /// println!("{tensor}");
88 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
89 /// }
90 /// ```
91 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
92 Self::new(K::add_scalar::<E>(self.primitive, other))
93 }
94
95 /// Applies element wise subtraction operation.
96 ///
97 /// `y = x2 - x1`
98 ///
99 /// # Arguments
100 ///
101 /// * `other` - The tensor to subtract.
102 ///
103 /// # Example
104 ///
105 /// ```rust
106 /// use burn_tensor::backend::Backend;
107 /// use burn_tensor::{Tensor, Shape};
108 ///
109 /// fn example<B: Backend>() {
110 /// let device = B::Device::default();
111 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
112 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
113 /// let tensor = tensor1 - tensor2;
114 /// println!("{tensor}");
115 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
116 /// }
117 /// ```
118 #[allow(clippy::should_implement_trait)]
119 pub fn sub(self, other: Self) -> Self {
120 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
121 Self::new(K::sub(self.primitive, other.primitive))
122 }
123
124 /// Applies element wise subtraction operation with a scalar.
125 ///
126 /// `y = x - s`
127 ///
128 /// # Arguments
129 ///
130 /// * `other` - The scalar to subtract, element wise.
131 ///
132 /// # Example
133 ///
134 /// ```rust
135 /// use burn_tensor::backend::Backend;
136 /// use burn_tensor::{Tensor, Shape};
137 ///
138 /// fn example<B: Backend>() {
139 /// let device = B::Device::default();
140 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
141 /// let scalar = 2.0;
142 /// let tensor = tensor - scalar;
143 /// println!("{tensor}");
144 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
145 /// }
146 /// ```
147 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
148 Self::new(K::sub_scalar::<E>(self.primitive, other))
149 }
150
151 /// Applies element wise division operation.
152 ///
153 /// `y = x2 / x1`
154 ///
155 /// # Arguments
156 ///
157 /// * `other` - The tensor to divide.
158 ///
159 /// # Example
160 ///
161 /// ```rust
162 /// use burn_tensor::backend::Backend;
163 /// use burn_tensor::{Tensor, Shape};
164 ///
165 /// fn example<B: Backend>() {
166 /// let device = B::Device::default();
167 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
168 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
169 /// let tensor = tensor1 / tensor2;
170 /// println!("{tensor}");
171 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
172 /// }
173 /// ```
174 #[allow(clippy::should_implement_trait)]
175 pub fn div(self, other: Self) -> Self {
176 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
177 Self::new(K::div(self.primitive, other.primitive))
178 }
179
180 /// Applies element wise division operation with a scalar.
181 ///
182 /// `y = x / s`
183 ///
184 /// # Arguments
185 ///
186 /// * `other` - The scalar to divide, element wise.
187 ///
188 /// # Example
189 ///
190 /// ```rust
191 /// use burn_tensor::backend::Backend;
192 /// use burn_tensor::{Tensor, Shape};
193 ///
194 /// fn example<B: Backend>() {
195 /// let device = B::Device::default();
196 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
197 /// let scalar = 2.0;
198 /// let tensor = tensor / scalar;
199 /// println!("{tensor}");
200 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
201 /// }
202 /// ```
203 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
204 Self::new(K::div_scalar::<E>(self.primitive, other))
205 }
206
207 /// Applies element wise the remainder operation with a scalar.
208 ///
209 /// `y = x2 % x1`
210 pub fn remainder(self, other: Self) -> Self {
211 Self::new(K::remainder(self.primitive, other.primitive))
212 }
213
214 /// Applies element wise the remainder operation with a scalar.
215 ///
216 /// `y = x % s`
217 ///
218 /// # Arguments
219 ///
220 /// * `other` - The scalar to divide, element wise.
221 ///
222 /// # Example
223 ///
224 /// ```rust
225 /// use burn_tensor::backend::Backend;
226 /// use burn_tensor::{Tensor, Shape};
227 ///
228 /// fn example<B: Backend>() {
229 /// let device = B::Device::default();
230 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
231 /// let scalar = 2.0;
232 /// let tensor = tensor1 % scalar;
233 /// println!("{tensor}");
234 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
235 /// }
236 /// ```
237 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
238 Self::new(K::remainder_scalar::<E>(self.primitive, other))
239 }
240
241 /// Applies element wise multiplication operation.
242 ///
243 /// `y = x2 * x1`
244 ///
245 /// # Arguments
246 ///
247 /// * `other` - The tensor to multiply.
248 ///
249 /// # Example
250 ///
251 /// ```rust
252 /// use burn_tensor::backend::Backend;
253 /// use burn_tensor::{Tensor, Shape};
254 ///
255 /// fn example<B: Backend>() {
256 /// let device = B::Device::default();
257 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
258 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
259 /// let tensor = tensor1 * tensor2;
260 /// println!("{tensor}");
261 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
262 /// }
263 /// ```
264 #[allow(clippy::should_implement_trait)]
265 pub fn mul(self, other: Self) -> Self {
266 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
267 Self::new(K::mul(self.primitive, other.primitive))
268 }
269
270 /// Applies element wise multiplication operation with a scalar.
271 ///
272 /// `y = x * s`
273 ///
274 /// # Arguments
275 ///
276 /// * `other` - The scalar to multiply, element wise.
277 ///
278 /// # Example
279 ///
280 /// ```rust
281 /// use burn_tensor::backend::Backend;
282 /// use burn_tensor::{Tensor, Shape};
283 ///
284 /// fn example<B: Backend>() {
285 /// let device = B::Device::default();
286 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
287 /// let scalar = 2.0;
288 /// let tensor = tensor * scalar;
289 /// println!("{tensor}");
290 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
291 /// }
292 /// ```
293 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
294 Self::new(K::mul_scalar::<E>(self.primitive, other))
295 }
296
297 /// Switch sign of each element in the tensor.
298 ///
299 /// `y = -x`
300 ///
301 /// # Example
302 ///
303 /// ```rust
304 /// use burn_tensor::backend::Backend;
305 /// use burn_tensor::{Tensor, Shape};
306 ///
307 /// fn example<B: Backend>() {
308 /// let device = B::Device::default();
309 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
310 /// let tensor = -tensor;
311 /// println!("{tensor}");
312 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
313 /// }
314 /// ```
315 #[allow(clippy::should_implement_trait)]
316 pub fn neg(self) -> Self {
317 Self::new(K::neg(self.primitive))
318 }
319
320 /// Returns the signs of the elements of the input tensor.
321 ///
322 /// # Example
323 ///
324 /// ```rust
325 /// use burn_tensor::backend::Backend;
326 /// use burn_tensor::{Tensor, Shape};
327 ///
328 /// fn example<B: Backend>() {
329 /// let device = B::Device::default();
330 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
331 /// let tensor = tensor.sign();
332 /// println!("{tensor}");
333 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
334 /// }
335 /// ```
336 pub fn sign(self) -> Self {
337 Self::new(K::sign(self.primitive))
338 }
339
340 /// Create a tensor of the given shape where each element is zero.
341 ///
342 /// # Example
343 ///
344 /// ```rust
345 /// use burn_tensor::backend::Backend;
346 /// use burn_tensor::{Tensor, Shape};
347 ///
348 /// fn example<B: Backend>() {
349 /// let device = B::Device::default();
350 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
351 /// println!("{tensor}");
352 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
353 /// }
354 /// ```
355 pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
356 let shape = shape.into();
357 check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
358 Self::new(K::zeros(shape, device, K::Elem::dtype()))
359 }
360
361 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.
362 ///
363 /// # Example
364 ///
365 /// ```rust
366 /// use burn_tensor::backend::Backend;
367 /// use burn_tensor::{Tensor, Shape};
368 ///
369 /// fn example<B: Backend>() {
370 /// let device = B::Device::default();
371 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
372 /// let tensor = tensor.zeros_like();
373 /// println!("{tensor}");
374 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
375 /// }
376 /// ```
377 pub fn zeros_like(&self) -> Self {
378 Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))
379 }
380
381 /// Create a tensor of the given shape where each element is one.
382 ///
383 /// # Example
384 ///
385 /// ```rust
386 /// use burn_tensor::backend::Backend;
387 /// use burn_tensor::{Tensor, Shape};
388 ///
389 /// fn example<B: Backend>() {
390 /// let device = B::Device::default();
391 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
392 /// println!("{tensor}");
393 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
394 /// }
395 /// ```
396 pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
397 let shape = shape.into();
398 check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
399 Self::new(K::ones(shape, device, K::Elem::dtype()))
400 }
401
402 /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.
403 ///
404 /// # Example
405 ///
406 /// ```rust
407 /// use burn_tensor::backend::Backend;
408 /// use burn_tensor::{Tensor, Shape};
409 ///
410 /// fn example<B: Backend>() {
411 /// let device = B::Device::default();
412 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
413 /// let tensor = tensor.ones_like();
414 /// println!("{tensor}");
415 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
416 /// }
417 /// ```
418 pub fn ones_like(&self) -> Self {
419 Self::new(K::ones(self.shape(), &self.device(), self.dtype()))
420 }
421
422 /// Aggregate all elements in the tensor with the mean operation.
423 ///
424 /// # Example
425 ///
426 /// ```rust
427 /// use burn_tensor::backend::Backend;
428 /// use burn_tensor::{Tensor, Shape};
429 ///
430 /// fn example<B: Backend>() {
431 /// let device = B::Device::default();
432 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
433 /// let tensor = tensor.mean();
434 /// println!("{tensor}");
435 /// // [3.6666667]
436 /// }
437 /// ```
438 pub fn mean(self) -> Tensor<B, 1, K> {
439 Tensor::new(K::mean(self.primitive))
440 }
441
442 /// Aggregate all elements in the tensor with the sum operation.
443 ///
444 /// # Example
445 ///
446 /// ```rust
447 /// use burn_tensor::backend::Backend;
448 /// use burn_tensor::{Tensor, Shape};
449 ///
450 /// fn example<B: Backend>() {
451 /// let device = B::Device::default();
452 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
453 /// let tensor = tensor.sum();
454 /// println!("{tensor}");
455 /// // [22.0]
456 /// }
457 /// ```
458 pub fn sum(self) -> Tensor<B, 1, K> {
459 Tensor::new(K::sum(self.primitive))
460 }
461
462 /// Aggregate all elements along the given *dimension* or *axis*
463 /// in the tensor with the mean operation.
464 ///
465 /// # Arguments
466 ///
467 /// * `dim` - The dimension or axis along which to aggregate the elements;
468 /// supports negative indexing.
469 ///
470 /// # Example
471 ///
472 /// ```rust
473 /// use burn_tensor::backend::Backend;
474 /// use burn_tensor::{Tensor, Shape};
475 ///
476 /// fn example<B: Backend>() {
477 /// let device = B::Device::default();
478 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
479 /// let tensor = tensor.clone().mean_dim(0);
480 /// println!("{tensor}");
481 /// // [[3.0, 3.5, 4.5]]
482 /// let tensor = tensor.clone().mean_dim(1);
483 /// println!("{tensor}");
484 /// // [[0.6666667], [6.6666665]]
485 /// }
486 /// ```
487 pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {
488 let dim = canonicalize_dim(dim, D, false);
489 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
490 Self::new(K::mean_dim(self.primitive, dim))
491 }
492
493 /// Aggregate all elements along the given *axes*
494 /// in the tensor with the mean operation.
495 ///
496 /// # Arguments
497 ///
498 /// * `dims` - the dimensions to aggregate; supports negative indexing.
499 ///
500 /// # Returns
501 ///
502 /// The returned tensor will have the same rank,
503 /// but the aggregated dimensions will have size 1.
504 ///
505 /// # Example
506 ///
507 /// ```rust
508 /// use burn_tensor::backend::Backend;
509 /// use burn_tensor::{Tensor, Shape};
510 ///
511 /// fn example<B: Backend>() {
512 /// let device = B::Device::default();
513 /// let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);
514 /// let tensor = tensor.clone().mean_dims(&[0, 1]);
515 /// println!("{tensor}");
516 /// // [[2.0]]
517 /// }
518 /// ```
519 pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {
520 dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))
521 }
522
523 /// Aggregate all elements along the given *dimension* or *axis*
524 /// in the tensor with the sum operation.
525 ///
526 /// # Arguments
527 ///
528 /// * `dim` - The dimension or axis along which to aggregate the elements;
529 /// supports negative indexing.
530 ///
531 /// # Example
532 ///
533 /// ```rust
534 /// use burn_tensor::backend::Backend;
535 /// use burn_tensor::{Tensor, Shape};
536 ///
537 /// fn example<B: Backend>() {
538 /// let device = B::Device::default();
539 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
540 /// let tensor = tensor.clone().sum_dim(0);
541 /// println!("{tensor}");
542 /// // [[6.0, 7.0, 9.0]]
543 /// let tensor = tensor.clone().sum_dim(1);
544 /// println!("{tensor}");
545 /// // [[2.0], [20.0]]
546 /// }
547 /// ```
548 pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {
549 let dim = canonicalize_dim(dim, D, false);
550 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
551 Self::new(K::sum_dim(self.primitive, dim))
552 }
553
554 /// Aggregate all elements along the given *axes*
555 /// in the tensor with the sum operation.
556 ///
557 /// # Arguments
558 ///
559 /// * `dims` - the dimensions to aggregate; supports negative indexing.
560 ///
561 /// # Returns
562 ///
563 /// The returned tensor will have the same rank,
564 /// but the aggregated dimensions will have size 1.
565 ///
566 /// # Example
567 ///
568 /// ```rust
569 /// use burn_tensor::backend::Backend;
570 /// use burn_tensor::{Tensor, Shape};
571 ///
572 /// fn example<B: Backend>() {
573 /// let device = B::Device::default();
574 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
575 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
576 /// println!("{tensor}");
577 /// // [[27]]
578 /// }
579 /// ```
580 pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {
581 dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))
582 }
583
584 /// Aggregate and squeeze along the given dimensions.
585 ///
586 /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``
587 ///
588 /// # Arguments
589 ///
590 /// * `dims` - the dimensions to aggregate; supports negative indexing.
591 ///
592 /// # Returns
593 ///
594 /// The returned tensor will have the same rank,
595 /// but the aggregated dimensions will have size 1.
596 ///
597 /// # Example
598 ///
599 /// ```rust
600 /// use burn_tensor::backend::Backend;
601 /// use burn_tensor::{Tensor, Shape};
602 ///
603 /// fn example<B: Backend>() {
604 /// let device = B::Device::default();
605 /// let tensor = Tensor::<B, 3>::from_data([
606 /// [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],
607 /// [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],
608 /// ], &device);
609 /// let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);
610 /// println!("{tensor}");
611 /// // [20.0, 16.0, 21.0]
612 /// }
613 /// ```
614 pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {
615 // TODO: remove idims when squeeze_dims uses AsIndex.
616 let idims = dims
617 .iter()
618 .map(|&dim| canonicalize_dim(dim, D, false) as isize)
619 .collect::<Vec<_>>();
620 self.sum_dims(dims).squeeze_dims::<D2>(&idims)
621 }
622
623 /// Aggregate all elements in the tensor with the product operation.
624 ///
625 /// # Example
626 ///
627 /// ```rust
628 /// use burn_tensor::backend::Backend;
629 /// use burn_tensor::{Tensor, Shape};
630 ///
631 /// fn example<B: Backend>() {
632 /// let device = B::Device::default();
633 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
634 /// let tensor = tensor.prod();
635 /// println!("{tensor}");
636 /// // [-1620.0]
637 /// }
638 /// ```
639 pub fn prod(self) -> Tensor<B, 1, K> {
640 Tensor::new(K::prod(self.primitive))
641 }
642
643 /// Aggregate all elements along the given *dimension* or *axis*
644 /// in the tensor with the product operation.
645 ///
646 /// # Arguments
647 ///
648 /// * `dim` - The dimension or axis along which to aggregate the elements,
649 /// supports negative indexing.
650 ///
651 /// # Returns
652 ///
653 /// The returned tensor will have the same rank,
654 /// but the aggregated dimension will have size 1.
655 ///
656 /// # Example
657 ///
658 /// ```rust
659 /// use burn_tensor::backend::Backend;
660 /// use burn_tensor::{Tensor, Shape};
661 ///
662 /// fn example<B: Backend>() {
663 /// let device = B::Device::default();
664 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
665 /// let tensor = tensor.clone().prod_dim(0);
666 /// println!("{tensor}");
667 /// // [[5.0, -18.0, 18.0]]
668 /// let tensor = tensor.clone().prod_dim(1);
669 /// println!("{tensor}");
670 /// // [[-6.0], [270.0]]
671 /// }
672 /// ```
673 pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {
674 let dim = canonicalize_dim(dim, D, false);
675 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
676 Self::new(K::prod_dim(self.primitive, dim))
677 }
678
679 /// Aggregate all elements along the given *axes*
680 /// in the tensor with the prod operation.
681 ///
682 /// # Arguments
683 ///
684 /// * `dims` - the dimensions to aggregate, supports negative indexing.
685 ///
686 /// # Returns
687 ///
688 /// The returned tensor will have the same rank,
689 /// but the aggregated dimensions will have size 1.
690 ///
691 /// # Example
692 ///
693 /// ```rust
694 /// use burn_tensor::backend::Backend;
695 /// use burn_tensor::{Tensor, Shape};
696 ///
697 /// fn example<B: Backend>() {
698 /// let device = B::Device::default();
699 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
700 /// let tensor = tensor.clone().sum_dims(&[0, 1]);
701 /// println!("{tensor}");
702 /// // [[-1620.0]]
703 /// }
704 /// ```
705 pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {
706 dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))
707 }
708
709 /// Computes the cumulative sum of elements along the given *dimension* or *axis*.
710 ///
711 /// # Arguments
712 ///
713 /// * `dim` - The dimension or axis along which to compute the cumulative sum.
714 ///
715 /// # Example
716 ///
717 /// ```rust
718 /// use burn_tensor::backend::Backend;
719 /// use burn_tensor::{Tensor, Shape};
720 ///
721 /// fn example<B: Backend>() {
722 /// let device = B::Device::default();
723 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
724 /// let result = tensor.clone().cumsum(0);
725 /// println!("{result}");
726 /// // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]
727 /// let result = tensor.cumsum(1);
728 /// println!("{result}");
729 /// // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]
730 /// }
731 /// ```
732 pub fn cumsum(self, dim: usize) -> Self {
733 check!(TensorCheck::aggregate_dim::<D>("CumSum", dim));
734 Self::new(K::cumsum(self.primitive, dim))
735 }
736
737 /// Computes the cumulative product of elements along the given *dimension* or *axis*.
738 ///
739 /// # Arguments
740 ///
741 /// * `dim` - The dimension or axis along which to compute the cumulative product.
742 ///
743 /// # Example
744 ///
745 /// ```rust
746 /// use burn_tensor::backend::Backend;
747 /// use burn_tensor::{Tensor, Shape};
748 ///
749 /// fn example<B: Backend>() {
750 /// let device = B::Device::default();
751 /// let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
752 /// let result = tensor.clone().cumprod(0);
753 /// println!("{result}");
754 /// // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]
755 /// let result = tensor.cumprod(1);
756 /// println!("{result}");
757 /// // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]
758 /// }
759 /// ```
760 pub fn cumprod(self, dim: usize) -> Self {
761 check!(TensorCheck::aggregate_dim::<D>("CumProd", dim));
762 Self::new(K::cumprod(self.primitive, dim))
763 }
764
765 /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.
766 ///
767 /// # Arguments
768 ///
769 /// * `dim` - The dimension or axis along which to compute the cumulative minimum.
770 ///
771 /// # Example
772 ///
773 /// ```rust
774 /// use burn_tensor::backend::Backend;
775 /// use burn_tensor::{Tensor, Shape};
776 ///
777 /// fn example<B: Backend>() {
778 /// let device = B::Device::default();
779 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);
780 /// let result = tensor.clone().cummin(0);
781 /// println!("{result}");
782 /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]
783 /// let result = tensor.cummin(1);
784 /// println!("{result}");
785 /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]
786 /// }
787 /// ```
788 pub fn cummin(self, dim: usize) -> Self {
789 check!(TensorCheck::aggregate_dim::<D>("CumMin", dim));
790 Self::new(K::cummin(self.primitive, dim))
791 }
792
793 /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.
794 ///
795 /// # Arguments
796 ///
797 /// * `dim` - The dimension or axis along which to compute the cumulative maximum.
798 ///
799 /// # Example
800 ///
801 /// ```rust
802 /// use burn_tensor::backend::Backend;
803 /// use burn_tensor::{Tensor, Shape};
804 ///
805 /// fn example<B: Backend>() {
806 /// let device = B::Device::default();
807 /// let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);
808 /// let result = tensor.clone().cummax(0);
809 /// println!("{result}");
810 /// // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]
811 /// let result = tensor.cummax(1);
812 /// println!("{result}");
813 /// // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]
814 /// }
815 /// ```
816 pub fn cummax(self, dim: usize) -> Self {
817 check!(TensorCheck::aggregate_dim::<D>("CumMax", dim));
818 Self::new(K::cummax(self.primitive, dim))
819 }
820
821 ///
822 /// # Arguments
823 ///
824 /// * `other` - The element to compare.
825 ///
826 /// # Example
827 ///
828 /// ```rust
829 /// use burn_tensor::backend::Backend;
830 /// use burn_tensor::{Tensor, Shape};
831 ///
832 /// fn example<B: Backend>() {
833 /// let device = B::Device::default();
834 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
835 /// let tensor = tensor.equal_elem(3.0);
836 /// println!("{tensor}");
837 /// // [[false, false, true], [false, false, false]]
838 /// }
839 /// ```
840 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
841 Tensor::new(K::equal_elem(self.primitive, other.elem()))
842 }
843
844 /// Applies element wise non-equality comparison and returns a boolean tensor.
845 ///
846 /// # Arguments
847 ///
848 /// * `other` - The element to compare.
849 ///
850 /// # Example
851 ///
852 /// ```rust
853 /// use burn_tensor::backend::Backend;
854 /// use burn_tensor::{Tensor, Shape};
855 ///
856 /// fn example<B: Backend>() {
857 /// let device = B::Device::default();
858 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
859 /// let tensor = tensor.not_equal_elem(3.0);
860 /// println!("{tensor}");
861 /// // [[true, true, false], [true, true, true]]
862 /// }
863 /// ```
864 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
865 Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
866 }
867
868 /// Applies element wise greater comparison and returns a boolean tensor.
869 ///
870 /// # Panics
871 ///
872 /// If the two tensors don't have the same shape.
873 ///
874 /// # Example
875 ///
876 /// ```rust
877 /// use burn_tensor::backend::Backend;
878 /// use burn_tensor::{Tensor, Shape};
879 ///
880 /// fn example<B: Backend>() {
881 /// let device = B::Device::default();
882 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
883 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
884 /// let tensor = tensor1.greater(tensor2);
885 /// println!("{tensor}");
886 /// // [[false, false, false], [true, true, true]]
887 /// }
888 /// ```
889 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
890 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
891 Tensor::new(K::greater(self.primitive, other.primitive))
892 }
893
894 /// Applies element wise greater-equal comparison and returns a boolean tensor.
895 ///
896 /// # Panics
897 ///
898 /// If the two tensors don't have the same shape.
899 ///
900 /// # Example
901 ///
902 /// ```rust
903 /// use burn_tensor::backend::Backend;
904 /// use burn_tensor::{Tensor, Shape};
905 ///
906 /// fn example<B: Backend>() {
907 /// let device = B::Device::default();
908 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
909 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
910 /// let tensor = tensor1.greater_equal(tensor2);
911 /// println!("{tensor}");
912 /// // [[true, false, false], [true, true, true]]
913 /// }
914 /// ```
915 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
916 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
917 Tensor::new(K::greater_equal(self.primitive, other.primitive))
918 }
919
920 /// Applies element wise lower comparison and returns a boolean tensor.
921 ///
922 /// # Panics
923 ///
924 /// If the two tensors don't have the same shape.
925 ///
926 /// # Example
927 ///
928 /// ```rust
929 /// use burn_tensor::backend::Backend;
930 /// use burn_tensor::{Tensor, Shape};
931 ///
932 /// fn example<B: Backend>() {
933 /// let device = B::Device::default();
934 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
935 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
936 /// let tensor = tensor1.lower(tensor2);
937 /// println!("{tensor}");
938 /// // [[false, true, true], [false, false, false]]
939 /// }
940 /// ```
941 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
942 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
943 Tensor::new(K::lower(self.primitive, other.primitive))
944 }
945
946 /// Applies element wise lower-equal comparison and returns a boolean tensor.
947 ///
948 /// # Panics
949 ///
950 /// If the two tensors don't have the same shape.
951 ///
952 /// # Example
953 ///
954 /// ```rust
955 /// use burn_tensor::backend::Backend;
956 /// use burn_tensor::{Tensor, Shape};
957 ///
958 /// fn example<B: Backend>() {
959 /// let device = B::Device::default();
960 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
961 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
962 /// let tensor = tensor1.lower_equal(tensor2);
963 /// println!("{tensor}");
964 /// // [[true, true, true], [false, false, false]]
965 /// }
966 /// ```
967 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
968 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
969 Tensor::new(K::lower_equal(self.primitive, other.primitive))
970 }
971
972 /// Applies greater than `other` comparison and returns a boolean tensor.
973 ///
974 /// # Arguments
975 ///
976 /// * `other` - The element to compare.
977 ///
978 /// # Example
979 ///
980 /// ```rust
981 /// use burn_tensor::backend::Backend;
982 /// use burn_tensor::{Tensor, Shape};
983 ///
984 /// fn example<B: Backend>() {
985 /// let device = B::Device::default();
986 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
987 /// let tensor = tensor.greater_elem(3.0);
988 /// println!("{tensor}");
989 /// // [[false, false, true], [true, true, true]]
990 /// }
991 /// ```
992 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
993 Tensor::new(K::greater_elem(self.primitive, other.elem()))
994 }
995
996 /// Applies greater-equal than `other` comparison and returns a boolean tensor.
997 ///
998 /// # Arguments
999 ///
1000 /// * `other` - The element to compare.
1001 ///
1002 /// # Example
1003 ///
1004 /// ```rust
1005 /// use burn_tensor::backend::Backend;
1006 /// use burn_tensor::{Tensor, Shape};
1007 ///
1008 /// fn example<B: Backend>() {
1009 /// let device = B::Device::default();
1010 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1011 /// let tensor = tensor.greater_equal_elem(3.0);
1012 /// println!("{tensor}");
1013 /// // [[false, false, true], [true, true, true]]
1014 /// }
1015 /// ```
1016 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1017 Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
1018 }
1019
1020 /// Applies lower than `other` comparison and returns a boolean tensor.
1021 ///
1022 /// # Arguments
1023 ///
1024 /// * `other` - The element to compare.
1025 ///
1026 /// # Example
1027 ///
1028 /// ```rust
1029 /// use burn_tensor::backend::Backend;
1030 /// use burn_tensor::{Tensor, Shape};
1031 ///
1032 /// fn example<B: Backend>() {
1033 /// let device = B::Device::default();
1034 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1035 /// let tensor = tensor.lower_elem(3.0);
1036 /// println!("{tensor}");
1037 /// // [[true, true, false], [false, false, false]]
1038 /// }
1039 /// ```
1040 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1041 Tensor::new(K::lower_elem(self.primitive, other.elem()))
1042 }
1043
1044 /// Applies lower-equal than `other` comparison and returns a boolean tensor.
1045 ///
1046 /// # Arguments
1047 ///
1048 /// * `other` - The element to compare.
1049 ///
1050 /// # Example
1051 ///
1052 /// ```rust
1053 /// use burn_tensor::backend::Backend;
1054 /// use burn_tensor::{Tensor, Shape};
1055 ///
1056 /// fn example<B: Backend>() {
1057 /// let device = B::Device::default();
1058 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1059 /// let tensor = tensor.lower_equal_elem(3.0);
1060 /// println!("{tensor}");
1061 /// // [[true, true, true], [false, false, false]]
1062 /// }
1063 /// ```
1064 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
1065 Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
1066 }
1067
1068 /// Update the given tensor with the value tensor where the mask is true.
1069 ///
1070 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
1071 /// a scalar.
1072 ///
1073 /// # Example
1074 ///
1075 /// ```rust
1076 /// use burn_tensor::backend::Backend;
1077 /// use burn_tensor::{Tensor, Shape, Bool};
1078 ///
1079 /// fn example<B: Backend>() {
1080 /// let device = B::Device::default();
1081 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1082 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1083 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1084 /// let tensor = tensor.mask_where(mask, value);
1085 /// println!("{tensor}");
1086 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
1087 /// }
1088 /// ```
1089 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
1090 Self::new(K::mask_where(
1091 self.primitive,
1092 mask.primitive,
1093 value.primitive,
1094 ))
1095 }
1096
1097 /// Update the given tensor with the value where the mask is true.
1098 ///
1099 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
1100 /// a tensor.
1101 ///
1102 /// # Example
1103 ///
1104 /// ```rust
1105 /// use burn_tensor::backend::Backend;
1106 /// use burn_tensor::{Tensor, Shape, Bool};
1107 ///
1108 /// fn example<B: Backend>() {
1109 /// let device = B::Device::default();
1110 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1111 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
1112 /// let tensor = tensor.mask_fill(mask, 3.0);
1113 /// println!("{tensor}");
1114 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
1115 /// }
1116 /// ```
1117 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
1118 Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
1119 }
1120
1121 /// Gather tensor elements corresponding to the given indices from the specified dim.
1122 ///
1123 /// Example using a 3D tensor:
1124 ///
1125 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
1126 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
1127 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
1128 ///
1129 /// # Notes
1130 ///
1131 /// The index tensor should have the same shape as the original tensor except for the dim
1132 /// specified.
1133 ///
1134 /// # Warning
1135 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1136 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1137 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
1138 check!(TensorCheck::gather::<D>(
1139 dim,
1140 &self.shape(),
1141 &indices.shape()
1142 ));
1143
1144 Self::new(K::gather(dim, self.primitive, indices.primitive))
1145 }
1146
1147 /// Assign the gathered elements corresponding to the given indices along the specified dimension
1148 /// from the value tensor to the original tensor using sum reduction.
1149 ///
1150 /// Example using a 3D tensor:
1151 ///
1152 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
1153 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
1154 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
1155 ///
1156 /// # Notes
1157 ///
1158 /// The index tensor should have the same shape as the original tensor except for the specified
1159 /// dimension. The value and index tensors should have the same shape.
1160 ///
1161 /// Other references to the input tensor will not be modified by this operation.
1162 ///
1163 /// # Warning
1164 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1165 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1166 pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
1167 check!(TensorCheck::scatter::<D>(
1168 dim,
1169 &self.shape(),
1170 &indices.shape(),
1171 &values.shape()
1172 ));
1173
1174 Self::new(K::scatter(
1175 dim,
1176 self.primitive,
1177 indices.primitive,
1178 values.primitive,
1179 ))
1180 }
1181
1182 /// Applies the argmax function along the given dimension and returns an integer tensor.
1183 ///
1184 /// # Example
1185 ///
1186 /// ```rust
1187 /// use burn_tensor::backend::Backend;
1188 /// use burn_tensor::{Tensor, Shape};
1189 ///
1190 /// fn example<B: Backend>() {
1191 /// let device = B::Device::default();
1192 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1193 /// let tensor = tensor.argmax(1);
1194 /// println!("{:?}", tensor.shape());
1195 /// // Shape { dims: [2, 1, 3] }
1196 /// }
1197 /// ```
1198 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1199 Tensor::new(K::argmax(self.primitive, dim))
1200 }
1201
1202 /// Find the maximum value.
1203 ///
1204 /// # Example
1205 ///
1206 /// ```rust
1207 /// use burn_tensor::backend::Backend;
1208 /// use burn_tensor::{Tensor, Shape};
1209 ///
1210 /// fn example<B: Backend>() {
1211 /// let device = B::Device::default();
1212 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1213 /// let tensor = tensor.max();
1214 /// println!("{tensor}");
1215 /// // [9.0]
1216 /// }
1217 /// ```
1218 pub fn max(self) -> Tensor<B, 1, K> {
1219 Tensor::new(K::max(self.primitive))
1220 }
1221
1222 /// Find the maximum value along the given dimension.
1223 ///
1224 /// # Arguments
1225 ///
1226 /// * `dim` - The dimension or axis along which to aggregate the elements;
1227 /// supports negative indexing.
1228 ///
1229 /// # Returns
1230 ///
1231 /// The returned tensor will have the same rank,
1232 /// but the aggregated dimension will have size 1.
1233 ///
1234 /// # Example
1235 ///
1236 /// ```rust
1237 /// use burn_tensor::backend::Backend;
1238 /// use burn_tensor::{Tensor, Shape};
1239 ///
1240 /// fn example<B: Backend>() {
1241 /// let device = B::Device::default();
1242 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1243 /// let tensor = tensor.max_dim(0);
1244 /// println!("{tensor}");
1245 /// // [[5.0, 9.0, 6.0]]
1246 /// }
1247 /// ```
1248 pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {
1249 let dim = canonicalize_dim(dim, D, false);
1250 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1251 Tensor::new(K::max_dim(self.primitive, dim))
1252 }
1253
1254 /// Find the maximum value along the given dimensions.
1255 ///
1256 /// # Arguments
1257 ///
1258 /// * `dims` - The dimensions or axis along which to aggregate the elements;
1259 /// supports negative indexing.
1260 ///
1261 /// # Returns
1262 ///
1263 /// The returned tensor will have the same rank,
1264 /// but the aggregated dimensions will have size 1.
1265 ///
1266 /// # Example
1267 ///
1268 /// ```rust
1269 /// use burn_tensor::backend::Backend;
1270 /// use burn_tensor::{Tensor, Shape};
1271 ///
1272 /// fn example<B: Backend>() {
1273 /// let device = B::Device::default();
1274 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1275 /// let tensor = tensor.max_dims(&[0, 1]);
1276 /// println!("{tensor}");
1277 /// // [[9.0]]
1278 /// }
1279 /// ```
1280 pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1281 dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))
1282 }
1283
1284 /// Find the maximum value along the given dimension.
1285 ///
1286 /// Also returns the indices.
1287 ///
1288 /// # Example
1289 ///
1290 /// ```rust
1291 /// use burn_tensor::backend::Backend;
1292 /// use burn_tensor::{Tensor, Shape};
1293 ///
1294 /// fn example<B: Backend>() {
1295 /// let device = B::Device::default();
1296 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1297 /// let (tensor, index) = tensor.max_dim_with_indices(0);
1298 /// // [[5.0, 9.0, 6.0]]
1299 /// println!("{tensor}");
1300 /// // [[1, 1, 1]]
1301 /// println!("{index}");
1302 /// }
1303 /// ```
1304 pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1305 let dim = canonicalize_dim(dim, D, false);
1306 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1307
1308 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1309
1310 let tensor = Tensor::new(tensor);
1311 let index = Tensor::new(index);
1312
1313 (tensor, index)
1314 }
1315
1316 /// Finds the maximum pair wise values with another tensor.
1317 ///
1318 /// # Arguments
1319 ///
1320 /// * `other` - Other tensor to find maximum elements with
1321 ///
1322 /// # Returns
1323 ///
1324 /// A tensor with the same shape as the input tensors containing the maximum value found
1325 /// in the input tensors.
1326 ///
1327 /// # Example
1328 ///
1329 /// ```rust
1330 /// use burn_tensor::backend::Backend;
1331 /// use burn_tensor::{Tensor, Shape};
1332 ///
1333 /// fn example<B: Backend>() {
1334 /// let device = B::Device::default();
1335 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1336 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1337 /// let tensor = tensor1.max_pair(tensor2);
1338 /// println!("{tensor}");
1339 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1340 /// }
1341 /// ```
1342 pub fn max_pair(self, other: Self) -> Self {
1343 let mask = self.clone().lower(other.clone());
1344 self.mask_where(mask, other)
1345 }
1346
1347 /// Find the maximum absolute value.
1348 ///
1349 /// # Example
1350 ///
1351 /// ```rust
1352 /// use burn_tensor::backend::Backend;
1353 /// use burn_tensor::{Tensor, Shape};
1354 ///
1355 /// fn example<B: Backend>() {
1356 /// let device = B::Device::default();
1357 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1358 /// let tensor = tensor.max_abs();
1359 /// println!("{tensor}");
1360 /// // [7.0]
1361 /// }
1362 /// ```
1363 pub fn max_abs(self) -> Tensor<B, 1, K> {
1364 Tensor::new(K::max_abs(self.primitive))
1365 }
1366
1367 /// Find the maximum absolute value along the given dimension.
1368 ///
1369 /// # Arguments
1370 ///
1371 /// * `dim` - The dimension or axis along which to aggregate the elements,
1372 /// supports negative indexing.
1373 ///
1374 /// # Returns
1375 ///
1376 /// The returned tensor will have the same rank,
1377 /// but the aggregated dimension will have size 1.
1378 ///
1379 /// # Example
1380 ///
1381 /// ```rust
1382 /// use burn_tensor::backend::Backend;
1383 /// use burn_tensor::{Tensor, Shape};
1384 ///
1385 /// fn example<B: Backend>() {
1386 /// let device = B::Device::default();
1387 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1388 /// let tensor = tensor.max_dim(0);
1389 /// println!("{tensor}");
1390 /// // [[5.0, 9.0, 6.0]]
1391 /// }
1392 /// ```
1393 pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {
1394 let dim = canonicalize_dim(dim, D, false);
1395 check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1396
1397 Tensor::new(K::max_abs_dim(self.primitive, dim))
1398 }
1399
1400 /// Find the maximum absolute value along the given dimensions.
1401 ///
1402 /// # Arguments
1403 ///
1404 /// * `dims` - The dimensions or axes along which to aggregate the elements,
1405 /// supports negative indexing.
1406 ///
1407 /// # Returns
1408 ///
1409 /// The returned tensor will have the same rank,
1410 /// but the aggregated dimensions will have size 1.
1411 ///
1412 /// # Example
1413 ///
1414 /// ```rust
1415 /// use burn_tensor::backend::Backend;
1416 /// use burn_tensor::{Tensor, Shape};
1417 ///
1418 /// fn example<B: Backend>() {
1419 /// let device = B::Device::default();
1420 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1421 /// let tensor = tensor.max_abs_dims(&[0, 1]);
1422 /// println!("{tensor}");
1423 /// // [[9.0]]
1424 /// }
1425 /// ```
1426 pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1427 dims.iter()
1428 .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))
1429 }
1430
1431 /// Applies the argmin function along the given dimension and returns an integer tensor.
1432 ///
1433 /// # Example
1434 ///
1435 /// ```rust
1436 /// use burn_tensor::backend::Backend;
1437 /// use burn_tensor::{Tensor, Shape};
1438 ///
1439 /// fn example<B: Backend>() {
1440 /// let device = Default::default();
1441 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1442 /// let tensor = tensor.argmin(1);
1443 /// println!("{:?}", tensor.shape());
1444 /// // Shape { dims: [2, 1, 3] }
1445 /// }
1446 /// ```
1447 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1448 Tensor::new(K::argmin(self.primitive, dim))
1449 }
1450
1451 /// Find the minimum value.
1452 ///
1453 /// # Example
1454 ///
1455 /// ```rust
1456 /// use burn_tensor::backend::Backend;
1457 /// use burn_tensor::{Tensor, Shape};
1458 ///
1459 /// fn example<B: Backend>() {
1460 /// let device = B::Device::default();
1461 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1462 /// let tensor = tensor.min();
1463 /// println!("{tensor}");
1464 /// // [-2.0]
1465 /// }
1466 /// ```
1467 pub fn min(self) -> Tensor<B, 1, K> {
1468 Tensor::new(K::min(self.primitive))
1469 }
1470
1471 /// Find the minimum value along the given dimension.
1472 ///
1473 /// # Arguments
1474 ///
1475 /// * `dim` - The dimension or axis along which to aggregate the elements;
1476 /// supports negative indexing.
1477 ///
1478 /// # Returns
1479 ///
1480 /// The returned tensor will have the same rank,
1481 /// but the aggregated dimension will have size 1.
1482 ///
1483 /// # Example
1484 ///
1485 /// ```rust
1486 /// use burn_tensor::backend::Backend;
1487 /// use burn_tensor::{Tensor, Shape};
1488 ///
1489 /// fn example<B: Backend>() {
1490 /// let device = B::Device::default();
1491 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1492 /// let tensor = tensor.min_dim(0);
1493 /// println!("{tensor}");
1494 /// // [[1.0, -2.0, 3.0]]
1495 /// }
1496 /// ```
1497 pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {
1498 let dim = canonicalize_dim(dim, D, false);
1499 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1500 Tensor::new(K::min_dim(self.primitive, dim))
1501 }
1502
1503 /// Find the minimum value along the given dimensions.
1504 ///
1505 /// # Arguments
1506 ///
1507 /// * `dims` - The dimensions or axes along which to aggregate the elements;
1508 /// supports negative indexing.
1509 ///
1510 /// # Returns
1511 ///
1512 /// The returned tensor will have the same rank,
1513 /// but the aggregated dimensions will have size 1.
1514 ///
1515 /// # Example
1516 ///
1517 /// ```rust
1518 /// use burn_tensor::backend::Backend;
1519 /// use burn_tensor::{Tensor, Shape};
1520 ///
1521 /// fn example<B: Backend>() {
1522 /// let device = B::Device::default();
1523 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1524 /// let tensor = tensor.min_dims(&[0, 1]);
1525 /// println!("{tensor}");
1526 /// // [[-2.0]]
1527 /// }
1528 /// ```
1529 pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {
1530 dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))
1531 }
1532
1533 /// Find the minimum value along the given dimension.
1534 ///
1535 /// Also returns the indices.
1536 ///
1537 /// # Example
1538 ///
1539 /// ```rust
1540 /// use burn_tensor::backend::Backend;
1541 /// use burn_tensor::{Tensor, Shape};
1542 ///
1543 /// fn example<B: Backend>() {
1544 /// let device = B::Device::default();
1545 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1546 /// let (tensor, index) = tensor.min_dim_with_indices(0);
1547 /// println!("{tensor}");
1548 /// // [[5.0, -2.0, 3.0]]
1549 /// println!("{}", index);
1550 /// // [[1, 0, 0]]
1551 /// }
1552 /// ```
1553 pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {
1554 let dim = canonicalize_dim(dim, D, false);
1555 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1556
1557 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1558
1559 let tensor = Tensor::new(tensor);
1560 let index = Tensor::new(index);
1561
1562 (tensor, index)
1563 }
1564
1565 /// Finds the minimum pair wise values with another tensor.
1566 ///
1567 /// # Arguments
1568 ///
1569 /// * `other` - Other tensor to find minimum elements with
1570 ///
1571 /// # Returns
1572 ///
1573 /// A tensor with the same shape as the input tensors containing the minimum value found
1574 /// between each element of the two source tensors.
1575 ///
1576 /// # Example
1577 ///
1578 /// ```rust
1579 /// use burn_tensor::backend::Backend;
1580 /// use burn_tensor::{Tensor, Shape};
1581 ///
1582 /// fn example<B: Backend>() {
1583 /// let device = B::Device::default();
1584 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1585 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1586 /// let tensor = tensor1.min_pair(tensor2);
1587 /// println!("{tensor}");
1588 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1589 /// }
1590 pub fn min_pair(self, other: Self) -> Self {
1591 let mask = other.clone().lower(self.clone());
1592 self.mask_where(mask, other)
1593 }
1594
1595 /// Clamp element wise between the given min and max values.
1596 ///
1597 /// # Arguments
1598 ///
1599 /// * `min` - The minimum value.
1600 /// * `max` - The maximum value.
1601 ///
1602 /// # Returns
1603 ///
1604 /// A new tensor with the values clamped between the given min and max values.
1605 ///
1606 /// # Example
1607 ///
1608 /// ```rust
1609 /// use burn_tensor::backend::Backend;
1610 /// use burn_tensor::{Int, Tensor};
1611 ///
1612 /// fn example<B: Backend>() {
1613 /// let device = Default::default();
1614 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1615 /// [
1616 /// [1, 2, 3],
1617 /// [4, 5, 6],
1618 /// [7, 8, 9]
1619 /// ],
1620 /// &device);
1621 /// let tensor = tensor.clamp(2, 6);
1622 /// println!("{tensor}");
1623 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1624 /// }
1625 /// ```
1626 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1627 Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1628 }
1629
1630 /// Clamp element wise under a minimum value.
1631 ///
1632 /// # Arguments
1633 ///
1634 /// * `tensor` - The tensor to clamp.
1635 /// * `min` - The minimum value.
1636 ///
1637 /// # Returns
1638 ///
1639 /// A new tensor with the values clamped under the given min value.
1640 ///
1641 /// # Example
1642 ///
1643 /// ```rust
1644 /// use burn_tensor::backend::Backend;
1645 /// use burn_tensor::{Int, Tensor};
1646 ///
1647 /// fn example<B: Backend>() {
1648 /// let device = Default::default();
1649 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1650 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1651 /// &device);
1652 /// let tensor = tensor.clamp_min(4);
1653 /// println!("{tensor}");
1654 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1655 /// }
1656 /// ```
1657 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1658 Self::new(K::clamp_min(self.primitive, min.elem()))
1659 }
1660
1661 /// Clamp element wise over a maximum value.
1662 ///
1663 /// # Arguments
1664 ///
1665 /// * `tensor` - The tensor to clamp.
1666 /// * `max` - The maximum value.
1667 ///
1668 /// # Returns
1669 ///
1670 /// A new tensor with the values clamped over the given max value.
1671 ///
1672 /// # Example
1673 ///
1674 /// ```rust
1675 /// use burn_tensor::backend::Backend;
1676 /// use burn_tensor::{Int, Tensor};
1677 ///
1678 /// fn example<B: Backend>() {
1679 /// let device = Default::default();
1680 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1681 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1682 /// &device);
1683 /// let tensor = tensor.clamp_max(5);
1684 /// println!("{tensor}");
1685 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1686 /// }
1687 /// ```
1688 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1689 Self::new(K::clamp_max(self.primitive, max.elem()))
1690 }
1691
1692 /// Apply element wise absolute value operation.
1693 ///
1694 /// # Example
1695 ///
1696 /// ```rust
1697 /// use burn_tensor::backend::Backend;
1698 /// use burn_tensor::{Int, Tensor};
1699 ///
1700 /// fn example<B: Backend>() {
1701 /// let device = Default::default();
1702 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1703 /// let tensor = tensor.abs();
1704 /// println!("{tensor}");
1705 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1706 /// }
1707 /// ```
1708 pub fn abs(self) -> Self {
1709 Self::new(K::abs(self.primitive))
1710 }
1711
1712 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1713 /// the other elements of the result tensor out are set to 0.
1714 ///
1715 /// See also [`triu_mask`](Tensor::triu_mask).
1716 ///
1717 /// # Arguments
1718 ///
1719 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1720 /// towards the upper triangle.
1721 ///
1722 /// # Example
1723 /// ```rust
1724 /// use burn_tensor::backend::Backend;
1725 /// use burn_tensor::{Int, Tensor};
1726 ///
1727 /// fn example<B: Backend>() {
1728 /// let device = Default::default();
1729 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1730 /// [
1731 /// [1, 2, 3],
1732 /// [4, 5, 6],
1733 /// [7, 8, 9]
1734 /// ],
1735 /// &device
1736 /// );
1737 /// let tensor = tensor.triu(1);
1738 /// println!("{tensor}");
1739 /// // [
1740 /// // [0, 2, 3],
1741 /// // [0, 0, 6],
1742 /// // [0, 0, 0]
1743 /// // ]
1744 /// }
1745 /// ```
1746 pub fn triu(self, diagonal: i64) -> Self {
1747 check!(TensorCheck::tri::<{ D }>());
1748
1749 // last two dimensions
1750 let shape = &self.shape().dims[D - 2..].to_owned();
1751
1752 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1753 self.mask_fill(mask, 0)
1754 }
1755
1756 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1757 /// the other elements of the result tensor out are set to 0.
1758 ///
1759 /// See also [`tril_mask`](Tensor::tril_mask).
1760 ///
1761 /// # Arguments
1762 ///
1763 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1764 /// towards the upper triangle.
1765 ///
1766 /// # Example
1767 /// ```rust
1768 /// use burn_tensor::backend::Backend;
1769 /// use burn_tensor::{Int, Tensor};
1770 ///
1771 /// fn example<B: Backend>() {
1772 /// let device = Default::default();
1773 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1774 /// [
1775 /// [1, 2, 3],
1776 /// [4, 5, 6],
1777 /// [7, 8, 9]
1778 /// ],
1779 /// &device
1780 /// );
1781 ///
1782 /// let tensor = tensor.tril(-1);
1783 /// println!("{tensor}");
1784 /// // [
1785 /// // [0, 0, 0],
1786 /// // [4, 0, 0],
1787 /// // [7, 8, 0]
1788 /// // ]
1789 /// }
1790 /// ```
1791 pub fn tril(self, diagonal: i64) -> Self {
1792 check!(TensorCheck::tri::<{ D }>());
1793
1794 // last two dimensions
1795 let shape = &self.shape().dims[D - 2..].to_owned();
1796 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1797
1798 self.mask_fill(mask, 0)
1799 }
1800
1801 /// Applies element wise power operation with a float Tensor
1802 ///
1803 /// # Arguments
1804 ///
1805 /// * `other` - The tensor to apply the power operation with.
1806 ///
1807 /// # Example
1808 ///
1809 /// ```rust
1810 /// use burn_tensor::backend::Backend;
1811 /// use burn_tensor::{Tensor, Shape};
1812 ///
1813 /// fn example<B: Backend>() {
1814 /// let device = B::Device::default();
1815 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1816 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1817 /// let tensor = tensor1.powf(tensor2);
1818 /// println!("{tensor}");
1819 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1820 /// }
1821 /// ```
1822 pub fn powf(self, other: Self) -> Self {
1823 Self::new(K::powf(self.primitive, other.primitive))
1824 }
1825
1826 /// Applies element wise power operation with a float scalar
1827 ///
1828 /// # Arguments
1829 ///
1830 /// * `other` - The scalar to apply the power operation with.
1831 ///
1832 /// # Example
1833 ///
1834 /// ```rust
1835 /// use burn_tensor::backend::Backend;
1836 /// use burn_tensor::{Tensor, Shape};
1837 ///
1838 /// fn example<B: Backend>() {
1839 /// let device = B::Device::default();
1840 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1841 /// let tensor = tensor.powf_scalar(2.0);
1842 /// println!("{tensor}");
1843 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1844 /// }
1845 /// ```
1846 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1847 Self::new(K::powf_scalar::<E>(self.primitive, other))
1848 }
1849
1850 /// Applies element wise power operation with a integer Tensor
1851 ///
1852 /// # Arguments
1853 ///
1854 /// * `other` - The tensor to apply the power operation with.
1855 ///
1856 /// # Example
1857 ///
1858 /// ```rust
1859 /// use burn_tensor::backend::Backend;
1860 /// use burn_tensor::{Tensor, Shape, Int};
1861 ///
1862 /// fn example<B: Backend>() {
1863 /// let device = B::Device::default();
1864 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1865 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1866 /// let tensor = tensor1.powi(tensor2);
1867 /// println!("{tensor}");
1868 /// // [[1, -8, 81], [5, 81, 216]]
1869 /// }
1870 /// ```
1871 pub fn powi(self, other: Self) -> Self {
1872 Self::new(K::powi(self.primitive, other.primitive))
1873 }
1874
1875 /// Applies element wise power operation with a integer scalar
1876 ///
1877 /// # Arguments
1878 ///
1879 /// * `other` - The scalar to apply the power operation with.
1880 ///
1881 /// # Example
1882 ///
1883 /// ```rust
1884 /// use burn_tensor::backend::Backend;
1885 /// use burn_tensor::{Tensor, Shape, Int};
1886 ///
1887 /// fn example<B: Backend>() {
1888 /// let device = B::Device::default();
1889 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1890 /// let tensor = tensor.powi_scalar(2);
1891 /// println!("{tensor}");
1892 ///
1893 /// // [[1, 4, 9], [25, 81, 36]]
1894 /// let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1895 /// let tensor = tensor.powi_scalar(2);
1896 /// println!("{tensor}");
1897 /// // [[2.25, 4., 9.], [25., 81., 36.]]
1898 /// }
1899 /// ```
1900 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1901 Self::new(K::powi_scalar::<E>(self.primitive, other))
1902 }
1903
1904 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1905 ///
1906 /// # Returns
1907 ///
1908 /// A boolean tensor with the same shape as the input tensor.
1909 ///
1910 /// # Example
1911 ///
1912 /// ```rust
1913 /// use burn_tensor::backend::Backend;
1914 /// use burn_tensor::{Tensor, Shape};
1915 ///
1916 /// fn example<B: Backend>() {
1917 /// let device = B::Device::default();
1918 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1919 /// let tensor = tensor.bool();
1920 /// println!("{tensor}");
1921 /// // [
1922 /// // [true, true, true],
1923 /// // [false, true, true]
1924 /// // ]
1925 /// }
1926 pub fn bool(self) -> Tensor<B, D, Bool> {
1927 Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1928 }
1929
1930 /// Create a random tensor of the given shape on the given device where each element is
1931 /// sampled from the given distribution.
1932 ///
1933 /// See also [`random_like`](Tensor::random_like).
1934 ///
1935 /// # Arguments
1936 ///
1937 /// * `shape` - The shape of the tensor.
1938 /// * `distribution` - The distribution to sample from.
1939 /// * `device` - The device to create the tensor on.
1940 ///
1941 /// # Returns
1942 ///
1943 /// A new tensor with the given shape and elements sampled from the given distribution.
1944 ///
1945 /// # Example
1946 ///
1947 /// ```rust
1948 /// use burn_tensor::backend::Backend;
1949 /// use burn_tensor::{Tensor, Shape, Distribution};
1950 ///
1951 /// fn example<B: Backend>() {
1952 /// let device = B::Device::default();
1953 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1954 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1955 /// println!("{tensor}");
1956 /// // [
1957 /// // [0.08347523, 0.70498955, 0.60332155],
1958 /// // [0.08173251, 0.18028641, 0.97942924]
1959 /// // ]
1960 /// }
1961 /// ```
1962 pub fn random<S: Into<Shape>>(
1963 shape: S,
1964 distribution: Distribution,
1965 device: &B::Device,
1966 ) -> Self {
1967 Self::new(K::random(shape.into(), distribution, device))
1968 }
1969
1970 /// Sort the elements by value in ascending order along a given dimension.
1971 ///
1972 /// This sort is unstable (i.e., may reorder equal elements).
1973 ///
1974 /// # Arguments
1975 ///
1976 /// * `dim` - The dimension to sort along.
1977 ///
1978 /// # Returns
1979 ///
1980 /// A new tensor with the elements sorted in ascending order along the given dimension.
1981 ///
1982 /// # Example
1983 ///
1984 /// ```rust
1985 /// use burn_tensor::backend::Backend;
1986 /// use burn_tensor::{Tensor, Shape};
1987 ///
1988 /// fn example<B: Backend>() {
1989 /// let device = B::Device::default();
1990 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991 /// let tensor = tensor.sort(0);
1992 /// println!("{tensor}");
1993 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1994 /// let tensor = tensor.sort(1);
1995 /// println!("{tensor}");
1996 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1997 /// }
1998 /// ```
1999 pub fn sort(self, dim: usize) -> Self {
2000 check!(TensorCheck::sort_dim::<D>("Sort", dim));
2001 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
2002 }
2003
2004 /// Sort the elements by value in descending order along a given dimension.
2005 ///
2006 /// This sort is unstable (i.e., may reorder equal elements).
2007 ///
2008 /// # Arguments
2009 ///
2010 /// * `dim` - The dimension to sort along.
2011 ///
2012 /// # Returns
2013 ///
2014 /// A new tensor with the elements sorted in descending order along the given dimension.
2015 ///
2016 /// # Example
2017 ///
2018 /// ```rust
2019 /// use burn_tensor::backend::Backend;
2020 /// use burn_tensor::{Tensor, Shape};
2021 ///
2022 /// fn example<B: Backend>() {
2023 /// let device = B::Device::default();
2024 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2025 /// let tensor = tensor.sort_descending(0);
2026 /// println!("{tensor}");
2027 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2028 /// let tensor = tensor.sort_descending(1);
2029 /// println!("{tensor}");
2030 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
2031 /// }
2032 /// ```
2033 pub fn sort_descending(self, dim: usize) -> Self {
2034 check!(TensorCheck::sort_dim::<D>("Sort", dim));
2035 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
2036 }
2037
2038 /// Sort the elements by value in ascending order along a given dimension.
2039 /// Also returns the indices.
2040 ///
2041 /// This sort is unstable (i.e., may reorder equal elements).
2042 ///
2043 /// # Arguments
2044 ///
2045 /// * `dim` - The dimension to sort along.
2046 ///
2047 /// # Returns
2048 ///
2049 /// A tuple containing the sorted tensor and the indices tensor.
2050 ///
2051 /// # Example
2052 ///
2053 /// ```rust
2054 /// use burn_tensor::backend::Backend;
2055 /// use burn_tensor::{Tensor, Shape};
2056 ///
2057 /// fn example<B: Backend>() {
2058 /// let device = B::Device::default();
2059 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2060 /// let (tensor, indices) = tensor.sort_with_indices(0);
2061 /// println!("{tensor}");
2062 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
2063 /// println!("{}", indices);
2064 /// // [[1, 0, 0], [0, 1, 1]]
2065 /// }
2066 /// ```
2067 pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
2068 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
2069 let (values, indices) =
2070 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
2071 (Tensor::new(values), Tensor::new(indices))
2072 }
2073
2074 /// Sort the elements by value in descending order along a given dimension.
2075 /// Also returns the indices.
2076 ///
2077 /// This sort is unstable (i.e., may reorder equal elements).
2078 ///
2079 /// # Arguments
2080 ///
2081 /// * `dim` - The dimension to sort along.
2082 ///
2083 /// # Example
2084 ///
2085 /// ```rust
2086 /// use burn_tensor::backend::Backend;
2087 /// use burn_tensor::{Tensor, Shape};
2088 ///
2089 /// fn example<B: Backend>() {
2090 /// let device = B::Device::default();
2091 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2092 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
2093 /// println!("{tensor}");
2094 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2095 /// println!("{}", indices);
2096 /// // [[0, 1, 1], [1, 0, 0]]
2097 /// }
2098 /// ```
2099 pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {
2100 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
2101 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
2102 (Tensor::new(values), Tensor::new(indices))
2103 }
2104
2105 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
2106 ///
2107 /// This sort is unstable (i.e., may reorder equal elements).
2108 ///
2109 /// # Arguments
2110 ///
2111 /// * `dim` - The dimension to sort along.
2112 ///
2113 /// # Example
2114 ///
2115 /// ```rust
2116 /// use burn_tensor::backend::Backend;
2117 /// use burn_tensor::{Tensor, Shape};
2118 ///
2119 /// fn example<B: Backend>() {
2120 /// let device = B::Device::default();
2121 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2122 /// let tensor = tensor.argsort(0);
2123 /// println!("{tensor}");
2124 /// // [[1, 0, 0], [0, 1, 1]]
2125 /// }
2126 /// ```
2127 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
2128 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
2129 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
2130 }
2131
2132 /// Returns the indices that sort the elements by value in descending order along a given dimension.
2133 ///
2134 /// This sort is unstable (i.e., may reorder equal elements).
2135 ///
2136 /// # Arguments
2137 ///
2138 /// * `dim` - The dimension to sort along.
2139 ///
2140 /// # Example
2141 ///
2142 /// ```rust
2143 /// use burn_tensor::backend::Backend;
2144 /// use burn_tensor::{Tensor, Shape};
2145 ///
2146 /// fn example<B: Backend>() {
2147 /// let device = B::Device::default();
2148 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2149 /// let tensor = tensor.argsort_descending(0);
2150 /// println!("{tensor}");
2151 /// // [[0, 1, 1], [1, 0, 0]]
2152 /// let tensor = tensor.argsort_descending(1);
2153 /// println!("{tensor}");
2154 /// // [[0, 2, 1], [2, 0, 1]]
2155 /// }
2156 /// ```
2157 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
2158 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
2159 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
2160 }
2161
2162 /// Returns the `k` largest elements of the given input tensor along a given dimension.
2163 ///
2164 /// # Arguments
2165 ///
2166 /// * `k` - The number of elements to return.
2167 ///
2168 /// # Returns
2169 ///
2170 /// A new tensor with the `k` largest elements along the given dimension.
2171 ///
2172 /// # Example
2173 ///
2174 /// ```rust
2175 /// use burn_tensor::backend::Backend;
2176 /// use burn_tensor::{Tensor, Shape};
2177 ///
2178 /// fn example<B: Backend>() {
2179 /// let device = B::Device::default();
2180 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2181 /// let tensor = tensor.topk(2, 0);
2182 /// println!("{tensor}");
2183 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2184 /// let tensor = tensor.topk(1, 1);
2185 /// println!("{tensor}");
2186 /// // [[12.0], [6.0]]
2187 /// }
2188 /// ```
2189 pub fn topk(self, k: usize, dim: usize) -> Self {
2190 let k_indices = Tensor::arange(0..k as i64, &self.device());
2191 self.sort_descending(dim).select(dim, k_indices)
2192 }
2193
2194 /// Returns the `k` largest elements of the given input tensor along a given dimension.
2195 /// Also returns the indices.
2196 ///
2197 /// # Arguments
2198 ///
2199 /// * `k` - The number of elements to return.
2200 /// * `dim` - The dimension to sort along.
2201 ///
2202 /// # Example
2203 ///
2204 /// ```rust
2205 /// use burn_tensor::backend::Backend;
2206 /// use burn_tensor::{Tensor, Shape};
2207 ///
2208 /// fn example<B: Backend>() {
2209 /// let device = B::Device::default();
2210 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2211 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
2212 /// println!("{tensor}");
2213 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2214 /// println!("{}", indices);
2215 /// // [[0, 1, 1], [1, 0, 0]]
2216 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
2217 /// println!("{tensor}");
2218 /// // [[12.0], [6.0]]
2219 /// println!("{indices}");
2220 /// // [[0], [2]]
2221 /// }
2222 /// ```
2223 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {
2224 let k_indices = Tensor::arange(0..k as i64, &self.device());
2225 let (values, indices) = self.sort_descending_with_indices(dim);
2226 (
2227 values.select(dim, k_indices.clone()),
2228 indices.select(dim, k_indices),
2229 )
2230 }
2231
2232 /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
2233 ///
2234 /// # Arguments
2235 ///
2236 /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
2237 /// * `value` - The value to pad the tensor with.
2238 ///
2239 /// # Returns
2240 ///
2241 /// A new tensor with the given padding.
2242 ///
2243 /// # Example
2244 ///
2245 /// ```rust
2246 /// use burn_tensor::backend::Backend;
2247 /// use burn_tensor::{Tensor, Shape};
2248 ///
2249 /// fn example<B: Backend<FloatElem: From<f32>>>() {
2250 /// let device = B::Device::default();
2251 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2252 /// let tensor = tensor.pad((1, 1, 1, 1), 0.0);
2253 /// println!("{tensor}");
2254 /// // [
2255 /// // [0.0, 0.0, 0.0, 0.0, 0.0],
2256 /// // [0.0, 12.0, -2.0, 3.0, 0.0],
2257 /// // [0.0, 5.0, 3.0, 6.0, 0.0],
2258 /// // [0.0, 0.0, 0.0, 0.0, 0.0]
2259 /// // ]
2260 /// }
2261 /// ```
2262 pub fn pad<E: ElementConversion>(
2263 self,
2264 padding: (usize, usize, usize, usize),
2265 value: E,
2266 ) -> Self {
2267 let (left, right, top, bottom) = padding;
2268
2269 let mut padded_dims: [usize; D] = self.dims();
2270
2271 // Update the last two dimensions with padding
2272 padded_dims[D - 2] += top + bottom;
2273 padded_dims[D - 1] += left + right;
2274
2275 // Create the ranges for the padded tensor
2276 let ranges: [core::ops::Range<usize>; D] = padded_dims
2277 .iter()
2278 .enumerate()
2279 .map(|(i, &dim)| {
2280 if i == D - 2 {
2281 top..dim - bottom
2282 } else if i == D - 1 {
2283 left..dim - right
2284 } else {
2285 0..dim
2286 }
2287 })
2288 .collect::<Vec<core::ops::Range<usize>>>()
2289 .try_into()
2290 .unwrap();
2291
2292 // Create the padded tensor
2293 let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2294
2295 // Assign the original tensor data to the appropriate slice of the padded tensor
2296 padded_tensor.slice_assign(ranges, self)
2297 }
2298 /// Create a one hot tensor.
2299 ///
2300 /// # Example
2301 ///
2302 /// ```rust
2303 /// use burn_tensor::backend::Backend;
2304 /// use burn_tensor::Tensor;
2305 ///
2306 /// fn example<B: Backend>(){
2307 /// let device = Default::default();
2308 /// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
2309 /// let one_hot: Tensor<B, 2> = indices.one_hot(4);
2310 /// println!("{}", one_hot.to_data());
2311 /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
2312 /// }
2313 /// ```
2314 pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
2315 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
2316 self.one_hot_fill(num_classes, 1.0, 0.0, -1)
2317 }
2318
2319 /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
2320 ///
2321 /// # Arguments
2322 ///
2323 /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
2324 /// * `on_value`: The value to assign for active positions (corresponding to indices).
2325 /// * `off_value`: The value to assign for inactive positions.
2326 /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
2327 ///
2328 /// # Returns
2329 ///
2330 /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2331 ///
2332 /// # Example
2333 /// ```rust
2334 /// use burn_tensor::backend::Backend;
2335 /// use burn_tensor::{Tensor, Float};
2336 /// fn example<B: Backend<FloatElem: From<f32>>>() {
2337 /// let device = B::Device::default();
2338 /// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2339 /// // One-hot encoding
2340 /// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2341 /// println!("{tensor}");
2342 /// // [[[5.0, 0.0, 0.0],
2343 /// // [0.0, 0.0, 5.0]],
2344 /// // [[0.0, 5.0, 0.0],
2345 /// // [0.0, 0.0, 5.0]]]
2346 /// }
2347 /// ```
2348 pub fn one_hot_fill<const D2: usize>(
2349 self,
2350 num_classes: usize,
2351 on_value: f32,
2352 off_value: f32,
2353 axis: i64,
2354 ) -> Tensor<B, D2, K> {
2355 check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2356 // Initialize shape from the current tensor dimensions and prepare for modification
2357 let mut shape = self.shape();
2358 let device = self.device();
2359 let rank = self.dims().len();
2360
2361 // Adjust negative axis to a positive index
2362 let axis = if axis < 0 {
2363 axis + rank as i64 + 1
2364 } else {
2365 axis
2366 };
2367
2368 // Ensure axis is within valid range
2369 if axis < 0 || axis > rank as i64 {
2370 panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2371 }
2372 // Convert the input tensor to integer indices
2373 let indices: Tensor<B, D, Int> =
2374 Tensor::from_data(self.to_data().convert::<i64>(), &device);
2375 // Insert the new dimension for the one-hot representation
2376 shape.insert(axis as usize, num_classes);
2377 // Adjust indices to valid range and handle invalid indices
2378 let adjusted_indices = indices
2379 .clone()
2380 .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2381 .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2382 // Unsqueeze the indices tensor along the specified axis
2383 let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2384
2385 // Initialize the output tensor with the off_value
2386 let output = Tensor::full(shape.clone(), off_value, &device);
2387
2388 // Prepare scatter tensor for on_value and off_value adjustments
2389 let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2390 - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2391
2392 // Scatter on_value at the appropriate indices to create the one-hot representation
2393 output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
2394 }
2395
2396 /// Applies the matrix multiplication operation.
2397 ///
2398 /// ```math
2399 /// C = AB
2400 /// ```
2401 pub fn matmul(self, other: Self) -> Self {
2402 check!(TensorCheck::matmul(&self, &other));
2403 Tensor::new(K::matmul(self.primitive, other.primitive))
2404 }
2405}
2406
2407impl<B, K> Tensor<B, 1, K>
2408where
2409 B: Backend,
2410 K: Numeric<B>,
2411 K::Elem: Element,
2412{
2413 /// Calculates the dot product with another tensor.
2414 ///
2415 /// `y = x2.dot(x1)`
2416 ///
2417 /// # Arguments
2418 ///
2419 /// * `other` - The tensor to compute dot product with.
2420 ///
2421 /// # Notes
2422 ///
2423 /// Both tensors must have the same number of elements.
2424 ///
2425 /// # Example
2426 ///
2427 /// ```rust
2428 /// use burn_tensor::backend::Backend;
2429 /// use burn_tensor::{Tensor, Shape};
2430 ///
2431 /// fn example<B: Backend>() {
2432 /// let device = B::Device::default();
2433 /// let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);
2434 /// let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);
2435 /// let tensor = tensor1.dot(tensor2);
2436 /// println!("{tensor}");
2437 /// // [4]
2438 /// }
2439 /// ```
2440 pub fn dot(self, other: Self) -> Self {
2441 self.mul(other).sum()
2442 }
2443}
2444
2445impl<B, K> Tensor<B, 2, K>
2446where
2447 B: Backend,
2448 K: Numeric<B>,
2449 K::Elem: Element,
2450{
2451 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2452 ///
2453 /// # Arguments
2454 ///
2455 /// * `size` - The size of the square matrix.
2456 pub fn eye(size: usize, device: &B::Device) -> Self {
2457 let dtype = K::Elem::dtype();
2458 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2459 let ones = K::ones([1, size].into(), device, dtype);
2460 let zeros = K::zeros([size, size].into(), device, dtype);
2461
2462 Self::new(K::scatter(0, zeros, indices.primitive, ones))
2463 }
2464}
2465
2466/// Trait that list all operations that can be applied on all numerical tensors.
2467///
2468/// # Warnings
2469///
2470/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2471pub trait Numeric<B: Backend>: BasicOps<B>
2472where
2473 Self::Elem: Element,
2474{
2475 /// Adds two tensors together.
2476 ///
2477 /// # Arguments
2478 ///
2479 /// * `lhs` - The left hand side tensor.
2480 /// * `rhs` - The right hand side tensor.
2481 ///
2482 /// # Returns
2483 ///
2484 /// The sum of the two tensors.
2485 ///
2486 /// # Remarks
2487 ///
2488 /// This is a low-level function used internally by the library to call different backend functions
2489 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2490 /// or use this function directly.
2491 ///
2492 /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2493 /// which is more high-level and designed for public use.
2494 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2495
2496 /// Adds a scalar to a tensor element-wise.
2497 ///
2498 /// # Arguments
2499 ///
2500 /// * `lhs` - The left hand side tensor.
2501 /// * `rhs` - The right hand side scalar.
2502 ///
2503 /// # Returns
2504 ///
2505 /// The sum of the tensor and the scalar.
2506 ///
2507 /// # Remarks
2508 ///
2509 /// This is a low-level function used internally by the library to call different backend functions
2510 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2511 /// or use this function directly.
2512 ///
2513 /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2514 /// which is more high-level and designed for public use.
2515 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2516
2517 /// Subtracts two tensors.
2518 ///
2519 /// # Arguments
2520 ///
2521 /// * `lhs` - The left hand side tensor.
2522 /// * `rhs` - The right hand side tensor.
2523 ///
2524 /// # Returns
2525 ///
2526 /// The difference of the two tensors.
2527 ///
2528 /// # Remarks
2529 ///
2530 /// This is a low-level function used internally by the library to call different backend functions
2531 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2532 /// or use this function directly.
2533 ///
2534 /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2535 /// which is more high-level and designed for public use.
2536 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2537
2538 /// Subtracts a scalar from a tensor element-wise.
2539 ///
2540 /// # Arguments
2541 ///
2542 /// * `lhs` - The left hand side tensor.
2543 /// * `rhs` - The right hand side scalar.
2544 ///
2545 /// # Returns
2546 ///
2547 /// The difference of the tensor and the scalar.
2548 ///
2549 /// # Remarks
2550 ///
2551 /// This is a low-level function used internally by the library to call different backend functions
2552 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2553 /// or use this function directly.
2554 ///
2555 /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2556 /// which is more high-level and designed for public use.
2557 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2558
2559 /// Divides two tensors.
2560 ///
2561 /// # Arguments
2562 ///
2563 /// * `lhs` - The left hand side tensor.
2564 /// * `rhs` - The right hand side tensor.
2565 ///
2566 /// # Returns
2567 ///
2568 /// The quotient of the two tensors.
2569 ///
2570 /// # Remarks
2571 ///
2572 /// This is a low-level function used internally by the library to call different backend functions
2573 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2574 /// or use this function directly.
2575 ///
2576 /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2577 /// which is more high-level and designed for public use.
2578 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2579
2580 /// Divides a tensor by a scalar element-wise.
2581 ///
2582 /// # Arguments
2583 ///
2584 /// * `lhs` - The left hand side tensor.
2585 /// * `rhs` - The right hand side scalar.
2586 ///
2587 /// # Returns
2588 ///
2589 /// The quotient of the tensor and the scalar.
2590 ///
2591 /// # Remarks
2592 ///
2593 /// This is a low-level function used internally by the library to call different backend functions
2594 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2595 /// or use this function directly.
2596 ///
2597 /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2598 /// which is more high-level and designed for public use.
2599 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2600
2601 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2602 /// less than that of the divisor.
2603 ///
2604 /// # Arguments
2605 ///
2606 /// * `lhs` - The dividend.
2607 /// * `rhs` - The divisor.
2608 ///
2609 /// # Returns
2610 ///
2611 /// The modulo of the input tensor with the divisor.
2612 ///
2613 /// # Remarks
2614 ///
2615 /// This is a low-level function used internally by the library to call different backend functions
2616 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2617 /// or use this function directly.
2618 ///
2619 /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2620 /// which is more high-level and designed for public use.
2621 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2622
2623 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2624 /// less than that of the divisor.
2625 ///
2626 /// # Arguments
2627 ///
2628 /// * `lhs` - The dividend.
2629 /// * `rhs` - The divisor.
2630 ///
2631 /// # Returns
2632 ///
2633 /// The modulo of the input tensor with the divisor.
2634 ///
2635 /// # Remarks
2636 ///
2637 /// This is a low-level function used internally by the library to call different backend functions
2638 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2639 /// or use this function directly.
2640 ///
2641 /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2642 /// which is more high-level and designed for public use.
2643 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2644
2645 /// Multiplies two tensors.
2646 ///
2647 /// # Arguments
2648 ///
2649 /// * `lhs` - The left hand side tensor.
2650 /// * `rhs` - The right hand side tensor.
2651 ///
2652 /// # Returns
2653 ///
2654 /// The product of the two tensors.
2655 ///
2656 /// # Remarks
2657 ///
2658 /// This is a low-level function used internally by the library to call different backend functions
2659 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2660 /// or use this function directly.
2661 ///
2662 /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2663 /// which is more high-level and designed for public use.
2664 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2665
2666 /// Multiplies a tensor by a scalar element-wise.
2667 ///
2668 /// # Arguments
2669 ///
2670 /// * `lhs` - The left hand side tensor.
2671 /// * `rhs` - The right hand side scalar.
2672 ///
2673 /// # Returns
2674 ///
2675 /// The product of the tensor and the scalar.
2676 ///
2677 /// # Remarks
2678 ///
2679 /// This is a low-level function used internally by the library to call different backend functions
2680 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2681 /// or use this function directly.
2682 ///
2683 /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2684 /// which is more high-level and designed for public use.
2685 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2686
2687 /// Negates a tensor.
2688 ///
2689 /// # Arguments
2690 ///
2691 /// * `tensor` - The tensor to negate.
2692 ///
2693 /// # Returns
2694 ///
2695 /// The negated tensor.
2696 ///
2697 /// # Remarks
2698 ///
2699 /// This is a low-level function used internally by the library to call different backend functions
2700 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2701 /// or use this function directly.
2702 ///
2703 /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2704 /// which is more high-level and designed for public use.
2705 fn neg(tensor: Self::Primitive) -> Self::Primitive;
2706
2707 /// Returns the signs of the elements of a tensor.
2708 ///
2709 /// # Arguments
2710 ///
2711 /// * `tensor` - The tensor.
2712 ///
2713 /// # Returns
2714 ///
2715 /// The signs of the elements of the tensor.
2716 ///
2717 /// # Remarks
2718 ///
2719 /// This is a low-level function used internally by the library to call different backend functions
2720 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2721 /// or use this function directly.
2722 ///
2723 /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2724 /// which is more high-level and designed for public use.
2725 fn sign(tensor: Self::Primitive) -> Self::Primitive;
2726
2727 /// Creates a tensor filled with zeros.
2728 ///
2729 /// # Arguments
2730 ///
2731 /// * `shape` - The shape of the tensor.
2732 /// * `device` - The device on which the tensor will be allocated.
2733 /// * `dtype` - The target data type.
2734 ///
2735 /// # Returns
2736 ///
2737 /// The tensor filled with zeros.
2738 ///
2739 /// # Remarks
2740 ///
2741 /// This is a low-level function used internally by the library to call different backend functions
2742 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2743 /// or use this function directly.
2744 ///
2745 /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2746 /// which is more high-level and designed for public use.
2747 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2748
2749 /// Creates a tensor filled with ones.
2750 ///
2751 /// # Arguments
2752 ///
2753 /// * `shape` - The shape of the tensor.
2754 /// * `device` - The device on which the tensor will be allocated.
2755 /// * `dtype` - The target data type.
2756 ///
2757 /// # Returns
2758 ///
2759 /// The tensor filled with ones.
2760 ///
2761 /// # Remarks
2762 ///
2763 /// This is a low-level function used internally by the library to call different backend functions
2764 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2765 /// or use this function directly.
2766 ///
2767 /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2768 /// which is more high-level and designed for public use.
2769 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;
2770
2771 /// Sums all the elements of the tensor.
2772 ///
2773 /// # Arguments
2774 ///
2775 /// * `tensor` - The tensor to sum.
2776 ///
2777 /// # Returns
2778 ///
2779 /// The sum of all the elements of the tensor.
2780 ///
2781 /// # Remarks
2782 ///
2783 /// This is a low-level function used internally by the library to call different backend functions
2784 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2785 /// or use this function directly.
2786 ///
2787 /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2788 /// which is more high-level and designed for public use.
2789 fn sum(tensor: Self::Primitive) -> Self::Primitive;
2790
2791 /// Sums all the elements of the tensor along a dimension.
2792 ///
2793 /// # Arguments
2794 ///
2795 /// * `tensor` - The tensor to sum.
2796 /// * `dim` - The dimension along which to sum.
2797 ///
2798 /// # Returns
2799 ///
2800 /// The sum of all the elements of the tensor along the specified dimension.
2801 ///
2802 /// # Remarks
2803 ///
2804 /// This is a low-level function used internally by the library to call different backend functions
2805 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2806 /// or use this function directly.
2807 ///
2808 /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2809 /// which is more high-level and designed for public use.
2810 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2811
2812 /// Computes the product of all the elements of the tensor.
2813 ///
2814 /// # Arguments
2815 ///
2816 /// * `tensor` - The tensor to compute the product of.
2817 ///
2818 /// # Returns
2819 ///
2820 /// The product of all the elements of the tensor.
2821 ///
2822 /// # Remarks
2823 ///
2824 /// This is a low-level function used internally by the library to call different backend functions
2825 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2826 /// or use this function directly.
2827 ///
2828 /// For computing the product of all the elements of a tensor, users should prefer the
2829 /// [Tensor::prod](Tensor::prod) function,
2830 /// which is more high-level and designed for public use.
2831 fn prod(tensor: Self::Primitive) -> Self::Primitive;
2832
2833 /// Computes the product of all the elements of the tensor along a dimension.
2834 ///
2835 /// # Arguments
2836 ///
2837 /// * `tensor` - The tensor to compute the product of.
2838 /// * `dim` - The dimension along which to compute the product.
2839 ///
2840 /// # Returns
2841 ///
2842 /// The product of all the elements of the tensor along the specified dimension.
2843 ///
2844 /// # Remarks
2845 ///
2846 /// This is a low-level function used internally by the library to call different backend functions
2847 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2848 /// or use this function directly.
2849 ///
2850 /// For computing the product of all the elements of a tensor along a dimension, users should
2851 /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2852 /// which is more high-level and designed for public use.
2853 ///
2854 ///
2855 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2856
2857 /// Computes the mean of all the elements of the tensor.
2858 ///
2859 /// # Arguments
2860 ///
2861 /// * `tensor` - The tensor to compute the mean of.
2862 ///
2863 /// # Returns
2864 ///
2865 /// The mean of all the elements of the tensor.
2866 ///
2867 /// # Remarks
2868 ///
2869 /// This is a low-level function used internally by the library to call different backend functions
2870 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2871 /// or use this function directly.
2872 ///
2873 /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2874 /// which is more high-level and designed for public use.
2875 fn mean(tensor: Self::Primitive) -> Self::Primitive;
2876
2877 /// Computes the mean of all the elements of the tensor along a dimension.
2878 ///
2879 /// # Arguments
2880 ///
2881 /// * `tensor` - The tensor to compute the mean of.
2882 /// * `dim` - The dimension along which to compute the mean.
2883 ///
2884 /// # Returns
2885 ///
2886 /// The mean of all the elements of the tensor along the specified dimension.
2887 ///
2888 /// # Remarks
2889 ///
2890 /// This is a low-level function used internally by the library to call different backend functions
2891 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2892 /// or use this function directly.
2893 ///
2894 /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2895 /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2896 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2897
2898 /// Computes the cumulative sum of elements along a dimension.
2899 ///
2900 /// # Arguments
2901 ///
2902 /// * `tensor` - The tensor to compute the cumulative sum of.
2903 /// * `dim` - The dimension along which to compute the cumulative sum.
2904 ///
2905 /// # Returns
2906 ///
2907 /// A tensor with the same shape as the input tensor, where each element is the cumulative sum
2908 /// of all elements up to and including that position along the specified dimension.
2909 ///
2910 /// # Remarks
2911 ///
2912 /// This is a low-level function used internally by the library to call different backend functions
2913 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2914 /// or use this function directly.
2915 ///
2916 /// For computing the cumulative sum of elements along a dimension, users should prefer
2917 /// the [Tensor::cumsum](Tensor::cumsum) function, which is more high-level and designed for public use.
2918 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2919
2920 /// Computes the cumulative product of elements along a dimension.
2921 ///
2922 /// # Arguments
2923 ///
2924 /// * `tensor` - The tensor to compute the cumulative product of.
2925 /// * `dim` - The dimension along which to compute the cumulative product.
2926 ///
2927 /// # Returns
2928 ///
2929 /// A tensor with the same shape as the input tensor, where each element is the cumulative product
2930 /// of all elements up to and including that position along the specified dimension.
2931 ///
2932 /// # Remarks
2933 ///
2934 /// This is a low-level function used internally by the library to call different backend functions
2935 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2936 /// or use this function directly.
2937 ///
2938 /// For computing the cumulative product of elements along a dimension, users should prefer
2939 /// the [Tensor::cumprod](Tensor::cumprod) function, which is more high-level and designed for public use.
2940 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2941
2942 /// Computes the cumulative minimum of elements along a dimension.
2943 ///
2944 /// # Arguments
2945 ///
2946 /// * `tensor` - The tensor to compute the cumulative minimum of.
2947 /// * `dim` - The dimension along which to compute the cumulative minimum.
2948 ///
2949 /// # Returns
2950 ///
2951 /// A tensor with the same shape as the input tensor, where each element is the minimum
2952 /// of all elements up to and including that position along the specified dimension.
2953 ///
2954 /// # Remarks
2955 ///
2956 /// This is a low-level function used internally by the library to call different backend functions
2957 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2958 /// or use this function directly.
2959 ///
2960 /// For computing the cumulative minimum of elements along a dimension, users should prefer
2961 /// the [Tensor::cummin](Tensor::cummin) function, which is more high-level and designed for public use.
2962 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2963
2964 /// Computes the cumulative maximum of elements along a dimension.
2965 ///
2966 /// # Arguments
2967 ///
2968 /// * `tensor` - The tensor to compute the cumulative maximum of.
2969 /// * `dim` - The dimension along which to compute the cumulative maximum.
2970 ///
2971 /// # Returns
2972 ///
2973 /// A tensor with the same shape as the input tensor, where each element is the maximum
2974 /// of all elements up to and including that position along the specified dimension.
2975 ///
2976 /// # Remarks
2977 ///
2978 /// This is a low-level function used internally by the library to call different backend functions
2979 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2980 /// or use this function directly.
2981 ///
2982 /// For computing the cumulative maximum of elements along a dimension, users should prefer
2983 /// the [Tensor::cummax](Tensor::cummax) function, which is more high-level and designed for public use.
2984 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2985 /// Element-wise equality between two tensors.
2986 ///
2987 /// # Arguments
2988 ///
2989 /// * `lhs` - The left hand side tensor.
2990 /// * `rhs` - The right hand side tensor.
2991 ///
2992 /// # Returns
2993 ///
2994 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2995 /// corresponding elements of the input tensors are equal, and false otherwise.
2996 ///
2997 /// # Remarks
2998 ///
2999 /// This is a low-level function used internally by the library to call different backend functions
3000 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3001 /// or use this function directly.
3002 ///
3003 /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
3004 /// function, which is more high-level and designed for public use.
3005 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3006
3007 /// Element-wise non-equality between two tensors.
3008 ///
3009 /// # Arguments
3010 ///
3011 /// * `lhs` - The left hand side tensor.
3012 /// * `rhs` - The right hand side tensor.
3013 ///
3014 /// # Returns
3015 ///
3016 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3017 /// corresponding elements of the input tensors are equal, and false otherwise.
3018 ///
3019 /// # Remarks
3020 ///
3021 /// This is a low-level function used internally by the library to call different backend functions
3022 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3023 /// or use this function directly.
3024 ///
3025 /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
3026 /// function, which is more high-level and designed for public use.
3027 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3028
3029 /// Element-wise greater than comparison between two tensors.
3030 ///
3031 /// # Arguments
3032 ///
3033 /// * `lhs` - The left hand side tensor.
3034 /// * `rhs` - The right hand side tensor.
3035 ///
3036 /// # Returns
3037 ///
3038 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3039 /// corresponding element of the left hand side tensor is greater than the corresponding element
3040 /// of the right hand side tensor, and false otherwise.
3041 ///
3042 /// # Remarks
3043 ///
3044 /// This is a low-level function used internally by the library to call different backend functions
3045 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3046 /// or use this function directly.
3047 ///
3048 /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
3049 /// which is more high-level and designed for public use.
3050 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3051
3052 /// Element-wise greater than comparison between a tensor and a scalar.
3053 ///
3054 /// # Arguments
3055 ///
3056 /// * `lhs` - The left hand side tensor.
3057 /// * `rhs` - The right hand side scalar.
3058 ///
3059 /// # Returns
3060 ///
3061 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3062 /// corresponding element of the left hand side tensor is greater than the right hand side
3063 /// scalar, and false otherwise.
3064 ///
3065 /// # Remarks
3066 ///
3067 /// This is a low-level function used internally by the library to call different backend functions
3068 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3069 /// or use this function directly.
3070 ///
3071 /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
3072 /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
3073 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3074
3075 /// Element-wise greater than or equal comparison between two tensors.
3076 ///
3077 /// # Arguments
3078 ///
3079 /// * `lhs` - The left hand side tensor.
3080 /// * `rhs` - The right hand side tensor.
3081 ///
3082 /// # Returns
3083 ///
3084 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3085 /// corresponding element of the left hand side tensor is greater than or equal to the
3086 /// corresponding element of the right hand side tensor, and false otherwise.
3087 ///
3088 /// # Remarks
3089 ///
3090 /// This is a low-level function used internally by the library to call different backend functions
3091 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3092 /// or use this function directly.
3093 ///
3094 /// For element-wise greater than or equal comparison between two tensors, users should prefer
3095 /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
3096 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3097
3098 /// Element-wise greater than or equal comparison between a tensor and a scalar.
3099 ///
3100 /// # Arguments
3101 ///
3102 /// * `lhs` - The left hand side tensor.
3103 /// * `rhs` - The right hand side scalar.
3104 ///
3105 /// # Returns
3106 ///
3107 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3108 /// corresponding element of the left hand side tensor is greater than or equal to the right
3109 /// hand side scalar, and false otherwise.
3110 ///
3111 /// # Remarks
3112 ///
3113 /// This is a low-level function used internally by the library to call different backend functions
3114 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3115 /// or use this function directly.
3116 ///
3117 /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
3118 /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
3119 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3120
3121 /// Element-wise less than comparison between two tensors.
3122 ///
3123 /// # Arguments
3124 ///
3125 /// * `lhs` - The left hand side tensor.
3126 /// * `rhs` - The right hand side tensor.
3127 ///
3128 /// # Returns
3129 ///
3130 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3131 /// corresponding element of the left hand side tensor is less than the corresponding element of
3132 /// the right hand side tensor, and false otherwise.
3133 ///
3134 /// # Remarks
3135 ///
3136 /// This is a low-level function used internally by the library to call different backend functions
3137 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3138 /// or use this function directly.
3139 ///
3140 /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
3141 /// which is more high-level and designed for public use.
3142 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3143
3144 /// Element-wise less than comparison between a tensor and a scalar.
3145 ///
3146 /// # Arguments
3147 ///
3148 /// * `lhs` - The left hand side tensor.
3149 /// * `rhs` - The right hand side scalar.
3150 ///
3151 /// # Returns
3152 ///
3153 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3154 /// corresponding element of the left hand side tensor is less than the right hand side scalar,
3155 /// and false otherwise.
3156 ///
3157 /// # Remarks
3158 ///
3159 /// This is a low-level function used internally by the library to call different backend functions
3160 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3161 /// or use this function directly.
3162 ///
3163 /// For element-wise less than comparison between a tensor and a scalar, users should prefer
3164 /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
3165 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3166
3167 /// Element-wise less than or equal comparison between two tensors.
3168 ///
3169 /// # Arguments
3170 ///
3171 /// * `lhs` - The left hand side tensor.
3172 /// * `rhs` - The right hand side tensor.
3173 ///
3174 /// # Returns
3175 ///
3176 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
3177 /// corresponding element of the left hand side tensor is less than or equal to the corresponding
3178 /// element of the right hand side tensor, and false otherwise.
3179 ///
3180 /// # Remarks
3181 ///
3182 /// This is a low-level function used internally by the library to call different backend functions
3183 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3184 /// or use this function directly.
3185 ///
3186 /// For element-wise less than or equal comparison between two tensors, users should prefer
3187 /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
3188 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
3189
3190 /// Element-wise less than or equal comparison between a tensor and a scalar.
3191 ///
3192 /// # Arguments
3193 ///
3194 /// * `lhs` - The left hand side tensor.
3195 /// * `rhs` - The right hand side scalar.
3196 ///
3197 /// # Returns
3198 ///
3199 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
3200 /// corresponding element of the left hand side tensor is less than or equal to the right hand
3201 /// side scalar, and false otherwise.
3202 ///
3203 /// # Remarks
3204 ///
3205 /// This is a low-level function used internally by the library to call different backend functions
3206 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3207 /// or use this function directly.
3208 ///
3209 /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
3210 /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
3211 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
3212
3213 /// Selects elements from a tensor based on a boolean mask.
3214 ///
3215 /// # Arguments
3216 ///
3217 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
3218 /// * `mask` - The boolean mask to use for selecting elements.
3219 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
3220 ///
3221 /// # Returns
3222 ///
3223 /// A tensor with the same shape as the input tensors, where each element is taken from the
3224 /// corresponding element of the left hand side tensor if the corresponding element of the mask
3225 /// is true, and from the corresponding element of the right hand side tensor otherwise.
3226 ///
3227 /// # Remarks
3228 ///
3229 /// This is a low-level function used internally by the library to call different backend functions
3230 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3231 /// or use this function directly.
3232 ///
3233 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
3234 /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
3235 fn mask_where(
3236 tensor: Self::Primitive,
3237 mask: B::BoolTensorPrimitive,
3238 source: Self::Primitive,
3239 ) -> Self::Primitive;
3240
3241 /// Fills elements of a tensor based on a boolean mask.
3242 ///
3243 /// # Arguments
3244 ///
3245 /// * `tensor` - The tensor where will be overwritten with the value
3246 /// when the corresponding element of the mask is true.
3247 /// * `mask` - The boolean mask to use for filling elements.
3248 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
3249 ///
3250 /// # Returns
3251 ///
3252 /// A tensor with the same shape as the input tensors, where each element is taken from the
3253 /// corresponding element unmodified if the corresponding element of the mask is false, and
3254 /// filled with the value otherwise.
3255 ///
3256 /// # Remarks
3257 ///
3258 /// This is a low-level function used internally by the library to call different backend functions
3259 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3260 /// or use this function directly.
3261 ///
3262 /// For filling elements of a tensor based on a boolean mask, users should prefer the
3263 /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
3264 fn mask_fill(
3265 tensor: Self::Primitive,
3266 mask: B::BoolTensorPrimitive,
3267 value: Self::Elem,
3268 ) -> Self::Primitive;
3269
3270 /// Gathers elements from a tensor along an axis.
3271 ///
3272 /// # Arguments
3273 ///
3274 /// * `dim` - The axis along which to gather elements.
3275 /// * `tensor` - The tensor to gather elements from.
3276 /// * `indices` - The indices of the elements to gather.
3277 ///
3278 /// # Returns
3279 ///
3280 /// A tensor with the same shape as the input tensor, where each element is taken from the
3281 /// corresponding element of the input tensor at the corresponding index along the specified axis.
3282 ///
3283 /// # Remarks
3284 ///
3285 /// This is a low-level function used internally by the library to call different backend functions
3286 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3287 /// or use this function directly.
3288 ///
3289 /// For gathering elements from a tensor along an axis, users should prefer the
3290 /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
3291 fn gather(
3292 dim: usize,
3293 tensor: Self::Primitive,
3294 indices: B::IntTensorPrimitive,
3295 ) -> Self::Primitive;
3296
3297 /// Scatters elements into a tensor along an axis.
3298 ///
3299 /// # Arguments
3300 ///
3301 /// * `dim` - The axis along which to scatter elements.
3302 /// * `tensor` - The tensor to scatter elements into.
3303 /// * `indices` - The indices of the elements to scatter.
3304 /// * `values` - The values to scatter into the tensor.
3305 ///
3306 /// # Returns
3307 ///
3308 /// A tensor with the same shape as the input tensor, where each element is taken from the
3309 /// corresponding element of the input tensor at the corresponding index along the specified axis,
3310 /// except for the elements at the specified indices, which are taken from the corresponding
3311 /// element of the values tensor.
3312 ///
3313 /// # Remarks
3314 ///
3315 /// This is a low-level function used internally by the library to call different backend functions
3316 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3317 /// or use this function directly.
3318 ///
3319 /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
3320 /// which is more high-level and designed for public use.
3321 fn scatter(
3322 dim: usize,
3323 tensor: Self::Primitive,
3324 indices: B::IntTensorPrimitive,
3325 values: Self::Primitive,
3326 ) -> Self::Primitive;
3327
3328 /// Gets the indices of the maximum elements of a tensor along an axis.
3329 ///
3330 /// # Arguments
3331 ///
3332 /// * `dim` - The axis along which to get the indices of the maximum elements.
3333 /// * `tensor` - The tensor to get the indices of the maximum elements from.
3334 ///
3335 /// # Returns
3336 ///
3337 /// A tensor with the same shape as the input tensor, where each element is the index of the
3338 /// maximum element of the input tensor at the corresponding index along the specified axis.
3339 ///
3340 /// # Remarks
3341 ///
3342 /// This is a low-level function used internally by the library to call different backend functions
3343 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3344 /// or use this function directly.
3345 ///
3346 /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
3347 /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
3348 fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3349
3350 /// Gets the indices of the minimum elements of a tensor along an axis.
3351 ///
3352 /// # Arguments
3353 ///
3354 /// * `dim` - The axis along which to get the indices of the minimum elements.
3355 /// * `tensor` - The tensor to get the indices of the minimum elements from.
3356 ///
3357 /// # Returns
3358 ///
3359 /// A tensor with the same shape as the input tensor, where each element is the index of the
3360 /// minimum element of the input tensor at the corresponding index along the specified axis.
3361 ///
3362 /// # Remarks
3363 ///
3364 /// This is a low-level function used internally by the library to call different backend functions
3365 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3366 /// or use this function directly.
3367 ///
3368 /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3369 /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3370 fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3371
3372 /// Gets the maximum elements of a tensor along an axis.
3373 ///
3374 /// # Arguments
3375 ///
3376 /// * `dim` - The axis along which to get the maximum elements.
3377 ///
3378 /// # Returns
3379 ///
3380 /// A single-element tensor containing the maximum element of the input tensor.
3381 ///
3382 /// # Remarks
3383 ///
3384 /// This is a low-level function used internally by the library to call different backend functions
3385 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3386 /// or use this function directly.
3387 ///
3388 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3389 /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3390 fn max(tensor: Self::Primitive) -> Self::Primitive;
3391
3392 /// Gets the maximum elements of a tensor along an axis.
3393 ///
3394 /// # Arguments
3395 ///
3396 /// * `tensor` - The tensor to get the maximum elements from.
3397 /// * `dim` - The axis along which to get the maximum elements.
3398 ///
3399 /// # Returns
3400 ///
3401 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3402 /// Each element is the maximum element of the corresponding input dim.
3403 ///
3404 /// # Remarks
3405 ///
3406 /// This is a low-level function used internally by the library to call different backend functions
3407 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3408 /// or use this function directly.
3409 ///
3410 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3411 /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3412 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3413
3414 /// Gets the maximum elements of a tensor along an axis.
3415 ///
3416 /// # Arguments
3417 ///
3418 /// * `tensor` - The tensor to get the maximum elements from.
3419 /// * `dim` - The axis along which to get the maximum elements.
3420 ///
3421 /// # Returns
3422 ///
3423 /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3424 /// as the input tensor, where each element is the index of the maximum element of the input tensor
3425 /// at the corresponding index along the specified axis.
3426 ///
3427 /// # Remarks
3428 ///
3429 /// This is a low-level function used internally by the library to call different backend functions
3430 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3431 /// or use this function directly.
3432 ///
3433 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3434 /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3435 fn max_dim_with_indices(
3436 tensor: Self::Primitive,
3437 dim: usize,
3438 ) -> (Self::Primitive, B::IntTensorPrimitive);
3439
3440 /// Gets the maximum elements of a tensor along an axis.
3441 ///
3442 /// # Arguments
3443 ///
3444 /// * `dim` - The axis along which to get the maximum elements.
3445 ///
3446 /// # Returns
3447 ///
3448 /// A single-element tensor containing the maximum absolute element of the input tensor.
3449 ///
3450 /// # Remarks
3451 ///
3452 /// This is a low-level function used internally by the library to call different backend functions
3453 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3454 /// or use this function directly.
3455 ///
3456 /// For getting the maximum absolute elements of a tensor, users should prefer the
3457 /// [Tensor::max_abs](Tensor::max_abs) function, which is more high-level and designed for public use.
3458 fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
3459
3460 /// Gets the maximum elements of a tensor along an axis.
3461 ///
3462 /// # Arguments
3463 ///
3464 /// * `tensor` - The tensor to get the maximum elements from.
3465 /// * `dim` - The axis along which to get the maximum elements.
3466 ///
3467 /// # Returns
3468 ///
3469 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3470 /// Each element is the maximum absolute element of the corresponding input dim.
3471 ///
3472 /// # Remarks
3473 ///
3474 /// This is a low-level function used internally by the library to call different backend functions
3475 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3476 /// or use this function directly.
3477 ///
3478 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3479 /// [Tensor::max_abs_dim](Tensor::max_abs_dim) function, which is more high-level and designed for public use.
3480 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3481
3482 /// Gets the minimum elements of a tensor along an axis.
3483 ///
3484 /// # Arguments
3485 ///
3486 /// * `tensor` - The tensor to get the minimum elements from.
3487 ///
3488 /// # Returns
3489 ///
3490 /// A single-element tensor containing the minimum element of the input tensor.
3491 ///
3492 /// # Remarks
3493 ///
3494 /// This is a low-level function used internally by the library to call different backend functions
3495 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3496 /// or use this function directly.
3497 ///
3498 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3499 /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3500 fn min(tensor: Self::Primitive) -> Self::Primitive;
3501
3502 /// Gets the minimum elements of a tensor along an axis.
3503 ///
3504 /// # Arguments
3505 ///
3506 /// * `tensor` - The tensor to get the minimum elements from.
3507 /// * `dim` - The axis along which to get the minimum elements.
3508 ///
3509 /// # Returns
3510 ///
3511 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3512 /// Each element is the minimum element of the corresponding input dim.
3513 ///
3514 /// # Remarks
3515 ///
3516 /// This is a low-level function used internally by the library to call different backend functions
3517 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3518 /// or use this function directly.
3519 ///
3520 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3521 /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3522 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3523
3524 /// Gets the minimum elements and indices of a tensor along an axis.
3525 ///
3526 /// # Arguments
3527 ///
3528 /// * `tensor` - The tensor to get the minimum elements from.
3529 ///
3530 /// # Returns
3531 ///
3532 /// A tensor with the same shape as the input tensor and corresponding indices, where
3533 /// each element is the minimum element of the input tensor at the corresponding index
3534 /// along the specified axis.
3535 ///
3536 /// # Remarks
3537 ///
3538 /// This is a low-level function used internally by the library to call different backend functions
3539 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3540 /// or use this function directly.
3541 ///
3542 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3543 /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3544 fn min_dim_with_indices(
3545 tensor: Self::Primitive,
3546 dim: usize,
3547 ) -> (Self::Primitive, B::IntTensorPrimitive);
3548
3549 /// Clamp the tensor between the given min and max values.
3550 ///
3551 /// # Arguments
3552 ///
3553 /// * `min` - The minimum value.
3554 /// * `max` - The maximum value.
3555 ///
3556 /// # Returns
3557 ///
3558 /// A new tensor with the values clamped between the given min and max values.
3559 ///
3560 /// # Remarks
3561 ///
3562 /// This is a low-level function used internally by the library to call different backend functions
3563 /// with static dispatch. It is not designed for direct usage by users.
3564 ///
3565 /// For clamping a tensor between the given min and max values, users should prefer the
3566 /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3567 fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3568
3569 /// Clamps a tensor under a minimum value.
3570 ///
3571 /// # Arguments
3572 ///
3573 /// * `tensor` - The tensor to clamp.
3574 /// * `min` - The minimum value.
3575 ///
3576 /// # Returns
3577 ///
3578 /// A new tensor with the values clamped under the given min value.
3579 ///
3580 /// # Remarks
3581 ///
3582 /// This is a low-level function used internally by the library to call different backend functions
3583 /// with static dispatch. It is not designed for direct usage by users.
3584 ///
3585 /// For clamping a tensor under a minimum value, users should prefer the
3586 /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3587 fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3588
3589 /// Clamps a tensor over a maximum value.
3590 ///
3591 /// # Arguments
3592 ///
3593 /// * `tensor` - The tensor to clamp.
3594 /// * `max` - The maximum value.
3595 ///
3596 /// # Returns
3597 ///
3598 /// A new tensor with the values clamped over the given max value.
3599 ///
3600 /// # Remarks
3601 ///
3602 /// This is a low-level function used internally by the library to call different backend functions
3603 /// with static dispatch. It is not designed for direct usage by users.
3604 ///
3605 /// For clamping a tensor over a maximum value, users should prefer the
3606 /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3607 fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3608
3609 /// Calculate absolute value on all elements of a tensor
3610 ///
3611 /// # Arguments
3612 ///
3613 /// * `tensor` - The tensor to apply abs to.
3614 ///
3615 /// # Returns
3616 ///
3617 /// A tensor with absolute values.
3618 ///
3619 /// # Remarks
3620 ///
3621 /// This is a low-level function used internally by the library to call different backend functions
3622 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3623 /// or use this function directly.
3624 ///
3625 /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3626 /// which is more high-level and designed for public use.
3627 fn abs(tensor: Self::Primitive) -> Self::Primitive;
3628
3629 /// Element-wise power of a tensor to a float tensor
3630 ///
3631 /// # Arguments
3632 /// * `tensor` - The tensor to apply power to.
3633 /// * `power` - The power to apply to the tensor.
3634 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3635
3636 /// Element-wise power of a tensor
3637 ///
3638 /// # Arguments
3639 /// * `tensor` - The tensor to apply power to.
3640 /// * `power` - The power to apply to the tensor.
3641 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3642
3643 /// Element-wise power of a tensor to a scalar float
3644 ///
3645 /// # Arguments
3646 /// * `tensor` - The tensor to apply power to.
3647 /// * `power` - The power to apply to the tensor.
3648 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3649
3650 /// Element-wise power of a tensor to a scalar int
3651 ///
3652 /// # Arguments
3653 /// * `tensor` - The tensor to apply power to.
3654 /// * `power` - The power to apply to the tensor.
3655 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3656
3657 /// Create a random tensor.
3658 ///
3659 /// # Arguments
3660 ///
3661 /// * `shape` - The shape of the output tensor.
3662 /// * `distribution` - The distribution used to sample.
3663 /// * `device` - The device to use.
3664 ///
3665 /// # Returns
3666 ///
3667 /// A new tensor.
3668 ///
3669 /// # Remarks
3670 ///
3671 /// This is a low-level function used internally by the library to call different backend functions
3672 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3673 /// or use this function directly.
3674 ///
3675 /// Users should prefer the [Tensor::random](Tensor::random) function,
3676 /// which is more high-level and designed for public use.
3677 fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3678
3679 /// Sort the elements of the input `tensor` by value along a given dimension.
3680 ///
3681 /// This sort is unstable (i.e., may reorder equal elements).
3682 ///
3683 /// # Arguments
3684 ///
3685 /// * `tensor` - The input tensor.
3686 /// * `dim` - The axis along which to sort.
3687 /// * `descending` - The sorting order.
3688 ///
3689 /// # Returns
3690 ///
3691 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3692 ///
3693 /// # Remarks
3694 /// This is a low-level function used internally by the library to call different backend functions
3695 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3696 /// or use this function directly.
3697 ///
3698 /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3699 /// which is more high-level and designed for public use.
3700 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3701
3702 /// Sort the elements of the input `tensor` by value along a given dimension.
3703 ///
3704 /// This sort is unstable (i.e., may reorder equal elements).
3705 ///
3706 /// # Arguments
3707 ///
3708 /// * `tensor` - The input tensor.
3709 /// * `dim` - The axis along which to sort.
3710 /// * `descending` - The sorting order.
3711 ///
3712 /// # Returns
3713 ///
3714 /// A tensor with the same shape as the input tensor and corresponding indices, where
3715 /// the elements are sorted by value and the indices map back to the original input tensor.
3716 ///
3717 /// # Remarks
3718 /// This is a low-level function used internally by the library to call different backend functions
3719 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3720 /// or use this function directly.
3721 ///
3722 /// For sorting the elements of a tensor, users should prefer the
3723 /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3724 /// and designed for public use.
3725 fn sort_with_indices(
3726 tensor: Self::Primitive,
3727 dim: usize,
3728 descending: bool,
3729 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3730
3731 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3732 ///
3733 /// This sort is unstable (i.e., may reorder equal elements).
3734 ///
3735 /// # Arguments
3736 ///
3737 /// * `tensor` - The input tensor.
3738 /// * `dim` - The axis along which to sort.
3739 /// * `descending` - The sorting order.
3740 ///
3741 /// # Returns
3742 ///
3743 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3744 ///
3745 /// # Remarks
3746 /// This is a low-level function used internally by the library to call different backend functions
3747 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3748 /// or use this function directly.
3749 ///
3750 /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3751 /// which is more high-level and designed for public use.
3752 fn argsort(
3753 tensor: Self::Primitive,
3754 dim: usize,
3755 descending: bool,
3756 ) -> <Int as TensorKind<B>>::Primitive;
3757
3758 /// Applies the matrix multiplication operation.
3759 ///
3760 /// ```math
3761 /// C = AB
3762 /// ```
3763 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3764}
3765
3766impl<B: Backend> Numeric<B> for Int {
3767 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3768 B::int_add(lhs, rhs)
3769 }
3770 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3771 B::int_add_scalar(lhs, rhs.elem())
3772 }
3773 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3774 B::int_sub(lhs, rhs)
3775 }
3776 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3777 B::int_sub_scalar(lhs, rhs.elem())
3778 }
3779 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3780 B::int_div(lhs, rhs)
3781 }
3782 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3783 B::int_div_scalar(lhs, rhs.elem())
3784 }
3785 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3786 B::int_remainder(lhs, rhs)
3787 }
3788 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3789 B::int_remainder_scalar(lhs, rhs.elem())
3790 }
3791 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3792 B::int_mul(lhs, rhs)
3793 }
3794 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3795 B::int_mul_scalar(lhs, rhs.elem())
3796 }
3797 fn neg(tensor: Self::Primitive) -> Self::Primitive {
3798 B::int_neg(tensor)
3799 }
3800 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3801 B::int_zeros(shape, device, dtype.into())
3802 }
3803 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
3804 B::int_ones(shape, device, dtype.into())
3805 }
3806
3807 fn sum(tensor: Self::Primitive) -> Self::Primitive {
3808 B::int_sum(tensor)
3809 }
3810
3811 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3812 B::int_sum_dim(tensor, dim)
3813 }
3814
3815 fn prod(tensor: Self::Primitive) -> Self::Primitive {
3816 B::int_prod(tensor)
3817 }
3818
3819 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3820 B::int_prod_dim(tensor, dim)
3821 }
3822
3823 fn mean(tensor: Self::Primitive) -> Self::Primitive {
3824 B::int_mean(tensor)
3825 }
3826 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3827 B::int_mean_dim(tensor, dim)
3828 }
3829 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3830 B::int_cumsum(tensor, dim)
3831 }
3832 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3833 B::int_cumprod(tensor, dim)
3834 }
3835
3836 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3837 B::int_cummin(tensor, dim)
3838 }
3839
3840 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3841 B::int_cummax(tensor, dim)
3842 }
3843
3844 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3845 B::int_equal_elem(lhs, rhs)
3846 }
3847 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3848 B::int_not_equal_elem(lhs, rhs)
3849 }
3850 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3851 B::int_greater(lhs, rhs)
3852 }
3853
3854 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3855 B::int_greater_elem(lhs, rhs)
3856 }
3857
3858 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3859 B::int_greater_equal(lhs, rhs)
3860 }
3861
3862 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3863 B::int_greater_equal_elem(lhs, rhs)
3864 }
3865
3866 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3867 B::int_lower(lhs, rhs)
3868 }
3869
3870 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3871 B::int_lower_elem(lhs, rhs)
3872 }
3873
3874 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3875 B::int_lower_equal(lhs, rhs)
3876 }
3877
3878 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3879 B::int_lower_equal_elem(lhs, rhs)
3880 }
3881
3882 fn mask_where(
3883 tensor: Self::Primitive,
3884 mask: B::BoolTensorPrimitive,
3885 source: Self::Primitive,
3886 ) -> Self::Primitive {
3887 B::int_mask_where(tensor, mask, source)
3888 }
3889
3890 fn mask_fill(
3891 tensor: Self::Primitive,
3892 mask: B::BoolTensorPrimitive,
3893 value: Self::Elem,
3894 ) -> Self::Primitive {
3895 B::int_mask_fill(tensor, mask, value)
3896 }
3897
3898 fn gather(
3899 dim: usize,
3900 tensor: Self::Primitive,
3901 indices: B::IntTensorPrimitive,
3902 ) -> Self::Primitive {
3903 B::int_gather(dim, tensor, indices)
3904 }
3905
3906 fn scatter(
3907 dim: usize,
3908 tensor: Self::Primitive,
3909 indices: B::IntTensorPrimitive,
3910 values: Self::Primitive,
3911 ) -> Self::Primitive {
3912 B::int_scatter(dim, tensor, indices, values)
3913 }
3914
3915 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3916 B::int_argmax(tensor, dim)
3917 }
3918
3919 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3920 B::int_argmin(tensor, dim)
3921 }
3922
3923 fn max(tensor: Self::Primitive) -> Self::Primitive {
3924 B::int_max(tensor)
3925 }
3926
3927 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3928 B::int_max_dim(tensor, dim)
3929 }
3930
3931 fn max_dim_with_indices(
3932 tensor: Self::Primitive,
3933 dim: usize,
3934 ) -> (Self::Primitive, IntTensor<B>) {
3935 B::int_max_dim_with_indices(tensor, dim)
3936 }
3937
3938 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
3939 B::int_max_abs(tensor)
3940 }
3941
3942 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3943 B::int_max_abs_dim(tensor, dim)
3944 }
3945
3946 fn min(tensor: Self::Primitive) -> Self::Primitive {
3947 B::int_min(tensor)
3948 }
3949
3950 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3951 B::int_min_dim(tensor, dim)
3952 }
3953
3954 fn min_dim_with_indices(
3955 tensor: Self::Primitive,
3956 dim: usize,
3957 ) -> (Self::Primitive, IntTensor<B>) {
3958 B::int_min_dim_with_indices(tensor, dim)
3959 }
3960
3961 fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3962 B::int_clamp(tensor, min, max)
3963 }
3964
3965 fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3966 B::int_clamp_min(tensor, min)
3967 }
3968
3969 fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3970 B::int_clamp_max(tensor, max)
3971 }
3972
3973 fn abs(tensor: Self::Primitive) -> Self::Primitive {
3974 B::int_abs(tensor)
3975 }
3976
3977 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3978 B::int_powf(lhs, B::int_into_float(rhs))
3979 }
3980
3981 fn powf_scalar<E: ElementConversion>(
3982 lhs: Self::Primitive,
3983 rhs: E,
3984 ) -> <Int as TensorKind<B>>::Primitive {
3985 B::int_powf_scalar(lhs, rhs.elem())
3986 }
3987
3988 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3989 B::int_powi(lhs, rhs)
3990 }
3991
3992 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3993 B::int_powi_scalar(lhs, rhs.elem())
3994 }
3995
3996 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3997 B::int_random(shape, distribution, device)
3998 }
3999
4000 fn sign(tensor: Self::Primitive) -> Self::Primitive {
4001 B::int_sign(tensor)
4002 }
4003
4004 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4005 B::int_sort(tensor, dim, descending)
4006 }
4007
4008 fn sort_with_indices(
4009 tensor: Self::Primitive,
4010 dim: usize,
4011 descending: bool,
4012 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4013 B::int_sort_with_indices(tensor, dim, descending)
4014 }
4015
4016 fn argsort(
4017 tensor: Self::Primitive,
4018 dim: usize,
4019 descending: bool,
4020 ) -> <Int as TensorKind<B>>::Primitive {
4021 B::int_argsort(tensor, dim, descending)
4022 }
4023
4024 /// Applies the matrix multiplication operation.
4025 ///
4026 /// `C = AB`
4027 ///
4028 /// # Panics
4029 ///
4030 /// If the two tensors don't have a compatible shape.
4031 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4032 B::int_matmul(lhs, rhs)
4033 }
4034}
4035
4036impl<B: Backend> Numeric<B> for Float {
4037 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4038 q_bin_ops!(lhs, rhs, float_add, q_add)
4039 }
4040
4041 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4042 match lhs {
4043 TensorPrimitive::Float(lhs) => {
4044 TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
4045 }
4046 TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs.elem()),
4047 }
4048 }
4049
4050 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4051 q_bin_ops!(lhs, rhs, float_sub, q_sub)
4052 }
4053
4054 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4055 match lhs {
4056 TensorPrimitive::Float(lhs) => {
4057 TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
4058 }
4059 TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs.elem()),
4060 }
4061 }
4062
4063 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4064 q_bin_ops!(lhs, rhs, float_div, q_div)
4065 }
4066
4067 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4068 match lhs {
4069 TensorPrimitive::Float(lhs) => {
4070 TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
4071 }
4072 TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs.elem()),
4073 }
4074 }
4075 fn remainder(
4076 lhs: Self::Primitive,
4077 rhs: Self::Primitive,
4078 ) -> <Float as TensorKind<B>>::Primitive {
4079 TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))
4080 }
4081
4082 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4083 TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem()))
4084 }
4085
4086 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
4087 q_bin_ops!(lhs, rhs, float_mul, q_mul)
4088 }
4089
4090 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4091 match lhs {
4092 TensorPrimitive::Float(lhs) => {
4093 TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
4094 }
4095 TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs.elem()),
4096 }
4097 }
4098 fn neg(tensor: Self::Primitive) -> Self::Primitive {
4099 match tensor {
4100 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
4101 TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),
4102 }
4103 }
4104 fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
4105 TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))
4106 }
4107 fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive {
4108 TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))
4109 }
4110
4111 fn sum(tensor: Self::Primitive) -> Self::Primitive {
4112 match tensor {
4113 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
4114 TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),
4115 }
4116 }
4117
4118 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4119 match tensor {
4120 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
4121 TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),
4122 }
4123 }
4124
4125 fn prod(tensor: Self::Primitive) -> Self::Primitive {
4126 match tensor {
4127 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
4128 TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),
4129 }
4130 }
4131
4132 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4133 match tensor {
4134 TensorPrimitive::Float(tensor) => {
4135 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
4136 }
4137 TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),
4138 }
4139 }
4140
4141 fn mean(tensor: Self::Primitive) -> Self::Primitive {
4142 match tensor {
4143 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
4144 TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),
4145 }
4146 }
4147
4148 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4149 match tensor {
4150 TensorPrimitive::Float(tensor) => {
4151 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
4152 }
4153 TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),
4154 }
4155 }
4156
4157 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4158 match tensor {
4159 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),
4160 TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),
4161 }
4162 }
4163
4164 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4165 match tensor {
4166 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),
4167 TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),
4168 }
4169 }
4170
4171 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4172 match tensor {
4173 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),
4174 TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),
4175 }
4176 }
4177
4178 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4179 match tensor {
4180 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),
4181 TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),
4182 }
4183 }
4184
4185 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4186 B::float_equal_elem(lhs.tensor(), rhs)
4187 }
4188 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4189 B::float_not_equal_elem(lhs.tensor(), rhs)
4190 }
4191 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4192 B::float_greater(lhs.tensor(), rhs.tensor())
4193 }
4194
4195 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4196 B::float_greater_elem(lhs.tensor(), rhs)
4197 }
4198
4199 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4200 B::float_greater_equal(lhs.tensor(), rhs.tensor())
4201 }
4202
4203 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4204 B::float_greater_equal_elem(lhs.tensor(), rhs)
4205 }
4206
4207 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4208 B::float_lower(lhs.tensor(), rhs.tensor())
4209 }
4210
4211 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4212 B::float_lower_elem(lhs.tensor(), rhs)
4213 }
4214
4215 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4216 B::float_lower_equal(lhs.tensor(), rhs.tensor())
4217 }
4218
4219 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4220 B::float_lower_equal_elem(lhs.tensor(), rhs)
4221 }
4222
4223 fn mask_where(
4224 tensor: Self::Primitive,
4225 mask: B::BoolTensorPrimitive,
4226 source: Self::Primitive,
4227 ) -> Self::Primitive {
4228 TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))
4229 }
4230
4231 fn mask_fill(
4232 tensor: Self::Primitive,
4233 mask: B::BoolTensorPrimitive,
4234 value: Self::Elem,
4235 ) -> Self::Primitive {
4236 TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))
4237 }
4238
4239 fn gather(
4240 dim: usize,
4241 tensor: Self::Primitive,
4242 indices: B::IntTensorPrimitive,
4243 ) -> Self::Primitive {
4244 match tensor {
4245 TensorPrimitive::Float(tensor) => {
4246 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
4247 }
4248 TensorPrimitive::QFloat(tensor) => {
4249 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
4250 }
4251 }
4252 }
4253
4254 fn scatter(
4255 dim: usize,
4256 tensor: Self::Primitive,
4257 indices: B::IntTensorPrimitive,
4258 values: Self::Primitive,
4259 ) -> Self::Primitive {
4260 TensorPrimitive::Float(B::float_scatter(
4261 dim,
4262 tensor.tensor(),
4263 indices,
4264 values.tensor(),
4265 ))
4266 }
4267
4268 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4269 match tensor {
4270 TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
4271 TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
4272 }
4273 }
4274
4275 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4276 match tensor {
4277 TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
4278 TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
4279 }
4280 }
4281
4282 fn max(tensor: Self::Primitive) -> Self::Primitive {
4283 match tensor {
4284 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
4285 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
4286 }
4287 }
4288
4289 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4290 match tensor {
4291 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
4292 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
4293 }
4294 }
4295
4296 fn max_dim_with_indices(
4297 tensor: Self::Primitive,
4298 dim: usize,
4299 ) -> (Self::Primitive, IntTensor<B>) {
4300 match tensor {
4301 TensorPrimitive::Float(tensor) => {
4302 let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
4303 (TensorPrimitive::Float(values), indices)
4304 }
4305 TensorPrimitive::QFloat(tensor) => {
4306 let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
4307 (TensorPrimitive::QFloat(values), indices)
4308 }
4309 }
4310 }
4311
4312 fn min(tensor: Self::Primitive) -> Self::Primitive {
4313 match tensor {
4314 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
4315 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
4316 }
4317 }
4318
4319 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4320 match tensor {
4321 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
4322 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
4323 }
4324 }
4325
4326 fn min_dim_with_indices(
4327 tensor: Self::Primitive,
4328 dim: usize,
4329 ) -> (Self::Primitive, IntTensor<B>) {
4330 match tensor {
4331 TensorPrimitive::Float(tensor) => {
4332 let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
4333 (TensorPrimitive::Float(values), indices)
4334 }
4335 TensorPrimitive::QFloat(tensor) => {
4336 let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
4337 (TensorPrimitive::QFloat(values), indices)
4338 }
4339 }
4340 }
4341
4342 fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4343 match tensor {
4344 TensorPrimitive::Float(tensor) => {
4345 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4346 }
4347 TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),
4348 }
4349 }
4350
4351 fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4352 match tensor {
4353 TensorPrimitive::Float(tensor) => {
4354 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4355 }
4356 TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),
4357 }
4358 }
4359
4360 fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4361 match tensor {
4362 TensorPrimitive::Float(tensor) => {
4363 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4364 }
4365 TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),
4366 }
4367 }
4368
4369 fn abs(tensor: Self::Primitive) -> Self::Primitive {
4370 match tensor {
4371 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4372 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4373 }
4374 }
4375
4376 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4377 q_bin_ops!(lhs, rhs, float_powf, q_powf)
4378 }
4379
4380 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4381 match lhs {
4382 TensorPrimitive::Float(lhs) => {
4383 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4384 }
4385 TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()),
4386 }
4387 }
4388
4389 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4390 q_bin_ops!(lhs, rhs, float_powf, q_powf)
4391 }
4392
4393 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4394 match lhs {
4395 TensorPrimitive::Float(lhs) => {
4396 TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem()))
4397 }
4398 TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()),
4399 }
4400 }
4401
4402 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4403 TensorPrimitive::Float(B::float_random(shape, distribution, device))
4404 }
4405
4406 fn sign(tensor: Self::Primitive) -> Self::Primitive {
4407 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4408 }
4409
4410 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4411 match tensor {
4412 TensorPrimitive::Float(tensor) => {
4413 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4414 }
4415 TensorPrimitive::QFloat(tensor) => {
4416 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4417 }
4418 }
4419 }
4420
4421 fn sort_with_indices(
4422 tensor: Self::Primitive,
4423 dim: usize,
4424 descending: bool,
4425 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4426 match tensor {
4427 TensorPrimitive::Float(tensor) => {
4428 let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4429 (TensorPrimitive::Float(values), indices)
4430 }
4431 TensorPrimitive::QFloat(tensor) => {
4432 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4433 (TensorPrimitive::QFloat(values), indices)
4434 }
4435 }
4436 }
4437
4438 fn argsort(
4439 tensor: Self::Primitive,
4440 dim: usize,
4441 descending: bool,
4442 ) -> <Int as TensorKind<B>>::Primitive {
4443 match tensor {
4444 TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4445 TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4446 }
4447 }
4448
4449 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
4450 match tensor {
4451 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
4452 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
4453 }
4454 }
4455
4456 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4457 match tensor {
4458 TensorPrimitive::Float(tensor) => {
4459 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
4460 }
4461 TensorPrimitive::QFloat(tensor) => {
4462 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
4463 }
4464 }
4465 }
4466
4467 /// Applies the matrix multiplication operation.
4468 ///
4469 /// `C = AB`
4470 ///
4471 /// # Panics
4472 ///
4473 /// If the two tensors don't have a compatible shape.
4474 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4475 match (lhs, rhs) {
4476 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4477 TensorPrimitive::Float(B::float_matmul(lhs, rhs))
4478 }
4479 (lhs, rhs) => B::q_matmul(lhs, rhs),
4480 }
4481 }
4482}
4483
4484// Tensor + tensor
4485impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>
4486where
4487 K::Elem: Element,
4488{
4489 type Output = Self;
4490
4491 fn add(self, rhs: Self) -> Self::Output {
4492 Self::add(self, rhs)
4493 }
4494}
4495
4496// Tensor + scalar
4497impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>
4498 for Tensor<B, D, K>
4499where
4500 K::Elem: Element,
4501{
4502 type Output = Self;
4503
4504 fn add(self, other: E) -> Self::Output {
4505 Tensor::add_scalar(self, other)
4506 }
4507}
4508
4509// Scalar + tensor
4510macro_rules! impl_tensor_scalar_add {
4511 ($($t:ty),*) => {
4512 $(
4513 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t
4514 where
4515 K::Elem: Element,
4516 {
4517 type Output = Tensor<B, D, K>;
4518
4519 fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {
4520 Tensor::add_scalar(tensor, self)
4521 }
4522 }
4523 )*
4524 }
4525}
4526impl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);
4527
4528// Tensor - tensor
4529impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>
4530where
4531 K::Elem: Element,
4532{
4533 type Output = Self;
4534
4535 fn sub(self, rhs: Self) -> Self::Output {
4536 Tensor::sub(self, rhs)
4537 }
4538}
4539
4540// Tensor - scalar
4541impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>
4542 for Tensor<B, D, K>
4543where
4544 K::Elem: Element,
4545{
4546 type Output = Self;
4547
4548 fn sub(self, other: E) -> Self::Output {
4549 Tensor::sub_scalar(self, other)
4550 }
4551}
4552
4553// Scalar - tensor
4554macro_rules! impl_tensor_scalar_sub {
4555 ($($t:ty),*) => {
4556 $(
4557 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t
4558 where
4559 K::Elem: Element,
4560 {
4561 type Output = Tensor<B, D, K>;
4562
4563 fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {
4564 Tensor::add_scalar(Tensor::neg(tensor), self)
4565 }
4566 }
4567 )*
4568 }
4569}
4570impl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);
4571
4572// Tensor / tensor
4573impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>
4574where
4575 K::Elem: Element,
4576{
4577 type Output = Self;
4578
4579 fn div(self, rhs: Self) -> Self::Output {
4580 Tensor::div(self, rhs)
4581 }
4582}
4583
4584// Tensor / scalar
4585impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>
4586 for Tensor<B, D, K>
4587where
4588 K::Elem: Element,
4589{
4590 type Output = Self;
4591
4592 fn div(self, other: E) -> Self::Output {
4593 Tensor::div_scalar(self, other)
4594 }
4595}
4596
4597// Scalar / tensor (float only)
4598macro_rules! impl_tensor_scalar_div {
4599 ($($t:ty),*) => {
4600 $(
4601 impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t
4602 {
4603 type Output = Tensor<B, D>;
4604
4605 fn div(self, tensor: Tensor<B, D>) -> Self::Output {
4606 tensor.recip().mul_scalar(self)
4607 }
4608 }
4609 )*
4610 }
4611}
4612
4613impl_tensor_scalar_div!(f32, f64);
4614
4615// Tensor % tensor.
4616impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>
4617where
4618 K::Elem: Element,
4619{
4620 type Output = Self;
4621
4622 fn rem(self, rhs: Self) -> Self::Output {
4623 Tensor::remainder(self, rhs)
4624 }
4625}
4626
4627// Tensor % scalar.
4628impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>
4629 for Tensor<B, D, K>
4630where
4631 K::Elem: Element,
4632{
4633 type Output = Self;
4634
4635 fn rem(self, other: E) -> Self::Output {
4636 Tensor::remainder_scalar(self, other)
4637 }
4638}
4639
4640// Tensor * tensor.
4641impl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>
4642where
4643 K::Elem: Element,
4644{
4645 type Output = Self;
4646
4647 fn mul(self, rhs: Self) -> Self::Output {
4648 Tensor::mul(self, rhs)
4649 }
4650}
4651
4652// Tensor * scalar.
4653impl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>
4654 for Tensor<B, D, K>
4655where
4656 K::Elem: Element,
4657{
4658 type Output = Self;
4659
4660 fn mul(self, other: E) -> Self::Output {
4661 Tensor::mul_scalar(self, other)
4662 }
4663}
4664
4665macro_rules! impl_tensor_scalar_mul {
4666 ($($t:ty),*) => {
4667 $(
4668 impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t
4669 where
4670 K::Elem: Element,
4671 {
4672 type Output = Tensor<B, D, K>;
4673
4674 fn mul(self, other: Tensor<B, D, K>) -> Self::Output {
4675 Tensor::mul_scalar(other, self)
4676 }
4677 }
4678 )*
4679 }
4680}
4681
4682impl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);
4683
4684impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4685where
4686 B: Backend,
4687 K: Numeric<B>,
4688 K::Elem: Element,
4689{
4690 type Output = Self;
4691
4692 fn neg(self) -> Self::Output {
4693 Tensor::neg(self)
4694 }
4695}