burn_tensor/tensor/api/numeric.rs
1use alloc::vec::Vec;
2
3use crate::{alloc::borrow::ToOwned, cast::ToElement};
4
5use crate::TensorPrimitive;
6use crate::{
7 BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
8 TensorKind,
9 backend::Backend,
10 check,
11 check::TensorCheck,
12 ops::{Device, IntTensor},
13};
14
15/// Default RTOL value for `is_close` and `all_close`.
16pub const DEFAULT_RTOL: f64 = 1e-5;
17
18/// Default ATOL value for `is_close` and `all_close`.
19pub const DEFAULT_ATOL: f64 = 1e-8;
20
21impl<B, const D: usize, K> Tensor<B, D, K>
22where
23 B: Backend,
24 K: Numeric<B>,
25 K::Elem: Element,
26{
27 /// Applies element wise addition operation.
28 ///
29 /// `y = x2 + x1`
30 ///
31 /// # Arguments
32 ///
33 /// * `other` - The tensor to add.
34 ///
35 /// # Example
36 ///
37 /// ```rust
38 /// use burn_tensor::backend::Backend;
39 /// use burn_tensor::{Tensor, Shape};
40 ///
41 /// fn example<B: Backend>() {
42 /// let device = B::Device::default();
43 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
44 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
45 /// let tensor = tensor1 + tensor2;
46 /// println!("{tensor}");
47 /// // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
48 /// }
49 /// ```
50 #[allow(clippy::should_implement_trait)]
51 pub fn add(self, other: Self) -> Self {
52 check!(TensorCheck::binary_ops_ew("Add", &self, &other));
53 Self::new(K::add(self.primitive, other.primitive))
54 }
55
56 /// Applies element wise addition operation with a scalar.
57 ///
58 /// `y = x + s`
59 ///
60 /// # Arguments
61 ///
62 /// * `other` - The scalar to add, element wise.
63 ///
64 /// # Example
65 ///
66 /// ```rust
67 /// use burn_tensor::backend::Backend;
68 /// use burn_tensor::{Tensor, Shape};
69 ///
70 /// fn example<B: Backend>() {
71 /// let device = B::Device::default();
72 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
73 /// let scalar = 2.0;
74 /// let tensor = tensor + scalar;
75 /// println!("{tensor}");
76 /// // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
77 /// }
78 /// ```
79 pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
80 Self::new(K::add_scalar::<E>(self.primitive, other))
81 }
82
83 /// Applies element wise subtraction operation.
84 ///
85 /// `y = x2 - x1`
86 ///
87 /// # Arguments
88 ///
89 /// * `other` - The tensor to subtract.
90 ///
91 /// # Example
92 ///
93 /// ```rust
94 /// use burn_tensor::backend::Backend;
95 /// use burn_tensor::{Tensor, Shape};
96 ///
97 /// fn example<B: Backend>() {
98 /// let device = B::Device::default();
99 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
100 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
101 /// let tensor = tensor1 - tensor2;
102 /// println!("{tensor}");
103 /// // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
104 /// }
105 /// ```
106 #[allow(clippy::should_implement_trait)]
107 pub fn sub(self, other: Self) -> Self {
108 check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
109 Self::new(K::sub(self.primitive, other.primitive))
110 }
111
112 /// Applies element wise subtraction operation with a scalar.
113 ///
114 /// `y = x - s`
115 ///
116 /// # Arguments
117 ///
118 /// * `other` - The scalar to subtract, element wise.
119 ///
120 /// # Example
121 ///
122 /// ```rust
123 /// use burn_tensor::backend::Backend;
124 /// use burn_tensor::{Tensor, Shape};
125 ///
126 /// fn example<B: Backend>() {
127 /// let device = B::Device::default();
128 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
129 /// let scalar = 2.0;
130 /// let tensor = tensor - scalar;
131 /// println!("{tensor}");
132 /// // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
133 /// }
134 /// ```
135 pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
136 Self::new(K::sub_scalar::<E>(self.primitive, other))
137 }
138
139 /// Applies element wise division operation.
140 ///
141 /// `y = x2 / x1`
142 ///
143 /// # Arguments
144 ///
145 /// * `other` - The tensor to divide.
146 ///
147 /// # Example
148 ///
149 /// ```rust
150 /// use burn_tensor::backend::Backend;
151 /// use burn_tensor::{Tensor, Shape};
152 ///
153 /// fn example<B: Backend>() {
154 /// let device = B::Device::default();
155 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
156 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
157 /// let tensor = tensor1 / tensor2;
158 /// println!("{tensor}");
159 /// // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
160 /// }
161 /// ```
162 #[allow(clippy::should_implement_trait)]
163 pub fn div(self, other: Self) -> Self {
164 check!(TensorCheck::binary_ops_ew("Div", &self, &other));
165 Self::new(K::div(self.primitive, other.primitive))
166 }
167
168 /// Applies element wise division operation with a scalar.
169 ///
170 /// `y = x / s`
171 ///
172 /// # Arguments
173 ///
174 /// * `other` - The scalar to divide, element wise.
175 ///
176 /// # Example
177 ///
178 /// ```rust
179 /// use burn_tensor::backend::Backend;
180 /// use burn_tensor::{Tensor, Shape};
181 ///
182 /// fn example<B: Backend>() {
183 /// let device = B::Device::default();
184 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
185 /// let scalar = 2.0;
186 /// let tensor = tensor / scalar;
187 /// println!("{tensor}");
188 /// // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
189 /// }
190 /// ```
191 pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
192 Self::new(K::div_scalar::<E>(self.primitive, other))
193 }
194
195 /// Applies element wise the remainder operation with a scalar.
196 ///
197 /// `y = x2 % x1`
198 pub fn remainder(self, other: Self) -> Self {
199 Self::new(K::remainder(self.primitive, other.primitive))
200 }
201
202 /// Applies element wise the remainder operation with a scalar.
203 ///
204 /// `y = x % s`
205 ///
206 /// # Arguments
207 ///
208 /// * `other` - The scalar to divide, element wise.
209 ///
210 /// # Example
211 ///
212 /// ```rust
213 /// use burn_tensor::backend::Backend;
214 /// use burn_tensor::{Tensor, Shape};
215 ///
216 /// fn example<B: Backend>() {
217 /// let device = B::Device::default();
218 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
219 /// let scalar = 2.0;
220 /// let tensor = tensor1 % scalar;
221 /// println!("{tensor}");
222 /// // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
223 /// }
224 /// ```
225 pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
226 Self::new(K::remainder_scalar::<E>(self.primitive, other))
227 }
228
229 /// Applies element wise multiplication operation.
230 ///
231 /// `y = x2 * x1`
232 ///
233 /// # Arguments
234 ///
235 /// * `other` - The tensor to multiply.
236 ///
237 /// # Example
238 ///
239 /// ```rust
240 /// use burn_tensor::backend::Backend;
241 /// use burn_tensor::{Tensor, Shape};
242 ///
243 /// fn example<B: Backend>() {
244 /// let device = B::Device::default();
245 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
246 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
247 /// let tensor = tensor1 * tensor2;
248 /// println!("{tensor}");
249 /// // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
250 /// }
251 /// ```
252 #[allow(clippy::should_implement_trait)]
253 pub fn mul(self, other: Self) -> Self {
254 check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
255 Self::new(K::mul(self.primitive, other.primitive))
256 }
257
258 /// Applies element wise multiplication operation with a scalar.
259 ///
260 /// `y = x * s`
261 ///
262 /// # Arguments
263 ///
264 /// * `other` - The scalar to multiply, element wise.
265 ///
266 /// # Example
267 ///
268 /// ```rust
269 /// use burn_tensor::backend::Backend;
270 /// use burn_tensor::{Tensor, Shape};
271 ///
272 /// fn example<B: Backend>() {
273 /// let device = B::Device::default();
274 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
275 /// let scalar = 2.0;
276 /// let tensor = tensor * scalar;
277 /// println!("{tensor}");
278 /// // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
279 /// }
280 /// ```
281 pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
282 Self::new(K::mul_scalar::<E>(self.primitive, other))
283 }
284
285 /// Switch sign of each element in the tensor.
286 ///
287 /// `y = -x`
288 ///
289 /// # Example
290 ///
291 /// ```rust
292 /// use burn_tensor::backend::Backend;
293 /// use burn_tensor::{Tensor, Shape};
294 ///
295 /// fn example<B: Backend>() {
296 /// let device = B::Device::default();
297 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
298 /// let tensor = -tensor;
299 /// println!("{tensor}");
300 /// // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
301 /// }
302 /// ```
303 #[allow(clippy::should_implement_trait)]
304 pub fn neg(self) -> Self {
305 Self::new(K::neg(self.primitive))
306 }
307
308 /// Returns the signs of the elements of the input tensor.
309 ///
310 /// # Example
311 ///
312 /// ```rust
313 /// use burn_tensor::backend::Backend;
314 /// use burn_tensor::{Tensor, Shape};
315 ///
316 /// fn example<B: Backend>() {
317 /// let device = B::Device::default();
318 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
319 /// let tensor = tensor.sign();
320 /// println!("{tensor}");
321 /// // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
322 /// }
323 /// ```
324 pub fn sign(self) -> Self {
325 Self::new(K::sign(self.primitive))
326 }
327
328 /// Create a tensor of the given shape where each element is zero.
329 ///
330 /// # Example
331 ///
332 /// ```rust
333 /// use burn_tensor::backend::Backend;
334 /// use burn_tensor::{Tensor, Shape};
335 ///
336 /// fn example<B: Backend>() {
337 /// let device = B::Device::default();
338 /// let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
339 /// println!("{tensor}");
340 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
341 /// }
342 /// ```
343 pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
344 let shape = shape.into();
345 check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
346 Self::new(K::zeros(shape, device))
347 }
348
349 /// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
350 ///
351 /// # Example
352 ///
353 /// ```rust
354 /// use burn_tensor::backend::Backend;
355 /// use burn_tensor::{Tensor, Shape};
356 ///
357 /// fn example<B: Backend>() {
358 /// let device = B::Device::default();
359 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
360 /// let tensor = tensor.zeros_like();
361 /// println!("{tensor}");
362 /// // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
363 /// }
364 /// ```
365 pub fn zeros_like(&self) -> Self {
366 Self::zeros(self.shape(), &self.device())
367 }
368
369 /// Create a tensor of the given shape where each element is one.
370 ///
371 /// # Example
372 ///
373 /// ```rust
374 /// use burn_tensor::backend::Backend;
375 /// use burn_tensor::{Tensor, Shape};
376 ///
377 /// fn example<B: Backend>() {
378 /// let device = B::Device::default();
379 /// let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
380 /// println!("{tensor}");
381 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
382 /// }
383 /// ```
384 pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
385 let shape = shape.into();
386 check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
387 Self::new(K::ones(shape, device))
388 }
389
390 /// Returns a new tensor with the same shape and device as the current tensor filled with ones.
391 ///
392 /// # Example
393 ///
394 /// ```rust
395 /// use burn_tensor::backend::Backend;
396 /// use burn_tensor::{Tensor, Shape};
397 ///
398 /// fn example<B: Backend>() {
399 /// let device = B::Device::default();
400 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
401 /// let tensor = tensor.ones_like();
402 /// println!("{tensor}");
403 /// // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
404 /// }
405 /// ```
406 pub fn ones_like(&self) -> Self {
407 Self::ones(self.shape(), &self.device())
408 }
409
410 /// Create a tensor of the given shape where each element is equal to the provided value.
411 ///
412 /// # Example
413 ///
414 /// ```rust
415 /// use burn_tensor::backend::Backend;
416 /// use burn_tensor::{Tensor, Shape};
417 ///
418 /// fn example<B: Backend>() {
419 /// let device = B::Device::default();
420 /// let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
421 /// println!("{tensor}");
422 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
423 /// }
424 /// ```
425 pub fn full<S: Into<Shape>, E: ElementConversion>(
426 shape: S,
427 fill_value: E,
428 device: &B::Device,
429 ) -> Self {
430 let shape = shape.into();
431 check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
432 Self::new(K::full(shape, fill_value, device))
433 }
434
435 ///Returns a new tensor with the same shape and device as the current tensor filled with the provided value.
436 ///
437 /// # Example
438 ///
439 /// ```rust
440 /// use burn_tensor::backend::Backend;
441 /// use burn_tensor::{Tensor, Shape};
442 ///
443 /// fn example<B: Backend>() {
444 /// let device = B::Device::default();
445 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
446 /// let tensor = tensor.full_like(5.0);
447 /// println!("{tensor}");
448 /// // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
449 /// }
450 /// ```
451 pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
452 Self::full(self.shape(), fill_value, &self.device())
453 }
454
455 /// Aggregate all elements in the tensor with the mean operation.
456 ///
457 /// # Example
458 ///
459 /// ```rust
460 /// use burn_tensor::backend::Backend;
461 /// use burn_tensor::{Tensor, Shape};
462 ///
463 /// fn example<B: Backend>() {
464 /// let device = B::Device::default();
465 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
466 /// let tensor = tensor.mean();
467 /// println!("{tensor}");
468 /// // [3.6666667]
469 /// }
470 /// ```
471 pub fn mean(self) -> Tensor<B, 1, K> {
472 Tensor::new(K::mean(self.primitive))
473 }
474
475 /// Aggregate all elements in the tensor with the sum operation.
476 ///
477 /// # Example
478 ///
479 /// ```rust
480 /// use burn_tensor::backend::Backend;
481 /// use burn_tensor::{Tensor, Shape};
482 ///
483 /// fn example<B: Backend>() {
484 /// let device = B::Device::default();
485 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
486 /// let tensor = tensor.sum();
487 /// println!("{tensor}");
488 /// // [22.0]
489 /// }
490 /// ```
491 pub fn sum(self) -> Tensor<B, 1, K> {
492 Tensor::new(K::sum(self.primitive))
493 }
494
495 /// Aggregate all elements along the given *dimension* or *axis*
496 /// in the tensor with the mean operation.
497 ///
498 /// # Arguments
499 ///
500 /// * `dim` - The dimension or axis along which to aggregate the elements.
501 ///
502 /// # Example
503 ///
504 /// ```rust
505 /// use burn_tensor::backend::Backend;
506 /// use burn_tensor::{Tensor, Shape};
507 ///
508 /// fn example<B: Backend>() {
509 /// let device = B::Device::default();
510 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
511 /// let tensor = tensor.clone().mean_dim(0);
512 /// println!("{tensor}");
513 /// // [[3.0, 3.5, 4.5]]
514 /// let tensor = tensor.clone().mean_dim(1);
515 /// println!("{tensor}");
516 /// // [[0.6666667], [6.6666665]]
517 /// }
518 /// ```
519 pub fn mean_dim(self, dim: usize) -> Self {
520 check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
521 Self::new(K::mean_dim(self.primitive, dim))
522 }
523
524 /// Aggregate all elements along the given *dimension* or *axis*
525 /// in the tensor with the sum operation.
526 ///
527 /// # Arguments
528 ///
529 /// * `dim` - The dimension or axis along which to aggregate the elements.
530 ///
531 /// # Example
532 ///
533 /// ```rust
534 /// use burn_tensor::backend::Backend;
535 /// use burn_tensor::{Tensor, Shape};
536 ///
537 /// fn example<B: Backend>() {
538 /// let device = B::Device::default();
539 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
540 /// let tensor = tensor.clone().sum_dim(0);
541 /// println!("{tensor}");
542 /// // [[6.0, 7.0, 9.0]]
543 /// let tensor = tensor.clone().sum_dim(1);
544 /// println!("{tensor}");
545 /// // [[2.0], [20.0]]
546 /// }
547 /// ```
548 pub fn sum_dim(self, dim: usize) -> Self {
549 check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
550 Self::new(K::sum_dim(self.primitive, dim))
551 }
552
553 /// Aggregate all elements in the tensor with the product operation.
554 ///
555 /// # Example
556 ///
557 /// ```rust
558 /// use burn_tensor::backend::Backend;
559 /// use burn_tensor::{Tensor, Shape};
560 ///
561 /// fn example<B: Backend>() {
562 /// let device = B::Device::default();
563 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
564 /// let tensor = tensor.prod();
565 /// println!("{tensor}");
566 /// // [-1620.0]
567 /// }
568 /// ```
569 pub fn prod(self) -> Tensor<B, 1, K> {
570 Tensor::new(K::prod(self.primitive))
571 }
572
573 /// Aggregate all elements along the given *dimension* or *axis*
574 /// in the tensor with the product operation.
575 ///
576 /// # Arguments
577 ///
578 /// * `dim` - The dimension or axis along which to aggregate the elements.
579 ///
580 /// # Example
581 ///
582 /// ```rust
583 /// use burn_tensor::backend::Backend;
584 /// use burn_tensor::{Tensor, Shape};
585 ///
586 /// fn example<B: Backend>() {
587 /// let device = B::Device::default();
588 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
589 /// let tensor = tensor.clone().prod_dim(0);
590 /// println!("{tensor}");
591 /// // [[5.0, -18.0, 18.0]]
592 /// let tensor = tensor.clone().prod_dim(1);
593 /// println!("{tensor}");
594 /// // [[-6.0], [270.0]]
595 /// }
596 /// ```
597 pub fn prod_dim(self, dim: usize) -> Self {
598 check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
599 Self::new(K::prod_dim(self.primitive, dim))
600 }
601
602 /// Applies element wise equal comparison and returns a boolean tensor.
603 ///
604 /// # Arguments
605 ///
606 /// * `other` - The element to compare.
607 ///
608 /// # Example
609 ///
610 /// ```rust
611 /// use burn_tensor::backend::Backend;
612 /// use burn_tensor::{Tensor, Shape};
613 ///
614 /// fn example<B: Backend>() {
615 /// let device = B::Device::default();
616 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
617 /// let tensor = tensor.equal_elem(3.0);
618 /// println!("{tensor}");
619 /// // [[false, false, true], [false, false, false]]
620 /// }
621 /// ```
622 pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
623 Tensor::new(K::equal_elem(self.primitive, other.elem()))
624 }
625
626 /// Applies element wise non-equality comparison and returns a boolean tensor.
627 ///
628 /// # Arguments
629 ///
630 /// * `other` - The element to compare.
631 ///
632 /// # Example
633 ///
634 /// ```rust
635 /// use burn_tensor::backend::Backend;
636 /// use burn_tensor::{Tensor, Shape};
637 ///
638 /// fn example<B: Backend>() {
639 /// let device = B::Device::default();
640 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
641 /// let tensor = tensor.not_equal_elem(3.0);
642 /// println!("{tensor}");
643 /// // [[true, true, false], [true, true, true]]
644 /// }
645 /// ```
646 pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
647 Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
648 }
649
650 /// Applies element wise greater comparison and returns a boolean tensor.
651 ///
652 /// # Panics
653 ///
654 /// If the two tensors don't have the same shape.
655 ///
656 /// # Example
657 ///
658 /// ```rust
659 /// use burn_tensor::backend::Backend;
660 /// use burn_tensor::{Tensor, Shape};
661 ///
662 /// fn example<B: Backend>() {
663 /// let device = B::Device::default();
664 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
665 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
666 /// let tensor = tensor1.greater(tensor2);
667 /// println!("{tensor}");
668 /// // [[false, false, false], [true, true, true]]
669 /// }
670 /// ```
671 pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
672 check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
673 Tensor::new(K::greater(self.primitive, other.primitive))
674 }
675
676 /// Applies element wise greater-equal comparison and returns a boolean tensor.
677 ///
678 /// # Panics
679 ///
680 /// If the two tensors don't have the same shape.
681 ///
682 /// # Example
683 ///
684 /// ```rust
685 /// use burn_tensor::backend::Backend;
686 /// use burn_tensor::{Tensor, Shape};
687 ///
688 /// fn example<B: Backend>() {
689 /// let device = B::Device::default();
690 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
691 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
692 /// let tensor = tensor1.greater_equal(tensor2);
693 /// println!("{tensor}");
694 /// // [[true, false, false], [true, true, true]]
695 /// }
696 /// ```
697 pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
698 check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
699 Tensor::new(K::greater_equal(self.primitive, other.primitive))
700 }
701
702 /// Applies element wise lower comparison and returns a boolean tensor.
703 ///
704 /// # Panics
705 ///
706 /// If the two tensors don't have the same shape.
707 ///
708 /// # Example
709 ///
710 /// ```rust
711 /// use burn_tensor::backend::Backend;
712 /// use burn_tensor::{Tensor, Shape};
713 ///
714 /// fn example<B: Backend>() {
715 /// let device = B::Device::default();
716 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
717 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
718 /// let tensor = tensor1.lower(tensor2);
719 /// println!("{tensor}");
720 /// // [[false, true, true], [false, false, false]]
721 /// }
722 /// ```
723 pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
724 check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
725 Tensor::new(K::lower(self.primitive, other.primitive))
726 }
727
728 /// Applies element wise lower-equal comparison and returns a boolean tensor.
729 ///
730 /// # Panics
731 ///
732 /// If the two tensors don't have the same shape.
733 ///
734 /// # Example
735 ///
736 /// ```rust
737 /// use burn_tensor::backend::Backend;
738 /// use burn_tensor::{Tensor, Shape};
739 ///
740 /// fn example<B: Backend>() {
741 /// let device = B::Device::default();
742 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
743 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
744 /// let tensor = tensor1.lower_equal(tensor2);
745 /// println!("{tensor}");
746 /// // [[true, true, true], [false, false, false]]
747 /// }
748 /// ```
749 pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
750 check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
751 Tensor::new(K::lower_equal(self.primitive, other.primitive))
752 }
753
754 /// Applies greater than `other` comparison and returns a boolean tensor.
755 ///
756 /// # Arguments
757 ///
758 /// * `other` - The element to compare.
759 ///
760 /// # Example
761 ///
762 /// ```rust
763 /// use burn_tensor::backend::Backend;
764 /// use burn_tensor::{Tensor, Shape};
765 ///
766 /// fn example<B: Backend>() {
767 /// let device = B::Device::default();
768 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
769 /// let tensor = tensor.greater_elem(3.0);
770 /// println!("{tensor}");
771 /// // [[false, false, true], [true, true, true]]
772 /// }
773 /// ```
774 pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
775 Tensor::new(K::greater_elem(self.primitive, other.elem()))
776 }
777
778 /// Applies greater-equal than `other` comparison and returns a boolean tensor.
779 ///
780 /// # Arguments
781 ///
782 /// * `other` - The element to compare.
783 ///
784 /// # Example
785 ///
786 /// ```rust
787 /// use burn_tensor::backend::Backend;
788 /// use burn_tensor::{Tensor, Shape};
789 ///
790 /// fn example<B: Backend>() {
791 /// let device = B::Device::default();
792 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
793 /// let tensor = tensor.greater_equal_elem(3.0);
794 /// println!("{tensor}");
795 /// // [[false, false, true], [true, true, true]]
796 /// }
797 /// ```
798 pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
799 Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
800 }
801
802 /// Applies lower than `other` comparison and returns a boolean tensor.
803 ///
804 /// # Arguments
805 ///
806 /// * `other` - The element to compare.
807 ///
808 /// # Example
809 ///
810 /// ```rust
811 /// use burn_tensor::backend::Backend;
812 /// use burn_tensor::{Tensor, Shape};
813 ///
814 /// fn example<B: Backend>() {
815 /// let device = B::Device::default();
816 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
817 /// let tensor = tensor.lower_elem(3.0);
818 /// println!("{tensor}");
819 /// // [[true, true, false], [false, false, false]]
820 /// }
821 /// ```
822 pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
823 Tensor::new(K::lower_elem(self.primitive, other.elem()))
824 }
825
826 /// Applies lower-equal than `other` comparison and returns a boolean tensor.
827 ///
828 /// # Arguments
829 ///
830 /// * `other` - The element to compare.
831 ///
832 /// # Example
833 ///
834 /// ```rust
835 /// use burn_tensor::backend::Backend;
836 /// use burn_tensor::{Tensor, Shape};
837 ///
838 /// fn example<B: Backend>() {
839 /// let device = B::Device::default();
840 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
841 /// let tensor = tensor.lower_equal_elem(3.0);
842 /// println!("{tensor}");
843 /// // [[true, true, true], [false, false, false]]
844 /// }
845 /// ```
846 pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
847 Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
848 }
849
850 /// Update the given tensor with the value tensor where the mask is true.
851 ///
852 /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
853 /// a scalar.
854 ///
855 /// # Example
856 ///
857 /// ```rust
858 /// use burn_tensor::backend::Backend;
859 /// use burn_tensor::{Tensor, Shape, Bool};
860 ///
861 /// fn example<B: Backend>() {
862 /// let device = B::Device::default();
863 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
864 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
865 /// let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
866 /// let tensor = tensor.mask_where(mask, value);
867 /// println!("{tensor}");
868 /// // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
869 /// }
870 /// ```
871 pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
872 Self::new(K::mask_where(
873 self.primitive,
874 mask.primitive,
875 value.primitive,
876 ))
877 }
878
879 /// Update the given tensor with the value where the mask is true.
880 ///
881 /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
882 /// a tensor.
883 ///
884 /// # Example
885 ///
886 /// ```rust
887 /// use burn_tensor::backend::Backend;
888 /// use burn_tensor::{Tensor, Shape, Bool};
889 ///
890 /// fn example<B: Backend>() {
891 /// let device = B::Device::default();
892 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
893 /// let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
894 /// let tensor = tensor.mask_fill(mask, 3.0);
895 /// println!("{tensor}");
896 /// // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
897 /// }
898 /// ```
899 pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
900 Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
901 }
902
903 /// Gather tensor elements corresponding to the given indices from the specified dim.
904 ///
905 /// Example using a 3D tensor:
906 ///
907 /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
908 /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
909 /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
910 ///
911 /// # Notes
912 ///
913 /// The index tensor should have the same shape as the original tensor except for the dim
914 /// specified.
915 ///
916 /// # Warning
917 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
918 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
919 pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
920 check!(TensorCheck::gather::<D>(
921 dim,
922 &self.shape(),
923 &indices.shape()
924 ));
925
926 Self::new(K::gather(dim, self.primitive, indices.primitive))
927 }
928
929 /// Assign the gathered elements corresponding to the given indices along the specified dimension
930 /// from the value tensor to the original tensor using sum reduction.
931 ///
932 /// Example using a 3D tensor:
933 ///
934 /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
935 /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
936 /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
937 ///
938 /// # Notes
939 ///
940 /// The index tensor should have the same shape as the original tensor except for the specified
941 /// dimension. The value and index tensors should have the same shape.
942 ///
943 /// Other references to the input tensor will not be modified by this operation.
944 ///
945 /// # Warning
946 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
947 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
948 pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
949 check!(TensorCheck::scatter::<D>(
950 dim,
951 &self.shape(),
952 &indices.shape(),
953 &values.shape()
954 ));
955
956 Self::new(K::scatter(
957 dim,
958 self.primitive,
959 indices.primitive,
960 values.primitive,
961 ))
962 }
963
964 /// Select the tensor elements along the given dimension corresponding to the given indices.
965 ///
966 /// Example using a 3D tensor:
967 ///
968 /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
969 /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
970 /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
971 ///
972 /// # Warning
973 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
974 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
975 ///
976 /// # Example
977 ///
978 /// ```rust
979 /// use burn_tensor::backend::Backend;
980 /// use burn_tensor::{Tensor, Shape, Int};
981 ///
982 /// fn example<B: Backend>() {
983 /// let device = B::Device::default();
984 /// let tensor = Tensor::<B, 3>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
985 /// let indices = Tensor::<B, 1, Int>::from_data([0], &device);
986 /// let tensor = tensor.select(0, indices);
987 /// println!("{tensor}");
988 /// // [[1.0, -2.0, 3.0]]
989 /// }
990 /// ```
991 pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
992 check!(TensorCheck::select::<D>(dim));
993 Self::new(K::select(self.primitive, dim, indices))
994 }
995
996 /// Assign the selected elements along the given dimension corresponding to the given indices
997 /// from the value tensor to the original tensor using sum reduction.
998 ///
999 /// Example using a 3D tensor:
1000 ///
1001 /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
1002 /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
1003 /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
1004 ///
1005 /// # Warning
1006 /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1007 /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1008 pub fn select_assign(
1009 self,
1010 dim: usize,
1011 indices: Tensor<B, 1, Int>,
1012 values: Tensor<B, D, K>,
1013 ) -> Self {
1014 check!(TensorCheck::select_assign::<D>(dim));
1015
1016 Self::new(K::select_assign(
1017 self.primitive,
1018 dim,
1019 indices,
1020 values.primitive,
1021 ))
1022 }
1023
1024 /// Applies the argmax function along the given dimension and returns an integer tensor.
1025 ///
1026 /// # Example
1027 ///
1028 /// ```rust
1029 /// use burn_tensor::backend::Backend;
1030 /// use burn_tensor::{Tensor, Shape};
1031 ///
1032 /// fn example<B: Backend>() {
1033 /// let device = B::Device::default();
1034 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1035 /// let tensor = tensor.argmax(1);
1036 /// println!("{:?}", tensor.shape());
1037 /// // Shape { dims: [2, 1, 3] }
1038 /// }
1039 /// ```
1040 pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1041 Tensor::new(K::argmax(self.primitive, dim))
1042 }
1043
1044 /// Find the maximum value.
1045 ///
1046 /// # Example
1047 ///
1048 /// ```rust
1049 /// use burn_tensor::backend::Backend;
1050 /// use burn_tensor::{Tensor, Shape};
1051 ///
1052 /// fn example<B: Backend>() {
1053 /// let device = B::Device::default();
1054 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1055 /// let tensor = tensor.max();
1056 /// println!("{tensor}");
1057 /// // [9.0]
1058 /// }
1059 /// ```
1060 pub fn max(self) -> Tensor<B, 1, K> {
1061 Tensor::new(K::max(self.primitive))
1062 }
1063
1064 /// Find the maximum value along the given dimension.
1065 ///
1066 /// # Example
1067 ///
1068 /// ```rust
1069 /// use burn_tensor::backend::Backend;
1070 /// use burn_tensor::{Tensor, Shape};
1071 ///
1072 /// fn example<B: Backend>() {
1073 /// let device = B::Device::default();
1074 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1075 /// let tensor = tensor.max_dim(0);
1076 /// println!("{tensor}");
1077 /// // [[5.0, 9.0, 6.0]]
1078 /// }
1079 /// ```
1080 pub fn max_dim(self, dim: usize) -> Tensor<B, D, K> {
1081 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1082
1083 Tensor::new(K::max_dim(self.primitive, dim))
1084 }
1085
1086 /// Find the maximum value along the given dimension.
1087 ///
1088 /// Also returns the indices.
1089 ///
1090 /// # Example
1091 ///
1092 /// ```rust
1093 /// use burn_tensor::backend::Backend;
1094 /// use burn_tensor::{Tensor, Shape};
1095 ///
1096 /// fn example<B: Backend>() {
1097 /// let device = B::Device::default();
1098 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1099 /// let (tensor, index) = tensor.max_dim_with_indices(0);
1100 /// // [[5.0, 9.0, 6.0]]
1101 /// println!("{tensor}");
1102 /// // [[1, 1, 1]]
1103 /// println!("{index}");
1104 /// }
1105 /// ```
1106 pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1107 check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1108
1109 let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1110
1111 let tensor = Tensor::new(tensor);
1112 let index = Tensor::new(index);
1113
1114 (tensor, index)
1115 }
1116
1117 /// Finds the maximum pair wise values with another tensor.
1118 ///
1119 /// # Arguments
1120 ///
1121 /// * `other` - Other tensor to find maximum elements with
1122 ///
1123 /// # Returns
1124 ///
1125 /// A tensor with the same shape as the input tensors containing the maximum value found
1126 /// in the input tensors.
1127 ///
1128 /// # Example
1129 ///
1130 /// ```rust
1131 /// use burn_tensor::backend::Backend;
1132 /// use burn_tensor::{Tensor, Shape};
1133 ///
1134 /// fn example<B: Backend>() {
1135 /// let device = B::Device::default();
1136 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1137 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1138 /// let tensor = tensor1.max_pair(tensor2);
1139 /// println!("{tensor}");
1140 /// // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1141 /// }
1142 /// ```
1143 pub fn max_pair(self, other: Self) -> Self {
1144 let mask = self.clone().lower(other.clone());
1145 self.mask_where(mask, other)
1146 }
1147
1148 /// Find the maximum absolute value.
1149 ///
1150 /// # Example
1151 ///
1152 /// ```rust
1153 /// use burn_tensor::backend::Backend;
1154 /// use burn_tensor::{Tensor, Shape};
1155 ///
1156 /// fn example<B: Backend>() {
1157 /// let device = B::Device::default();
1158 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);
1159 /// let tensor = tensor.max_abs();
1160 /// println!("{tensor}");
1161 /// // [7.0]
1162 /// }
1163 /// ```
1164 pub fn max_abs(self) -> Tensor<B, 1, K> {
1165 Tensor::new(K::max_abs(self.primitive))
1166 }
1167
1168 /// Find the maximum absolute value along the given dimension.
1169 ///
1170 /// # Example
1171 ///
1172 /// ```rust
1173 /// use burn_tensor::backend::Backend;
1174 /// use burn_tensor::{Tensor, Shape};
1175 ///
1176 /// fn example<B: Backend>() {
1177 /// let device = B::Device::default();
1178 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1179 /// let tensor = tensor.max_dim(0);
1180 /// println!("{tensor}");
1181 /// // [[5.0, 9.0, 6.0]]
1182 /// }
1183 /// ```
1184 pub fn max_abs_dim(self, dim: usize) -> Tensor<B, D, K> {
1185 check!(TensorCheck::aggregate_dim::<D>("MaxAbs", dim));
1186
1187 Tensor::new(K::max_abs_dim(self.primitive, dim))
1188 }
1189
1190 /// Applies the argmin function along the given dimension and returns an integer tensor.
1191 ///
1192 /// # Example
1193 ///
1194 /// ```rust
1195 /// use burn_tensor::backend::Backend;
1196 /// use burn_tensor::{Tensor, Shape};
1197 ///
1198 /// fn example<B: Backend>() {
1199 /// let device = Default::default();
1200 /// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1201 /// let tensor = tensor.argmin(1);
1202 /// println!("{:?}", tensor.shape());
1203 /// // Shape { dims: [2, 1, 3] }
1204 /// }
1205 /// ```
1206 pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1207 Tensor::new(K::argmin(self.primitive, dim))
1208 }
1209
1210 /// Find the minimum value.
1211 ///
1212 /// # Example
1213 ///
1214 /// ```rust
1215 /// use burn_tensor::backend::Backend;
1216 /// use burn_tensor::{Tensor, Shape};
1217 ///
1218 /// fn example<B: Backend>() {
1219 /// let device = B::Device::default();
1220 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1221 /// let tensor = tensor.min();
1222 /// println!("{tensor}");
1223 /// // [-2.0]
1224 /// }
1225 /// ```
1226 pub fn min(self) -> Tensor<B, 1, K> {
1227 Tensor::new(K::min(self.primitive))
1228 }
1229
1230 /// Find the minimum value along the given dimension.
1231 ///
1232 /// # Example
1233 ///
1234 /// ```rust
1235 /// use burn_tensor::backend::Backend;
1236 /// use burn_tensor::{Tensor, Shape};
1237 ///
1238 /// fn example<B: Backend>() {
1239 /// let device = B::Device::default();
1240 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1241 /// let tensor = tensor.min_dim(0);
1242 /// println!("{tensor}");
1243 /// // [[1.0, -2.0, 3.0]]
1244 /// }
1245 /// ```
1246 pub fn min_dim(self, dim: usize) -> Tensor<B, D, K> {
1247 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1248 Tensor::new(K::min_dim(self.primitive, dim))
1249 }
1250
1251 /// Find the minimum value along the given dimension.
1252 ///
1253 /// Also returns the indices.
1254 ///
1255 /// # Example
1256 ///
1257 /// ```rust
1258 /// use burn_tensor::backend::Backend;
1259 /// use burn_tensor::{Tensor, Shape};
1260 ///
1261 /// fn example<B: Backend>() {
1262 /// let device = B::Device::default();
1263 /// let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1264 /// let (tensor, index) = tensor.min_dim_with_indices(0);
1265 /// println!("{tensor}");
1266 /// // [[5.0, -2.0, 3.0]]
1267 /// println!("{}", index);
1268 /// // [[1, 0, 0]]
1269 /// }
1270 /// ```
1271 pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1272 check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1273
1274 let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1275
1276 let tensor = Tensor::new(tensor);
1277 let index = Tensor::new(index);
1278
1279 (tensor, index)
1280 }
1281
1282 /// Finds the minimum pair wise values with another tensor.
1283 ///
1284 /// # Arguments
1285 ///
1286 /// * `other` - Other tensor to find minimum elements with
1287 ///
1288 /// # Returns
1289 ///
1290 /// A tensor with the same shape as the input tensors containing the minimum value found
1291 /// between each element of the two source tensors.
1292 ///
1293 /// # Example
1294 ///
1295 /// ```rust
1296 /// use burn_tensor::backend::Backend;
1297 /// use burn_tensor::{Tensor, Shape};
1298 ///
1299 /// fn example<B: Backend>() {
1300 /// let device = B::Device::default();
1301 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1302 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1303 /// let tensor = tensor1.min_pair(tensor2);
1304 /// println!("{tensor}");
1305 /// // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1306 /// }
1307 pub fn min_pair(self, other: Self) -> Self {
1308 let mask = other.clone().lower(self.clone());
1309 self.mask_where(mask, other)
1310 }
1311
1312 /// Clamp element wise between the given min and max values.
1313 ///
1314 /// # Arguments
1315 ///
1316 /// * `min` - The minimum value.
1317 /// * `max` - The maximum value.
1318 ///
1319 /// # Returns
1320 ///
1321 /// A new tensor with the values clamped between the given min and max values.
1322 ///
1323 /// # Example
1324 ///
1325 /// ```rust
1326 /// use burn_tensor::backend::Backend;
1327 /// use burn_tensor::{Int, Tensor};
1328 ///
1329 /// fn example<B: Backend>() {
1330 /// let device = Default::default();
1331 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1332 /// [
1333 /// [1, 2, 3],
1334 /// [4, 5, 6],
1335 /// [7, 8, 9]
1336 /// ],
1337 /// &device);
1338 /// let tensor = tensor.clamp(2, 6);
1339 /// println!("{tensor}");
1340 /// // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1341 /// }
1342 /// ```
1343 pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1344 Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1345 }
1346
1347 /// Clamp element wise under a minimum value.
1348 ///
1349 /// # Arguments
1350 ///
1351 /// * `tensor` - The tensor to clamp.
1352 /// * `min` - The minimum value.
1353 ///
1354 /// # Returns
1355 ///
1356 /// A new tensor with the values clamped under the given min value.
1357 ///
1358 /// # Example
1359 ///
1360 /// ```rust
1361 /// use burn_tensor::backend::Backend;
1362 /// use burn_tensor::{Int, Tensor};
1363 ///
1364 /// fn example<B: Backend>() {
1365 /// let device = Default::default();
1366 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1367 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1368 /// &device);
1369 /// let tensor = tensor.clamp_min(4);
1370 /// println!("{tensor}");
1371 /// // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1372 /// }
1373 /// ```
1374 pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1375 Self::new(K::clamp_min(self.primitive, min.elem()))
1376 }
1377
1378 /// Clamp element wise over a maximum value.
1379 ///
1380 /// # Arguments
1381 ///
1382 /// * `tensor` - The tensor to clamp.
1383 /// * `max` - The maximum value.
1384 ///
1385 /// # Returns
1386 ///
1387 /// A new tensor with the values clamped over the given max value.
1388 ///
1389 /// # Example
1390 ///
1391 /// ```rust
1392 /// use burn_tensor::backend::Backend;
1393 /// use burn_tensor::{Int, Tensor};
1394 ///
1395 /// fn example<B: Backend>() {
1396 /// let device = Default::default();
1397 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1398 /// [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1399 /// &device);
1400 /// let tensor = tensor.clamp_max(5);
1401 /// println!("{tensor}");
1402 /// // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1403 /// }
1404 /// ```
1405 pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1406 Self::new(K::clamp_max(self.primitive, max.elem()))
1407 }
1408
1409 /// Apply element wise absolute value operation.
1410 ///
1411 /// # Example
1412 ///
1413 /// ```rust
1414 /// use burn_tensor::backend::Backend;
1415 /// use burn_tensor::{Int, Tensor};
1416 ///
1417 /// fn example<B: Backend>() {
1418 /// let device = Default::default();
1419 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1420 /// let tensor = tensor.abs();
1421 /// println!("{tensor}");
1422 /// // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1423 /// }
1424 /// ```
1425 pub fn abs(self) -> Self {
1426 Self::new(K::abs(self.primitive))
1427 }
1428
1429 /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1430 /// the other elements of the result tensor out are set to 0.
1431 ///
1432 /// See also [`triu_mask`](Tensor::triu_mask).
1433 ///
1434 /// # Arguments
1435 ///
1436 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1437 /// towards the upper triangle.
1438 ///
1439 /// # Example
1440 /// ```rust
1441 /// use burn_tensor::backend::Backend;
1442 /// use burn_tensor::{Int, Tensor};
1443 ///
1444 /// fn example<B: Backend>() {
1445 /// let device = Default::default();
1446 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1447 /// [
1448 /// [1, 2, 3],
1449 /// [4, 5, 6],
1450 /// [7, 8, 9]
1451 /// ],
1452 /// &device
1453 /// );
1454 /// let tensor = tensor.triu(1);
1455 /// println!("{tensor}");
1456 /// // [
1457 /// // [0, 2, 3],
1458 /// // [0, 0, 6],
1459 /// // [0, 0, 0]
1460 /// // ]
1461 /// }
1462 /// ```
1463 pub fn triu(self, diagonal: i64) -> Self {
1464 check!(TensorCheck::tri::<{ D }>());
1465
1466 // last two dimensions
1467 let shape = &self.shape().dims[D - 2..].to_owned();
1468
1469 let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1470 self.mask_fill(mask, 0)
1471 }
1472
1473 /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1474 /// the other elements of the result tensor out are set to 0.
1475 ///
1476 /// See also [`tril_mask`](Tensor::tril_mask).
1477 ///
1478 /// # Arguments
1479 ///
1480 /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift
1481 /// towards the upper triangle.
1482 ///
1483 /// # Example
1484 /// ```rust
1485 /// use burn_tensor::backend::Backend;
1486 /// use burn_tensor::{Int, Tensor};
1487 ///
1488 /// fn example<B: Backend>() {
1489 /// let device = Default::default();
1490 /// let tensor = Tensor::<B, 2, Int>::from_ints(
1491 /// [
1492 /// [1, 2, 3],
1493 /// [4, 5, 6],
1494 /// [7, 8, 9]
1495 /// ],
1496 /// &device
1497 /// );
1498 ///
1499 /// let tensor = tensor.tril(-1);
1500 /// println!("{tensor}");
1501 /// // [
1502 /// // [0, 0, 0],
1503 /// // [4, 0, 0],
1504 /// // [7, 8, 0]
1505 /// // ]
1506 /// }
1507 /// ```
1508 pub fn tril(self, diagonal: i64) -> Self {
1509 check!(TensorCheck::tri::<{ D }>());
1510
1511 // last two dimensions
1512 let shape = &self.shape().dims[D - 2..].to_owned();
1513 let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1514
1515 self.mask_fill(mask, 0)
1516 }
1517
1518 /// Applies element wise power operation with a float Tensor
1519 ///
1520 /// # Arguments
1521 ///
1522 /// * `other` - The tensor to apply the power operation with.
1523 ///
1524 /// # Example
1525 ///
1526 /// ```rust
1527 /// use burn_tensor::backend::Backend;
1528 /// use burn_tensor::{Tensor, Shape};
1529 ///
1530 /// fn example<B: Backend>() {
1531 /// let device = B::Device::default();
1532 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1533 /// let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1534 /// let tensor = tensor1.powf(tensor2);
1535 /// println!("{tensor}");
1536 /// // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1537 /// }
1538 /// ```
1539 pub fn powf(self, other: Self) -> Self {
1540 Self::new(K::powf(self.primitive, other.primitive))
1541 }
1542
1543 /// Applies element wise power operation with a float scalar
1544 ///
1545 /// # Arguments
1546 ///
1547 /// * `other` - The scalar to apply the power operation with.
1548 ///
1549 /// # Example
1550 ///
1551 /// ```rust
1552 /// use burn_tensor::backend::Backend;
1553 /// use burn_tensor::{Tensor, Shape};
1554 ///
1555 /// fn example<B: Backend>() {
1556 /// let device = B::Device::default();
1557 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1558 /// let tensor = tensor.powf_scalar(2.0);
1559 /// println!("{tensor}");
1560 /// // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1561 /// }
1562 /// ```
1563 pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1564 Self::new(K::powf_scalar::<E>(self.primitive, other))
1565 }
1566
1567 /// Applies element wise power operation with a integer Tensor
1568 ///
1569 /// # Arguments
1570 ///
1571 /// * `other` - The tensor to apply the power operation with.
1572 ///
1573 /// # Example
1574 ///
1575 /// ```rust
1576 /// use burn_tensor::backend::Backend;
1577 /// use burn_tensor::{Tensor, Shape, Int};
1578 ///
1579 /// fn example<B: Backend>() {
1580 /// let device = B::Device::default();
1581 /// let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1582 /// let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1583 /// let tensor = tensor1.powi(tensor2);
1584 /// println!("{tensor}");
1585 /// // [[1, -8, 81], [5, 81, 216]]
1586 /// }
1587 /// ```
1588 pub fn powi(self, other: Self) -> Self {
1589 Self::new(K::powi(self.primitive, other.primitive))
1590 }
1591
1592 /// Applies element wise power operation with a integer scalar
1593 ///
1594 /// # Arguments
1595 ///
1596 /// * `other` - The scalar to apply the power operation with.
1597 ///
1598 /// # Example
1599 ///
1600 /// ```rust
1601 /// use burn_tensor::backend::Backend;
1602 /// use burn_tensor::{Tensor, Shape, Int};
1603 ///
1604 /// fn example<B: Backend>() {
1605 /// let device = B::Device::default();
1606 /// let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1607 /// let tensor = tensor.powi_scalar(2);
1608 /// println!("{tensor}");
1609 ///
1610 /// // [[1, 4, 9], [25, 81, 36]]
1611 /// let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);
1612 /// let tensor = tensor.powi_scalar(2);
1613 /// println!("{tensor}");
1614 /// // [[2.25, 4., 9.], [25., 81., 36.]]
1615 /// }
1616 /// ```
1617 pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1618 Self::new(K::powi_scalar::<E>(self.primitive, other))
1619 }
1620
1621 /// Checks element wise if the tensor is close to another tensor.
1622 ///
1623 /// The tolerance is defined by the following equation:
1624 ///
1625 /// ```text
1626 /// abs(a - b) <= (atol + rtol * abs(b))
1627 ///
1628 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1629 /// and `atol` is the absolute tolerance.
1630 /// ```
1631 ///
1632 /// # Arguments
1633 ///
1634 /// * `other` - The tensor to compare with.
1635 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
1636 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
1637 ///
1638 /// # Returns
1639 ///
1640 /// A boolean tensor with the same shape as the input tensors.
1641 ///
1642 /// # Example
1643 ///
1644 /// ```rust
1645 /// use burn_tensor::backend::Backend;
1646 /// use burn_tensor::{Tensor, Shape};
1647 ///
1648 /// fn example<B: Backend>() {
1649 /// let device = B::Device::default();
1650 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1651 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1652 /// let tensor = tensor1.is_close(tensor2, None, None);
1653 /// println!("{tensor}");
1654 /// // [[true, true, true], [true, true, true]]
1655 /// }
1656 /// ```
1657 pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
1658 let rtol = rtol.unwrap_or(DEFAULT_RTOL);
1659 let atol = atol.unwrap_or(DEFAULT_ATOL);
1660
1661 Tensor::new(K::lower_equal(
1662 K::abs(K::sub(self.primitive, other.primitive.clone())),
1663 K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
1664 ))
1665 }
1666
1667 /// Checks if all elements are close to another tensor.
1668 ///
1669 /// The tolerance is defined by the following equation:
1670 ///
1671 /// ```text
1672 ///
1673 /// abs(a - b) <= (atol + rtol * abs(b))
1674 ///
1675 /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1676 /// and `atol` is the absolute tolerance.
1677 ///
1678 /// ```
1679 ///
1680 /// # Arguments
1681 ///
1682 /// * `other` - The tensor to compare with.
1683 /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.
1684 /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.
1685 ///
1686 /// # Returns
1687 ///
1688 /// A boolean scalar.
1689 ///
1690 /// # Remarks
1691 ///
1692 /// # Example
1693 ///
1694 /// ```rust
1695 /// use burn_tensor::backend::Backend;
1696 /// use burn_tensor::{Tensor, Shape};
1697 ///
1698 /// fn example<B: Backend>() {
1699 /// let device = B::Device::default();
1700 /// let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1701 /// let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1702 /// let result = tensor1.all_close(tensor2, None, None);
1703 /// println!("{}", result);
1704 /// // true
1705 /// }
1706 /// ```
1707 pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
1708 self.is_close(other, rtol, atol)
1709 .all()
1710 .into_scalar()
1711 .to_bool()
1712 }
1713
1714 /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1715 ///
1716 /// # Returns
1717 ///
1718 /// A boolean tensor with the same shape as the input tensor.
1719 ///
1720 /// # Example
1721 ///
1722 /// ```rust
1723 /// use burn_tensor::backend::Backend;
1724 /// use burn_tensor::{Tensor, Shape};
1725 ///
1726 /// fn example<B: Backend>() {
1727 /// let device = B::Device::default();
1728 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1729 /// let tensor = tensor.bool();
1730 /// println!("{tensor}");
1731 /// // [
1732 /// // [true, true, true],
1733 /// // [false, true, true]
1734 /// // ]
1735 /// }
1736 pub fn bool(self) -> Tensor<B, D, Bool> {
1737 Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1738 }
1739
1740 /// Create a random tensor of the given shape on the given device where each element is
1741 /// sampled from the given distribution.
1742 ///
1743 /// See also [`random_like`](Tensor::random_like).
1744 ///
1745 /// # Arguments
1746 ///
1747 /// * `shape` - The shape of the tensor.
1748 /// * `distribution` - The distribution to sample from.
1749 /// * `device` - The device to create the tensor on.
1750 ///
1751 /// # Returns
1752 ///
1753 /// A new tensor with the given shape and elements sampled from the given distribution.
1754 ///
1755 /// # Example
1756 ///
1757 /// ```rust
1758 /// use burn_tensor::backend::Backend;
1759 /// use burn_tensor::{Tensor, Shape, Distribution};
1760 ///
1761 /// fn example<B: Backend>() {
1762 /// let device = B::Device::default();
1763 /// let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1764 /// let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1765 /// println!("{tensor}");
1766 /// // [
1767 /// // [0.08347523, 0.70498955, 0.60332155],
1768 /// // [0.08173251, 0.18028641, 0.97942924]
1769 /// // ]
1770 /// }
1771 /// ```
1772 pub fn random<S: Into<Shape>>(
1773 shape: S,
1774 distribution: Distribution,
1775 device: &B::Device,
1776 ) -> Self {
1777 Self::new(K::random(shape.into(), distribution, device))
1778 }
1779
1780 /// Sort the elements by value in ascending order along a given dimension.
1781 ///
1782 /// This sort is unstable (i.e., may reorder equal elements).
1783 ///
1784 /// # Arguments
1785 ///
1786 /// * `dim` - The dimension to sort along.
1787 ///
1788 /// # Returns
1789 ///
1790 /// A new tensor with the elements sorted in ascending order along the given dimension.
1791 ///
1792 /// # Example
1793 ///
1794 /// ```rust
1795 /// use burn_tensor::backend::Backend;
1796 /// use burn_tensor::{Tensor, Shape};
1797 ///
1798 /// fn example<B: Backend>() {
1799 /// let device = B::Device::default();
1800 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1801 /// let tensor = tensor.sort(0);
1802 /// println!("{tensor}");
1803 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1804 /// let tensor = tensor.sort(1);
1805 /// println!("{tensor}");
1806 /// // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1807 /// }
1808 /// ```
1809 pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
1810 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1811 Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1812 }
1813
1814 /// Sort the elements by value in descending order along a given dimension.
1815 ///
1816 /// This sort is unstable (i.e., may reorder equal elements).
1817 ///
1818 /// # Arguments
1819 ///
1820 /// * `dim` - The dimension to sort along.
1821 ///
1822 /// # Returns
1823 ///
1824 /// A new tensor with the elements sorted in descending order along the given dimension.
1825 ///
1826 /// # Example
1827 ///
1828 /// ```rust
1829 /// use burn_tensor::backend::Backend;
1830 /// use burn_tensor::{Tensor, Shape};
1831 ///
1832 /// fn example<B: Backend>() {
1833 /// let device = B::Device::default();
1834 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1835 /// let tensor = tensor.sort_descending(0);
1836 /// println!("{tensor}");
1837 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1838 /// let tensor = tensor.sort_descending(1);
1839 /// println!("{tensor}");
1840 /// // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1841 /// }
1842 /// ```
1843 pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
1844 check!(TensorCheck::sort_dim::<D>("Sort", dim));
1845 Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1846 }
1847
1848 /// Sort the elements by value in ascending order along a given dimension.
1849 /// Also returns the indices.
1850 ///
1851 /// This sort is unstable (i.e., may reorder equal elements).
1852 ///
1853 /// # Arguments
1854 ///
1855 /// * `dim` - The dimension to sort along.
1856 ///
1857 /// # Returns
1858 ///
1859 /// A tuple containing the sorted tensor and the indices tensor.
1860 ///
1861 /// # Example
1862 ///
1863 /// ```rust
1864 /// use burn_tensor::backend::Backend;
1865 /// use burn_tensor::{Tensor, Shape};
1866 ///
1867 /// fn example<B: Backend>() {
1868 /// let device = B::Device::default();
1869 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1870 /// let (tensor, indices) = tensor.sort_with_indices(0);
1871 /// println!("{tensor}");
1872 /// // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1873 /// println!("{}", indices);
1874 /// // [[1, 0, 0], [0, 1, 1]]
1875 /// }
1876 /// ```
1877 pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1878 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1879 let (values, indices) =
1880 K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1881 (Tensor::new(values), Tensor::new(indices))
1882 }
1883
1884 /// Sort the elements by value in descending order along a given dimension.
1885 /// Also returns the indices.
1886 ///
1887 /// This sort is unstable (i.e., may reorder equal elements).
1888 ///
1889 /// # Arguments
1890 ///
1891 /// * `dim` - The dimension to sort along.
1892 ///
1893 /// # Example
1894 ///
1895 /// ```rust
1896 /// use burn_tensor::backend::Backend;
1897 /// use burn_tensor::{Tensor, Shape};
1898 ///
1899 /// fn example<B: Backend>() {
1900 /// let device = B::Device::default();
1901 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1902 /// let (tensor, indices) = tensor.sort_descending_with_indices(0);
1903 /// println!("{tensor}");
1904 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1905 /// println!("{}", indices);
1906 /// // [[0, 1, 1], [1, 0, 0]]
1907 /// }
1908 /// ```
1909 pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1910 check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1911 let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1912 (Tensor::new(values), Tensor::new(indices))
1913 }
1914
1915 /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1916 ///
1917 /// This sort is unstable (i.e., may reorder equal elements).
1918 ///
1919 /// # Arguments
1920 ///
1921 /// * `dim` - The dimension to sort along.
1922 ///
1923 /// # Example
1924 ///
1925 /// ```rust
1926 /// use burn_tensor::backend::Backend;
1927 /// use burn_tensor::{Tensor, Shape};
1928 ///
1929 /// fn example<B: Backend>() {
1930 /// let device = B::Device::default();
1931 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1932 /// let tensor = tensor.argsort(0);
1933 /// println!("{tensor}");
1934 /// // [[1, 0, 0], [0, 1, 1]]
1935 /// }
1936 /// ```
1937 pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1938 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1939 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1940 }
1941
1942 /// Returns the indices that sort the elements by value in descending order along a given dimension.
1943 ///
1944 /// This sort is unstable (i.e., may reorder equal elements).
1945 ///
1946 /// # Arguments
1947 ///
1948 /// * `dim` - The dimension to sort along.
1949 ///
1950 /// # Example
1951 ///
1952 /// ```rust
1953 /// use burn_tensor::backend::Backend;
1954 /// use burn_tensor::{Tensor, Shape};
1955 ///
1956 /// fn example<B: Backend>() {
1957 /// let device = B::Device::default();
1958 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1959 /// let tensor = tensor.argsort_descending(0);
1960 /// println!("{tensor}");
1961 /// // [[0, 1, 1], [1, 0, 0]]
1962 /// let tensor = tensor.argsort_descending(1);
1963 /// println!("{tensor}");
1964 /// // [[0, 2, 1], [2, 0, 1]]
1965 /// }
1966 /// ```
1967 pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1968 check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1969 Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1970 }
1971
1972 /// Returns the `k` largest elements of the given input tensor along a given dimension.
1973 ///
1974 /// # Arguments
1975 ///
1976 /// * `k` - The number of elements to return.
1977 ///
1978 /// # Returns
1979 ///
1980 /// A new tensor with the `k` largest elements along the given dimension.
1981 ///
1982 /// # Example
1983 ///
1984 /// ```rust
1985 /// use burn_tensor::backend::Backend;
1986 /// use burn_tensor::{Tensor, Shape};
1987 ///
1988 /// fn example<B: Backend>() {
1989 /// let device = B::Device::default();
1990 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991 /// let tensor = tensor.topk(2, 0);
1992 /// println!("{tensor}");
1993 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1994 /// let tensor = tensor.topk(1, 1);
1995 /// println!("{tensor}");
1996 /// // [[12.0], [6.0]]
1997 /// }
1998 /// ```
1999 pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
2000 let k_indices = Tensor::arange(0..k as i64, &self.device());
2001 self.sort_descending(dim).select(dim, k_indices)
2002 }
2003
2004 /// Returns the `k` largest elements of the given input tensor along a given dimension.
2005 /// Also returns the indices.
2006 ///
2007 /// # Arguments
2008 ///
2009 /// * `k` - The number of elements to return.
2010 /// * `dim` - The dimension to sort along.
2011 ///
2012 /// # Example
2013 ///
2014 /// ```rust
2015 /// use burn_tensor::backend::Backend;
2016 /// use burn_tensor::{Tensor, Shape};
2017 ///
2018 /// fn example<B: Backend>() {
2019 /// let device = B::Device::default();
2020 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2021 /// let (tensor, indices) = tensor.topk_with_indices(2, 0);
2022 /// println!("{tensor}");
2023 /// // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
2024 /// println!("{}", indices);
2025 /// // [[0, 1, 1], [1, 0, 0]]
2026 /// let (tensor, indices) = tensor.topk_with_indices(1, 1);
2027 /// println!("{tensor}");
2028 /// // [[12.0], [6.0]]
2029 /// println!("{indices}");
2030 /// // [[0], [2]]
2031 /// }
2032 /// ```
2033 pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
2034 let k_indices = Tensor::arange(0..k as i64, &self.device());
2035 let (values, indices) = self.sort_descending_with_indices(dim);
2036 (
2037 values.select(dim, k_indices.clone()),
2038 indices.select(dim, k_indices),
2039 )
2040 }
2041
2042 /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
2043 ///
2044 /// # Arguments
2045 ///
2046 /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
2047 /// * `value` - The value to pad the tensor with.
2048 ///
2049 /// # Returns
2050 ///
2051 /// A new tensor with the given padding.
2052 ///
2053 /// # Example
2054 ///
2055 /// ```rust
2056 /// use burn_tensor::backend::Backend;
2057 /// use burn_tensor::{Tensor, Shape};
2058 ///
2059 /// fn example<B: Backend<FloatElem: From<f32>>>() {
2060 /// let device = B::Device::default();
2061 /// let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
2062 /// let tensor = tensor.pad((1, 1, 1, 1), 0.0);
2063 /// println!("{tensor}");
2064 /// // [
2065 /// // [0.0, 0.0, 0.0, 0.0, 0.0],
2066 /// // [0.0, 12.0, -2.0, 3.0, 0.0],
2067 /// // [0.0, 5.0, 3.0, 6.0, 0.0],
2068 /// // [0.0, 0.0, 0.0, 0.0, 0.0]
2069 /// // ]
2070 /// }
2071 /// ```
2072 pub fn pad<E: ElementConversion>(
2073 self,
2074 padding: (usize, usize, usize, usize),
2075 value: E,
2076 ) -> Tensor<B, D, K> {
2077 let (left, right, top, bottom) = padding;
2078
2079 let mut padded_dims: [usize; D] = self.dims();
2080
2081 // Update the last two dimensions with padding
2082 padded_dims[D - 2] += top + bottom;
2083 padded_dims[D - 1] += left + right;
2084
2085 // Create the ranges for the padded tensor
2086 let ranges: [core::ops::Range<usize>; D] = padded_dims
2087 .iter()
2088 .enumerate()
2089 .map(|(i, &dim)| {
2090 if i == D - 2 {
2091 top..dim - bottom
2092 } else if i == D - 1 {
2093 left..dim - right
2094 } else {
2095 0..dim
2096 }
2097 })
2098 .collect::<Vec<core::ops::Range<usize>>>()
2099 .try_into()
2100 .unwrap();
2101
2102 // Create the padded tensor
2103 let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2104
2105 // Assign the original tensor data to the appropriate slice of the padded tensor
2106 padded_tensor.slice_assign(ranges, self)
2107 }
2108 /// Create a one hot tensor.
2109 ///
2110 /// # Example
2111 ///
2112 /// ```rust
2113 /// use burn_tensor::backend::Backend;
2114 /// use burn_tensor::Tensor;
2115 ///
2116 /// fn example<B: Backend>(){
2117 /// let device = Default::default();
2118 /// let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);
2119 /// let one_hot: Tensor<B, 2> = indices.one_hot(4);
2120 /// println!("{}", one_hot.to_data());
2121 /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
2122 /// }
2123 /// ```
2124 pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {
2125 check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
2126 self.one_hot_fill(num_classes, 1.0, 0.0, -1)
2127 }
2128
2129 /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.
2130 ///
2131 /// # Arguments
2132 ///
2133 /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.
2134 /// * `on_value`: The value to assign for active positions (corresponding to indices).
2135 /// * `off_value`: The value to assign for inactive positions.
2136 /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.
2137 ///
2138 /// # Returns
2139 ///
2140 /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.
2141 ///
2142 /// # Example
2143 /// ```rust
2144 /// use burn_tensor::backend::Backend;
2145 /// use burn_tensor::{Tensor, Float};
2146 /// fn example<B: Backend<FloatElem: From<f32>>>() {
2147 /// let device = B::Device::default();
2148 /// let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);
2149 /// // One-hot encoding
2150 /// let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);
2151 /// println!("{tensor}");
2152 /// // [[[5.0, 0.0, 0.0],
2153 /// // [0.0, 0.0, 5.0]],
2154 /// // [[0.0, 5.0, 0.0],
2155 /// // [0.0, 0.0, 5.0]]]
2156 /// }
2157 /// ```
2158 pub fn one_hot_fill<const D2: usize>(
2159 self,
2160 num_classes: usize,
2161 on_value: f32,
2162 off_value: f32,
2163 axis: i64,
2164 ) -> Tensor<B, D2, K> {
2165 check!(TensorCheck::one_hot_tensor_rank::<D, D2>());
2166 // Initialize shape from the current tensor dimensions and prepare for modification
2167 let mut shape = self.shape().dims::<D>().to_vec();
2168 let device = self.device();
2169 let rank = self.dims().len();
2170
2171 // Adjust negative axis to a positive index
2172 let axis = if axis < 0 {
2173 axis + rank as i64 + 1
2174 } else {
2175 axis
2176 };
2177
2178 // Ensure axis is within valid range
2179 if axis < 0 || axis > rank as i64 {
2180 panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).");
2181 }
2182 // Convert the input tensor to integer indices
2183 let indices: Tensor<B, D, Int> =
2184 Tensor::from_data(self.to_data().convert::<i64>(), &device);
2185 // Insert the new dimension for the one-hot representation
2186 shape.insert(axis as usize, num_classes);
2187 // Adjust indices to valid range and handle invalid indices
2188 let adjusted_indices = indices
2189 .clone()
2190 .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices
2191 .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices
2192 // Unsqueeze the indices tensor along the specified axis
2193 let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);
2194
2195 // Initialize the output tensor with the off_value
2196 let output = Tensor::full(shape.clone(), off_value, &device);
2197
2198 // Prepare scatter tensor for on_value and off_value adjustments
2199 let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)
2200 - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());
2201
2202 // Scatter on_value at the appropriate indices to create the one-hot representation
2203 output.scatter(axis as usize, indices_unsqueezed, scatter_on_values)
2204 }
2205
2206 /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
2207 ///
2208 /// # Returns
2209 ///
2210 /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
2211 ///
2212 /// # Example
2213 ///
2214 /// ```rust
2215 /// use burn_tensor::backend::Backend;
2216 /// use burn_tensor::{Tensor, Bool, Shape};
2217 ///
2218 /// fn example<B: Backend>() {
2219 /// let device = B::Device::default();
2220 /// let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
2221 /// let tensor = tensor.is_nan();
2222 /// println!("{tensor}");
2223 /// // [[false, true, false], [false, false, false]]
2224 /// }
2225 /// ```
2226 pub fn is_nan(&self) -> Tensor<B, D, Bool> {
2227 // Check if the input tensor is NaN by comparing it to itself
2228 // NaN is the only value that is not equal to itself
2229 Tensor::new(K::not_equal(self.primitive.clone(), self.primitive.clone()))
2230 }
2231
2232 /// Checks if the tensor contains any NaN values.
2233 ///
2234 /// # Returns
2235 ///
2236 /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
2237 ///
2238 /// # Example
2239 ///
2240 /// ```rust
2241 /// use burn_tensor::backend::Backend;
2242 /// use burn_tensor::{Tensor, Bool, Shape};
2243 ///
2244 /// fn example<B: Backend>() {
2245 /// let device = B::Device::default();
2246 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
2247 /// let tensor = tensor.contains_nan();
2248 /// println!("{tensor}");
2249 /// // [true]
2250 /// let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2251 /// let tensor = tensor.contains_nan();
2252 /// println!("{tensor}");
2253 /// // [false]
2254 /// }
2255 /// ```
2256 pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
2257 // Summing the tensor will result in NaN if the tensor contains any NaN values
2258 // This is faster than checking each element individually
2259 // because it rolls up the NaN values into a single value
2260 let sum = K::sum(self.primitive.clone());
2261
2262 // Check if the sum is NaN by comparing it to itself
2263 Tensor::new(K::not_equal(sum.clone(), sum))
2264 }
2265}
2266
2267impl<B, K> Tensor<B, 2, K>
2268where
2269 B: Backend,
2270 K: Numeric<B>,
2271 K::Elem: Element,
2272{
2273 /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2274 ///
2275 /// # Arguments
2276 ///
2277 /// * `size` - The size of the square matrix.
2278 pub fn eye(size: usize, device: &B::Device) -> Self {
2279 let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2280 let ones = K::ones([1, size].into(), device);
2281 let zeros = K::zeros([size, size].into(), device);
2282
2283 Self::new(K::scatter(0, zeros, indices.primitive, ones))
2284 }
2285}
2286
2287/// Trait that list all operations that can be applied on all numerical tensors.
2288///
2289/// # Warnings
2290///
2291/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2292pub trait Numeric<B: Backend>: BasicOps<B>
2293where
2294 Self::Elem: Element,
2295{
2296 /// Adds two tensors together.
2297 ///
2298 /// # Arguments
2299 ///
2300 /// * `lhs` - The left hand side tensor.
2301 /// * `rhs` - The right hand side tensor.
2302 ///
2303 /// # Returns
2304 ///
2305 /// The sum of the two tensors.
2306 ///
2307 /// # Remarks
2308 ///
2309 /// This is a low-level function used internally by the library to call different backend functions
2310 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2311 /// or use this function directly.
2312 ///
2313 /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2314 /// which is more high-level and designed for public use.
2315 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2316
2317 /// Adds a scalar to a tensor element-wise.
2318 ///
2319 /// # Arguments
2320 ///
2321 /// * `lhs` - The left hand side tensor.
2322 /// * `rhs` - The right hand side scalar.
2323 ///
2324 /// # Returns
2325 ///
2326 /// The sum of the tensor and the scalar.
2327 ///
2328 /// # Remarks
2329 ///
2330 /// This is a low-level function used internally by the library to call different backend functions
2331 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2332 /// or use this function directly.
2333 ///
2334 /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2335 /// which is more high-level and designed for public use.
2336 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2337
2338 /// Subtracts two tensors.
2339 ///
2340 /// # Arguments
2341 ///
2342 /// * `lhs` - The left hand side tensor.
2343 /// * `rhs` - The right hand side tensor.
2344 ///
2345 /// # Returns
2346 ///
2347 /// The difference of the two tensors.
2348 ///
2349 /// # Remarks
2350 ///
2351 /// This is a low-level function used internally by the library to call different backend functions
2352 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2353 /// or use this function directly.
2354 ///
2355 /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2356 /// which is more high-level and designed for public use.
2357 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2358
2359 /// Subtracts a scalar from a tensor element-wise.
2360 ///
2361 /// # Arguments
2362 ///
2363 /// * `lhs` - The left hand side tensor.
2364 /// * `rhs` - The right hand side scalar.
2365 ///
2366 /// # Returns
2367 ///
2368 /// The difference of the tensor and the scalar.
2369 ///
2370 /// # Remarks
2371 ///
2372 /// This is a low-level function used internally by the library to call different backend functions
2373 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2374 /// or use this function directly.
2375 ///
2376 /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2377 /// which is more high-level and designed for public use.
2378 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2379
2380 /// Divides two tensors.
2381 ///
2382 /// # Arguments
2383 ///
2384 /// * `lhs` - The left hand side tensor.
2385 /// * `rhs` - The right hand side tensor.
2386 ///
2387 /// # Returns
2388 ///
2389 /// The quotient of the two tensors.
2390 ///
2391 /// # Remarks
2392 ///
2393 /// This is a low-level function used internally by the library to call different backend functions
2394 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2395 /// or use this function directly.
2396 ///
2397 /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2398 /// which is more high-level and designed for public use.
2399 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2400
2401 /// Divides a tensor by a scalar element-wise.
2402 ///
2403 /// # Arguments
2404 ///
2405 /// * `lhs` - The left hand side tensor.
2406 /// * `rhs` - The right hand side scalar.
2407 ///
2408 /// # Returns
2409 ///
2410 /// The quotient of the tensor and the scalar.
2411 ///
2412 /// # Remarks
2413 ///
2414 /// This is a low-level function used internally by the library to call different backend functions
2415 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2416 /// or use this function directly.
2417 ///
2418 /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2419 /// which is more high-level and designed for public use.
2420 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2421
2422 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2423 /// less than that of the divisor.
2424 ///
2425 /// # Arguments
2426 ///
2427 /// * `lhs` - The dividend.
2428 /// * `rhs` - The divisor.
2429 ///
2430 /// # Returns
2431 ///
2432 /// The modulo of the input tensor with the divisor.
2433 ///
2434 /// # Remarks
2435 ///
2436 /// This is a low-level function used internally by the library to call different backend functions
2437 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2438 /// or use this function directly.
2439 ///
2440 /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2441 /// which is more high-level and designed for public use.
2442 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2443
2444 /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2445 /// less than that of the divisor.
2446 ///
2447 /// # Arguments
2448 ///
2449 /// * `lhs` - The dividend.
2450 /// * `rhs` - The divisor.
2451 ///
2452 /// # Returns
2453 ///
2454 /// The modulo of the input tensor with the divisor.
2455 ///
2456 /// # Remarks
2457 ///
2458 /// This is a low-level function used internally by the library to call different backend functions
2459 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2460 /// or use this function directly.
2461 ///
2462 /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2463 /// which is more high-level and designed for public use.
2464 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2465
2466 /// Multiplies two tensors.
2467 ///
2468 /// # Arguments
2469 ///
2470 /// * `lhs` - The left hand side tensor.
2471 /// * `rhs` - The right hand side tensor.
2472 ///
2473 /// # Returns
2474 ///
2475 /// The product of the two tensors.
2476 ///
2477 /// # Remarks
2478 ///
2479 /// This is a low-level function used internally by the library to call different backend functions
2480 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2481 /// or use this function directly.
2482 ///
2483 /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2484 /// which is more high-level and designed for public use.
2485 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2486
2487 /// Multiplies a tensor by a scalar element-wise.
2488 ///
2489 /// # Arguments
2490 ///
2491 /// * `lhs` - The left hand side tensor.
2492 /// * `rhs` - The right hand side scalar.
2493 ///
2494 /// # Returns
2495 ///
2496 /// The product of the tensor and the scalar.
2497 ///
2498 /// # Remarks
2499 ///
2500 /// This is a low-level function used internally by the library to call different backend functions
2501 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2502 /// or use this function directly.
2503 ///
2504 /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2505 /// which is more high-level and designed for public use.
2506 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2507
2508 /// Negates a tensor.
2509 ///
2510 /// # Arguments
2511 ///
2512 /// * `tensor` - The tensor to negate.
2513 ///
2514 /// # Returns
2515 ///
2516 /// The negated tensor.
2517 ///
2518 /// # Remarks
2519 ///
2520 /// This is a low-level function used internally by the library to call different backend functions
2521 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2522 /// or use this function directly.
2523 ///
2524 /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2525 /// which is more high-level and designed for public use.
2526 fn neg(tensor: Self::Primitive) -> Self::Primitive;
2527
2528 /// Returns the signs of the elements of a tensor.
2529 ///
2530 /// # Arguments
2531 ///
2532 /// * `tensor` - The tensor.
2533 ///
2534 /// # Returns
2535 ///
2536 /// The signs of the elements of the tensor.
2537 ///
2538 /// # Remarks
2539 ///
2540 /// This is a low-level function used internally by the library to call different backend functions
2541 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2542 /// or use this function directly.
2543 ///
2544 /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2545 /// which is more high-level and designed for public use.
2546 fn sign(tensor: Self::Primitive) -> Self::Primitive;
2547
2548 /// Creates a tensor filled with zeros.
2549 ///
2550 /// # Arguments
2551 ///
2552 /// * `shape` - The shape of the tensor.
2553 /// * `device` - The device on which the tensor will be allocated.
2554 ///
2555 /// # Returns
2556 ///
2557 /// The tensor filled with zeros.
2558 ///
2559 /// # Remarks
2560 ///
2561 /// This is a low-level function used internally by the library to call different backend functions
2562 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2563 /// or use this function directly.
2564 ///
2565 /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2566 /// which is more high-level and designed for public use.
2567 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive;
2568
2569 /// Creates a tensor filled with ones.
2570 ///
2571 /// # Arguments
2572 ///
2573 /// * `shape` - The shape of the tensor.
2574 /// * `device` - The device on which the tensor will be allocated.
2575 ///
2576 /// # Returns
2577 ///
2578 /// The tensor filled with ones.
2579 ///
2580 /// # Remarks
2581 ///
2582 /// This is a low-level function used internally by the library to call different backend functions
2583 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2584 /// or use this function directly.
2585 ///
2586 /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2587 /// which is more high-level and designed for public use.
2588 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive;
2589
2590 /// Creates a tensor filled with elements equal to the given value.
2591 ///
2592 /// # Arguments
2593 ///
2594 /// * `shape` - The shape of the tensor.
2595 /// * `fill_value` - The value with which to fill the tensor
2596 /// * `device` - The device on which the tensor will be allocated.
2597 ///
2598 /// # Returns
2599 ///
2600 /// The tensor filled with elements equal to the given value
2601 ///
2602 /// # Remarks
2603 ///
2604 /// This is a low-level function used internally by the library to call different backend functions
2605 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2606 /// or use this function directly.
2607 ///
2608 /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function,
2609 /// which is more high-level and designed for public use.
2610 fn full<E: ElementConversion>(
2611 shape: Shape,
2612 fill_value: E,
2613 device: &B::Device,
2614 ) -> Self::Primitive;
2615
2616 /// Sums all the elements of the tensor.
2617 ///
2618 /// # Arguments
2619 ///
2620 /// * `tensor` - The tensor to sum.
2621 ///
2622 /// # Returns
2623 ///
2624 /// The sum of all the elements of the tensor.
2625 ///
2626 /// # Remarks
2627 ///
2628 /// This is a low-level function used internally by the library to call different backend functions
2629 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2630 /// or use this function directly.
2631 ///
2632 /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2633 /// which is more high-level and designed for public use.
2634 fn sum(tensor: Self::Primitive) -> Self::Primitive;
2635
2636 /// Sums all the elements of the tensor along a dimension.
2637 ///
2638 /// # Arguments
2639 ///
2640 /// * `tensor` - The tensor to sum.
2641 /// * `dim` - The dimension along which to sum.
2642 ///
2643 /// # Returns
2644 ///
2645 /// The sum of all the elements of the tensor along the specified dimension.
2646 ///
2647 /// # Remarks
2648 ///
2649 /// This is a low-level function used internally by the library to call different backend functions
2650 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2651 /// or use this function directly.
2652 ///
2653 /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2654 /// which is more high-level and designed for public use.
2655 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2656
2657 /// Computes the product of all the elements of the tensor.
2658 ///
2659 /// # Arguments
2660 ///
2661 /// * `tensor` - The tensor to compute the product of.
2662 ///
2663 /// # Returns
2664 ///
2665 /// The product of all the elements of the tensor.
2666 ///
2667 /// # Remarks
2668 ///
2669 /// This is a low-level function used internally by the library to call different backend functions
2670 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2671 /// or use this function directly.
2672 ///
2673 /// For computing the product of all the elements of a tensor, users should prefer the
2674 /// [Tensor::prod](Tensor::prod) function,
2675 /// which is more high-level and designed for public use.
2676 fn prod(tensor: Self::Primitive) -> Self::Primitive;
2677
2678 /// Computes the product of all the elements of the tensor along a dimension.
2679 ///
2680 /// # Arguments
2681 ///
2682 /// * `tensor` - The tensor to compute the product of.
2683 /// * `dim` - The dimension along which to compute the product.
2684 ///
2685 /// # Returns
2686 ///
2687 /// The product of all the elements of the tensor along the specified dimension.
2688 ///
2689 /// # Remarks
2690 ///
2691 /// This is a low-level function used internally by the library to call different backend functions
2692 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2693 /// or use this function directly.
2694 ///
2695 /// For computing the product of all the elements of a tensor along a dimension, users should
2696 /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2697 /// which is more high-level and designed for public use.
2698 ///
2699 ///
2700 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2701
2702 /// Computes the mean of all the elements of the tensor.
2703 ///
2704 /// # Arguments
2705 ///
2706 /// * `tensor` - The tensor to compute the mean of.
2707 ///
2708 /// # Returns
2709 ///
2710 /// The mean of all the elements of the tensor.
2711 ///
2712 /// # Remarks
2713 ///
2714 /// This is a low-level function used internally by the library to call different backend functions
2715 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2716 /// or use this function directly.
2717 ///
2718 /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2719 /// which is more high-level and designed for public use.
2720 fn mean(tensor: Self::Primitive) -> Self::Primitive;
2721
2722 /// Computes the mean of all the elements of the tensor along a dimension.
2723 ///
2724 /// # Arguments
2725 ///
2726 /// * `tensor` - The tensor to compute the mean of.
2727 /// * `dim` - The dimension along which to compute the mean.
2728 ///
2729 /// # Returns
2730 ///
2731 /// The mean of all the elements of the tensor along the specified dimension.
2732 ///
2733 /// # Remarks
2734 ///
2735 /// This is a low-level function used internally by the library to call different backend functions
2736 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2737 /// or use this function directly.
2738 ///
2739 /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2740 /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2741 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2742
2743 /// Element-wise equality between two tensors.
2744 ///
2745 /// # Arguments
2746 ///
2747 /// * `lhs` - The left hand side tensor.
2748 /// * `rhs` - The right hand side tensor.
2749 ///
2750 /// # Returns
2751 ///
2752 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2753 /// corresponding elements of the input tensors are equal, and false otherwise.
2754 ///
2755 /// # Remarks
2756 ///
2757 /// This is a low-level function used internally by the library to call different backend functions
2758 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2759 /// or use this function directly.
2760 ///
2761 /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
2762 /// function, which is more high-level and designed for public use.
2763 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2764
2765 /// Element-wise non-equality between two tensors.
2766 ///
2767 /// # Arguments
2768 ///
2769 /// * `lhs` - The left hand side tensor.
2770 /// * `rhs` - The right hand side tensor.
2771 ///
2772 /// # Returns
2773 ///
2774 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2775 /// corresponding elements of the input tensors are equal, and false otherwise.
2776 ///
2777 /// # Remarks
2778 ///
2779 /// This is a low-level function used internally by the library to call different backend functions
2780 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2781 /// or use this function directly.
2782 ///
2783 /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
2784 /// function, which is more high-level and designed for public use.
2785 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2786
2787 /// Element-wise greater than comparison between two tensors.
2788 ///
2789 /// # Arguments
2790 ///
2791 /// * `lhs` - The left hand side tensor.
2792 /// * `rhs` - The right hand side tensor.
2793 ///
2794 /// # Returns
2795 ///
2796 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2797 /// corresponding element of the left hand side tensor is greater than the corresponding element
2798 /// of the right hand side tensor, and false otherwise.
2799 ///
2800 /// # Remarks
2801 ///
2802 /// This is a low-level function used internally by the library to call different backend functions
2803 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2804 /// or use this function directly.
2805 ///
2806 /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
2807 /// which is more high-level and designed for public use.
2808 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2809
2810 /// Element-wise greater than comparison between a tensor and a scalar.
2811 ///
2812 /// # Arguments
2813 ///
2814 /// * `lhs` - The left hand side tensor.
2815 /// * `rhs` - The right hand side scalar.
2816 ///
2817 /// # Returns
2818 ///
2819 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2820 /// corresponding element of the left hand side tensor is greater than the right hand side
2821 /// scalar, and false otherwise.
2822 ///
2823 /// # Remarks
2824 ///
2825 /// This is a low-level function used internally by the library to call different backend functions
2826 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2827 /// or use this function directly.
2828 ///
2829 /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
2830 /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
2831 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2832
2833 /// Element-wise greater than or equal comparison between two tensors.
2834 ///
2835 /// # Arguments
2836 ///
2837 /// * `lhs` - The left hand side tensor.
2838 /// * `rhs` - The right hand side tensor.
2839 ///
2840 /// # Returns
2841 ///
2842 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2843 /// corresponding element of the left hand side tensor is greater than or equal to the
2844 /// corresponding element of the right hand side tensor, and false otherwise.
2845 ///
2846 /// # Remarks
2847 ///
2848 /// This is a low-level function used internally by the library to call different backend functions
2849 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2850 /// or use this function directly.
2851 ///
2852 /// For element-wise greater than or equal comparison between two tensors, users should prefer
2853 /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
2854 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2855
2856 /// Element-wise greater than or equal comparison between a tensor and a scalar.
2857 ///
2858 /// # Arguments
2859 ///
2860 /// * `lhs` - The left hand side tensor.
2861 /// * `rhs` - The right hand side scalar.
2862 ///
2863 /// # Returns
2864 ///
2865 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2866 /// corresponding element of the left hand side tensor is greater than or equal to the right
2867 /// hand side scalar, and false otherwise.
2868 ///
2869 /// # Remarks
2870 ///
2871 /// This is a low-level function used internally by the library to call different backend functions
2872 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2873 /// or use this function directly.
2874 ///
2875 /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
2876 /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
2877 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2878
2879 /// Element-wise less than comparison between two tensors.
2880 ///
2881 /// # Arguments
2882 ///
2883 /// * `lhs` - The left hand side tensor.
2884 /// * `rhs` - The right hand side tensor.
2885 ///
2886 /// # Returns
2887 ///
2888 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2889 /// corresponding element of the left hand side tensor is less than the corresponding element of
2890 /// the right hand side tensor, and false otherwise.
2891 ///
2892 /// # Remarks
2893 ///
2894 /// This is a low-level function used internally by the library to call different backend functions
2895 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2896 /// or use this function directly.
2897 ///
2898 /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
2899 /// which is more high-level and designed for public use.
2900 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2901
2902 /// Element-wise less than comparison between a tensor and a scalar.
2903 ///
2904 /// # Arguments
2905 ///
2906 /// * `lhs` - The left hand side tensor.
2907 /// * `rhs` - The right hand side scalar.
2908 ///
2909 /// # Returns
2910 ///
2911 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2912 /// corresponding element of the left hand side tensor is less than the right hand side scalar,
2913 /// and false otherwise.
2914 ///
2915 /// # Remarks
2916 ///
2917 /// This is a low-level function used internally by the library to call different backend functions
2918 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2919 /// or use this function directly.
2920 ///
2921 /// For element-wise less than comparison between a tensor and a scalar, users should prefer
2922 /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
2923 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2924
2925 /// Element-wise less than or equal comparison between two tensors.
2926 ///
2927 /// # Arguments
2928 ///
2929 /// * `lhs` - The left hand side tensor.
2930 /// * `rhs` - The right hand side tensor.
2931 ///
2932 /// # Returns
2933 ///
2934 /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2935 /// corresponding element of the left hand side tensor is less than or equal to the corresponding
2936 /// element of the right hand side tensor, and false otherwise.
2937 ///
2938 /// # Remarks
2939 ///
2940 /// This is a low-level function used internally by the library to call different backend functions
2941 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2942 /// or use this function directly.
2943 ///
2944 /// For element-wise less than or equal comparison between two tensors, users should prefer
2945 /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
2946 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2947
2948 /// Element-wise less than or equal comparison between a tensor and a scalar.
2949 ///
2950 /// # Arguments
2951 ///
2952 /// * `lhs` - The left hand side tensor.
2953 /// * `rhs` - The right hand side scalar.
2954 ///
2955 /// # Returns
2956 ///
2957 /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2958 /// corresponding element of the left hand side tensor is less than or equal to the right hand
2959 /// side scalar, and false otherwise.
2960 ///
2961 /// # Remarks
2962 ///
2963 /// This is a low-level function used internally by the library to call different backend functions
2964 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2965 /// or use this function directly.
2966 ///
2967 /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
2968 /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
2969 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2970
2971 /// Selects elements from a tensor based on a boolean mask.
2972 ///
2973 /// # Arguments
2974 ///
2975 /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
2976 /// * `mask` - The boolean mask to use for selecting elements.
2977 /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
2978 ///
2979 /// # Returns
2980 ///
2981 /// A tensor with the same shape as the input tensors, where each element is taken from the
2982 /// corresponding element of the left hand side tensor if the corresponding element of the mask
2983 /// is true, and from the corresponding element of the right hand side tensor otherwise.
2984 ///
2985 /// # Remarks
2986 ///
2987 /// This is a low-level function used internally by the library to call different backend functions
2988 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2989 /// or use this function directly.
2990 ///
2991 /// For selecting elements from a tensor based on a boolean mask, users should prefer the
2992 /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
2993 fn mask_where(
2994 tensor: Self::Primitive,
2995 mask: B::BoolTensorPrimitive,
2996 source: Self::Primitive,
2997 ) -> Self::Primitive;
2998
2999 /// Fills elements of a tensor based on a boolean mask.
3000 ///
3001 /// # Arguments
3002 ///
3003 /// * `tensor` - The tensor where will be overwritten with the value
3004 /// when the corresponding element of the mask is true.
3005 /// * `mask` - The boolean mask to use for filling elements.
3006 /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
3007 ///
3008 /// # Returns
3009 ///
3010 /// A tensor with the same shape as the input tensors, where each element is taken from the
3011 /// corresponding element unmodified if the corresponding element of the mask is false, and
3012 /// filled with the value otherwise.
3013 ///
3014 /// # Remarks
3015 ///
3016 /// This is a low-level function used internally by the library to call different backend functions
3017 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3018 /// or use this function directly.
3019 ///
3020 /// For filling elements of a tensor based on a boolean mask, users should prefer the
3021 /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
3022 fn mask_fill(
3023 tensor: Self::Primitive,
3024 mask: B::BoolTensorPrimitive,
3025 value: Self::Elem,
3026 ) -> Self::Primitive;
3027
3028 /// Gathers elements from a tensor along an axis.
3029 ///
3030 /// # Arguments
3031 ///
3032 /// * `dim` - The axis along which to gather elements.
3033 /// * `tensor` - The tensor to gather elements from.
3034 /// * `indices` - The indices of the elements to gather.
3035 ///
3036 /// # Returns
3037 ///
3038 /// A tensor with the same shape as the input tensor, where each element is taken from the
3039 /// corresponding element of the input tensor at the corresponding index along the specified axis.
3040 ///
3041 /// # Remarks
3042 ///
3043 /// This is a low-level function used internally by the library to call different backend functions
3044 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3045 /// or use this function directly.
3046 ///
3047 /// For gathering elements from a tensor along an axis, users should prefer the
3048 /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
3049 fn gather(
3050 dim: usize,
3051 tensor: Self::Primitive,
3052 indices: B::IntTensorPrimitive,
3053 ) -> Self::Primitive;
3054
3055 /// Scatters elements into a tensor along an axis.
3056 ///
3057 /// # Arguments
3058 ///
3059 /// * `dim` - The axis along which to scatter elements.
3060 /// * `tensor` - The tensor to scatter elements into.
3061 /// * `indices` - The indices of the elements to scatter.
3062 /// * `values` - The values to scatter into the tensor.
3063 ///
3064 /// # Returns
3065 ///
3066 /// A tensor with the same shape as the input tensor, where each element is taken from the
3067 /// corresponding element of the input tensor at the corresponding index along the specified axis,
3068 /// except for the elements at the specified indices, which are taken from the corresponding
3069 /// element of the values tensor.
3070 ///
3071 /// # Remarks
3072 ///
3073 /// This is a low-level function used internally by the library to call different backend functions
3074 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3075 /// or use this function directly.
3076 ///
3077 /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
3078 /// which is more high-level and designed for public use.
3079 fn scatter(
3080 dim: usize,
3081 tensor: Self::Primitive,
3082 indices: B::IntTensorPrimitive,
3083 values: Self::Primitive,
3084 ) -> Self::Primitive;
3085
3086 /// Select tensor elements along the given dimension corresponding for the given indices.
3087 ///
3088 /// # Arguments
3089 ///
3090 /// * `tensor` - The tensor to select elements from.
3091 /// * `dim` - The axis along which to select elements.
3092 /// * `indices` - The indices of the elements to select.
3093 ///
3094 /// # Returns
3095 ///
3096 /// A tensor with the same shape as the input tensor, where each element is taken from the
3097 /// corresponding element of the input tensor at the corresponding index along the specified axis.
3098 ///
3099 /// # Remarks
3100 ///
3101 /// This is a low-level function used internally by the library to call different backend functions
3102 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3103 /// or use this function directly.
3104 ///
3105 /// For selecting elements from a tensor along an axis, users should prefer the
3106 /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
3107 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
3108
3109 /// Assign the selected elements along the given dimension corresponding to the given indices
3110 /// from the value tensor.
3111 ///
3112 /// # Arguments
3113 ///
3114 /// * `tensor` - The tensor to assign elements to.
3115 /// * `dim` - The axis along which to assign elements.
3116 /// * `indices` - The indices of the elements to assign.
3117 /// * `values` - The values to assign to the tensor.
3118 ///
3119 /// # Returns
3120 ///
3121 /// A tensor with the same shape as the input tensor, where each element is taken from the
3122 /// corresponding element of the input tensor at the corresponding index along the specified axis,
3123 /// except for the elements at the specified indices, which are taken from the corresponding
3124 /// element of the values tensor.
3125 ///
3126 /// # Remarks
3127 ///
3128 /// This is a low-level function used internally by the library to call different backend functions
3129 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3130 /// or use this function directly.
3131 ///
3132 /// For assigning elements to a tensor along an axis, users should prefer the
3133 /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
3134 fn select_assign(
3135 tensor: Self::Primitive,
3136 dim: usize,
3137 indices: Tensor<B, 1, Int>,
3138 values: Self::Primitive,
3139 ) -> Self::Primitive;
3140
3141 /// Gets the indices of the maximum elements of a tensor along an axis.
3142 ///
3143 /// # Arguments
3144 ///
3145 /// * `dim` - The axis along which to get the indices of the maximum elements.
3146 /// * `tensor` - The tensor to get the indices of the maximum elements from.
3147 ///
3148 /// # Returns
3149 ///
3150 /// A tensor with the same shape as the input tensor, where each element is the index of the
3151 /// maximum element of the input tensor at the corresponding index along the specified axis.
3152 ///
3153 /// # Remarks
3154 ///
3155 /// This is a low-level function used internally by the library to call different backend functions
3156 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3157 /// or use this function directly.
3158 ///
3159 /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
3160 /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
3161 fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3162
3163 /// Gets the indices of the minimum elements of a tensor along an axis.
3164 ///
3165 /// # Arguments
3166 ///
3167 /// * `dim` - The axis along which to get the indices of the minimum elements.
3168 /// * `tensor` - The tensor to get the indices of the minimum elements from.
3169 ///
3170 /// # Returns
3171 ///
3172 /// A tensor with the same shape as the input tensor, where each element is the index of the
3173 /// minimum element of the input tensor at the corresponding index along the specified axis.
3174 ///
3175 /// # Remarks
3176 ///
3177 /// This is a low-level function used internally by the library to call different backend functions
3178 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3179 /// or use this function directly.
3180 ///
3181 /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3182 /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3183 fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3184
3185 /// Gets the maximum elements of a tensor along an axis.
3186 ///
3187 /// # Arguments
3188 ///
3189 /// * `dim` - The axis along which to get the maximum elements.
3190 ///
3191 /// # Returns
3192 ///
3193 /// A single-element tensor containing the maximum element of the input tensor.
3194 ///
3195 /// # Remarks
3196 ///
3197 /// This is a low-level function used internally by the library to call different backend functions
3198 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3199 /// or use this function directly.
3200 ///
3201 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3202 /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3203 fn max(tensor: Self::Primitive) -> Self::Primitive;
3204
3205 /// Gets the maximum elements of a tensor along an axis.
3206 ///
3207 /// # Arguments
3208 ///
3209 /// * `tensor` - The tensor to get the maximum elements from.
3210 /// * `dim` - The axis along which to get the maximum elements.
3211 ///
3212 /// # Returns
3213 ///
3214 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3215 /// Each element is the maximum element of the corresponding input dim.
3216 ///
3217 /// # Remarks
3218 ///
3219 /// This is a low-level function used internally by the library to call different backend functions
3220 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3221 /// or use this function directly.
3222 ///
3223 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3224 /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3225 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3226
3227 /// Gets the maximum elements of a tensor along an axis.
3228 ///
3229 /// # Arguments
3230 ///
3231 /// * `tensor` - The tensor to get the maximum elements from.
3232 /// * `dim` - The axis along which to get the maximum elements.
3233 ///
3234 /// # Returns
3235 ///
3236 /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3237 /// as the input tensor, where each element is the index of the maximum element of the input tensor
3238 /// at the corresponding index along the specified axis.
3239 ///
3240 /// # Remarks
3241 ///
3242 /// This is a low-level function used internally by the library to call different backend functions
3243 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3244 /// or use this function directly.
3245 ///
3246 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3247 /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3248 fn max_dim_with_indices(
3249 tensor: Self::Primitive,
3250 dim: usize,
3251 ) -> (Self::Primitive, B::IntTensorPrimitive);
3252
3253 /// Gets the maximum elements of a tensor along an axis.
3254 ///
3255 /// # Arguments
3256 ///
3257 /// * `dim` - The axis along which to get the maximum elements.
3258 ///
3259 /// # Returns
3260 ///
3261 /// A single-element tensor containing the maximum absolute element of the input tensor.
3262 ///
3263 /// # Remarks
3264 ///
3265 /// This is a low-level function used internally by the library to call different backend functions
3266 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3267 /// or use this function directly.
3268 ///
3269 /// For getting the maximum absolute elements of a tensor, users should prefer the
3270 /// [Tensor::max_abs](Tensor::max_abs) function, which is more high-level and designed for public use.
3271 fn max_abs(tensor: Self::Primitive) -> Self::Primitive;
3272
3273 /// Gets the maximum elements of a tensor along an axis.
3274 ///
3275 /// # Arguments
3276 ///
3277 /// * `tensor` - The tensor to get the maximum elements from.
3278 /// * `dim` - The axis along which to get the maximum elements.
3279 ///
3280 /// # Returns
3281 ///
3282 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3283 /// Each element is the maximum absolute element of the corresponding input dim.
3284 ///
3285 /// # Remarks
3286 ///
3287 /// This is a low-level function used internally by the library to call different backend functions
3288 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3289 /// or use this function directly.
3290 ///
3291 /// For getting the maximum elements of a tensor along an axis, users should prefer the
3292 /// [Tensor::max_abs_dim](Tensor::max_abs_dim) function, which is more high-level and designed for public use.
3293 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3294
3295 /// Gets the minimum elements of a tensor along an axis.
3296 ///
3297 /// # Arguments
3298 ///
3299 /// * `tensor` - The tensor to get the minimum elements from.
3300 ///
3301 /// # Returns
3302 ///
3303 /// A single-element tensor containing the minimum element of the input tensor.
3304 ///
3305 /// # Remarks
3306 ///
3307 /// This is a low-level function used internally by the library to call different backend functions
3308 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3309 /// or use this function directly.
3310 ///
3311 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3312 /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3313 fn min(tensor: Self::Primitive) -> Self::Primitive;
3314
3315 /// Gets the minimum elements of a tensor along an axis.
3316 ///
3317 /// # Arguments
3318 ///
3319 /// * `tensor` - The tensor to get the minimum elements from.
3320 /// * `dim` - The axis along which to get the minimum elements.
3321 ///
3322 /// # Returns
3323 ///
3324 /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.
3325 /// Each element is the minimum element of the corresponding input dim.
3326 ///
3327 /// # Remarks
3328 ///
3329 /// This is a low-level function used internally by the library to call different backend functions
3330 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3331 /// or use this function directly.
3332 ///
3333 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3334 /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3335 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3336
3337 /// Gets the minimum elements and indices of a tensor along an axis.
3338 ///
3339 /// # Arguments
3340 ///
3341 /// * `tensor` - The tensor to get the minimum elements from.
3342 ///
3343 /// # Returns
3344 ///
3345 /// A tensor with the same shape as the input tensor and corresponding indices, where
3346 /// each element is the minimum element of the input tensor at the corresponding index
3347 /// along the specified axis.
3348 ///
3349 /// # Remarks
3350 ///
3351 /// This is a low-level function used internally by the library to call different backend functions
3352 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3353 /// or use this function directly.
3354 ///
3355 /// For getting the minimum elements of a tensor along an axis, users should prefer the
3356 /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3357 fn min_dim_with_indices(
3358 tensor: Self::Primitive,
3359 dim: usize,
3360 ) -> (Self::Primitive, B::IntTensorPrimitive);
3361
3362 /// Clamp the tensor between the given min and max values.
3363 ///
3364 /// # Arguments
3365 ///
3366 /// * `min` - The minimum value.
3367 /// * `max` - The maximum value.
3368 ///
3369 /// # Returns
3370 ///
3371 /// A new tensor with the values clamped between the given min and max values.
3372 ///
3373 /// # Remarks
3374 ///
3375 /// This is a low-level function used internally by the library to call different backend functions
3376 /// with static dispatch. It is not designed for direct usage by users.
3377 ///
3378 /// For clamping a tensor between the given min and max values, users should prefer the
3379 /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3380 fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3381
3382 /// Clamps a tensor under a minimum value.
3383 ///
3384 /// # Arguments
3385 ///
3386 /// * `tensor` - The tensor to clamp.
3387 /// * `min` - The minimum value.
3388 ///
3389 /// # Returns
3390 ///
3391 /// A new tensor with the values clamped under the given min value.
3392 ///
3393 /// # Remarks
3394 ///
3395 /// This is a low-level function used internally by the library to call different backend functions
3396 /// with static dispatch. It is not designed for direct usage by users.
3397 ///
3398 /// For clamping a tensor under a minimum value, users should prefer the
3399 /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3400 fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3401
3402 /// Clamps a tensor over a maximum value.
3403 ///
3404 /// # Arguments
3405 ///
3406 /// * `tensor` - The tensor to clamp.
3407 /// * `max` - The maximum value.
3408 ///
3409 /// # Returns
3410 ///
3411 /// A new tensor with the values clamped over the given max value.
3412 ///
3413 /// # Remarks
3414 ///
3415 /// This is a low-level function used internally by the library to call different backend functions
3416 /// with static dispatch. It is not designed for direct usage by users.
3417 ///
3418 /// For clamping a tensor over a maximum value, users should prefer the
3419 /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3420 fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3421
3422 /// Calculate absolute value on all elements of a tensor
3423 ///
3424 /// # Arguments
3425 ///
3426 /// * `tensor` - The tensor to apply abs to.
3427 ///
3428 /// # Returns
3429 ///
3430 /// A tensor with absolute values.
3431 ///
3432 /// # Remarks
3433 ///
3434 /// This is a low-level function used internally by the library to call different backend functions
3435 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3436 /// or use this function directly.
3437 ///
3438 /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3439 /// which is more high-level and designed for public use.
3440 fn abs(tensor: Self::Primitive) -> Self::Primitive;
3441
3442 /// Element-wise power of a tensor to a float tensor
3443 ///
3444 /// # Arguments
3445 /// * `tensor` - The tensor to apply power to.
3446 /// * `power` - The power to apply to the tensor.
3447 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3448
3449 /// Element-wise power of a tensor
3450 ///
3451 /// # Arguments
3452 /// * `tensor` - The tensor to apply power to.
3453 /// * `power` - The power to apply to the tensor.
3454 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3455
3456 /// Element-wise power of a tensor to a scalar float
3457 ///
3458 /// # Arguments
3459 /// * `tensor` - The tensor to apply power to.
3460 /// * `power` - The power to apply to the tensor.
3461 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3462
3463 /// Element-wise power of a tensor to a scalar int
3464 ///
3465 /// # Arguments
3466 /// * `tensor` - The tensor to apply power to.
3467 /// * `power` - The power to apply to the tensor.
3468 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3469
3470 /// Create a random tensor.
3471 ///
3472 /// # Arguments
3473 ///
3474 /// * `shape` - The shape of the output tensor.
3475 /// * `distribution` - The distribution used to sample.
3476 /// * `device` - The device to use.
3477 ///
3478 /// # Returns
3479 ///
3480 /// A new tensor.
3481 ///
3482 /// # Remarks
3483 ///
3484 /// This is a low-level function used internally by the library to call different backend functions
3485 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3486 /// or use this function directly.
3487 ///
3488 /// Users should prefer the [Tensor::random](Tensor::random) function,
3489 /// which is more high-level and designed for public use.
3490 fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3491
3492 /// Sort the elements of the input `tensor` by value along a given dimension.
3493 ///
3494 /// This sort is unstable (i.e., may reorder equal elements).
3495 ///
3496 /// # Arguments
3497 ///
3498 /// * `tensor` - The input tensor.
3499 /// * `dim` - The axis along which to sort.
3500 /// * `descending` - The sorting order.
3501 ///
3502 /// # Returns
3503 ///
3504 /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3505 ///
3506 /// # Remarks
3507 /// This is a low-level function used internally by the library to call different backend functions
3508 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3509 /// or use this function directly.
3510 ///
3511 /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3512 /// which is more high-level and designed for public use.
3513 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3514
3515 /// Sort the elements of the input `tensor` by value along a given dimension.
3516 ///
3517 /// This sort is unstable (i.e., may reorder equal elements).
3518 ///
3519 /// # Arguments
3520 ///
3521 /// * `tensor` - The input tensor.
3522 /// * `dim` - The axis along which to sort.
3523 /// * `descending` - The sorting order.
3524 ///
3525 /// # Returns
3526 ///
3527 /// A tensor with the same shape as the input tensor and corresponding indices, where
3528 /// the elements are sorted by value and the indices map back to the original input tensor.
3529 ///
3530 /// # Remarks
3531 /// This is a low-level function used internally by the library to call different backend functions
3532 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3533 /// or use this function directly.
3534 ///
3535 /// For sorting the elements of a tensor, users should prefer the
3536 /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3537 /// and designed for public use.
3538 fn sort_with_indices(
3539 tensor: Self::Primitive,
3540 dim: usize,
3541 descending: bool,
3542 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3543
3544 /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3545 ///
3546 /// This sort is unstable (i.e., may reorder equal elements).
3547 ///
3548 /// # Arguments
3549 ///
3550 /// * `tensor` - The input tensor.
3551 /// * `dim` - The axis along which to sort.
3552 /// * `descending` - The sorting order.
3553 ///
3554 /// # Returns
3555 ///
3556 /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3557 ///
3558 /// # Remarks
3559 /// This is a low-level function used internally by the library to call different backend functions
3560 /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3561 /// or use this function directly.
3562 ///
3563 /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3564 /// which is more high-level and designed for public use.
3565 fn argsort(
3566 tensor: Self::Primitive,
3567 dim: usize,
3568 descending: bool,
3569 ) -> <Int as TensorKind<B>>::Primitive;
3570}
3571
3572impl<B: Backend> Numeric<B> for Int {
3573 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3574 B::int_add(lhs, rhs)
3575 }
3576 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3577 B::int_add_scalar(lhs, rhs.elem())
3578 }
3579 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3580 B::int_sub(lhs, rhs)
3581 }
3582 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3583 B::int_sub_scalar(lhs, rhs.elem())
3584 }
3585 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3586 B::int_div(lhs, rhs)
3587 }
3588 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3589 B::int_div_scalar(lhs, rhs.elem())
3590 }
3591 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3592 B::int_remainder(lhs, rhs)
3593 }
3594 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3595 B::int_remainder_scalar(lhs, rhs.elem())
3596 }
3597 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3598 B::int_mul(lhs, rhs)
3599 }
3600 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3601 B::int_mul_scalar(lhs, rhs.elem())
3602 }
3603 fn neg(tensor: Self::Primitive) -> Self::Primitive {
3604 B::int_neg(tensor)
3605 }
3606 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3607 B::int_zeros(shape, device)
3608 }
3609 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3610 B::int_ones(shape, device)
3611 }
3612 fn full<E: ElementConversion>(
3613 shape: Shape,
3614 fill_value: E,
3615 device: &B::Device,
3616 ) -> Self::Primitive {
3617 B::int_full(shape, fill_value.elem(), device)
3618 }
3619
3620 fn sum(tensor: Self::Primitive) -> Self::Primitive {
3621 B::int_sum(tensor)
3622 }
3623
3624 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3625 B::int_sum_dim(tensor, dim)
3626 }
3627
3628 fn prod(tensor: Self::Primitive) -> Self::Primitive {
3629 B::int_prod(tensor)
3630 }
3631
3632 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3633 B::int_prod_dim(tensor, dim)
3634 }
3635
3636 fn mean(tensor: Self::Primitive) -> Self::Primitive {
3637 B::int_mean(tensor)
3638 }
3639 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3640 B::int_mean_dim(tensor, dim)
3641 }
3642
3643 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3644 B::int_equal_elem(lhs, rhs)
3645 }
3646 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3647 B::int_not_equal_elem(lhs, rhs)
3648 }
3649 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3650 B::int_greater(lhs, rhs)
3651 }
3652
3653 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3654 B::int_greater_elem(lhs, rhs)
3655 }
3656
3657 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3658 B::int_greater_equal(lhs, rhs)
3659 }
3660
3661 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3662 B::int_greater_equal_elem(lhs, rhs)
3663 }
3664
3665 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3666 B::int_lower(lhs, rhs)
3667 }
3668
3669 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3670 B::int_lower_elem(lhs, rhs)
3671 }
3672
3673 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3674 B::int_lower_equal(lhs, rhs)
3675 }
3676
3677 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3678 B::int_lower_equal_elem(lhs, rhs)
3679 }
3680
3681 fn mask_where(
3682 tensor: Self::Primitive,
3683 mask: B::BoolTensorPrimitive,
3684 source: Self::Primitive,
3685 ) -> Self::Primitive {
3686 B::int_mask_where(tensor, mask, source)
3687 }
3688
3689 fn mask_fill(
3690 tensor: Self::Primitive,
3691 mask: B::BoolTensorPrimitive,
3692 value: Self::Elem,
3693 ) -> Self::Primitive {
3694 B::int_mask_fill(tensor, mask, value)
3695 }
3696
3697 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3698 B::int_select(tensor, dim, indices.primitive)
3699 }
3700
3701 fn select_assign(
3702 tensor: Self::Primitive,
3703 dim: usize,
3704 indices: Tensor<B, 1, Int>,
3705 values: Self::Primitive,
3706 ) -> Self::Primitive {
3707 B::int_select_assign(tensor, dim, indices.primitive, values)
3708 }
3709 fn gather(
3710 dim: usize,
3711 tensor: Self::Primitive,
3712 indices: B::IntTensorPrimitive,
3713 ) -> Self::Primitive {
3714 B::int_gather(dim, tensor, indices)
3715 }
3716
3717 fn scatter(
3718 dim: usize,
3719 tensor: Self::Primitive,
3720 indices: B::IntTensorPrimitive,
3721 values: Self::Primitive,
3722 ) -> Self::Primitive {
3723 B::int_scatter(dim, tensor, indices, values)
3724 }
3725
3726 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3727 B::int_argmax(tensor, dim)
3728 }
3729
3730 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3731 B::int_argmin(tensor, dim)
3732 }
3733
3734 fn max(tensor: Self::Primitive) -> Self::Primitive {
3735 B::int_max(tensor)
3736 }
3737
3738 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3739 B::int_max_dim(tensor, dim)
3740 }
3741
3742 fn max_dim_with_indices(
3743 tensor: Self::Primitive,
3744 dim: usize,
3745 ) -> (Self::Primitive, IntTensor<B>) {
3746 B::int_max_dim_with_indices(tensor, dim)
3747 }
3748
3749 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
3750 B::int_max_abs(tensor)
3751 }
3752
3753 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3754 B::int_max_abs_dim(tensor, dim)
3755 }
3756
3757 fn min(tensor: Self::Primitive) -> Self::Primitive {
3758 B::int_min(tensor)
3759 }
3760
3761 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3762 B::int_min_dim(tensor, dim)
3763 }
3764
3765 fn min_dim_with_indices(
3766 tensor: Self::Primitive,
3767 dim: usize,
3768 ) -> (Self::Primitive, IntTensor<B>) {
3769 B::int_min_dim_with_indices(tensor, dim)
3770 }
3771
3772 fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3773 B::int_clamp(tensor, min, max)
3774 }
3775
3776 fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3777 B::int_clamp_min(tensor, min)
3778 }
3779
3780 fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3781 B::int_clamp_max(tensor, max)
3782 }
3783
3784 fn abs(tensor: Self::Primitive) -> Self::Primitive {
3785 B::int_abs(tensor)
3786 }
3787
3788 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3789 B::int_powf(lhs, B::int_into_float(rhs))
3790 }
3791
3792 fn powf_scalar<E: ElementConversion>(
3793 lhs: Self::Primitive,
3794 rhs: E,
3795 ) -> <Int as TensorKind<B>>::Primitive {
3796 B::int_powf_scalar(lhs, rhs.elem())
3797 }
3798
3799 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3800 B::int_powi(lhs, rhs)
3801 }
3802
3803 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3804 B::int_powf_scalar(lhs, rhs.elem())
3805 }
3806
3807 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3808 B::int_random(shape, distribution, device)
3809 }
3810
3811 fn sign(tensor: Self::Primitive) -> Self::Primitive {
3812 B::int_sign(tensor)
3813 }
3814
3815 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
3816 B::int_sort(tensor, dim, descending)
3817 }
3818
3819 fn sort_with_indices(
3820 tensor: Self::Primitive,
3821 dim: usize,
3822 descending: bool,
3823 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
3824 B::int_sort_with_indices(tensor, dim, descending)
3825 }
3826
3827 fn argsort(
3828 tensor: Self::Primitive,
3829 dim: usize,
3830 descending: bool,
3831 ) -> <Int as TensorKind<B>>::Primitive {
3832 B::int_argsort(tensor, dim, descending)
3833 }
3834}
3835
3836impl<B: Backend> Numeric<B> for Float {
3837 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3838 match (lhs, rhs) {
3839 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3840 TensorPrimitive::Float(B::float_add(lhs, rhs))
3841 }
3842 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3843 TensorPrimitive::QFloat(B::q_add(lhs, rhs))
3844 }
3845 _ => panic!("Primitive type mismatch for lhs and rhs"),
3846 }
3847 }
3848 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3849 match lhs {
3850 TensorPrimitive::Float(lhs) => {
3851 TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
3852 }
3853 TensorPrimitive::QFloat(lhs) => {
3854 TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem()))
3855 }
3856 }
3857 }
3858 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3859 match (lhs, rhs) {
3860 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3861 TensorPrimitive::Float(B::float_sub(lhs, rhs))
3862 }
3863 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3864 TensorPrimitive::QFloat(B::q_sub(lhs, rhs))
3865 }
3866 _ => panic!("Primitive type mismatch for lhs and rhs"),
3867 }
3868 }
3869 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3870 match lhs {
3871 TensorPrimitive::Float(lhs) => {
3872 TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
3873 }
3874 TensorPrimitive::QFloat(lhs) => {
3875 TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem()))
3876 }
3877 }
3878 }
3879 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3880 match (lhs, rhs) {
3881 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3882 TensorPrimitive::Float(B::float_div(lhs, rhs))
3883 }
3884 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3885 TensorPrimitive::QFloat(B::q_div(lhs, rhs))
3886 }
3887 _ => panic!("Primitive type mismatch for lhs and rhs"),
3888 }
3889 }
3890 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3891 match lhs {
3892 TensorPrimitive::Float(lhs) => {
3893 TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
3894 }
3895 TensorPrimitive::QFloat(lhs) => {
3896 TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem()))
3897 }
3898 }
3899 }
3900 fn remainder(
3901 lhs: Self::Primitive,
3902 rhs: Self::Primitive,
3903 ) -> <Float as TensorKind<B>>::Primitive {
3904 match (lhs, rhs) {
3905 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3906 TensorPrimitive::Float(B::float_remainder(lhs, rhs))
3907 }
3908 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3909 TensorPrimitive::QFloat(B::q_remainder(lhs, rhs))
3910 }
3911 _ => panic!("Primitive type mismatch for lhs and rhs"),
3912 }
3913 }
3914 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3915 match lhs {
3916 TensorPrimitive::Float(lhs) => {
3917 TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem()))
3918 }
3919 TensorPrimitive::QFloat(lhs) => {
3920 TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem()))
3921 }
3922 }
3923 }
3924 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3925 match (lhs, rhs) {
3926 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3927 TensorPrimitive::Float(B::float_mul(lhs, rhs))
3928 }
3929 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3930 TensorPrimitive::QFloat(B::q_mul(lhs, rhs))
3931 }
3932 _ => panic!("Primitive type mismatch for lhs and rhs"),
3933 }
3934 }
3935 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3936 match lhs {
3937 TensorPrimitive::Float(lhs) => {
3938 TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
3939 }
3940 TensorPrimitive::QFloat(lhs) => {
3941 TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem()))
3942 }
3943 }
3944 }
3945 fn neg(tensor: Self::Primitive) -> Self::Primitive {
3946 match tensor {
3947 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
3948 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)),
3949 }
3950 }
3951 fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3952 TensorPrimitive::Float(B::float_zeros(shape, device))
3953 }
3954 fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3955 TensorPrimitive::Float(B::float_ones(shape, device))
3956 }
3957
3958 fn full<E: ElementConversion>(
3959 shape: Shape,
3960 fill_value: E,
3961 device: &B::Device,
3962 ) -> Self::Primitive {
3963 TensorPrimitive::Float(B::float_full(shape, fill_value.elem(), device))
3964 }
3965
3966 fn sum(tensor: Self::Primitive) -> Self::Primitive {
3967 match tensor {
3968 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
3969 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)),
3970 }
3971 }
3972
3973 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3974 match tensor {
3975 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
3976 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)),
3977 }
3978 }
3979
3980 fn prod(tensor: Self::Primitive) -> Self::Primitive {
3981 match tensor {
3982 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
3983 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)),
3984 }
3985 }
3986
3987 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3988 match tensor {
3989 TensorPrimitive::Float(tensor) => {
3990 TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
3991 }
3992 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)),
3993 }
3994 }
3995
3996 fn mean(tensor: Self::Primitive) -> Self::Primitive {
3997 match tensor {
3998 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
3999 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)),
4000 }
4001 }
4002
4003 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4004 match tensor {
4005 TensorPrimitive::Float(tensor) => {
4006 TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
4007 }
4008 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)),
4009 }
4010 }
4011
4012 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4013 B::float_equal_elem(lhs.tensor(), rhs)
4014 }
4015 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4016 B::float_not_equal_elem(lhs.tensor(), rhs)
4017 }
4018 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4019 B::float_greater(lhs.tensor(), rhs.tensor())
4020 }
4021
4022 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4023 B::float_greater_elem(lhs.tensor(), rhs)
4024 }
4025
4026 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4027 B::float_greater_equal(lhs.tensor(), rhs.tensor())
4028 }
4029
4030 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4031 B::float_greater_equal_elem(lhs.tensor(), rhs)
4032 }
4033
4034 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4035 B::float_lower(lhs.tensor(), rhs.tensor())
4036 }
4037
4038 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4039 B::float_lower_elem(lhs.tensor(), rhs)
4040 }
4041
4042 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
4043 B::float_lower_equal(lhs.tensor(), rhs.tensor())
4044 }
4045
4046 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
4047 B::float_lower_equal_elem(lhs.tensor(), rhs)
4048 }
4049
4050 fn mask_where(
4051 tensor: Self::Primitive,
4052 mask: B::BoolTensorPrimitive,
4053 source: Self::Primitive,
4054 ) -> Self::Primitive {
4055 match (tensor, source) {
4056 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => {
4057 TensorPrimitive::Float(B::float_mask_where(tensor, mask, source))
4058 }
4059 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => {
4060 TensorPrimitive::QFloat(B::q_mask_where(tensor, mask, source))
4061 }
4062 _ => panic!("Primitive type mismatch for tensor and source"),
4063 }
4064 }
4065
4066 fn mask_fill(
4067 tensor: Self::Primitive,
4068 mask: B::BoolTensorPrimitive,
4069 value: Self::Elem,
4070 ) -> Self::Primitive {
4071 match tensor {
4072 TensorPrimitive::Float(tensor) => {
4073 TensorPrimitive::Float(B::float_mask_fill(tensor, mask, value))
4074 }
4075 TensorPrimitive::QFloat(tensor) => {
4076 TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask, value))
4077 }
4078 }
4079 }
4080
4081 fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
4082 match tensor {
4083 TensorPrimitive::Float(tensor) => {
4084 TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
4085 }
4086 TensorPrimitive::QFloat(tensor) => {
4087 TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
4088 }
4089 }
4090 }
4091
4092 fn select_assign(
4093 tensor: Self::Primitive,
4094 dim: usize,
4095 indices: Tensor<B, 1, Int>,
4096 values: Self::Primitive,
4097 ) -> Self::Primitive {
4098 match (tensor, values) {
4099 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
4100 TensorPrimitive::Float(B::float_select_assign(
4101 tensor,
4102 dim,
4103 indices.primitive,
4104 values,
4105 ))
4106 }
4107 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
4108 TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values))
4109 }
4110 _ => panic!("Primitive type mismatch for tensor and values"),
4111 }
4112 }
4113
4114 fn gather(
4115 dim: usize,
4116 tensor: Self::Primitive,
4117 indices: B::IntTensorPrimitive,
4118 ) -> Self::Primitive {
4119 match tensor {
4120 TensorPrimitive::Float(tensor) => {
4121 TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
4122 }
4123 TensorPrimitive::QFloat(tensor) => {
4124 TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
4125 }
4126 }
4127 }
4128
4129 fn scatter(
4130 dim: usize,
4131 tensor: Self::Primitive,
4132 indices: B::IntTensorPrimitive,
4133 values: Self::Primitive,
4134 ) -> Self::Primitive {
4135 match (tensor, values) {
4136 (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
4137 TensorPrimitive::Float(B::float_scatter(dim, tensor, indices, values))
4138 }
4139 (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
4140 TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices, values))
4141 }
4142 _ => panic!("Primitive type mismatch for tensor and values"),
4143 }
4144 }
4145
4146 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4147 match tensor {
4148 TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
4149 TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
4150 }
4151 }
4152
4153 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
4154 match tensor {
4155 TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
4156 TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
4157 }
4158 }
4159
4160 fn max(tensor: Self::Primitive) -> Self::Primitive {
4161 match tensor {
4162 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
4163 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
4164 }
4165 }
4166
4167 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4168 match tensor {
4169 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
4170 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
4171 }
4172 }
4173
4174 fn max_dim_with_indices(
4175 tensor: Self::Primitive,
4176 dim: usize,
4177 ) -> (Self::Primitive, IntTensor<B>) {
4178 match tensor {
4179 TensorPrimitive::Float(tensor) => {
4180 let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
4181 (TensorPrimitive::Float(values), indices)
4182 }
4183 TensorPrimitive::QFloat(tensor) => {
4184 let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
4185 (TensorPrimitive::QFloat(values), indices)
4186 }
4187 }
4188 }
4189
4190 fn min(tensor: Self::Primitive) -> Self::Primitive {
4191 match tensor {
4192 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
4193 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
4194 }
4195 }
4196
4197 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4198 match tensor {
4199 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
4200 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
4201 }
4202 }
4203
4204 fn min_dim_with_indices(
4205 tensor: Self::Primitive,
4206 dim: usize,
4207 ) -> (Self::Primitive, IntTensor<B>) {
4208 match tensor {
4209 TensorPrimitive::Float(tensor) => {
4210 let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
4211 (TensorPrimitive::Float(values), indices)
4212 }
4213 TensorPrimitive::QFloat(tensor) => {
4214 let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
4215 (TensorPrimitive::QFloat(values), indices)
4216 }
4217 }
4218 }
4219
4220 fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4221 match tensor {
4222 TensorPrimitive::Float(tensor) => {
4223 TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4224 }
4225 TensorPrimitive::QFloat(tensor) => {
4226 TensorPrimitive::QFloat(B::q_clamp(tensor, min, max))
4227 }
4228 }
4229 }
4230
4231 fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4232 match tensor {
4233 TensorPrimitive::Float(tensor) => {
4234 TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4235 }
4236 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)),
4237 }
4238 }
4239
4240 fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4241 match tensor {
4242 TensorPrimitive::Float(tensor) => {
4243 TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4244 }
4245 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)),
4246 }
4247 }
4248
4249 fn abs(tensor: Self::Primitive) -> Self::Primitive {
4250 match tensor {
4251 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4252 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4253 }
4254 }
4255
4256 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4257 match (lhs, rhs) {
4258 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4259 TensorPrimitive::Float(B::float_powf(lhs, rhs))
4260 }
4261 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4262 TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4263 }
4264 _ => panic!("Primitive type mismatch for lhs and rhs"),
4265 }
4266 }
4267
4268 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4269 match lhs {
4270 TensorPrimitive::Float(lhs) => {
4271 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4272 }
4273 TensorPrimitive::QFloat(lhs) => {
4274 TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4275 }
4276 }
4277 }
4278
4279 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4280 match (lhs, rhs) {
4281 (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4282 TensorPrimitive::Float(B::float_powf(lhs, rhs))
4283 }
4284 (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4285 TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4286 }
4287 _ => panic!("Primitive type mismatch for lhs and rhs"),
4288 }
4289 }
4290
4291 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4292 match lhs {
4293 TensorPrimitive::Float(lhs) => {
4294 TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4295 }
4296 TensorPrimitive::QFloat(lhs) => {
4297 TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4298 }
4299 }
4300 }
4301
4302 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4303 TensorPrimitive::Float(B::float_random(shape, distribution, device))
4304 }
4305
4306 fn sign(tensor: Self::Primitive) -> Self::Primitive {
4307 TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4308 }
4309
4310 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4311 match tensor {
4312 TensorPrimitive::Float(tensor) => {
4313 TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4314 }
4315 TensorPrimitive::QFloat(tensor) => {
4316 TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4317 }
4318 }
4319 }
4320
4321 fn sort_with_indices(
4322 tensor: Self::Primitive,
4323 dim: usize,
4324 descending: bool,
4325 ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4326 match tensor {
4327 TensorPrimitive::Float(tensor) => {
4328 let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4329 (TensorPrimitive::Float(values), indices)
4330 }
4331 TensorPrimitive::QFloat(tensor) => {
4332 let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4333 (TensorPrimitive::QFloat(values), indices)
4334 }
4335 }
4336 }
4337
4338 fn argsort(
4339 tensor: Self::Primitive,
4340 dim: usize,
4341 descending: bool,
4342 ) -> <Int as TensorKind<B>>::Primitive {
4343 match tensor {
4344 TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4345 TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4346 }
4347 }
4348
4349 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
4350 match tensor {
4351 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),
4352 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),
4353 }
4354 }
4355
4356 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
4357 match tensor {
4358 TensorPrimitive::Float(tensor) => {
4359 TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))
4360 }
4361 TensorPrimitive::QFloat(tensor) => {
4362 TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))
4363 }
4364 }
4365 }
4366}
4367
4368impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
4369where
4370 B: Backend,
4371 K: Numeric<B>,
4372 K::Elem: Element,
4373{
4374 type Output = Self;
4375
4376 fn add(self, rhs: Tensor<B, D, K>) -> Self {
4377 Self::add(self, rhs)
4378 }
4379}
4380
4381impl<E, const D: usize, B, K> core::ops::Add<E> for Tensor<B, D, K>
4382where
4383 E: ElementConversion,
4384 B: Backend,
4385 K: Numeric<B>,
4386 K::Elem: Element,
4387{
4388 type Output = Self;
4389
4390 fn add(self, other: E) -> Self {
4391 Tensor::add_scalar(self, other)
4392 }
4393}
4394
4395impl<B, const D: usize, K> core::ops::Sub<Tensor<B, D, K>> for Tensor<B, D, K>
4396where
4397 B: Backend,
4398 K: Numeric<B>,
4399 K::Elem: Element,
4400{
4401 type Output = Self;
4402
4403 fn sub(self, rhs: Tensor<B, D, K>) -> Self {
4404 Tensor::sub(self, rhs)
4405 }
4406}
4407
4408impl<E, const D: usize, B, K> core::ops::Sub<E> for Tensor<B, D, K>
4409where
4410 E: ElementConversion,
4411 B: Backend,
4412 K: Numeric<B>,
4413 K::Elem: Element,
4414{
4415 type Output = Self;
4416
4417 fn sub(self, other: E) -> Self {
4418 Tensor::sub_scalar(self, other)
4419 }
4420}
4421
4422impl<B, const D: usize, K> core::ops::Div<Tensor<B, D, K>> for Tensor<B, D, K>
4423where
4424 B: Backend,
4425 K: Numeric<B>,
4426 K::Elem: Element,
4427{
4428 type Output = Self;
4429
4430 fn div(self, rhs: Tensor<B, D, K>) -> Self {
4431 Tensor::div(self, rhs)
4432 }
4433}
4434
4435impl<E, const D: usize, B, K> core::ops::Div<E> for Tensor<B, D, K>
4436where
4437 E: ElementConversion,
4438 B: Backend,
4439 K: Numeric<B>,
4440 K::Elem: Element,
4441{
4442 type Output = Self;
4443
4444 fn div(self, other: E) -> Self {
4445 Tensor::div_scalar(self, other)
4446 }
4447}
4448
4449impl<const D: usize, B, K> core::ops::Rem<Tensor<B, D, K>> for Tensor<B, D, K>
4450where
4451 B: Backend,
4452 K: Numeric<B>,
4453 K::Elem: Element,
4454{
4455 type Output = Self;
4456 fn rem(self, rhs: Tensor<B, D, K>) -> Self::Output {
4457 Tensor::remainder(self, rhs)
4458 }
4459}
4460
4461impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
4462where
4463 E: ElementConversion,
4464 B: Backend,
4465 K: Numeric<B>,
4466 K::Elem: Element,
4467{
4468 type Output = Self;
4469
4470 fn rem(self, other: E) -> Self {
4471 Tensor::remainder_scalar(self, other)
4472 }
4473}
4474
4475impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
4476where
4477 B: Backend,
4478 K: Numeric<B>,
4479 K::Elem: Element,
4480{
4481 type Output = Self;
4482
4483 fn mul(self, rhs: Tensor<B, D, K>) -> Self {
4484 Tensor::mul(self, rhs)
4485 }
4486}
4487
4488impl<E, const D: usize, B, K> core::ops::Mul<E> for Tensor<B, D, K>
4489where
4490 E: ElementConversion,
4491 B: Backend,
4492 K: Numeric<B>,
4493 K::Elem: Element,
4494{
4495 type Output = Self;
4496
4497 fn mul(self, other: E) -> Self {
4498 Tensor::mul_scalar(self, other)
4499 }
4500}
4501
4502impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4503where
4504 B: Backend,
4505 K: Numeric<B>,
4506 K::Elem: Element,
4507{
4508 type Output = Self;
4509
4510 fn neg(self) -> Self {
4511 Tensor::neg(self)
4512 }
4513}