burn_tensor/tensor/api/numeric.rs
1use alloc::vec::Vec;
2
3use crate::alloc::borrow::ToOwned;
4
5use crate::TensorPrimitive;
6use crate::{
7 backend::Backend,
8 check,
9 check::TensorCheck,
10 ops::{Device, IntTensor},
11 BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
12 TensorKind,
13};
14
15impl<B, const D: usize, K> Tensor<B, D, K>
16where
17 B: Backend,
18 K: Numeric<B>,
19 K::Elem: Element,
20{
21 /// Applies element wise addition operation.
22 ///
23 /// `y = x2 + x1`
24 ///
25 /// # Arguments
26 ///
27 /// * `other` - The tensor to add.
28 ///
29 /// # Example
30 ///
31 /// ```rust
32 /// use burn_tensor::backend::Backend;
33 /// use burn_tensor::{Tensor, Shape};
34 ///
35 /// fn example<B: Backend>() {
36 /// let device = B::Device::default();
37 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
38 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
39 /// let tensor = tensor1 + tensor2;
40 /// println!("{tensor}");
41 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
42 /// }
43 /// ```
44 #[allow(clippy::should_implement_trait)]
45 pub fn add(self, other: Self) -> Self {
46 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
47 Self::new(K::add(self.primitive, other.primitive))
48 }
49
50 /// Applies element wise addition operation with a scalar.
51 ///
52 /// `y = x + s`
53 ///
54 /// # Arguments
55 ///
56 /// * `other` - The scalar to add, element wise.
57 ///
58 /// # Example
59 ///
60 /// ```rust
61 /// use burn_tensor::backend::Backend;
62 /// use burn_tensor::{Tensor, Shape};
63 ///
64 /// fn example<B: Backend>() {
65 /// let device = B::Device::default();
66 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
67 /// let scalar = 2.0;
68 /// let tensor = tensor + scalar;
69 /// println!("{tensor}");
70 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
71 /// }
72 /// ```
73 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
74 Self::new(K::add_scalar::<E>(self.primitive, other))
75 }
76
77 /// Applies element wise subtraction operation.
78 ///
79 /// `y = x2 - x1`
80 ///
81 /// # Arguments
82 ///
83 /// * `other` - The tensor to subtract.
84 ///
85 /// # Example
86 ///
87 /// ```rust
88 /// use burn_tensor::backend::Backend;
89 /// use burn_tensor::{Tensor, Shape};
90 ///
91 /// fn example<B: Backend>() {
92 /// let device = B::Device::default();
93 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
94 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
95 /// let tensor = tensor1 - tensor2;
96 /// println!("{tensor}");
97 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
98 /// }
99 /// ```
100 #[allow(clippy::should_implement_trait)]
101 pub fn sub(self, other: Self) -> Self {
102 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
103 Self::new(K::sub(self.primitive, other.primitive))
104 }
105
106 /// Applies element wise subtraction operation with a scalar.
107 ///
108 /// `y = x - s`
109 ///
110 /// # Arguments
111 ///
112 /// * `other` - The scalar to subtract, element wise.
113 ///
114 /// # Example
115 ///
116 /// ```rust
117 /// use burn_tensor::backend::Backend;
118 /// use burn_tensor::{Tensor, Shape};
119 ///
120 /// fn example<B: Backend>() {
121 /// let device = B::Device::default();
122 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
123 /// let scalar = 2.0;
124 /// let tensor = tensor - scalar;
125 /// println!("{tensor}");
126 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
127 /// }
128 /// ```
129 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
130 Self::new(K::sub_scalar::<E>(self.primitive, other))
131 }
132
133 /// Applies element wise division operation.
134 ///
135 /// `y = x2 / x1`
136 ///
137 /// # Arguments
138 ///
139 /// * `other` - The tensor to divide.
140 ///
141 /// # Example
142 ///
143 /// ```rust
144 /// use burn_tensor::backend::Backend;
145 /// use burn_tensor::{Tensor, Shape};
146 ///
147 /// fn example<B: Backend>() {
148 /// let device = B::Device::default();
149 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151 /// let tensor = tensor1 / tensor2;
152 /// println!("{tensor}");
153 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154 /// }
155 /// ```
156 #[allow(clippy::should_implement_trait)]
157 pub fn div(self, other: Self) -> Self {
158 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159 Self::new(K::div(self.primitive, other.primitive))
160 }
161
162 /// Applies element wise division operation with a scalar.
163 ///
164 /// `y = x / s`
165 ///
166 /// # Arguments
167 ///
168 /// * `other` - The scalar to divide, element wise.
169 ///
170 /// # Example
171 ///
172 /// ```rust
173 /// use burn_tensor::backend::Backend;
174 /// use burn_tensor::{Tensor, Shape};
175 ///
176 /// fn example<B: Backend>() {
177 /// let device = B::Device::default();
178 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179 /// let scalar = 2.0;
180 /// let tensor = tensor / scalar;
181 /// println!("{tensor}");
182 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183 /// }
184 /// ```
185 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186 Self::new(K::div_scalar::<E>(self.primitive, other))
187 }
188
189 /// Applies element wise the remainder operation with a scalar.
190 ///
191 /// `y = x2 % x1`
192 pub fn remainder(self, other: Self) -> Self {
193 Self::new(K::remainder(self.primitive, other.primitive))
194 }
195
196 /// Applies element wise the remainder operation with a scalar.
197 ///
198 /// `y = x2 % x1`
199 ///
200 /// # Arguments
201 ///
202 /// * `other` - The scalar to divide, element wise.
203 ///
204 /// # Example
205 ///
206 /// ```rust
207 /// use burn_tensor::backend::Backend;
208 /// use burn_tensor::{Tensor, Shape};
209 ///
210 /// fn example<B: Backend>() {
211 /// let device = B::Device::default();
212 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
213 /// let scalar = 2.0;
214 /// let tensor = tensor1 % scalar;
215 /// println!("{tensor}");
216 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
217 /// }
218 /// ```
219 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
220 Self::new(K::remainder_scalar::<E>(self.primitive, other))
221 }
222
223 /// Applies element wise multiplication operation.
224 ///
225 /// `y = x2 * x1`
226 ///
227 /// # Arguments
228 ///
229 /// * `other` - The tensor to multiply.
230 ///
231 /// # Example
232 ///
233 /// ```rust
234 /// use burn_tensor::backend::Backend;
235 /// use burn_tensor::{Tensor, Shape};
236 ///
237 /// fn example<B: Backend>() {
238 /// let device = B::Device::default();
239 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
240 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
241 /// let tensor = tensor1 * tensor2;
242 /// println!("{tensor}");
243 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
244 /// }
245 /// ```
246 #[allow(clippy::should_implement_trait)]
247 pub fn mul(self, other: Self) -> Self {
248 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
249 Self::new(K::mul(self.primitive, other.primitive))
250 }
251
252 /// Applies element wise multiplication operation with a scalar.
253 ///
254 /// `y = x * s`
255 ///
256 /// # Arguments
257 ///
258 /// * `other` - The scalar to multiply, element wise.
259 ///
260 /// # Example
261 ///
262 /// ```rust
263 /// use burn_tensor::backend::Backend;
264 /// use burn_tensor::{Tensor, Shape};
265 ///
266 /// fn example<B: Backend>() {
267 /// let device = B::Device::default();
268 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
269 /// let scalar = 2.0;
270 /// let tensor = tensor * scalar;
271 /// println!("{tensor}");
272 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
273 /// }
274 /// ```
275 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
276 Self::new(K::mul_scalar::<E>(self.primitive, other))
277 }
278
279 /// Switch sign of each element in the tensor.
280 ///
281 /// `y = -x`
282 ///
283 /// # Example
284 ///
285 /// ```rust
286 /// use burn_tensor::backend::Backend;
287 /// use burn_tensor::{Tensor, Shape};
288 ///
289 /// fn example<B: Backend>() {
290 /// let device = B::Device::default();
291 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
292 /// let tensor = -tensor;
293 /// println!("{tensor}");
294 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
295 /// }
296 /// ```
297 #[allow(clippy::should_implement_trait)]
298 pub fn neg(self) -> Self {
299 Self::new(K::neg(self.primitive))
300 }
301
302 /// Returns the signs of the elements of the input tensor.
303 ///
304 /// # Example
305 ///
306 /// ```rust
307 /// use burn_tensor::backend::Backend;
308 /// use burn_tensor::{Tensor, Shape};
309 ///
310 /// fn example<B: Backend>() {
311 /// let device = B::Device::default();
312 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
313 /// let tensor = tensor.sign();
314 /// println!("{tensor}");
315 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
316 /// }
317 /// ```
318 pub fn sign(self) -> Self {
319 Self::new(K::sign(self.primitive))
320 }
321
322 /// Create a tensor of the given shape where each element is zero.
323 ///
324 /// # Example
325 ///
326 /// ```rust
327 /// use burn_tensor::backend::Backend;
328 /// use burn_tensor::{Tensor, Shape};
329 ///
330 /// fn example<B: Backend>() {
331 /// let device = B::Device::default();
332 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
333 /// println!("{tensor}");
334 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
335 /// }
336 /// ```
337 pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
338 let shape = shape.into();
339 check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
340 Self::new(K::zeros(shape, device))
341 }
342
343 /// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
344 ///
345 /// # Example
346 ///
347 /// ```rust
348 /// use burn_tensor::backend::Backend;
349 /// use burn_tensor::{Tensor, Shape};
350 ///
351 /// fn example<B: Backend>() {
352 /// let device = B::Device::default();
353 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
354 /// let tensor = tensor.zeros_like();
355 /// println!("{tensor}");
356 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
357 /// }
358 /// ```
359 pub fn zeros_like(&self) -> Self {
360 Self::zeros(self.shape(), &self.device())
361 }
362
363 /// Create a tensor of the given shape where each element is one.
364 ///
365 /// # Example
366 ///
367 /// ```rust
368 /// use burn_tensor::backend::Backend;
369 /// use burn_tensor::{Tensor, Shape};
370 ///
371 /// fn example<B: Backend>() {
372 /// let device = B::Device::default();
373 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
374 /// println!("{tensor}");
375 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
376 /// }
377 /// ```
378 pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
379 let shape = shape.into();
380 check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
381 Self::new(K::ones(shape, device))
382 }
383
384 /// Returns a new tensor with the same shape and device as the current tensor filled with ones.
385 ///
386 /// # Example
387 ///
388 /// ```rust
389 /// use burn_tensor::backend::Backend;
390 /// use burn_tensor::{Tensor, Shape};
391 ///
392 /// fn example<B: Backend>() {
393 /// let device = B::Device::default();
394 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395 /// let tensor = tensor.ones_like();
396 /// println!("{tensor}");
397 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
398 /// }
399 /// ```
400 pub fn ones_like(&self) -> Self {
401 Self::ones(self.shape(), &self.device())
402 }
403
404 /// Create a tensor of the given shape where each element is equal to the provided value.
405 ///
406 /// # Example
407 ///
408 /// ```rust
409 /// use burn_tensor::backend::Backend;
410 /// use burn_tensor::{Tensor, Shape};
411 ///
412 /// fn example<B: Backend>() {
413 /// let device = B::Device::default();
414 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
415 /// println!("{tensor}");
416 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
417 /// }
418 /// ```
419 pub fn full<S: Into<Shape>, E: ElementConversion>(
420 shape: S,
421 fill_value: E,
422 device: &B::Device,
423 ) -> Self {
424 let shape = shape.into();
425 check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
426 Self::new(K::full(shape, fill_value, device))
427 }
428
429 ///Returns a new tensor with the same shape and device as the current tensor filled with the provided value.
430 ///
431 /// # Example
432 ///
433 /// ```rust
434 /// use burn_tensor::backend::Backend;
435 /// use burn_tensor::{Tensor, Shape};
436 ///
437 /// fn example<B: Backend>() {
438 /// let device = B::Device::default();
439 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
440 /// let tensor = tensor.full_like(5.0);
441 /// println!("{tensor}");
442 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
443 /// }
444 /// ```
445 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
446 Self::full(self.shape(), fill_value, &self.device())
447 }
448
449 /// Aggregate all elements in the tensor with the mean operation.
450 ///
451 /// # Example
452 ///
453 /// ```rust
454 /// use burn_tensor::backend::Backend;
455 /// use burn_tensor::{Tensor, Shape};
456 ///
457 /// fn example<B: Backend>() {
458 /// let device = B::Device::default();
459 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
460 /// let tensor = tensor.mean();
461 /// println!("{tensor}");
462 /// // [3.6666667]
463 /// }
464 /// ```
465 pub fn mean(self) -> Tensor<B, 1, K> {
466 Tensor::new(K::mean(self.primitive))
467 }
468
469 /// Aggregate all elements in the tensor with the sum operation.
470 ///
471 /// # Example
472 ///
473 /// ```rust
474 /// use burn_tensor::backend::Backend;
475 /// use burn_tensor::{Tensor, Shape};
476 ///
477 /// fn example<B: Backend>() {
478 /// let device = B::Device::default();
479 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
480 /// let tensor = tensor.sum();
481 /// println!("{tensor}");
482 /// // [22.0]
483 /// }
484 /// ```
485 pub fn sum(self) -> Tensor<B, 1, K> {
486 Tensor::new(K::sum(self.primitive))
487 }
488
489 /// Aggregate all elements along the given *dimension* or *axis*
490 /// in the tensor with the mean operation.
491 ///
492 /// # Arguments
493 ///
494 /// * `dim` - The dimension or axis along which to aggregate the elements.
495 ///
496 /// # Example
497 ///
498 /// ```rust
499 /// use burn_tensor::backend::Backend;
500 /// use burn_tensor::{Tensor, Shape};
501 ///
502 /// fn example<B: Backend>() {
503 /// let device = B::Device::default();
504 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
505 /// let tensor = tensor.clone().mean_dim(0);
506 /// println!("{tensor}");
507 /// // [[3.0, 3.5, 4.5]]
508 /// let tensor = tensor.clone().mean_dim(1);
509 /// println!("{tensor}");
510 /// // [[0.6666667], [6.6666665]]
511 /// }
512 /// ```
513 pub fn mean_dim(self, dim: usize) -> Self {
514 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
515 Self::new(K::mean_dim(self.primitive, dim))
516 }
517
518 /// Aggregate all elements along the given *dimension* or *axis*
519 /// in the tensor with the sum operation.
520 ///
521 /// # Arguments
522 ///
523 /// * `dim` - The dimension or axis along which to aggregate the elements.
524 ///
525 /// # Example
526 ///
527 /// ```rust
528 /// use burn_tensor::backend::Backend;
529 /// use burn_tensor::{Tensor, Shape};
530 ///
531 /// fn example<B: Backend>() {
532 /// let device = B::Device::default();
533 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534 /// let tensor = tensor.clone().sum_dim(0);
535 /// println!("{tensor}");
536 /// let tensor = tensor.clone().sum_dim(1);
537 /// // [[6.0, 7.0, 9.0]]
538 /// println!("{tensor}");
539 /// // [[2.0], [20.0]]
540 /// }
541 /// ```
542 pub fn sum_dim(self, dim: usize) -> Self {
543 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
544 Self::new(K::sum_dim(self.primitive, dim))
545 }
546
547 /// Aggregate all elements along the given *dimension* or *axis*
548 /// in the tensor with the product operation.
549 ///
550 /// # Example
551 ///
552 /// ```rust
553 /// use burn_tensor::backend::Backend;
554 /// use burn_tensor::{Tensor, Shape};
555 ///
556 /// fn example<B: Backend>() {
557 /// let device = B::Device::default();
558 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
559 /// let tensor = tensor.prod();
560 /// println!("{tensor}");
561 /// // [-1620.0]
562 /// }
563 /// ```
564 pub fn prod(self) -> Tensor<B, 1, K> {
565 Tensor::new(K::prod(self.primitive))
566 }
567
568 /// Aggregate all elements along the given *dimension* or *axis*
569 /// in the tensor with the product operation.
570 ///
571 /// # Arguments
572 ///
573 /// * `dim` - The dimension or axis along which to aggregate the elements.
574 ///
575 /// # Example
576 ///
577 /// ```rust
578 /// use burn_tensor::backend::Backend;
579 /// use burn_tensor::{Tensor, Shape};
580 ///
581 /// fn example<B: Backend>() {
582 /// let device = B::Device::default();
583 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
584 /// let tensor = tensor.clone().prod_dim(0);
585 /// println!("{tensor}");
586 /// // [[5.0, -18.0, 18.0]]
587 /// let tensor = tensor.clone().prod_dim(1);
588 /// println!("{tensor}");
589 /// // [[-6.0], [270.0]]
590 /// }
591 /// ```
592 pub fn prod_dim(self, dim: usize) -> Self {
593 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
594 Self::new(K::prod_dim(self.primitive, dim))
595 }
596
597 /// Applies element wise equal comparison and returns a boolean tensor.
598 ///
599 /// # Arguments
600 ///
601 /// * `other` - The element to compare.
602 ///
603 /// # Example
604 ///
605 /// ```rust
606 /// use burn_tensor::backend::Backend;
607 /// use burn_tensor::{Tensor, Shape};
608 ///
609 /// fn example<B: Backend>() {
610 /// let device = B::Device::default();
611 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
612 /// let tensor = tensor.equal_elem(3.0);
613 /// println!("{tensor}");
614 /// // [[false, false, true], [false, false, false]]
615 /// }
616 /// ```
617 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
618 Tensor::new(K::equal_elem(self.primitive, other.elem()))
619 }
620
621 /// Applies element wise non-equality comparison and returns a boolean tensor.
622 ///
623 /// # Arguments
624 ///
625 /// * `other` - The element to compare.
626 ///
627 /// # Example
628 ///
629 /// ```rust
630 /// use burn_tensor::backend::Backend;
631 /// use burn_tensor::{Tensor, Shape};
632 ///
633 /// fn example<B: Backend>() {
634 /// let device = B::Device::default();
635 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
636 /// let tensor = tensor.not_equal_elem(3.0);
637 /// println!("{tensor}");
638 /// // [[true, true, false], [true, true, true]]
639 /// }
640 /// ```
641 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
642 Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
643 }
644
645 /// Applies element wise greater comparison and returns a boolean tensor.
646 ///
647 /// # Panics
648 ///
649 /// If the two tensors don't have the same shape.
650 ///
651 /// # Example
652 ///
653 /// ```rust
654 /// use burn_tensor::backend::Backend;
655 /// use burn_tensor::{Tensor, Shape};
656 ///
657 /// fn example<B: Backend>() {
658 /// let device = B::Device::default();
659 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
660 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
661 /// let tensor = tensor1.greater(tensor2);
662 /// println!("{tensor}");
663 /// // [[false, false, false], [true, true, true]]
664 /// }
665 /// ```
666 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
667 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
668 Tensor::new(K::greater(self.primitive, other.primitive))
669 }
670
671 /// Applies element wise greater-equal comparison and returns a boolean tensor.
672 ///
673 /// # Panics
674 ///
675 /// If the two tensors don't have the same shape.
676 ///
677 /// # Example
678 ///
679 /// ```rust
680 /// use burn_tensor::backend::Backend;
681 /// use burn_tensor::{Tensor, Shape};
682 ///
683 /// fn example<B: Backend>() {
684 /// let device = B::Device::default();
685 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
686 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
687 /// let tensor = tensor1.greater_equal(tensor2);
688 /// println!("{tensor}");
689 /// // [[false, false, false], [true, true, true]]
690 /// }
691 /// ```
692 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
693 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
694 Tensor::new(K::greater_equal(self.primitive, other.primitive))
695 }
696
697 /// Applies element wise lower comparison and returns a boolean tensor.
698 ///
699 /// # Panics
700 ///
701 /// If the two tensors don't have the same shape.
702 ///
703 /// # Example
704 ///
705 /// ```rust
706 /// use burn_tensor::backend::Backend;
707 /// use burn_tensor::{Tensor, Shape};
708 ///
709 /// fn example<B: Backend>() {
710 /// let device = B::Device::default();
711 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
712 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
713 /// let tensor = tensor1.lower(tensor2);
714 /// println!("{tensor}");
715 /// // [[true, true, true], [false, false, false]]
716 /// }
717 /// ```
718 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
719 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
720 Tensor::new(K::lower(self.primitive, other.primitive))
721 }
722
723 /// Applies element wise lower-equal comparison and returns a boolean tensor.
724 ///
725 /// # Panics
726 ///
727 /// If the two tensors don't have the same shape.
728 ///
729 /// # Example
730 ///
731 /// ```rust
732 /// use burn_tensor::backend::Backend;
733 /// use burn_tensor::{Tensor, Shape};
734 ///
735 /// fn example<B: Backend>() {
736 /// let device = B::Device::default();
737 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
738 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
739 /// let tensor = tensor1.lower_equal(tensor2);
740 /// println!("{tensor}");
741 /// // [[true, true, true], [false, false, false]]
742 /// }
743 /// ```
744 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
745 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
746 Tensor::new(K::lower_equal(self.primitive, other.primitive))
747 }
748
749 /// Applies element wise greater comparison and returns a boolean tensor.
750 ///
751 /// # Arguments
752 ///
753 /// * `other` - The element to compare.
754 ///
755 /// # Example
756 ///
757 /// ```rust
758 /// use burn_tensor::backend::Backend;
759 /// use burn_tensor::{Tensor, Shape};
760 ///
761 /// fn example<B: Backend>() {
762 /// let device = B::Device::default();
763 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
764 /// let tensor = tensor.greater_elem(3.0);
765 /// println!("{tensor}");
766 /// // [[false, false, true], [true, true, true]]
767 /// }
768 /// ```
769 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
770 Tensor::new(K::greater_elem(self.primitive, other.elem()))
771 }
772
773 /// Applies element wise greater-equal comparison and returns a boolean tensor.
774 ///
775 /// # Arguments
776 ///
777 /// * `other` - The element to compare.
778 ///
779 /// # Example
780 ///
781 /// ```rust
782 /// use burn_tensor::backend::Backend;
783 /// use burn_tensor::{Tensor, Shape};
784 ///
785 /// fn example<B: Backend>() {
786 /// let device = B::Device::default();
787 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
788 /// let tensor = tensor.greater_equal_elem(3.0);
789 /// println!("{tensor}");
790 /// // [[false, false, true], [true, true, true]]
791 /// }
792 /// ```
793 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
794 Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
795 }
796
797 /// Applies element wise lower comparison and returns a boolean tensor.
798 ///
799 /// # Arguments
800 ///
801 /// * `other` - The element to compare.
802 ///
803 /// # Example
804 ///
805 /// ```rust
806 /// use burn_tensor::backend::Backend;
807 /// use burn_tensor::{Tensor, Shape};
808 ///
809 /// fn example<B: Backend>() {
810 /// let device = B::Device::default();
811 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
812 /// let tensor = tensor.lower_elem(3.0);
813 /// println!("{tensor}");
814 /// // [[true, true, false], [false, false, false]]
815 /// }
816 /// ```
817 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
818 Tensor::new(K::lower_elem(self.primitive, other.elem()))
819 }
820
821 /// Applies element wise lower-equal comparison and returns a boolean tensor.
822 ///
823 /// # Arguments
824 ///
825 /// * `other` - The element to compare.
826 ///
827 /// # Example
828 ///
829 /// ```rust
830 /// use burn_tensor::backend::Backend;
831 /// use burn_tensor::{Tensor, Shape};
832 ///
833 /// fn example<B: Backend>() {
834 /// let device = B::Device::default();
835 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
836 /// let tensor = tensor.lower_equal_elem(3.0);
837 /// println!("{tensor}");
838 /// // [[true, true, true], [false, false, false]]
839 /// }
840 /// ```
841 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
842 Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
843 }
844
845 /// Update the given tensor with the value tensor where the mask is true.
846 ///
847 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
848 /// a scalar.
849 ///
850 /// # Example
851 ///
852 /// ```rust
853 /// use burn_tensor::backend::Backend;
854 /// use burn_tensor::{Tensor, Shape, Bool};
855 ///
856 /// fn example<B: Backend>() {
857 /// let device = B::Device::default();
858 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
859 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
860 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
861 /// let tensor = tensor.mask_where(mask, value);
862 /// println!("{tensor}");
863 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
864 /// }
865 /// ```
866 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
867 Self::new(K::mask_where(
868 self.primitive,
869 mask.primitive,
870 value.primitive,
871 ))
872 }
873
874 /// Update the given tensor with the value where the mask is true.
875 ///
876 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
877 /// a tensor.
878 ///
879 /// # Example
880 ///
881 /// ```rust
882 /// use burn_tensor::backend::Backend;
883 /// use burn_tensor::{Tensor, Shape, Bool};
884 ///
885 /// fn example<B: Backend>() {
886 /// let device = B::Device::default();
887 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
888 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
889 /// let tensor = tensor.mask_fill(mask, 3.0);
890 /// println!("{tensor}");
891 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
892 /// }
893 /// ```
894 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
895 Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
896 }
897
898 /// Gather tensor elements corresponding to the given indices from the specified dim.
899 ///
900 /// Example using a 3D tensor:
901 ///
902 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
903 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
904 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
905 ///
906 /// # Notes
907 ///
908 /// The index tensor should have the same shape as the original tensor except for the dim
909 /// specified.
910 ///
911 /// # Warning
912 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
913 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
914 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
915 check!(TensorCheck::gather::<D>(
916 dim,
917 &self.shape(),
918 &indices.shape()
919 ));
920
921 Self::new(K::gather(dim, self.primitive, indices.primitive))
922 }
923
924 /// Assign the gathered elements corresponding to the given indices along the specified dimension
925 /// from the value tensor to the original tensor using sum reduction.
926 ///
927 /// Example using a 3D tensor:
928 ///
929 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
930 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
931 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
932 ///
933 /// # Notes
934 ///
935 /// The index tensor should have the same shape as the original tensor except for the specified
936 /// dimension. The value and index tensors should have the same shape.
937 ///
938 /// Other references to the input tensor will not be modified by this operation.
939 ///
940 /// # Warning
941 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
942 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
943 pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
944 check!(TensorCheck::scatter::<D>(
945 dim,
946 &self.shape(),
947 &indices.shape(),
948 &values.shape()
949 ));
950
951 Self::new(K::scatter(
952 dim,
953 self.primitive,
954 indices.primitive,
955 values.primitive,
956 ))
957 }
958
959 /// Select the tensor elements along the given dimension corresponding to the given indices.
960 ///
961 /// Example using a 3D tensor:
962 ///
963 /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
964 /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
965 /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
966 ///
967 /// # Warning
968 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
969 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
970 ///
971 /// # Example
972 ///
973 /// ```rust
974 /// use burn_tensor::backend::Backend;
975 /// use burn_tensor::{Tensor, Shape, Int};
976 ///
977 /// fn example<B: Backend>() {
978 /// let device = B::Device::default();
979 /// let tensor = Tensor::<B, 3>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
980 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
981 /// let tensor = tensor.select(0, indices);
982 /// println!("{tensor}");
983 /// // [[1.0, -2.0, 3.0]]
984 /// }
985 /// ```
986 pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
987 check!(TensorCheck::select::<D>(dim));
988 Self::new(K::select(self.primitive, dim, indices))
989 }
990
991 /// Assign the selected elements along the given dimension corresponding to the given indices
992 /// from the value tensor to the original tensor using sum reduction.
993 ///
994 /// Example using a 3D tensor:
995 ///
996 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
997 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
998 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
999 ///
1000 /// # Warning
1001 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1002 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1003 pub fn select_assign(
1004 self,
1005 dim: usize,
1006 indices: Tensor<B, 1, Int>,
1007 values: Tensor<B, D, K>,
1008 ) -> Self {
1009 check!(TensorCheck::select_assign::<D>(dim));
1010
1011 Self::new(K::select_assign(
1012 self.primitive,
1013 dim,
1014 indices,
1015 values.primitive,
1016 ))
1017 }
1018
1019 /// Applies the argmax function along the given dimension and returns an integer tensor.
1020 ///
1021 /// # Example
1022 ///
1023 /// ```rust
1024 /// use burn_tensor::backend::Backend;
1025 /// use burn_tensor::{Tensor, Shape};
1026 ///
1027 /// fn example<B: Backend>() {
1028 /// let device = B::Device::default();
1029 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1030 /// let tensor = tensor.argmax(1);
1031 /// println!("{:?}", tensor.shape());
1032 /// // Shape { dims: [2, 1, 3] }
1033 /// }
1034 /// ```
1035 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1036 Tensor::new(K::argmax(self.primitive, dim))
1037 }
1038
1039 /// Find the maximum value.
1040 ///
1041 /// # Example
1042 ///
1043 /// ```rust
1044 /// use burn_tensor::backend::Backend;
1045 /// use burn_tensor::{Tensor, Shape};
1046 ///
1047 /// fn example<B: Backend>() {
1048 /// let device = B::Device::default();
1049 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1050 /// let tensor = tensor.max();
1051 /// println!("{tensor}");
1052 /// // [9.0]
1053 /// }
1054 /// ```
1055 pub fn max(self) -> Tensor<B, 1, K> {
1056 Tensor::new(K::max(self.primitive))
1057 }
1058
1059 /// Find the maximum value along the given dimension.
1060 ///
1061 /// # Example
1062 ///
1063 /// ```rust
1064 /// use burn_tensor::backend::Backend;
1065 /// use burn_tensor::{Tensor, Shape};
1066 ///
1067 /// fn example<B: Backend>() {
1068 /// let device = B::Device::default();
1069 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1070 /// let tensor = tensor.max_dim(0);
1071 /// println!("{tensor}");
1072 /// // [[5.0, 9.0, 6.0]]
1073 /// }
1074 /// ```
1075 pub fn max_dim(self, dim: usize) -> Tensor<B, D, K> {
1076 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1077
1078 Tensor::new(K::max_dim(self.primitive, dim))
1079 }
1080
1081 /// Find the maximum value along the given dimension.
1082 ///
1083 /// Also returns the indices.
1084 ///
1085 /// # Example
1086 ///
1087 /// ```rust
1088 /// use burn_tensor::backend::Backend;
1089 /// use burn_tensor::{Tensor, Shape};
1090 ///
1091 /// fn example<B: Backend>() {
1092 /// let device = B::Device::default();
1093 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1094 /// let (tensor, index) = tensor.max_dim_with_indices(0);
1095 /// // [[5.0, 9.0, 6.0]]
1096 /// println!("{tensor}");
1097 /// // [[1, 1, 1]]
1098 /// println!("{index}");
1099 /// }
1100 /// ```
1101 pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1102 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1103
1104 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1105
1106 let tensor = Tensor::new(tensor);
1107 let index = Tensor::new(index);
1108
1109 (tensor, index)
1110 }
1111
1112 /// Finds the maximum pair wise values with another Tensor
1113 ///
1114 /// # Arguments
1115 ///
1116 /// * `other` - Other tensor to find maximum elements with
1117 ///
1118 /// # Returns
1119 ///
1120 /// A tensor with the same shape as the input tensors containing the maximum value found
1121 /// in the input tensors.
1122 ///
1123 /// # Example
1124 ///
1125 /// ```rust
1126 /// use burn_tensor::backend::Backend;
1127 /// use burn_tensor::{Tensor, Shape};
1128 ///
1129 /// fn example<B: Backend>() {
1130 /// let device = B::Device::default();
1131 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1132 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1133 /// let tensor = tensor1.max_pair(tensor2);
1134 /// println!("{tensor}");
1135 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1136 /// }
1137 /// ```
1138 pub fn max_pair(self, other: Self) -> Self {
1139 let mask = self.clone().lower(other.clone());
1140 self.mask_where(mask, other)
1141 }
1142
1143 /// Applies the argmin function along the given dimension and returns an integer tensor.
1144 ///
1145 /// # Example
1146 ///
1147 /// ```rust
1148 /// use burn_tensor::backend::Backend;
1149 /// use burn_tensor::{Tensor, Shape};
1150 ///
1151 /// fn example<B: Backend>() {
1152 /// let device = Default::default();
1153 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1154 /// let tensor = tensor.argmin(1);
1155 /// println!("{:?}", tensor.shape());
1156 /// // Shape { dims: [2, 1, 3] }
1157 /// }
1158 /// ```
1159 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1160 Tensor::new(K::argmin(self.primitive, dim))
1161 }
1162
1163 /// Find the minimum value.
1164 ///
1165 /// # Example
1166 ///
1167 /// ```rust
1168 /// use burn_tensor::backend::Backend;
1169 /// use burn_tensor::{Tensor, Shape};
1170 ///
1171 /// fn example<B: Backend>() {
1172 /// let device = B::Device::default();
1173 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1174 /// let tensor = tensor.min();
1175 /// println!("{tensor}");
1176 /// // [-2.0]
1177 /// }
1178 /// ```
1179 pub fn min(self) -> Tensor<B, 1, K> {
1180 Tensor::new(K::min(self.primitive))
1181 }
1182
1183 /// Find the minimum value along the given dimension.
1184 ///
1185 /// # Example
1186 ///
1187 /// ```rust
1188 /// use burn_tensor::backend::Backend;
1189 /// use burn_tensor::{Tensor, Shape};
1190 ///
1191 /// fn example<B: Backend>() {
1192 /// let device = B::Device::default();
1193 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1194 /// let tensor = tensor.min_dim(0);
1195 /// println!("{tensor}");
1196 /// // [[1.0, -2.0, 3.0]]
1197 /// }
1198 /// ```
1199 pub fn min_dim(self, dim: usize) -> Tensor<B, D, K> {
1200 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1201 Tensor::new(K::min_dim(self.primitive, dim))
1202 }
1203
1204 /// Find the minimum value along the given dimension.
1205 ///
1206 /// Also returns the indices.
1207 ///
1208 /// # Example
1209 ///
1210 /// ```rust
1211 /// use burn_tensor::backend::Backend;
1212 /// use burn_tensor::{Tensor, Shape};
1213 ///
1214 /// fn example<B: Backend>() {
1215 /// let device = B::Device::default();
1216 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1217 /// let (tensor, index) = tensor.min_dim_with_indices(0);
1218 /// println!("{tensor}");
1219 /// // [[1.0, -2.0, 3.0]]
1220 /// println!("{}", index);
1221 /// // [[1, 0, 0]]
1222 /// }
1223 /// ```
1224 pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1225 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1226
1227 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1228
1229 let tensor = Tensor::new(tensor);
1230 let index = Tensor::new(index);
1231
1232 (tensor, index)
1233 }
1234
1235 /// Finds the minimum pair wise values with another Tensor
1236 ///
1237 /// # Arguments
1238 ///
1239 /// * `other` - Other tensor to find minimum elements with
1240 ///
1241 /// # Returns
1242 ///
1243 /// A tensor with the same shape as the input tensors containing the minimum value found
1244 /// between each element of the two source tensors.
1245 ///
1246 /// # Example
1247 ///
1248 /// ```rust
1249 /// use burn_tensor::backend::Backend;
1250 /// use burn_tensor::{Tensor, Shape};
1251 ///
1252 /// fn example<B: Backend>() {
1253 /// let device = B::Device::default();
1254 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1255 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1256 /// let tensor = tensor1.min_pair(tensor2);
1257 /// println!("{tensor}");
1258 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1259 /// }
1260 pub fn min_pair(self, other: Self) -> Self {
1261 let mask = other.clone().lower(self.clone());
1262 self.mask_where(mask, other)
1263 }
1264
1265 /// Clamp the tensor between the given min and max values.
1266 ///
1267 /// # Arguments
1268 ///
1269 /// * `min` - The minimum value.
1270 /// * `max` - The maximum value.
1271 ///
1272 /// # Returns
1273 ///
1274 /// A new tensor with the values clamped between the given min and max values.
1275 ///
1276 /// # Example
1277 ///
1278 /// ```rust
1279 /// use burn_tensor::backend::Backend;
1280 /// use burn_tensor::{Int, Tensor};
1281 ///
1282 /// fn example<B: Backend>() {
1283 /// let device = Default::default();
1284 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1285 /// [
1286 /// [1, 2, 3],
1287 /// [4, 5, 6],
1288 /// [7, 8, 9]
1289 /// ],
1290 /// &device);
1291 /// let tensor = tensor.clamp(2, 6);
1292 /// println!("{tensor}");
1293 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1294 /// }
1295 /// ```
1296 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1297 Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1298 }
1299
1300 /// Clamps a tensor under a minimum value.
1301 ///
1302 /// # Arguments
1303 ///
1304 /// * `tensor` - The tensor to clamp.
1305 /// * `min` - The minimum value.
1306 ///
1307 /// # Returns
1308 ///
1309 /// A new tensor with the values clamped under the given min value.
1310 ///
1311 /// # Example
1312 ///
1313 /// ```rust
1314 /// use burn_tensor::backend::Backend;
1315 /// use burn_tensor::{Int, Tensor};
1316 ///
1317 /// fn example<B: Backend>() {
1318 /// let device = Default::default();
1319 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1320 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1321 /// &device);
1322 /// let tensor = tensor.clamp_min(4);
1323 /// println!("{tensor}");
1324 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1325 /// }
1326 /// ```
1327 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1328 Self::new(K::clamp_min(self.primitive, min.elem()))
1329 }
1330
1331 /// Clamps a tensor over a maximum value.
1332 ///
1333 /// # Arguments
1334 ///
1335 /// * `tensor` - The tensor to clamp.
1336 /// * `max` - The maximum value.
1337 ///
1338 /// # Returns
1339 ///
1340 /// A new tensor with the values clamped over the given max value.
1341 ///
1342 /// # Example
1343 ///
1344 /// ```rust
1345 /// use burn_tensor::backend::Backend;
1346 /// use burn_tensor::{Int, Tensor};
1347 ///
1348 /// fn example<B: Backend>() {
1349 /// let device = Default::default();
1350 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1351 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1352 /// &device);
1353 /// let tensor = tensor.clamp_max(5);
1354 /// println!("{tensor}");
1355 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1356 /// }
1357 /// ```
1358 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1359 Self::new(K::clamp_max(self.primitive, max.elem()))
1360 }
1361
1362 /// Apply element wise absolute value operation
1363 ///
1364 /// # Example
1365 ///
1366 /// ```rust
1367 /// use burn_tensor::backend::Backend;
1368 /// use burn_tensor::{Int, Tensor};
1369 ///
1370 /// fn example<B: Backend>() {
1371 /// let device = Default::default();
1372 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1373 /// let tensor = tensor.abs();
1374 /// println!("{tensor}");
1375 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1376 /// }
1377 /// ```
1378 pub fn abs(self) -> Self {
1379 Self::new(K::abs(self.primitive))
1380 }
1381
1382 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1383 /// the other elements of the result tensor out are set to 0.
1384 ///
1385 /// # Example
1386 /// ```rust
1387 /// use burn_tensor::backend::Backend;
1388 /// use burn_tensor::{Int, Tensor};
1389 ///
1390 /// fn example<B: Backend>() {
1391 /// let device = Default::default();
1392 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1393 /// [
1394 /// [1, 2, 3],
1395 /// [4, 5, 6],
1396 /// [7, 8, 9]
1397 /// ],
1398 /// &device
1399 /// );
1400 /// let tensor = tensor.triu(1);
1401 /// println!("{tensor}");
1402 /// // [
1403 /// // [0, 2, 3],
1404 /// // [0, 0, 6],
1405 /// // [0, 0, 0]
1406 /// // ]
1407 /// }
1408 /// ```
1409 pub fn triu(self, diagonal: i64) -> Self {
1410 check!(TensorCheck::tri::<{ D }>());
1411
1412 // last two dimensions
1413 let shape = &self.shape().dims[D - 2..].to_owned();
1414
1415 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1416 self.mask_fill(mask, 0)
1417 }
1418
1419 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1420 /// the other elements of the result tensor out are set to 0.
1421 ///
1422 /// # Example
1423 /// ```rust
1424 /// use burn_tensor::backend::Backend;
1425 /// use burn_tensor::{Int, Tensor};
1426 ///
1427 /// fn example<B: Backend>() {
1428 /// let device = Default::default();
1429 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1430 /// [
1431 /// [1, 2, 3],
1432 /// [4, 5, 6],
1433 /// [7, 8, 9]
1434 /// ],
1435 /// &device
1436 /// );
1437 ///
1438 /// let tensor = tensor.tril(-1);
1439 /// println!("{tensor}");
1440 /// // [
1441 /// // [0, 0, 0],
1442 /// // [4, 0, 0],
1443 /// // [7, 8, 0]
1444 /// // ]
1445 /// }
1446 /// ```
1447 pub fn tril(self, diagonal: i64) -> Self {
1448 check!(TensorCheck::tri::<{ D }>());
1449
1450 // last two dimensions
1451 let shape = &self.shape().dims[D - 2..].to_owned();
1452
1453 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1454 self.mask_fill(mask, 0)
1455 }
1456
1457 /// Applies element wise power operation with a float Tensor
1458 ///
1459 /// # Arguments
1460 ///
1461 /// * `other` - The tensor to apply the power operation with.
1462 ///
1463 /// # Example
1464 ///
1465 /// ```rust
1466 /// use burn_tensor::backend::Backend;
1467 /// use burn_tensor::{Tensor, Shape};
1468 ///
1469 /// fn example<B: Backend>() {
1470 /// let device = B::Device::default();
1471 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1472 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1473 /// let tensor = tensor1.powf(tensor2);
1474 /// println!("{tensor}");
1475 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1476 /// }
1477 /// ```
1478 pub fn powf(self, other: Self) -> Self {
1479 Self::new(K::powf(self.primitive, other.primitive))
1480 }
1481
1482 /// Applies element wise power operation with a float scalar
1483 ///
1484 /// # Arguments
1485 ///
1486 /// * `other` - The scalar to apply the power operation with.
1487 ///
1488 /// # Example
1489 ///
1490 /// ```rust
1491 /// use burn_tensor::backend::Backend;
1492 /// use burn_tensor::{Tensor, Shape};
1493 ///
1494 /// fn example<B: Backend>() {
1495 /// let device = B::Device::default();
1496 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1497 /// let tensor = tensor.powf_scalar(2.0);
1498 /// println!("{tensor}");
1499 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1500 /// }
1501 /// ```
1502 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1503 Self::new(K::powf_scalar::<E>(self.primitive, other))
1504 }
1505
1506 /// Applies element wise power operation with a integer Tensor
1507 ///
1508 /// # Arguments
1509 ///
1510 /// * `other` - The tensor to apply the power operation with.
1511 ///
1512 /// # Example
1513 ///
1514 /// ```rust
1515 /// use burn_tensor::backend::Backend;
1516 /// use burn_tensor::{Tensor, Shape, Int};
1517 ///
1518 /// fn example<B: Backend>() {
1519 /// let device = B::Device::default();
1520 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1521 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1522 /// let tensor = tensor1.powi(tensor2);
1523 /// println!("{tensor}");
1524 /// // [[1, -8, 81], [5, 81, 216]]
1525 /// }
1526 /// ```
1527 pub fn powi(self, other: Self) -> Self {
1528 Self::new(K::powi(self.primitive, other.primitive))
1529 }
1530
1531 /// Applies element wise power operation with a integer scalar
1532 ///
1533 /// # Arguments
1534 ///
1535 /// * `other` - The scalar to apply the power operation with.
1536 ///
1537 /// # Example
1538 ///
1539 /// ```rust
1540 /// use burn_tensor::backend::Backend;
1541 /// use burn_tensor::{Tensor, Shape, Int};
1542 ///
1543 /// fn example<B: Backend>() {
1544 /// let device = B::Device::default();
1545 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1546 /// let tensor = tensor.powi_scalar(2);
1547 /// println!("{tensor}");
1548 /// // [[1, 4, 9], [25, 81, 36]]
1549 /// }
1550 /// ```
1551 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1552 Self::new(K::powi_scalar::<E>(self.primitive, other))
1553 }
1554
1555 /// Checks element wise if the tensor is close to another tensor.
1556 ///
1557 /// The tolerance is defined by the following equation:
1558 ///
1559 /// ```text
1560 /// abs(a - b) <= (atol + rtol * abs(b))
1561 ///
1562 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1563 /// and `atol` is the absolute tolerance.
1564 /// ```
1565 ///
1566 /// # Arguments
1567 ///
1568 /// * `other` - The tensor to compare with.
1569 /// * `rtol` - Optional relative tolerance. Default is 1e-5.
1570 /// * `atol` - Optional absolute tolerance. Default is 1e-8.
1571 ///
1572 /// # Returns
1573 ///
1574 /// A boolean tensor with the same shape as the input tensors.
1575 ///
1576 /// # Example
1577 ///
1578 /// ```rust
1579 /// use burn_tensor::backend::Backend;
1580 /// use burn_tensor::{Tensor, Shape};
1581 ///
1582 /// fn example<B: Backend>() {
1583 /// let device = B::Device::default();
1584 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1585 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1586 /// let tensor = tensor1.is_close(tensor2, None, None);
1587 /// println!("{tensor}");
1588 /// // [[true, true, true], [true, true, true]]
1589 /// }
1590 /// ```
1591 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
1592 let rtol = rtol.unwrap_or(1e-5);
1593 let atol = atol.unwrap_or(1e-8);
1594
1595 Tensor::new(K::lower_equal(
1596 K::abs(K::sub(self.primitive, other.primitive.clone())),
1597 K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
1598 ))
1599 }
1600
1601 /// Checks if all elements are close to another tensor.
1602 ///
1603 /// The tolerance is defined by the following equation:
1604 ///
1605 /// ```text
1606 ///
1607 /// abs(a - b) <= (atol + rtol * abs(b))
1608 ///
1609 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1610 /// and `atol` is the absolute tolerance.
1611 ///
1612 /// ```
1613 ///
1614 /// # Arguments
1615 ///
1616 /// * `other` - The tensor to compare with.
1617 /// * `rtol` - Optional relative tolerance. Default is 1e-5.
1618 /// * `atol` - Optional absolute tolerance. Default is 1e-8.
1619 ///
1620 /// # Returns
1621 ///
1622 /// A boolean scalar.
1623 ///
1624 /// # Remarks
1625 ///
1626 /// # Example
1627 ///
1628 /// ```rust
1629 /// use burn_tensor::backend::Backend;
1630 /// use burn_tensor::{Tensor, Shape};
1631 ///
1632 /// fn example<B: Backend>() {
1633 /// let device = B::Device::default();
1634 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1635 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1636 /// let result = tensor1.all_close(tensor2, None, None);
1637 /// println!("{}", result);
1638 /// // true
1639 /// }
1640 /// ```
1641 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
1642 self.is_close(other, rtol, atol).all().into_scalar()
1643 }
1644
1645 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1646 ///
1647 /// # Returns
1648 ///
1649 /// A boolean tensor with the same shape as the input tensor.
1650 ///
1651 /// # Example
1652 ///
1653 /// ```rust
1654 /// use burn_tensor::backend::Backend;
1655 /// use burn_tensor::{Tensor, Shape};
1656 ///
1657 /// fn example<B: Backend>() {
1658 /// let device = B::Device::default();
1659 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1660 /// let tensor = tensor.bool();
1661 /// println!("{tensor}");
1662 /// // [
1663 /// // [true, true, true],
1664 /// // [false, true, true]
1665 /// // ]
1666 /// }
1667 pub fn bool(self) -> Tensor<B, D, Bool> {
1668 Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1669 }
1670
1671 /// Create a random tensor of the given shape on the given device where each element is
1672 /// sampled from the given distribution.
1673 ///
1674 /// # Arguments
1675 ///
1676 /// * `shape` - The shape of the tensor.
1677 /// * `distribution` - The distribution to sample from.
1678 /// * `device` - The device to create the tensor on.
1679 ///
1680 /// # Returns
1681 ///
1682 /// A new tensor with the given shape and elements sampled from the given distribution.
1683 ///
1684 /// # Example
1685 ///
1686 /// ```rust
1687 /// use burn_tensor::backend::Backend;
1688 /// use burn_tensor::{Tensor, Shape, Distribution};
1689 ///
1690 /// fn example<B: Backend>() {
1691 /// let device = B::Device::default();
1692 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1693 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1694 /// println!("{tensor}");
1695 /// // [
1696 /// // [0.08347523, 0.70498955, 0.60332155],
1697 /// // [0.08173251, 0.18028641, 0.97942924]
1698 /// // ]
1699 /// }
1700 /// ```
1701 pub fn random<S: Into<Shape>>(
1702 shape: S,
1703 distribution: Distribution,
1704 device: &B::Device,
1705 ) -> Self {
1706 Self::new(K::random(shape.into(), distribution, device))
1707 }
1708
1709 /// Sort the elements by value in ascending order along a given dimension.
1710 ///
1711 /// This sort is unstable (i.e., may reorder equal elements).
1712 ///
1713 /// # Arguments
1714 ///
1715 /// * `dim` - The dimension to sort along.
1716 ///
1717 /// # Returns
1718 ///
1719 /// A new tensor with the elements sorted in ascending order along the given dimension.
1720 ///
1721 /// # Example
1722 ///
1723 /// ```rust
1724 /// use burn_tensor::backend::Backend;
1725 /// use burn_tensor::{Tensor, Shape};
1726 ///
1727 /// fn example<B: Backend>() {
1728 /// let device = B::Device::default();
1729 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1730 /// let tensor = tensor.sort(0);
1731 /// println!("{tensor}");
1732 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1733 /// let tensor = tensor.sort(1);
1734 /// println!("{tensor}");
1735 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1736 /// }
1737 /// ```
1738 pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
1739 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1740 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1741 }
1742
1743 /// Sort the elements by value in descending order along a given dimension.
1744 ///
1745 /// This sort is unstable (i.e., may reorder equal elements).
1746 ///
1747 /// # Arguments
1748 ///
1749 /// * `dim` - The dimension to sort along.
1750 ///
1751 /// # Returns
1752 ///
1753 /// A new tensor with the elements sorted in descending order along the given dimension.
1754 ///
1755 /// # Example
1756 ///
1757 /// ```rust
1758 /// use burn_tensor::backend::Backend;
1759 /// use burn_tensor::{Tensor, Shape};
1760 ///
1761 /// fn example<B: Backend>() {
1762 /// let device = B::Device::default();
1763 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1764 /// let tensor = tensor.sort_descending(0);
1765 /// println!("{tensor}");
1766 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1767 /// let tensor = tensor.sort_descending(1);
1768 /// println!("{tensor}");
1769 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1770 /// }
1771 /// ```
1772 pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
1773 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1774 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1775 }
1776
1777 /// Sort the elements by value in ascending order along a given dimension.
1778 /// Also returns the indices.
1779 ///
1780 /// This sort is unstable (i.e., may reorder equal elements).
1781 ///
1782 /// # Arguments
1783 ///
1784 /// * `dim` - The dimension to sort along.
1785 ///
1786 /// # Returns
1787 ///
1788 /// A tuple containing the sorted tensor and the indices tensor.
1789 ///
1790 /// # Example
1791 ///
1792 /// ```rust
1793 /// use burn_tensor::backend::Backend;
1794 /// use burn_tensor::{Tensor, Shape};
1795 ///
1796 /// fn example<B: Backend>() {
1797 /// let device = B::Device::default();
1798 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1799 /// let (tensor, indices) = tensor.sort_with_indices(0);
1800 /// println!("{tensor}");
1801 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1802 /// println!("{}", indices);
1803 /// // [[1, 0, 0], [0, 1, 1]]
1804 /// }
1805 /// ```
1806 pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1807 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1808 let (values, indices) =
1809 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1810 (Tensor::new(values), Tensor::new(indices))
1811 }
1812
1813 /// Sort the elements by value in descending order along a given dimension.
1814 /// Also returns the indices.
1815 ///
1816 /// This sort is unstable (i.e., may reorder equal elements).
1817 ///
1818 /// # Arguments
1819 ///
1820 /// * `dim` - The dimension to sort along.
1821 ///
1822 /// # Example
1823 ///
1824 /// ```rust
1825 /// use burn_tensor::backend::Backend;
1826 /// use burn_tensor::{Tensor, Shape};
1827 ///
1828 /// fn example<B: Backend>() {
1829 /// let device = B::Device::default();
1830 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1831 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
1832 /// println!("{tensor}");
1833 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1834 /// println!("{}", indices);
1835 /// // [[0, 1, 1], [1, 0, 0]]
1836 /// }
1837 /// ```
1838 pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1839 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1840 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1841 (Tensor::new(values), Tensor::new(indices))
1842 }
1843
1844 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1845 ///
1846 /// This sort is unstable (i.e., may reorder equal elements).
1847 ///
1848 /// # Arguments
1849 ///
1850 /// * `dim` - The dimension to sort along.
1851 ///
1852 /// # Example
1853 ///
1854 /// ```rust
1855 /// use burn_tensor::backend::Backend;
1856 /// use burn_tensor::{Tensor, Shape};
1857 ///
1858 /// fn example<B: Backend>() {
1859 /// let device = B::Device::default();
1860 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1861 /// let tensor = tensor.argsort(0);
1862 /// println!("{tensor}");
1863 /// // [[1, 0, 0], [0, 1, 1]]
1864 /// }
1865 /// ```
1866 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1867 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1868 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1869 }
1870
1871 /// Returns the indices that sort the elements by value in descending order along a given dimension.
1872 ///
1873 /// This sort is unstable (i.e., may reorder equal elements).
1874 ///
1875 /// # Arguments
1876 ///
1877 /// * `dim` - The dimension to sort along.
1878 ///
1879 /// # Example
1880 ///
1881 /// ```rust
1882 /// use burn_tensor::backend::Backend;
1883 /// use burn_tensor::{Tensor, Shape};
1884 ///
1885 /// fn example<B: Backend>() {
1886 /// let device = B::Device::default();
1887 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1888 /// let tensor = tensor.argsort_descending(0);
1889 /// println!("{tensor}");
1890 /// // [[0, 1, 1], [1, 0, 0]]
1891 /// let tensor = tensor.argsort_descending(1);
1892 /// println!("{tensor}");
1893 /// // [[0, 2, 1], [2, 0, 1]]
1894 /// }
1895 /// ```
1896 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1897 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1898 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1899 }
1900
1901 /// Returns the `k` largest elements of the given input tensor along a given dimension.
1902 ///
1903 /// # Arguments
1904 ///
1905 /// * `k` - The number of elements to return.
1906 ///
1907 /// # Returns
1908 ///
1909 /// A new tensor with the `k` largest elements along the given dimension.
1910 ///
1911 /// # Example
1912 ///
1913 /// ```rust
1914 /// use burn_tensor::backend::Backend;
1915 /// use burn_tensor::{Tensor, Shape};
1916 ///
1917 /// fn example<B: Backend>() {
1918 /// let device = B::Device::default();
1919 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1920 /// let tensor = tensor.topk(2, 0);
1921 /// println!("{tensor}");
1922 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1923 /// let tensor = tensor.topk(1, 1);
1924 /// println!("{tensor}");
1925 /// // [[12.0], [6.0]]
1926 /// }
1927 /// ```
1928 pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
1929 let k_indices = Tensor::arange(0..k as i64, &self.device());
1930 self.sort_descending(dim).select(dim, k_indices)
1931 }
1932
1933 /// Returns the `k` largest elements of the given input tensor along a given dimension.
1934 /// Also returns the indices.
1935 ///
1936 /// # Arguments
1937 ///
1938 /// * `k` - The number of elements to return.
1939 /// * `dim` - The dimension to sort along.
1940 ///
1941 /// # Example
1942 ///
1943 /// ```rust
1944 /// use burn_tensor::backend::Backend;
1945 /// use burn_tensor::{Tensor, Shape};
1946 ///
1947 /// fn example<B: Backend>() {
1948 /// let device = B::Device::default();
1949 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1950 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
1951 /// println!("{tensor}");
1952 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1953 /// println!("{}", indices);
1954 /// // [[0, 1, 1], [1, 0, 0]]
1955 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
1956 /// println!("{tensor}");
1957 /// // [[12.0], [6.0]]
1958 /// println!("{indices}");
1959 /// // [[0], [2]]
1960 /// }
1961 /// ```
1962 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1963 let k_indices = Tensor::arange(0..k as i64, &self.device());
1964 let (values, indices) = self.sort_descending_with_indices(dim);
1965 (
1966 values.select(dim, k_indices.clone()),
1967 indices.select(dim, k_indices),
1968 )
1969 }
1970
1971 /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
1972 ///
1973 /// # Arguments
1974 ///
1975 /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
1976 /// * `value` - The value to pad the tensor with.
1977 ///
1978 /// # Returns
1979 ///
1980 /// A new tensor with the given padding.
1981 ///
1982 /// # Example
1983 ///
1984 /// ```rust
1985 /// use burn_tensor::backend::Backend;
1986 /// use burn_tensor::{Tensor, Shape};
1987 ///
1988 /// fn example<B: Backend<FloatElem: From<f32>>>() {
1989 /// let device = B::Device::default();
1990 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991 /// let tensor = tensor.pad((1, 1, 1, 1), 0.0);
1992 /// println!("{tensor}");
1993 /// // [
1994 /// // [0.0, 0.0, 0.0, 0.0, 0.0],
1995 /// // [0.0, 12.0, -2.0, 3.0, 0.0],
1996 /// // [0.0, 5.0, 3.0, 6.0, 0.0],
1997 /// // [0.0, 0.0, 0.0, 0.0, 0.0]
1998 /// // ]
1999 /// }
2000 /// ```
2001 pub fn pad<E: ElementConversion>(
2002 self,
2003 padding: (usize, usize, usize, usize),
2004 value: E,
2005 ) -> Tensor<B, D, K> {
2006 let (left, right, top, bottom) = padding;
2007
2008 let mut padded_dims: [usize; D] = self.dims();
2009
2010 // Update the last two dimensions with padding
2011 padded_dims[D - 2] += top + bottom;
2012 padded_dims[D - 1] += left + right;
2013
2014 // Create the ranges for the padded tensor
2015 let ranges: [core::ops::Range<usize>; D] = padded_dims
2016 .iter()
2017 .enumerate()
2018 .map(|(i, &dim)| {
2019 if i == D - 2 {
2020 top..dim - bottom
2021 } else if i == D - 1 {
2022 left..dim - right
2023 } else {
2024 0..dim
2025 }
2026 })
2027 .collect::<Vec<core::ops::Range<usize>>>()
2028 .try_into()
2029 .unwrap();
2030
2031 // Create the padded tensor
2032 let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2033
2034 // Assign the original tensor data to the appropriate slice of the padded tensor
2035 padded_tensor.slice_assign(ranges, self)
2036 }
2037
2038 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
2039 ///
2040 /// # Returns
2041 ///
2042 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
2043 ///
2044 /// # Example
2045 ///
2046 /// ```rust
2047 /// use burn_tensor::backend::Backend;
2048 /// use burn_tensor::{Tensor, Bool, Shape};
2049 ///
2050 /// fn example<B: Backend>() {
2051 /// let device = B::Device::default();
2052 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
2053 /// let tensor = tensor.is_nan();
2054 /// println!("{tensor}");
2055 /// // [[false, true, false], [false, false, false]]
2056 /// }
2057 /// ```
2058 pub fn is_nan(&self) -> Tensor<B, D, Bool> {
2059 // Check if the input tensor is NaN by comparing it to itself
2060 // NaN is the only value that is not equal to itself
2061 Tensor::new(K::not_equal(self.primitive.clone(), self.primitive.clone()))
2062 }
2063
2064 /// Checks if the tensor contains any NaN values.
2065 ///
2066 /// # Returns
2067 ///
2068 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
2069 ///
2070 /// # Example
2071 ///
2072 /// ```rust
2073 /// use burn_tensor::backend::Backend;
2074 /// use burn_tensor::{Tensor, Bool, Shape};
2075 ///
2076 /// fn example<B: Backend>() {
2077 /// let device = B::Device::default();
2078 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
2079 /// let tensor = tensor.contains_nan();
2080 /// println!("{tensor}");
2081 /// // [true]
2082 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2083 /// let tensor = tensor.contains_nan();
2084 /// println!("{tensor}");
2085 /// // [false]
2086 /// }
2087 /// ```
2088 pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
2089 // Summing the tensor will result in NaN if the tensor contains any NaN values
2090 // This is faster than checking each element individually
2091 // because it rolls up the NaN values into a single value
2092 let sum = K::sum(self.primitive.clone());
2093
2094 // Check if the sum is NaN by comparing it to itself
2095 Tensor::new(K::not_equal(sum.clone(), sum))
2096 }
2097}
2098
2099impl<B, K> Tensor<B, 2, K>
2100where
2101 B: Backend,
2102 K: Numeric<B>,
2103 K::Elem: Element,
2104{
2105 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2106 ///
2107 /// # Arguments
2108 ///
2109 /// * `size` - The size of the square matrix.
2110 pub fn eye(size: usize, device: &B::Device) -> Self {
2111 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2112 let ones = K::ones([1, size].into(), device);
2113 let zeros = K::zeros([size, size].into(), device);
2114 Self::new(K::scatter(0, zeros, indices.primitive, ones))
2115 }
2116}
2117
2118/// Trait that list all operations that can be applied on all numerical tensors.
2119///
2120/// # Warnings
2121///
2122/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2123pub trait Numeric<B: Backend>: BasicOps<B>
2124where
2125 Self::Elem: Element,
2126{
2127 /// Adds two tensors together.
2128 ///
2129 /// # Arguments
2130 ///
2131 /// * `lhs` - The left hand side tensor.
2132 /// * `rhs` - The right hand side tensor.
2133 ///
2134 /// # Returns
2135 ///
2136 /// The sum of the two tensors.
2137 ///
2138 /// # Remarks
2139 ///
2140 /// This is a low-level function used internally by the library to call different backend functions
2141 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2142 /// or use this function directly.
2143 ///
2144 /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2145 /// which is more high-level and designed for public use.
2146 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2147
2148 /// Adds a scalar to a tensor element-wise.
2149 ///
2150 /// # Arguments
2151 ///
2152 /// * `lhs` - The left hand side tensor.
2153 /// * `rhs` - The right hand side scalar.
2154 ///
2155 /// # Returns
2156 ///
2157 /// The sum of the tensor and the scalar.
2158 ///
2159 /// # Remarks
2160 ///
2161 /// This is a low-level function used internally by the library to call different backend functions
2162 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2163 /// or use this function directly.
2164 ///
2165 /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2166 /// which is more high-level and designed for public use.
2167 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2168
2169 /// Subtracts two tensors.
2170 ///
2171 /// # Arguments
2172 ///
2173 /// * `lhs` - The left hand side tensor.
2174 /// * `rhs` - The right hand side tensor.
2175 ///
2176 /// # Returns
2177 ///
2178 /// The difference of the two tensors.
2179 ///
2180 /// # Remarks
2181 ///
2182 /// This is a low-level function used internally by the library to call different backend functions
2183 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2184 /// or use this function directly.
2185 ///
2186 /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2187 /// which is more high-level and designed for public use.
2188 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2189
2190 /// Subtracts a scalar from a tensor element-wise.
2191 ///
2192 /// # Arguments
2193 ///
2194 /// * `lhs` - The left hand side tensor.
2195 /// * `rhs` - The right hand side scalar.
2196 ///
2197 /// # Returns
2198 ///
2199 /// The difference of the tensor and the scalar.
2200 ///
2201 /// # Remarks
2202 ///
2203 /// This is a low-level function used internally by the library to call different backend functions
2204 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2205 /// or use this function directly.
2206 ///
2207 /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2208 /// which is more high-level and designed for public use.
2209 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2210
2211 /// Divides two tensors.
2212 ///
2213 /// # Arguments
2214 ///
2215 /// * `lhs` - The left hand side tensor.
2216 /// * `rhs` - The right hand side tensor.
2217 ///
2218 /// # Returns
2219 ///
2220 /// The quotient of the two tensors.
2221 ///
2222 /// # Remarks
2223 ///
2224 /// This is a low-level function used internally by the library to call different backend functions
2225 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2226 /// or use this function directly.
2227 ///
2228 /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2229 /// which is more high-level and designed for public use.
2230 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2231
2232 /// Divides a tensor by a scalar element-wise.
2233 ///
2234 /// # Arguments
2235 ///
2236 /// * `lhs` - The left hand side tensor.
2237 /// * `rhs` - The right hand side scalar.
2238 ///
2239 /// # Returns
2240 ///
2241 /// The quotient of the tensor and the scalar.
2242 ///
2243 /// # Remarks
2244 ///
2245 /// This is a low-level function used internally by the library to call different backend functions
2246 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2247 /// or use this function directly.
2248 ///
2249 /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2250 /// which is more high-level and designed for public use.
2251 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2252
2253 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2254 /// less than that of the divisor.
2255 ///
2256 /// # Arguments
2257 ///
2258 /// * `lhs` - The dividend.
2259 /// * `rhs` - The divisor.
2260 ///
2261 /// # Returns
2262 ///
2263 /// The modulo of the input tensor with the divisor.
2264 ///
2265 /// # Remarks
2266 ///
2267 /// This is a low-level function used internally by the library to call different backend functions
2268 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2269 /// or use this function directly.
2270 ///
2271 /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2272 /// which is more high-level and designed for public use.
2273 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2274
2275 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2276 /// less than that of the divisor.
2277 ///
2278 /// # Arguments
2279 ///
2280 /// * `lhs` - The dividend.
2281 /// * `rhs` - The divisor.
2282 ///
2283 /// # Returns
2284 ///
2285 /// The modulo of the input tensor with the divisor.
2286 ///
2287 /// # Remarks
2288 ///
2289 /// This is a low-level function used internally by the library to call different backend functions
2290 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2291 /// or use this function directly.
2292 ///
2293 /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2294 /// which is more high-level and designed for public use.
2295 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2296
2297 /// Multiplies two tensors.
2298 ///
2299 /// # Arguments
2300 ///
2301 /// * `lhs` - The left hand side tensor.
2302 /// * `rhs` - The right hand side tensor.
2303 ///
2304 /// # Returns
2305 ///
2306 /// The product of the two tensors.
2307 ///
2308 /// # Remarks
2309 ///
2310 /// This is a low-level function used internally by the library to call different backend functions
2311 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2312 /// or use this function directly.
2313 ///
2314 /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2315 /// which is more high-level and designed for public use.
2316 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2317
2318 /// Multiplies a tensor by a scalar element-wise.
2319 ///
2320 /// # Arguments
2321 ///
2322 /// * `lhs` - The left hand side tensor.
2323 /// * `rhs` - The right hand side scalar.
2324 ///
2325 /// # Returns
2326 ///
2327 /// The product of the tensor and the scalar.
2328 ///
2329 /// # Remarks
2330 ///
2331 /// This is a low-level function used internally by the library to call different backend functions
2332 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2333 /// or use this function directly.
2334 ///
2335 /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2336 /// which is more high-level and designed for public use.
2337 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2338
2339 /// Negates a tensor.
2340 ///
2341 /// # Arguments
2342 ///
2343 /// * `tensor` - The tensor to negate.
2344 ///
2345 /// # Returns
2346 ///
2347 /// The negated tensor.
2348 ///
2349 /// # Remarks
2350 ///
2351 /// This is a low-level function used internally by the library to call different backend functions
2352 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2353 /// or use this function directly.
2354 ///
2355 /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2356 /// which is more high-level and designed for public use.
2357 fn neg(tensor: Self::Primitive) -> Self::Primitive;
2358
2359 /// Returns the signs of the elements of a tensor.
2360 ///
2361 /// # Arguments
2362 ///
2363 /// * `tensor` - The tensor.
2364 ///
2365 /// # Returns
2366 ///
2367 /// The signs of the elements of the tensor.
2368 ///
2369 /// # Remarks
2370 ///
2371 /// This is a low-level function used internally by the library to call different backend functions
2372 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2373 /// or use this function directly.
2374 ///
2375 /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2376 /// which is more high-level and designed for public use.
2377 fn sign(tensor: Self::Primitive) -> Self::Primitive;
2378
2379 /// Creates a tensor filled with zeros.
2380 ///
2381 /// # Arguments
2382 ///
2383 /// * `shape` - The shape of the tensor.
2384 /// * `device` - The device on which the tensor will be allocated.
2385 ///
2386 /// # Returns
2387 ///
2388 /// The tensor filled with zeros.
2389 ///
2390 /// # Remarks
2391 ///
2392 /// This is a low-level function used internally by the library to call different backend functions
2393 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2394 /// or use this function directly.
2395 ///
2396 /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2397 /// which is more high-level and designed for public use.
2398 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive;
2399
2400 /// Creates a tensor filled with ones.
2401 ///
2402 /// # Arguments
2403 ///
2404 /// * `shape` - The shape of the tensor.
2405 /// * `device` - The device on which the tensor will be allocated.
2406 ///
2407 /// # Returns
2408 ///
2409 /// The tensor filled with ones.
2410 ///
2411 /// # Remarks
2412 ///
2413 /// This is a low-level function used internally by the library to call different backend functions
2414 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2415 /// or use this function directly.
2416 ///
2417 /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2418 /// which is more high-level and designed for public use.
2419 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive;
2420
2421 /// Creates a tensor filled with elements equal to the given value.
2422 ///
2423 /// # Arguments
2424 ///
2425 /// * `shape` - The shape of the tensor.
2426 /// * `fill_value` - The value with which to fill the tensor
2427 /// * `device` - The device on which the tensor will be allocated.
2428 ///
2429 /// # Returns
2430 ///
2431 /// The tensor filled with elements equal to the given value
2432 ///
2433 /// # Remarks
2434 ///
2435 /// This is a low-level function used internally by the library to call different backend functions
2436 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2437 /// or use this function directly.
2438 ///
2439 /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function,
2440 /// which is more high-level and designed for public use.
2441 fn full<E: ElementConversion>(
2442 shape: Shape,
2443 fill_value: E,
2444 device: &B::Device,
2445 ) -> Self::Primitive;
2446
2447 /// Sums all the elements of the tensor.
2448 ///
2449 /// # Arguments
2450 ///
2451 /// * `tensor` - The tensor to sum.
2452 ///
2453 /// # Returns
2454 ///
2455 /// The sum of all the elements of the tensor.
2456 ///
2457 /// # Remarks
2458 ///
2459 /// This is a low-level function used internally by the library to call different backend functions
2460 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2461 /// or use this function directly.
2462 ///
2463 /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2464 /// which is more high-level and designed for public use.
2465 fn sum(tensor: Self::Primitive) -> Self::Primitive;
2466
2467 /// Sums all the elements of the tensor along a dimension.
2468 ///
2469 /// # Arguments
2470 ///
2471 /// * `tensor` - The tensor to sum.
2472 /// * `dim` - The dimension along which to sum.
2473 ///
2474 /// # Returns
2475 ///
2476 /// The sum of all the elements of the tensor along the specified dimension.
2477 ///
2478 /// # Remarks
2479 ///
2480 /// This is a low-level function used internally by the library to call different backend functions
2481 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2482 /// or use this function directly.
2483 ///
2484 /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2485 /// which is more high-level and designed for public use.
2486 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2487
2488 /// Computes the product of all the elements of the tensor.
2489 ///
2490 /// # Arguments
2491 ///
2492 /// * `tensor` - The tensor to compute the product of.
2493 ///
2494 /// # Returns
2495 ///
2496 /// The product of all the elements of the tensor.
2497 ///
2498 /// # Remarks
2499 ///
2500 /// This is a low-level function used internally by the library to call different backend functions
2501 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2502 /// or use this function directly.
2503 ///
2504 /// For computing the product of all the elements of a tensor, users should prefer the
2505 /// [Tensor::prod](Tensor::prod) function,
2506 /// which is more high-level and designed for public use.
2507 fn prod(tensor: Self::Primitive) -> Self::Primitive;
2508
2509 /// Computes the product of all the elements of the tensor along a dimension.
2510 ///
2511 /// # Arguments
2512 ///
2513 /// * `tensor` - The tensor to compute the product of.
2514 /// * `dim` - The dimension along which to compute the product.
2515 ///
2516 /// # Returns
2517 ///
2518 /// The product of all the elements of the tensor along the specified dimension.
2519 ///
2520 /// # Remarks
2521 ///
2522 /// This is a low-level function used internally by the library to call different backend functions
2523 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2524 /// or use this function directly.
2525 ///
2526 /// For computing the product of all the elements of a tensor along a dimension, users should
2527 /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2528 /// which is more high-level and designed for public use.
2529 ///
2530 ///
2531 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2532
2533 /// Computes the mean of all the elements of the tensor.
2534 ///
2535 /// # Arguments
2536 ///
2537 /// * `tensor` - The tensor to compute the mean of.
2538 ///
2539 /// # Returns
2540 ///
2541 /// The mean of all the elements of the tensor.
2542 ///
2543 /// # Remarks
2544 ///
2545 /// This is a low-level function used internally by the library to call different backend functions
2546 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2547 /// or use this function directly.
2548 ///
2549 /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2550 /// which is more high-level and designed for public use.
2551 fn mean(tensor: Self::Primitive) -> Self::Primitive;
2552
2553 /// Computes the mean of all the elements of the tensor along a dimension.
2554 ///
2555 /// # Arguments
2556 ///
2557 /// * `tensor` - The tensor to compute the mean of.
2558 /// * `dim` - The dimension along which to compute the mean.
2559 ///
2560 /// # Returns
2561 ///
2562 /// The mean of all the elements of the tensor along the specified dimension.
2563 ///
2564 /// # Remarks
2565 ///
2566 /// This is a low-level function used internally by the library to call different backend functions
2567 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2568 /// or use this function directly.
2569 ///
2570 /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2571 /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2572 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2573
2574 /// Element-wise equality between two tensors.
2575 ///
2576 /// # Arguments
2577 ///
2578 /// * `lhs` - The left hand side tensor.
2579 /// * `rhs` - The right hand side tensor.
2580 ///
2581 /// # Returns
2582 ///
2583 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2584 /// corresponding elements of the input tensors are equal, and false otherwise.
2585 ///
2586 /// # Remarks
2587 ///
2588 /// This is a low-level function used internally by the library to call different backend functions
2589 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2590 /// or use this function directly.
2591 ///
2592 /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
2593 /// function, which is more high-level and designed for public use.
2594 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2595
2596 /// Element-wise non-equality between two tensors.
2597 ///
2598 /// # Arguments
2599 ///
2600 /// * `lhs` - The left hand side tensor.
2601 /// * `rhs` - The right hand side tensor.
2602 ///
2603 /// # Returns
2604 ///
2605 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2606 /// corresponding elements of the input tensors are equal, and false otherwise.
2607 ///
2608 /// # Remarks
2609 ///
2610 /// This is a low-level function used internally by the library to call different backend functions
2611 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2612 /// or use this function directly.
2613 ///
2614 /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
2615 /// function, which is more high-level and designed for public use.
2616 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2617
2618 /// Element-wise greater than comparison between two tensors.
2619 ///
2620 /// # Arguments
2621 ///
2622 /// * `lhs` - The left hand side tensor.
2623 /// * `rhs` - The right hand side tensor.
2624 ///
2625 /// # Returns
2626 ///
2627 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2628 /// corresponding element of the left hand side tensor is greater than the corresponding element
2629 /// of the right hand side tensor, and false otherwise.
2630 ///
2631 /// # Remarks
2632 ///
2633 /// This is a low-level function used internally by the library to call different backend functions
2634 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2635 /// or use this function directly.
2636 ///
2637 /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
2638 /// which is more high-level and designed for public use.
2639 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2640
2641 /// Element-wise greater than comparison between a tensor and a scalar.
2642 ///
2643 /// # Arguments
2644 ///
2645 /// * `lhs` - The left hand side tensor.
2646 /// * `rhs` - The right hand side scalar.
2647 ///
2648 /// # Returns
2649 ///
2650 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2651 /// corresponding element of the left hand side tensor is greater than the right hand side
2652 /// scalar, and false otherwise.
2653 ///
2654 /// # Remarks
2655 ///
2656 /// This is a low-level function used internally by the library to call different backend functions
2657 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2658 /// or use this function directly.
2659 ///
2660 /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
2661 /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
2662 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2663
2664 /// Element-wise greater than or equal comparison between two tensors.
2665 ///
2666 /// # Arguments
2667 ///
2668 /// * `lhs` - The left hand side tensor.
2669 /// * `rhs` - The right hand side tensor.
2670 ///
2671 /// # Returns
2672 ///
2673 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2674 /// corresponding element of the left hand side tensor is greater than or equal to the
2675 /// corresponding element of the right hand side tensor, and false otherwise.
2676 ///
2677 /// # Remarks
2678 ///
2679 /// This is a low-level function used internally by the library to call different backend functions
2680 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2681 /// or use this function directly.
2682 ///
2683 /// For element-wise greater than or equal comparison between two tensors, users should prefer
2684 /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
2685 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2686
2687 /// Element-wise greater than or equal comparison between a tensor and a scalar.
2688 ///
2689 /// # Arguments
2690 ///
2691 /// * `lhs` - The left hand side tensor.
2692 /// * `rhs` - The right hand side scalar.
2693 ///
2694 /// # Returns
2695 ///
2696 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2697 /// corresponding element of the left hand side tensor is greater than or equal to the right
2698 /// hand side scalar, and false otherwise.
2699 ///
2700 /// # Remarks
2701 ///
2702 /// This is a low-level function used internally by the library to call different backend functions
2703 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2704 /// or use this function directly.
2705 ///
2706 /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
2707 /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
2708 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2709
2710 /// Element-wise less than comparison between two tensors.
2711 ///
2712 /// # Arguments
2713 ///
2714 /// * `lhs` - The left hand side tensor.
2715 /// * `rhs` - The right hand side tensor.
2716 ///
2717 /// # Returns
2718 ///
2719 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2720 /// corresponding element of the left hand side tensor is less than the corresponding element of
2721 /// the right hand side tensor, and false otherwise.
2722 ///
2723 /// # Remarks
2724 ///
2725 /// This is a low-level function used internally by the library to call different backend functions
2726 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2727 /// or use this function directly.
2728 ///
2729 /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
2730 /// which is more high-level and designed for public use.
2731 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2732
2733 /// Element-wise less than comparison between a tensor and a scalar.
2734 ///
2735 /// # Arguments
2736 ///
2737 /// * `lhs` - The left hand side tensor.
2738 /// * `rhs` - The right hand side scalar.
2739 ///
2740 /// # Returns
2741 ///
2742 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2743 /// corresponding element of the left hand side tensor is less than the right hand side scalar,
2744 /// and false otherwise.
2745 ///
2746 /// # Remarks
2747 ///
2748 /// This is a low-level function used internally by the library to call different backend functions
2749 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2750 /// or use this function directly.
2751 ///
2752 /// For element-wise less than comparison between a tensor and a scalar, users should prefer
2753 /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
2754 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2755
2756 /// Element-wise less than or equal comparison between two tensors.
2757 ///
2758 /// # Arguments
2759 ///
2760 /// * `lhs` - The left hand side tensor.
2761 /// * `rhs` - The right hand side tensor.
2762 ///
2763 /// # Returns
2764 ///
2765 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2766 /// corresponding element of the left hand side tensor is less than or equal to the corresponding
2767 /// element of the right hand side tensor, and false otherwise.
2768 ///
2769 /// # Remarks
2770 ///
2771 /// This is a low-level function used internally by the library to call different backend functions
2772 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2773 /// or use this function directly.
2774 ///
2775 /// For element-wise less than or equal comparison between two tensors, users should prefer
2776 /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
2777 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2778
2779 /// Element-wise less than or equal comparison between a tensor and a scalar.
2780 ///
2781 /// # Arguments
2782 ///
2783 /// * `lhs` - The left hand side tensor.
2784 /// * `rhs` - The right hand side scalar.
2785 ///
2786 /// # Returns
2787 ///
2788 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2789 /// corresponding element of the left hand side tensor is less than or equal to the right hand
2790 /// side scalar, and false otherwise.
2791 ///
2792 /// # Remarks
2793 ///
2794 /// This is a low-level function used internally by the library to call different backend functions
2795 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2796 /// or use this function directly.
2797 ///
2798 /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
2799 /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
2800 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2801
2802 /// Selects elements from a tensor based on a boolean mask.
2803 ///
2804 /// # Arguments
2805 ///
2806 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
2807 /// * `mask` - The boolean mask to use for selecting elements.
2808 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
2809 ///
2810 /// # Returns
2811 ///
2812 /// A tensor with the same shape as the input tensors, where each element is taken from the
2813 /// corresponding element of the left hand side tensor if the corresponding element of the mask
2814 /// is true, and from the corresponding element of the right hand side tensor otherwise.
2815 ///
2816 /// # Remarks
2817 ///
2818 /// This is a low-level function used internally by the library to call different backend functions
2819 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2820 /// or use this function directly.
2821 ///
2822 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
2823 /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
2824 fn mask_where(
2825 tensor: Self::Primitive,
2826 mask: B::BoolTensorPrimitive,
2827 source: Self::Primitive,
2828 ) -> Self::Primitive;
2829
2830 /// Fills elements of a tensor based on a boolean mask.
2831 ///
2832 /// # Arguments
2833 ///
2834 /// * `tensor` - The tensor where will be overwritten with the value
2835 /// when the corresponding element of the mask is true.
2836 /// * `mask` - The boolean mask to use for filling elements.
2837 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
2838 ///
2839 /// # Returns
2840 ///
2841 /// A tensor with the same shape as the input tensors, where each element is taken from the
2842 /// corresponding element unmodified if the corresponding element of the mask is false, and
2843 /// filled with the value otherwise.
2844 ///
2845 /// # Remarks
2846 ///
2847 /// This is a low-level function used internally by the library to call different backend functions
2848 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2849 /// or use this function directly.
2850 ///
2851 /// For filling elements of a tensor based on a boolean mask, users should prefer the
2852 /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
2853 fn mask_fill(
2854 tensor: Self::Primitive,
2855 mask: B::BoolTensorPrimitive,
2856 value: Self::Elem,
2857 ) -> Self::Primitive;
2858
2859 /// Gathers elements from a tensor along an axis.
2860 ///
2861 /// # Arguments
2862 ///
2863 /// * `dim` - The axis along which to gather elements.
2864 /// * `tensor` - The tensor to gather elements from.
2865 /// * `indices` - The indices of the elements to gather.
2866 ///
2867 /// # Returns
2868 ///
2869 /// A tensor with the same shape as the input tensor, where each element is taken from the
2870 /// corresponding element of the input tensor at the corresponding index along the specified axis.
2871 ///
2872 /// # Remarks
2873 ///
2874 /// This is a low-level function used internally by the library to call different backend functions
2875 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2876 /// or use this function directly.
2877 ///
2878 /// For gathering elements from a tensor along an axis, users should prefer the
2879 /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
2880 fn gather(
2881 dim: usize,
2882 tensor: Self::Primitive,
2883 indices: B::IntTensorPrimitive,
2884 ) -> Self::Primitive;
2885
2886 /// Scatters elements into a tensor along an axis.
2887 ///
2888 /// # Arguments
2889 ///
2890 /// * `dim` - The axis along which to scatter elements.
2891 /// * `tensor` - The tensor to scatter elements into.
2892 /// * `indices` - The indices of the elements to scatter.
2893 /// * `values` - The values to scatter into the tensor.
2894 ///
2895 /// # Returns
2896 ///
2897 /// A tensor with the same shape as the input tensor, where each element is taken from the
2898 /// corresponding element of the input tensor at the corresponding index along the specified axis,
2899 /// except for the elements at the specified indices, which are taken from the corresponding
2900 /// element of the values tensor.
2901 ///
2902 /// # Remarks
2903 ///
2904 /// This is a low-level function used internally by the library to call different backend functions
2905 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2906 /// or use this function directly.
2907 ///
2908 /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
2909 /// which is more high-level and designed for public use.
2910 fn scatter(
2911 dim: usize,
2912 tensor: Self::Primitive,
2913 indices: B::IntTensorPrimitive,
2914 values: Self::Primitive,
2915 ) -> Self::Primitive;
2916
2917 /// Select tensor elements along the given dimension corresponding for the given indices.
2918 ///
2919 /// # Arguments
2920 ///
2921 /// * `tensor` - The tensor to select elements from.
2922 /// * `dim` - The axis along which to select elements.
2923 /// * `indices` - The indices of the elements to select.
2924 ///
2925 /// # Returns
2926 ///
2927 /// A tensor with the same shape as the input tensor, where each element is taken from the
2928 /// corresponding element of the input tensor at the corresponding index along the specified axis.
2929 ///
2930 /// # Remarks
2931 ///
2932 /// This is a low-level function used internally by the library to call different backend functions
2933 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2934 /// or use this function directly.
2935 ///
2936 /// For selecting elements from a tensor along an axis, users should prefer the
2937 /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
2938 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
2939
2940 /// Assign the selected elements along the given dimension corresponding to the given indices
2941 /// from the value tensor.
2942 ///
2943 /// # Arguments
2944 ///
2945 /// * `tensor` - The tensor to assign elements to.
2946 /// * `dim` - The axis along which to assign elements.
2947 /// * `indices` - The indices of the elements to assign.
2948 /// * `values` - The values to assign to the tensor.
2949 ///
2950 /// # Returns
2951 ///
2952 /// A tensor with the same shape as the input tensor, where each element is taken from the
2953 /// corresponding element of the input tensor at the corresponding index along the specified axis,
2954 /// except for the elements at the specified indices, which are taken from the corresponding
2955 /// element of the values tensor.
2956 ///
2957 /// # Remarks
2958 ///
2959 /// This is a low-level function used internally by the library to call different backend functions
2960 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2961 /// or use this function directly.
2962 ///
2963 /// For assigning elements to a tensor along an axis, users should prefer the
2964 /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
2965 fn select_assign(
2966 tensor: Self::Primitive,
2967 dim: usize,
2968 indices: Tensor<B, 1, Int>,
2969 values: Self::Primitive,
2970 ) -> Self::Primitive;
2971
2972 /// Gets the indices of the maximum elements of a tensor along an axis.
2973 ///
2974 /// # Arguments
2975 ///
2976 /// * `dim` - The axis along which to get the indices of the maximum elements.
2977 /// * `tensor` - The tensor to get the indices of the maximum elements from.
2978 ///
2979 /// # Returns
2980 ///
2981 /// A tensor with the same shape as the input tensor, where each element is the index of the
2982 /// maximum element of the input tensor at the corresponding index along the specified axis.
2983 ///
2984 /// # Remarks
2985 ///
2986 /// This is a low-level function used internally by the library to call different backend functions
2987 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2988 /// or use this function directly.
2989 ///
2990 /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
2991 /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
2992 fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
2993
2994 /// Gets the indices of the minimum elements of a tensor along an axis.
2995 ///
2996 /// # Arguments
2997 ///
2998 /// * `dim` - The axis along which to get the indices of the minimum elements.
2999 /// * `tensor` - The tensor to get the indices of the minimum elements from.
3000 ///
3001 /// # Returns
3002 ///
3003 /// A tensor with the same shape as the input tensor, where each element is the index of the
3004 /// minimum element of the input tensor at the corresponding index along the specified axis.
3005 ///
3006 /// # Remarks
3007 ///
3008 /// This is a low-level function used internally by the library to call different backend functions
3009 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3010 /// or use this function directly.
3011 ///
3012 /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3013 /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3014 fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3015
3016 /// Gets the maximum elements of a tensor along an axis.
3017 ///
3018 /// # Arguments
3019 ///
3020 /// * `dim` - The axis along which to get the maximum elements.
3021 ///
3022 /// # Returns
3023 ///
3024 /// A single-element tensor containing the maximum element of the input tensor.
3025 ///
3026 /// # Remarks
3027 ///
3028 /// This is a low-level function used internally by the library to call different backend functions
3029 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3030 /// or use this function directly.
3031 ///
3032 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3033 /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3034 fn max(tensor: Self::Primitive) -> Self::Primitive;
3035
3036 /// Gets the maximum elements of a tensor along an axis.
3037 ///
3038 /// # Arguments
3039 ///
3040 /// * `tensor` - The tensor to get the maximum elements from.
3041 /// * `dim` - The axis along which to get the maximum elements.
3042 ///
3043 /// # Returns
3044 ///
3045 /// A tensor with the same shape as the input tensor, where each element is the maximum element
3046 ///
3047 /// # Remarks
3048 ///
3049 /// This is a low-level function used internally by the library to call different backend functions
3050 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3051 /// or use this function directly.
3052 ///
3053 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3054 /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3055 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3056
3057 /// Gets the maximum elements of a tensor along an axis.
3058 ///
3059 /// # Arguments
3060 ///
3061 /// * `tensor` - The tensor to get the maximum elements from.
3062 /// * `dim` - The axis along which to get the maximum elements.
3063 ///
3064 /// # Returns
3065 ///
3066 /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3067 /// as the input tensor, where each element is the index of the maximum element of the input tensor
3068 /// at the corresponding index along the specified axis.
3069 ///
3070 /// # Remarks
3071 ///
3072 /// This is a low-level function used internally by the library to call different backend functions
3073 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3074 /// or use this function directly.
3075 ///
3076 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3077 /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3078 fn max_dim_with_indices(
3079 tensor: Self::Primitive,
3080 dim: usize,
3081 ) -> (Self::Primitive, B::IntTensorPrimitive);
3082
3083 /// Gets the minimum elements of a tensor along an axis.
3084 ///
3085 /// # Arguments
3086 ///
3087 /// * `tensor` - The tensor to get the minimum elements from.
3088 ///
3089 /// # Returns
3090 ///
3091 /// A single-element tensor containing the minimum element of the input tensor.
3092 ///
3093 /// # Remarks
3094 ///
3095 /// This is a low-level function used internally by the library to call different backend functions
3096 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3097 /// or use this function directly.
3098 ///
3099 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3100 /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3101 fn min(tensor: Self::Primitive) -> Self::Primitive;
3102
3103 /// Gets the minimum elements of a tensor along an axis.
3104 ///
3105 /// # Arguments
3106 ///
3107 /// * `tensor` - The tensor to get the minimum elements from.
3108 /// * `dim` - The axis along which to get the minimum elements.
3109 ///
3110 /// # Returns
3111 ///
3112 /// A tensor with the same shape as the input tensor, where each element is the minimum element
3113 /// of the input tensor at the corresponding index along the specified axis.
3114 ///
3115 /// # Remarks
3116 ///
3117 /// This is a low-level function used internally by the library to call different backend functions
3118 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3119 /// or use this function directly.
3120 ///
3121 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3122 /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3123 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3124
3125 /// Gets the minimum elements and indices of a tensor along an axis.
3126 ///
3127 /// # Arguments
3128 ///
3129 /// * `tensor` - The tensor to get the minimum elements from.
3130 ///
3131 /// # Returns
3132 ///
3133 /// A tensor with the same shape as the input tensor and corresponding indices, where
3134 /// each element is the minimum element of the input tensor at the corresponding index
3135 /// along the specified axis.
3136 ///
3137 /// # Remarks
3138 ///
3139 /// This is a low-level function used internally by the library to call different backend functions
3140 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3141 /// or use this function directly.
3142 ///
3143 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3144 /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3145 fn min_dim_with_indices(
3146 tensor: Self::Primitive,
3147 dim: usize,
3148 ) -> (Self::Primitive, B::IntTensorPrimitive);
3149
3150 /// Clamp the tensor between the given min and max values.
3151 ///
3152 /// # Arguments
3153 ///
3154 /// * `min` - The minimum value.
3155 /// * `max` - The maximum value.
3156 ///
3157 /// # Returns
3158 ///
3159 /// A new tensor with the values clamped between the given min and max values.
3160 ///
3161 /// # Remarks
3162 ///
3163 /// This is a low-level function used internally by the library to call different backend functions
3164 /// with static dispatch. It is not designed for direct usage by users.
3165 ///
3166 /// For clamping a tensor between the given min and max values, users should prefer the
3167 /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3168 fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3169
3170 /// Clamps a tensor under a minimum value.
3171 ///
3172 /// # Arguments
3173 ///
3174 /// * `tensor` - The tensor to clamp.
3175 /// * `min` - The minimum value.
3176 ///
3177 /// # Returns
3178 ///
3179 /// A new tensor with the values clamped under the given min value.
3180 ///
3181 /// # Remarks
3182 ///
3183 /// This is a low-level function used internally by the library to call different backend functions
3184 /// with static dispatch. It is not designed for direct usage by users.
3185 ///
3186 /// For clamping a tensor under a minimum value, users should prefer the
3187 /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3188 fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3189
3190 /// Clamps a tensor over a maximum value.
3191 ///
3192 /// # Arguments
3193 ///
3194 /// * `tensor` - The tensor to clamp.
3195 /// * `max` - The maximum value.
3196 ///
3197 /// # Returns
3198 ///
3199 /// A new tensor with the values clamped over the given max value.
3200 ///
3201 /// # Remarks
3202 ///
3203 /// This is a low-level function used internally by the library to call different backend functions
3204 /// with static dispatch. It is not designed for direct usage by users.
3205 ///
3206 /// For clamping a tensor over a maximum value, users should prefer the
3207 /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3208 fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3209
3210 /// Calculate absolute value on all elements of a tensor
3211 ///
3212 /// # Arguments
3213 ///
3214 /// * `tensor` - The tensor to apply abs to.
3215 ///
3216 /// # Returns
3217 ///
3218 /// A tensor with absolute values.
3219 ///
3220 /// # Remarks
3221 ///
3222 /// This is a low-level function used internally by the library to call different backend functions
3223 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3224 /// or use this function directly.
3225 ///
3226 /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3227 /// which is more high-level and designed for public use.
3228 fn abs(tensor: Self::Primitive) -> Self::Primitive;
3229
3230 /// Element-wise power of a tensor to a float tensor
3231 ///
3232 /// # Arguments
3233 /// * `tensor` - The tensor to apply power to.
3234 /// * `power` - The power to apply to the tensor.
3235 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3236
3237 /// Element-wise power of a tensor
3238 ///
3239 /// # Arguments
3240 /// * `tensor` - The tensor to apply power to.
3241 /// * `power` - The power to apply to the tensor.
3242 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3243
3244 /// Element-wise power of a tensor to a scalar float
3245 ///
3246 /// # Arguments
3247 /// * `tensor` - The tensor to apply power to.
3248 /// * `power` - The power to apply to the tensor.
3249 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3250
3251 /// Element-wise power of a tensor to a scalar int
3252 ///
3253 /// # Arguments
3254 /// * `tensor` - The tensor to apply power to.
3255 /// * `power` - The power to apply to the tensor.
3256 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3257
3258 /// Create a random tensor.
3259 ///
3260 /// # Arguments
3261 ///
3262 /// * `shape` - The shape of the output tensor.
3263 /// * `distribution` - The distribution used to sample.
3264 /// * `device` - The device to use.
3265 ///
3266 /// # Returns
3267 ///
3268 /// A new tensor.
3269 ///
3270 /// # Remarks
3271 ///
3272 /// This is a low-level function used internally by the library to call different backend functions
3273 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3274 /// or use this function directly.
3275 ///
3276 /// Users should prefer the [Tensor::random](Tensor::random) function,
3277 /// which is more high-level and designed for public use.
3278 fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3279
3280 /// Sort the elements of the input `tensor` by value along a given dimension.
3281 ///
3282 /// This sort is unstable (i.e., may reorder equal elements).
3283 ///
3284 /// # Arguments
3285 ///
3286 /// * `tensor` - The input tensor.
3287 /// * `dim` - The axis along which to sort.
3288 /// * `descending` - The sorting order.
3289 ///
3290 /// # Returns
3291 ///
3292 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3293 ///
3294 /// # Remarks
3295 /// This is a low-level function used internally by the library to call different backend functions
3296 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3297 /// or use this function directly.
3298 ///
3299 /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3300 /// which is more high-level and designed for public use.
3301 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3302
3303 /// Sort the elements of the input `tensor` by value along a given dimension.
3304 ///
3305 /// This sort is unstable (i.e., may reorder equal elements).
3306 ///
3307 /// # Arguments
3308 ///
3309 /// * `tensor` - The input tensor.
3310 /// * `dim` - The axis along which to sort.
3311 /// * `descending` - The sorting order.
3312 ///
3313 /// # Returns
3314 ///
3315 /// A tensor with the same shape as the input tensor and corresponding indices, where
3316 /// the elements are sorted by value and the indices map back to the original input tensor.
3317 ///
3318 /// # Remarks
3319 /// This is a low-level function used internally by the library to call different backend functions
3320 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3321 /// or use this function directly.
3322 ///
3323 /// For sorting the elements of a tensor, users should prefer the
3324 /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3325 /// and designed for public use.
3326 fn sort_with_indices(
3327 tensor: Self::Primitive,
3328 dim: usize,
3329 descending: bool,
3330 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3331
3332 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3333 ///
3334 /// This sort is unstable (i.e., may reorder equal elements).
3335 ///
3336 /// # Arguments
3337 ///
3338 /// * `tensor` - The input tensor.
3339 /// * `dim` - The axis along which to sort.
3340 /// * `descending` - The sorting order.
3341 ///
3342 /// # Returns
3343 ///
3344 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3345 ///
3346 /// # Remarks
3347 /// This is a low-level function used internally by the library to call different backend functions
3348 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3349 /// or use this function directly.
3350 ///
3351 /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3352 /// which is more high-level and designed for public use.
3353 fn argsort(
3354 tensor: Self::Primitive,
3355 dim: usize,
3356 descending: bool,
3357 ) -> <Int as TensorKind<B>>::Primitive;
3358}
3359
3360impl<B: Backend> Numeric<B> for Int {
3361 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3362 B::int_add(lhs, rhs)
3363 }
3364 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3365 B::int_add_scalar(lhs, rhs.elem())
3366 }
3367 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3368 B::int_sub(lhs, rhs)
3369 }
3370 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3371 B::int_sub_scalar(lhs, rhs.elem())
3372 }
3373 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3374 B::int_div(lhs, rhs)
3375 }
3376 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3377 B::int_div_scalar(lhs, rhs.elem())
3378 }
3379 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3380 B::int_remainder(lhs, rhs)
3381 }
3382 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3383 B::int_remainder_scalar(lhs, rhs.elem())
3384 }
3385 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3386 B::int_mul(lhs, rhs)
3387 }
3388 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3389 B::int_mul_scalar(lhs, rhs.elem())
3390 }
3391 fn neg(tensor: Self::Primitive) -> Self::Primitive {
3392 B::int_neg(tensor)
3393 }
3394 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3395 B::int_zeros(shape, device)
3396 }
3397 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3398 B::int_ones(shape, device)
3399 }
3400 fn full<E: ElementConversion>(
3401 shape: Shape,
3402 fill_value: E,
3403 device: &B::Device,
3404 ) -> Self::Primitive {
3405 B::int_full(shape, fill_value.elem(), device)
3406 }
3407
3408 fn sum(tensor: Self::Primitive) -> Self::Primitive {
3409 B::int_sum(tensor)
3410 }
3411
3412 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3413 B::int_sum_dim(tensor, dim)
3414 }
3415
3416 fn prod(tensor: Self::Primitive) -> Self::Primitive {
3417 B::int_prod(tensor)
3418 }
3419
3420 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3421 B::int_prod_dim(tensor, dim)
3422 }
3423
3424 fn mean(tensor: Self::Primitive) -> Self::Primitive {
3425 B::int_mean(tensor)
3426 }
3427 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3428 B::int_mean_dim(tensor, dim)
3429 }
3430
3431 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3432 B::int_equal_elem(lhs, rhs)
3433 }
3434 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3435 B::int_not_equal_elem(lhs, rhs)
3436 }
3437 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3438 B::int_greater(lhs, rhs)
3439 }
3440
3441 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3442 B::int_greater_elem(lhs, rhs)
3443 }
3444
3445 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3446 B::int_greater_equal(lhs, rhs)
3447 }
3448
3449 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3450 B::int_greater_equal_elem(lhs, rhs)
3451 }
3452
3453 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3454 B::int_lower(lhs, rhs)
3455 }
3456
3457 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3458 B::int_lower_elem(lhs, rhs)
3459 }
3460
3461 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3462 B::int_lower_equal(lhs, rhs)
3463 }
3464
3465 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3466 B::int_lower_equal_elem(lhs, rhs)
3467 }
3468
3469 fn mask_where(
3470 tensor: Self::Primitive,
3471 mask: B::BoolTensorPrimitive,
3472 source: Self::Primitive,
3473 ) -> Self::Primitive {
3474 B::int_mask_where(tensor, mask, source)
3475 }
3476
3477 fn mask_fill(
3478 tensor: Self::Primitive,
3479 mask: B::BoolTensorPrimitive,
3480 value: Self::Elem,
3481 ) -> Self::Primitive {
3482 B::int_mask_fill(tensor, mask, value)
3483 }
3484
3485 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3486 B::int_select(tensor, dim, indices.primitive)
3487 }
3488
3489 fn select_assign(
3490 tensor: Self::Primitive,
3491 dim: usize,
3492 indices: Tensor<B, 1, Int>,
3493 values: Self::Primitive,
3494 ) -> Self::Primitive {
3495 B::int_select_assign(tensor, dim, indices.primitive, values)
3496 }
3497 fn gather(
3498 dim: usize,
3499 tensor: Self::Primitive,
3500 indices: B::IntTensorPrimitive,
3501 ) -> Self::Primitive {
3502 B::int_gather(dim, tensor, indices)
3503 }
3504
3505 fn scatter(
3506 dim: usize,
3507 tensor: Self::Primitive,
3508 indices: B::IntTensorPrimitive,
3509 values: Self::Primitive,
3510 ) -> Self::Primitive {
3511 B::int_scatter(dim, tensor, indices, values)
3512 }
3513
3514 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3515 B::int_argmax(tensor, dim)
3516 }
3517
3518 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3519 B::int_argmin(tensor, dim)
3520 }
3521
3522 fn max(tensor: Self::Primitive) -> Self::Primitive {
3523 B::int_max(tensor)
3524 }
3525
3526 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3527 B::int_max_dim(tensor, dim)
3528 }
3529
3530 fn max_dim_with_indices(
3531 tensor: Self::Primitive,
3532 dim: usize,
3533 ) -> (Self::Primitive, IntTensor<B>) {
3534 B::int_max_dim_with_indices(tensor, dim)
3535 }
3536
3537 fn min(tensor: Self::Primitive) -> Self::Primitive {
3538 B::int_min(tensor)
3539 }
3540
3541 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3542 B::int_min_dim(tensor, dim)
3543 }
3544
3545 fn min_dim_with_indices(
3546 tensor: Self::Primitive,
3547 dim: usize,
3548 ) -> (Self::Primitive, IntTensor<B>) {
3549 B::int_min_dim_with_indices(tensor, dim)
3550 }
3551
3552 fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3553 B::int_clamp(tensor, min, max)
3554 }
3555
3556 fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3557 B::int_clamp_min(tensor, min)
3558 }
3559
3560 fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3561 B::int_clamp_max(tensor, max)
3562 }
3563
3564 fn abs(tensor: Self::Primitive) -> Self::Primitive {
3565 B::int_abs(tensor)
3566 }
3567
3568 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3569 B::int_powf(lhs, B::int_into_float(rhs))
3570 }
3571
3572 fn powf_scalar<E: ElementConversion>(
3573 lhs: Self::Primitive,
3574 rhs: E,
3575 ) -> <Int as TensorKind<B>>::Primitive {
3576 B::int_powf_scalar(lhs, rhs.elem())
3577 }
3578
3579 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3580 B::int_powi(lhs, rhs)
3581 }
3582
3583 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3584 B::int_powf_scalar(lhs, rhs.elem())
3585 }
3586
3587 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3588 B::int_random(shape, distribution, device)
3589 }
3590
3591 fn sign(tensor: Self::Primitive) -> Self::Primitive {
3592 B::int_sign(tensor)
3593 }
3594
3595 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
3596 B::int_sort(tensor, dim, descending)
3597 }
3598
3599 fn sort_with_indices(
3600 tensor: Self::Primitive,
3601 dim: usize,
3602 descending: bool,
3603 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
3604 B::int_sort_with_indices(tensor, dim, descending)
3605 }
3606
3607 fn argsort(
3608 tensor: Self::Primitive,
3609 dim: usize,
3610 descending: bool,
3611 ) -> <Int as TensorKind<B>>::Primitive {
3612 B::int_argsort(tensor, dim, descending)
3613 }
3614}
3615
3616impl<B: Backend> Numeric<B> for Float {
3617 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3618 match (lhs, rhs) {
3619 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3620 TensorPrimitive::Float(B::float_add(lhs, rhs))
3621 }
3622 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3623 TensorPrimitive::QFloat(B::q_add(lhs, rhs))
3624 }
3625 _ => panic!("Primitive type mismatch for lhs and rhs"),
3626 }
3627 }
3628 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3629 match lhs {
3630 TensorPrimitive::Float(lhs) => {
3631 TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
3632 }
3633 TensorPrimitive::QFloat(lhs) => {
3634 TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem()))
3635 }
3636 }
3637 }
3638 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3639 match (lhs, rhs) {
3640 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3641 TensorPrimitive::Float(B::float_sub(lhs, rhs))
3642 }
3643 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3644 TensorPrimitive::QFloat(B::q_sub(lhs, rhs))
3645 }
3646 _ => panic!("Primitive type mismatch for lhs and rhs"),
3647 }
3648 }
3649 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3650 match lhs {
3651 TensorPrimitive::Float(lhs) => {
3652 TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
3653 }
3654 TensorPrimitive::QFloat(lhs) => {
3655 TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem()))
3656 }
3657 }
3658 }
3659 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3660 match (lhs, rhs) {
3661 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3662 TensorPrimitive::Float(B::float_div(lhs, rhs))
3663 }
3664 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3665 TensorPrimitive::QFloat(B::q_div(lhs, rhs))
3666 }
3667 _ => panic!("Primitive type mismatch for lhs and rhs"),
3668 }
3669 }
3670 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3671 match lhs {
3672 TensorPrimitive::Float(lhs) => {
3673 TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
3674 }
3675 TensorPrimitive::QFloat(lhs) => {
3676 TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem()))
3677 }
3678 }
3679 }
3680 fn remainder(
3681 lhs: Self::Primitive,
3682 rhs: Self::Primitive,
3683 ) -> <Float as TensorKind<B>>::Primitive {
3684 match (lhs, rhs) {
3685 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3686 TensorPrimitive::Float(B::float_remainder(lhs, rhs))
3687 }
3688 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3689 TensorPrimitive::QFloat(B::q_remainder(lhs, rhs))
3690 }
3691 _ => panic!("Primitive type mismatch for lhs and rhs"),
3692 }
3693 }
3694 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3695 match lhs {
3696 TensorPrimitive::Float(lhs) => {
3697 TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem()))
3698 }
3699 TensorPrimitive::QFloat(lhs) => {
3700 TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem()))
3701 }
3702 }
3703 }
3704 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3705 match (lhs, rhs) {
3706 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3707 TensorPrimitive::Float(B::float_mul(lhs, rhs))
3708 }
3709 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3710 TensorPrimitive::QFloat(B::q_mul(lhs, rhs))
3711 }
3712 _ => panic!("Primitive type mismatch for lhs and rhs"),
3713 }
3714 }
3715 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3716 match lhs {
3717 TensorPrimitive::Float(lhs) => {
3718 TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
3719 }
3720 TensorPrimitive::QFloat(lhs) => {
3721 TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem()))
3722 }
3723 }
3724 }
3725 fn neg(tensor: Self::Primitive) -> Self::Primitive {
3726 match tensor {
3727 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
3728 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)),
3729 }
3730 }
3731 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3732 TensorPrimitive::Float(B::float_zeros(shape, device))
3733 }
3734 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3735 TensorPrimitive::Float(B::float_ones(shape, device))
3736 }
3737
3738 fn full<E: ElementConversion>(
3739 shape: Shape,
3740 fill_value: E,
3741 device: &B::Device,
3742 ) -> Self::Primitive {
3743 TensorPrimitive::Float(B::float_full(shape, fill_value.elem(), device))
3744 }
3745
3746 fn sum(tensor: Self::Primitive) -> Self::Primitive {
3747 match tensor {
3748 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
3749 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)),
3750 }
3751 }
3752
3753 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3754 match tensor {
3755 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
3756 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)),
3757 }
3758 }
3759
3760 fn prod(tensor: Self::Primitive) -> Self::Primitive {
3761 match tensor {
3762 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
3763 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)),
3764 }
3765 }
3766
3767 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3768 match tensor {
3769 TensorPrimitive::Float(tensor) => {
3770 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
3771 }
3772 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)),
3773 }
3774 }
3775
3776 fn mean(tensor: Self::Primitive) -> Self::Primitive {
3777 match tensor {
3778 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
3779 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)),
3780 }
3781 }
3782
3783 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3784 match tensor {
3785 TensorPrimitive::Float(tensor) => {
3786 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
3787 }
3788 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)),
3789 }
3790 }
3791
3792 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3793 B::float_equal_elem(lhs.tensor(), rhs)
3794 }
3795 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3796 B::float_not_equal_elem(lhs.tensor(), rhs)
3797 }
3798 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3799 B::float_greater(lhs.tensor(), rhs.tensor())
3800 }
3801
3802 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3803 B::float_greater_elem(lhs.tensor(), rhs)
3804 }
3805
3806 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3807 B::float_greater_equal(lhs.tensor(), rhs.tensor())
3808 }
3809
3810 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3811 B::float_greater_equal_elem(lhs.tensor(), rhs)
3812 }
3813
3814 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3815 B::float_lower(lhs.tensor(), rhs.tensor())
3816 }
3817
3818 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3819 B::float_lower_elem(lhs.tensor(), rhs)
3820 }
3821
3822 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3823 B::float_lower_equal(lhs.tensor(), rhs.tensor())
3824 }
3825
3826 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3827 B::float_lower_equal_elem(lhs.tensor(), rhs)
3828 }
3829
3830 fn mask_where(
3831 tensor: Self::Primitive,
3832 mask: B::BoolTensorPrimitive,
3833 source: Self::Primitive,
3834 ) -> Self::Primitive {
3835 match (tensor, source) {
3836 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => {
3837 TensorPrimitive::Float(B::float_mask_where(tensor, mask, source))
3838 }
3839 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => {
3840 TensorPrimitive::QFloat(B::q_mask_where(tensor, mask, source))
3841 }
3842 _ => panic!("Primitive type mismatch for tensor and source"),
3843 }
3844 }
3845
3846 fn mask_fill(
3847 tensor: Self::Primitive,
3848 mask: B::BoolTensorPrimitive,
3849 value: Self::Elem,
3850 ) -> Self::Primitive {
3851 match tensor {
3852 TensorPrimitive::Float(tensor) => {
3853 TensorPrimitive::Float(B::float_mask_fill(tensor, mask, value))
3854 }
3855 TensorPrimitive::QFloat(tensor) => {
3856 TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask, value))
3857 }
3858 }
3859 }
3860
3861 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3862 match tensor {
3863 TensorPrimitive::Float(tensor) => {
3864 TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
3865 }
3866 TensorPrimitive::QFloat(tensor) => {
3867 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
3868 }
3869 }
3870 }
3871
3872 fn select_assign(
3873 tensor: Self::Primitive,
3874 dim: usize,
3875 indices: Tensor<B, 1, Int>,
3876 values: Self::Primitive,
3877 ) -> Self::Primitive {
3878 match (tensor, values) {
3879 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
3880 TensorPrimitive::Float(B::float_select_assign(
3881 tensor,
3882 dim,
3883 indices.primitive,
3884 values,
3885 ))
3886 }
3887 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
3888 TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values))
3889 }
3890 _ => panic!("Primitive type mismatch for tensor and values"),
3891 }
3892 }
3893
3894 fn gather(
3895 dim: usize,
3896 tensor: Self::Primitive,
3897 indices: B::IntTensorPrimitive,
3898 ) -> Self::Primitive {
3899 match tensor {
3900 TensorPrimitive::Float(tensor) => {
3901 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
3902 }
3903 TensorPrimitive::QFloat(tensor) => {
3904 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
3905 }
3906 }
3907 }
3908
3909 fn scatter(
3910 dim: usize,
3911 tensor: Self::Primitive,
3912 indices: B::IntTensorPrimitive,
3913 values: Self::Primitive,
3914 ) -> Self::Primitive {
3915 match (tensor, values) {
3916 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
3917 TensorPrimitive::Float(B::float_scatter(dim, tensor, indices, values))
3918 }
3919 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
3920 TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices, values))
3921 }
3922 _ => panic!("Primitive type mismatch for tensor and values"),
3923 }
3924 }
3925
3926 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3927 match tensor {
3928 TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
3929 TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
3930 }
3931 }
3932
3933 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3934 match tensor {
3935 TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
3936 TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
3937 }
3938 }
3939
3940 fn max(tensor: Self::Primitive) -> Self::Primitive {
3941 match tensor {
3942 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
3943 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
3944 }
3945 }
3946
3947 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3948 match tensor {
3949 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
3950 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
3951 }
3952 }
3953
3954 fn max_dim_with_indices(
3955 tensor: Self::Primitive,
3956 dim: usize,
3957 ) -> (Self::Primitive, IntTensor<B>) {
3958 match tensor {
3959 TensorPrimitive::Float(tensor) => {
3960 let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
3961 (TensorPrimitive::Float(values), indices)
3962 }
3963 TensorPrimitive::QFloat(tensor) => {
3964 let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
3965 (TensorPrimitive::QFloat(values), indices)
3966 }
3967 }
3968 }
3969
3970 fn min(tensor: Self::Primitive) -> Self::Primitive {
3971 match tensor {
3972 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
3973 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
3974 }
3975 }
3976
3977 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3978 match tensor {
3979 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
3980 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
3981 }
3982 }
3983
3984 fn min_dim_with_indices(
3985 tensor: Self::Primitive,
3986 dim: usize,
3987 ) -> (Self::Primitive, IntTensor<B>) {
3988 match tensor {
3989 TensorPrimitive::Float(tensor) => {
3990 let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
3991 (TensorPrimitive::Float(values), indices)
3992 }
3993 TensorPrimitive::QFloat(tensor) => {
3994 let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
3995 (TensorPrimitive::QFloat(values), indices)
3996 }
3997 }
3998 }
3999
4000 fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4001 match tensor {
4002 TensorPrimitive::Float(tensor) => {
4003 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4004 }
4005 TensorPrimitive::QFloat(tensor) => {
4006 TensorPrimitive::QFloat(B::q_clamp(tensor, min, max))
4007 }
4008 }
4009 }
4010
4011 fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4012 match tensor {
4013 TensorPrimitive::Float(tensor) => {
4014 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4015 }
4016 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)),
4017 }
4018 }
4019
4020 fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4021 match tensor {
4022 TensorPrimitive::Float(tensor) => {
4023 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4024 }
4025 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)),
4026 }
4027 }
4028
4029 fn abs(tensor: Self::Primitive) -> Self::Primitive {
4030 match tensor {
4031 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4032 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4033 }
4034 }
4035
4036 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4037 match (lhs, rhs) {
4038 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4039 TensorPrimitive::Float(B::float_powf(lhs, rhs))
4040 }
4041 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4042 TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4043 }
4044 _ => panic!("Primitive type mismatch for lhs and rhs"),
4045 }
4046 }
4047
4048 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4049 match lhs {
4050 TensorPrimitive::Float(lhs) => {
4051 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4052 }
4053 TensorPrimitive::QFloat(lhs) => {
4054 TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4055 }
4056 }
4057 }
4058
4059 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4060 match (lhs, rhs) {
4061 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4062 TensorPrimitive::Float(B::float_powf(lhs, rhs))
4063 }
4064 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4065 TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4066 }
4067 _ => panic!("Primitive type mismatch for lhs and rhs"),
4068 }
4069 }
4070
4071 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4072 match lhs {
4073 TensorPrimitive::Float(lhs) => {
4074 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4075 }
4076 TensorPrimitive::QFloat(lhs) => {
4077 TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4078 }
4079 }
4080 }
4081
4082 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4083 TensorPrimitive::Float(B::float_random(shape, distribution, device))
4084 }
4085
4086 fn sign(tensor: Self::Primitive) -> Self::Primitive {
4087 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4088 }
4089
4090 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4091 match tensor {
4092 TensorPrimitive::Float(tensor) => {
4093 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4094 }
4095 TensorPrimitive::QFloat(tensor) => {
4096 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4097 }
4098 }
4099 }
4100
4101 fn sort_with_indices(
4102 tensor: Self::Primitive,
4103 dim: usize,
4104 descending: bool,
4105 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4106 match tensor {
4107 TensorPrimitive::Float(tensor) => {
4108 let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4109 (TensorPrimitive::Float(values), indices)
4110 }
4111 TensorPrimitive::QFloat(tensor) => {
4112 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4113 (TensorPrimitive::QFloat(values), indices)
4114 }
4115 }
4116 }
4117
4118 fn argsort(
4119 tensor: Self::Primitive,
4120 dim: usize,
4121 descending: bool,
4122 ) -> <Int as TensorKind<B>>::Primitive {
4123 match tensor {
4124 TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4125 TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4126 }
4127 }
4128}
4129
4130impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
4131where
4132 B: Backend,
4133 K: Numeric<B>,
4134 K::Elem: Element,
4135{
4136 type Output = Self;
4137
4138 fn add(self, rhs: Tensor<B, D, K>) -> Self {
4139 Self::add(self, rhs)
4140 }
4141}
4142
4143impl<E, const D: usize, B, K> core::ops::Add<E> for Tensor<B, D, K>
4144where
4145 E: ElementConversion,
4146 B: Backend,
4147 K: Numeric<B>,
4148 K::Elem: Element,
4149{
4150 type Output = Self;
4151
4152 fn add(self, other: E) -> Self {
4153 Tensor::add_scalar(self, other)
4154 }
4155}
4156
4157impl<B, const D: usize, K> core::ops::Sub<Tensor<B, D, K>> for Tensor<B, D, K>
4158where
4159 B: Backend,
4160 K: Numeric<B>,
4161 K::Elem: Element,
4162{
4163 type Output = Self;
4164
4165 fn sub(self, rhs: Tensor<B, D, K>) -> Self {
4166 Tensor::sub(self, rhs)
4167 }
4168}
4169
4170impl<E, const D: usize, B, K> core::ops::Sub<E> for Tensor<B, D, K>
4171where
4172 E: ElementConversion,
4173 B: Backend,
4174 K: Numeric<B>,
4175 K::Elem: Element,
4176{
4177 type Output = Self;
4178
4179 fn sub(self, other: E) -> Self {
4180 Tensor::sub_scalar(self, other)
4181 }
4182}
4183
4184impl<B, const D: usize, K> core::ops::Div<Tensor<B, D, K>> for Tensor<B, D, K>
4185where
4186 B: Backend,
4187 K: Numeric<B>,
4188 K::Elem: Element,
4189{
4190 type Output = Self;
4191
4192 fn div(self, rhs: Tensor<B, D, K>) -> Self {
4193 Tensor::div(self, rhs)
4194 }
4195}
4196
4197impl<E, const D: usize, B, K> core::ops::Div<E> for Tensor<B, D, K>
4198where
4199 E: ElementConversion,
4200 B: Backend,
4201 K: Numeric<B>,
4202 K::Elem: Element,
4203{
4204 type Output = Self;
4205
4206 fn div(self, other: E) -> Self {
4207 Tensor::div_scalar(self, other)
4208 }
4209}
4210
4211impl<const D: usize, B, K> core::ops::Rem<Tensor<B, D, K>> for Tensor<B, D, K>
4212where
4213 B: Backend,
4214 K: Numeric<B>,
4215 K::Elem: Element,
4216{
4217 type Output = Self;
4218 fn rem(self, rhs: Tensor<B, D, K>) -> Self::Output {
4219 Tensor::remainder(self, rhs)
4220 }
4221}
4222
4223impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
4224where
4225 E: ElementConversion,
4226 B: Backend,
4227 K: Numeric<B>,
4228 K::Elem: Element,
4229{
4230 type Output = Self;
4231
4232 fn rem(self, other: E) -> Self {
4233 Tensor::remainder_scalar(self, other)
4234 }
4235}
4236
4237impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
4238where
4239 B: Backend,
4240 K: Numeric<B>,
4241 K::Elem: Element,
4242{
4243 type Output = Self;
4244
4245 fn mul(self, rhs: Tensor<B, D, K>) -> Self {
4246 Tensor::mul(self, rhs)
4247 }
4248}
4249
4250impl<E, const D: usize, B, K> core::ops::Mul<E> for Tensor<B, D, K>
4251where
4252 E: ElementConversion,
4253 B: Backend,
4254 K: Numeric<B>,
4255 K::Elem: Element,
4256{
4257 type Output = Self;
4258
4259 fn mul(self, other: E) -> Self {
4260 Tensor::mul_scalar(self, other)
4261 }
4262}
4263
4264impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4265where
4266 B: Backend,
4267 K: Numeric<B>,
4268 K::Elem: Element,
4269{
4270 type Output = Self;
4271
4272 fn neg(self) -> Self {
4273 Tensor::neg(self)
4274 }
4275}