burn_tensor/tensor/api/
numeric.rs

1use alloc::vec::Vec;
2
3use crate::alloc::borrow::ToOwned;
4
5use crate::TensorPrimitive;
6use crate::{
7    backend::Backend,
8    check,
9    check::TensorCheck,
10    ops::{Device, IntTensor},
11    BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor,
12    TensorKind,
13};
14
15impl<B, const D: usize, K> Tensor<B, D, K>
16where
17    B: Backend,
18    K: Numeric<B>,
19    K::Elem: Element,
20{
21    /// Applies element wise addition operation.
22    ///
23    /// `y = x2 + x1`
24    ///
25    /// # Arguments
26    ///
27    /// * `other` - The tensor to add.
28    ///
29    /// # Example
30    ///
31    /// ```rust
32    /// use burn_tensor::backend::Backend;
33    /// use burn_tensor::{Tensor, Shape};
34    ///
35    /// fn example<B: Backend>() {
36    ///    let device = B::Device::default();
37    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
38    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
39    ///    let tensor = tensor1 + tensor2;
40    ///    println!("{tensor}");
41    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]
42    /// }
43    /// ```
44    #[allow(clippy::should_implement_trait)]
45    pub fn add(self, other: Self) -> Self {
46        check!(TensorCheck::binary_ops_ew("Add", &self, &other));
47        Self::new(K::add(self.primitive, other.primitive))
48    }
49
50    /// Applies element wise addition operation with a scalar.
51    ///
52    /// `y = x + s`
53    ///
54    /// # Arguments
55    ///
56    /// * `other` - The scalar to add, element wise.
57    ///
58    /// # Example
59    ///
60    /// ```rust
61    /// use burn_tensor::backend::Backend;
62    /// use burn_tensor::{Tensor, Shape};
63    ///
64    /// fn example<B: Backend>() {
65    ///   let device = B::Device::default();
66    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
67    ///   let scalar = 2.0;
68    ///   let tensor = tensor + scalar;
69    ///   println!("{tensor}");
70    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]
71    /// }
72    /// ```
73    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {
74        Self::new(K::add_scalar::<E>(self.primitive, other))
75    }
76
77    /// Applies element wise subtraction operation.
78    ///
79    /// `y = x2 - x1`
80    ///
81    /// # Arguments
82    ///
83    /// * `other` - The tensor to subtract.
84    ///
85    /// # Example
86    ///
87    /// ```rust
88    /// use burn_tensor::backend::Backend;
89    /// use burn_tensor::{Tensor, Shape};
90    ///
91    /// fn example<B: Backend>() {
92    ///   let device = B::Device::default();
93    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
94    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
95    ///   let tensor = tensor1 - tensor2;
96    ///   println!("{tensor}");
97    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]
98    /// }
99    /// ```
100    #[allow(clippy::should_implement_trait)]
101    pub fn sub(self, other: Self) -> Self {
102        check!(TensorCheck::binary_ops_ew("Sub", &self, &other));
103        Self::new(K::sub(self.primitive, other.primitive))
104    }
105
106    /// Applies element wise subtraction operation with a scalar.
107    ///
108    /// `y = x - s`
109    ///
110    /// # Arguments
111    ///
112    /// * `other` - The scalar to subtract, element wise.
113    ///
114    /// # Example
115    ///
116    /// ```rust
117    /// use burn_tensor::backend::Backend;
118    /// use burn_tensor::{Tensor, Shape};
119    ///
120    /// fn example<B: Backend>() {
121    ///    let device = B::Device::default();
122    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
123    ///    let scalar = 2.0;
124    ///    let tensor = tensor - scalar;
125    ///    println!("{tensor}");
126    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]
127    /// }
128    /// ```
129    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
130        Self::new(K::sub_scalar::<E>(self.primitive, other))
131    }
132
133    /// Applies element wise division operation.
134    ///
135    /// `y = x2 / x1`
136    ///
137    /// # Arguments
138    ///
139    /// * `other` - The tensor to divide.
140    ///
141    /// # Example
142    ///
143    /// ```rust
144    /// use burn_tensor::backend::Backend;
145    /// use burn_tensor::{Tensor, Shape};
146    ///
147    /// fn example<B: Backend>() {
148    ///    let device = B::Device::default();
149    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
150    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
151    ///    let tensor = tensor1 / tensor2;
152    ///    println!("{tensor}");
153    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]
154    /// }
155    /// ```
156    #[allow(clippy::should_implement_trait)]
157    pub fn div(self, other: Self) -> Self {
158        check!(TensorCheck::binary_ops_ew("Div", &self, &other));
159        Self::new(K::div(self.primitive, other.primitive))
160    }
161
162    /// Applies element wise division operation with a scalar.
163    ///
164    /// `y = x / s`
165    ///
166    /// # Arguments
167    ///
168    /// * `other` - The scalar to divide, element wise.
169    ///
170    /// # Example
171    ///
172    /// ```rust
173    /// use burn_tensor::backend::Backend;
174    /// use burn_tensor::{Tensor, Shape};
175    ///
176    /// fn example<B: Backend>() {
177    ///    let device = B::Device::default();
178    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
179    ///    let scalar = 2.0;
180    ///    let tensor = tensor / scalar;
181    ///    println!("{tensor}");
182    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]
183    /// }
184    /// ```
185    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {
186        Self::new(K::div_scalar::<E>(self.primitive, other))
187    }
188
189    /// Applies element wise the remainder operation with a scalar.
190    ///
191    /// `y = x2 % x1`
192    pub fn remainder(self, other: Self) -> Self {
193        Self::new(K::remainder(self.primitive, other.primitive))
194    }
195
196    /// Applies element wise the remainder operation with a scalar.
197    ///
198    /// `y = x2 % x1`
199    ///
200    /// # Arguments
201    ///
202    /// * `other` - The scalar to divide, element wise.
203    ///
204    /// # Example
205    ///
206    /// ```rust
207    /// use burn_tensor::backend::Backend;
208    /// use burn_tensor::{Tensor, Shape};
209    ///
210    /// fn example<B: Backend>() {
211    ///    let device = B::Device::default();
212    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
213    ///    let scalar = 2.0;
214    ///    let tensor = tensor1 % scalar;
215    ///    println!("{tensor}");
216    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]
217    /// }
218    /// ```
219    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {
220        Self::new(K::remainder_scalar::<E>(self.primitive, other))
221    }
222
223    /// Applies element wise multiplication operation.
224    ///
225    /// `y = x2 * x1`
226    ///
227    /// # Arguments
228    ///
229    /// * `other` - The tensor to multiply.
230    ///
231    /// # Example
232    ///
233    /// ```rust
234    /// use burn_tensor::backend::Backend;
235    /// use burn_tensor::{Tensor, Shape};
236    ///
237    /// fn example<B: Backend>() {
238    ///    let device = B::Device::default();
239    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
240    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
241    ///    let tensor = tensor1 * tensor2;
242    ///    println!("{tensor}");
243    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]
244    /// }
245    /// ```
246    #[allow(clippy::should_implement_trait)]
247    pub fn mul(self, other: Self) -> Self {
248        check!(TensorCheck::binary_ops_ew("Mul", &self, &other));
249        Self::new(K::mul(self.primitive, other.primitive))
250    }
251
252    /// Applies element wise multiplication operation with a scalar.
253    ///
254    /// `y = x * s`
255    ///
256    /// # Arguments
257    ///
258    /// * `other` - The scalar to multiply, element wise.
259    ///
260    /// # Example
261    ///
262    /// ```rust
263    /// use burn_tensor::backend::Backend;
264    /// use burn_tensor::{Tensor, Shape};
265    ///
266    /// fn example<B: Backend>() {
267    ///    let device = B::Device::default();
268    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
269    ///    let scalar = 2.0;
270    ///    let tensor = tensor * scalar;
271    ///    println!("{tensor}");
272    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]
273    /// }
274    /// ```
275    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {
276        Self::new(K::mul_scalar::<E>(self.primitive, other))
277    }
278
279    /// Switch sign of each element in the tensor.
280    ///
281    /// `y = -x`
282    ///
283    /// # Example
284    ///
285    /// ```rust
286    /// use burn_tensor::backend::Backend;
287    /// use burn_tensor::{Tensor, Shape};
288    ///
289    /// fn example<B: Backend>() {
290    ///    let device = B::Device::default();
291    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
292    ///    let tensor = -tensor;
293    ///    println!("{tensor}");
294    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]
295    /// }
296    /// ```
297    #[allow(clippy::should_implement_trait)]
298    pub fn neg(self) -> Self {
299        Self::new(K::neg(self.primitive))
300    }
301
302    /// Returns the signs of the elements of the input tensor.
303    ///
304    /// # Example
305    ///
306    /// ```rust
307    /// use burn_tensor::backend::Backend;
308    /// use burn_tensor::{Tensor, Shape};
309    ///
310    /// fn example<B: Backend>() {
311    ///    let device = B::Device::default();
312    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
313    ///    let tensor = tensor.sign();
314    ///    println!("{tensor}");
315    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]
316    /// }
317    /// ```
318    pub fn sign(self) -> Self {
319        Self::new(K::sign(self.primitive))
320    }
321
322    /// Create a tensor of the given shape where each element is zero.
323    ///
324    /// # Example
325    ///
326    /// ```rust
327    /// use burn_tensor::backend::Backend;
328    /// use burn_tensor::{Tensor, Shape};
329    ///
330    /// fn example<B: Backend>() {
331    ///    let device = B::Device::default();
332    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);
333    ///    println!("{tensor}");
334    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
335    /// }
336    /// ```
337    pub fn zeros<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
338        let shape = shape.into();
339        check!(TensorCheck::creation_ops::<D>("Zeros", &shape.dims));
340        Self::new(K::zeros(shape, device))
341    }
342
343    /// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
344    ///
345    /// # Example
346    ///
347    /// ```rust
348    /// use burn_tensor::backend::Backend;
349    /// use burn_tensor::{Tensor, Shape};
350    ///
351    /// fn example<B: Backend>() {
352    ///   let device = B::Device::default();
353    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
354    ///   let tensor = tensor.zeros_like();
355    ///   println!("{tensor}");
356    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
357    /// }
358    /// ```
359    pub fn zeros_like(&self) -> Self {
360        Self::zeros(self.shape(), &self.device())
361    }
362
363    /// Create a tensor of the given shape where each element is one.
364    ///
365    /// # Example
366    ///
367    /// ```rust
368    /// use burn_tensor::backend::Backend;
369    /// use burn_tensor::{Tensor, Shape};
370    ///
371    /// fn example<B: Backend>() {
372    ///   let device = B::Device::default();
373    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);
374    ///   println!("{tensor}");
375    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
376    /// }
377    /// ```
378    pub fn ones<S: Into<Shape>>(shape: S, device: &B::Device) -> Self {
379        let shape = shape.into();
380        check!(TensorCheck::creation_ops::<D>("Ones", &shape.dims));
381        Self::new(K::ones(shape, device))
382    }
383
384    /// Returns a new tensor with the same shape and device as the current tensor filled with ones.
385    ///
386    /// # Example
387    ///
388    /// ```rust
389    /// use burn_tensor::backend::Backend;
390    /// use burn_tensor::{Tensor, Shape};
391    ///
392    /// fn example<B: Backend>() {
393    ///    let device = B::Device::default();
394    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
395    ///    let tensor = tensor.ones_like();
396    ///    println!("{tensor}");
397    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
398    /// }
399    /// ```
400    pub fn ones_like(&self) -> Self {
401        Self::ones(self.shape(), &self.device())
402    }
403
404    /// Create a tensor of the given shape where each element is equal to the provided value.
405    ///
406    /// # Example
407    ///
408    /// ```rust
409    /// use burn_tensor::backend::Backend;
410    /// use burn_tensor::{Tensor, Shape};
411    ///
412    /// fn example<B: Backend>() {
413    ///   let device = B::Device::default();
414    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);
415    ///   println!("{tensor}");
416    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
417    /// }
418    /// ```
419    pub fn full<S: Into<Shape>, E: ElementConversion>(
420        shape: S,
421        fill_value: E,
422        device: &B::Device,
423    ) -> Self {
424        let shape = shape.into();
425        check!(TensorCheck::creation_ops::<D>("Full", &shape.dims));
426        Self::new(K::full(shape, fill_value, device))
427    }
428
429    ///Returns a new tensor with the same shape and device as the current tensor filled with the provided value.
430    ///
431    /// # Example
432    ///
433    /// ```rust
434    /// use burn_tensor::backend::Backend;
435    /// use burn_tensor::{Tensor, Shape};
436    ///
437    /// fn example<B: Backend>() {
438    ///    let device = B::Device::default();
439    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
440    ///    let tensor = tensor.full_like(5.0);
441    ///    println!("{tensor}");
442    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]
443    /// }
444    /// ```
445    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {
446        Self::full(self.shape(), fill_value, &self.device())
447    }
448
449    /// Aggregate all elements in the tensor with the mean operation.
450    ///
451    /// # Example
452    ///
453    /// ```rust
454    /// use burn_tensor::backend::Backend;
455    /// use burn_tensor::{Tensor, Shape};
456    ///
457    /// fn example<B: Backend>() {
458    ///    let device = B::Device::default();
459    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
460    ///    let tensor = tensor.mean();
461    ///    println!("{tensor}");
462    ///    // [3.6666667]
463    /// }
464    /// ```
465    pub fn mean(self) -> Tensor<B, 1, K> {
466        Tensor::new(K::mean(self.primitive))
467    }
468
469    /// Aggregate all elements in the tensor with the sum operation.
470    ///
471    /// # Example
472    ///
473    /// ```rust
474    /// use burn_tensor::backend::Backend;
475    /// use burn_tensor::{Tensor, Shape};
476    ///
477    /// fn example<B: Backend>() {
478    ///   let device = B::Device::default();
479    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
480    ///   let tensor = tensor.sum();
481    ///   println!("{tensor}");
482    ///   // [22.0]
483    /// }
484    /// ```
485    pub fn sum(self) -> Tensor<B, 1, K> {
486        Tensor::new(K::sum(self.primitive))
487    }
488
489    /// Aggregate all elements along the given *dimension* or *axis*
490    /// in the tensor with the mean operation.
491    ///
492    /// # Arguments
493    ///
494    /// * `dim` - The dimension or axis along which to aggregate the elements.
495    ///
496    /// # Example
497    ///
498    /// ```rust
499    /// use burn_tensor::backend::Backend;
500    /// use burn_tensor::{Tensor, Shape};
501    ///
502    /// fn example<B: Backend>() {
503    ///   let device = B::Device::default();
504    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
505    ///   let tensor = tensor.clone().mean_dim(0);
506    ///   println!("{tensor}");
507    ///   // [[3.0, 3.5, 4.5]]
508    ///   let tensor = tensor.clone().mean_dim(1);
509    ///   println!("{tensor}");
510    ///   // [[0.6666667], [6.6666665]]
511    /// }
512    /// ```
513    pub fn mean_dim(self, dim: usize) -> Self {
514        check!(TensorCheck::aggregate_dim::<D>("Mean", dim));
515        Self::new(K::mean_dim(self.primitive, dim))
516    }
517
518    /// Aggregate all elements along the given *dimension* or *axis*
519    /// in the tensor with the sum operation.
520    ///
521    /// # Arguments
522    ///
523    /// * `dim` - The dimension or axis along which to aggregate the elements.
524    ///
525    /// # Example
526    ///
527    /// ```rust
528    /// use burn_tensor::backend::Backend;
529    /// use burn_tensor::{Tensor, Shape};
530    ///
531    /// fn example<B: Backend>() {
532    ///    let device = B::Device::default();
533    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
534    ///    let tensor = tensor.clone().sum_dim(0);
535    ///    println!("{tensor}");
536    ///    let tensor = tensor.clone().sum_dim(1);
537    ///    // [[6.0, 7.0, 9.0]]
538    ///    println!("{tensor}");
539    ///    // [[2.0], [20.0]]
540    /// }
541    /// ```
542    pub fn sum_dim(self, dim: usize) -> Self {
543        check!(TensorCheck::aggregate_dim::<D>("Sum", dim));
544        Self::new(K::sum_dim(self.primitive, dim))
545    }
546
547    /// Aggregate all elements along the given *dimension* or *axis*
548    /// in the tensor with the product operation.
549    ///
550    /// # Example
551    ///
552    /// ```rust
553    /// use burn_tensor::backend::Backend;
554    /// use burn_tensor::{Tensor, Shape};
555    ///
556    /// fn example<B: Backend>() {
557    ///    let device = B::Device::default();
558    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
559    ///    let tensor = tensor.prod();
560    ///    println!("{tensor}");
561    ///    // [-1620.0]
562    /// }
563    /// ```
564    pub fn prod(self) -> Tensor<B, 1, K> {
565        Tensor::new(K::prod(self.primitive))
566    }
567
568    /// Aggregate all elements along the given *dimension* or *axis*
569    /// in the tensor with the product operation.
570    ///
571    /// # Arguments
572    ///
573    /// * `dim` - The dimension or axis along which to aggregate the elements.
574    ///
575    /// # Example
576    ///
577    /// ```rust
578    /// use burn_tensor::backend::Backend;
579    /// use burn_tensor::{Tensor, Shape};
580    ///
581    /// fn example<B: Backend>() {
582    ///    let device = B::Device::default();
583    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
584    ///    let tensor = tensor.clone().prod_dim(0);
585    ///    println!("{tensor}");
586    ///    // [[5.0, -18.0, 18.0]]
587    ///    let tensor = tensor.clone().prod_dim(1);
588    ///    println!("{tensor}");
589    ///    // [[-6.0], [270.0]]
590    /// }
591    /// ```
592    pub fn prod_dim(self, dim: usize) -> Self {
593        check!(TensorCheck::aggregate_dim::<D>("Prod", dim));
594        Self::new(K::prod_dim(self.primitive, dim))
595    }
596
597    /// Applies element wise equal comparison and returns a boolean tensor.
598    ///
599    /// # Arguments
600    ///
601    /// * `other` - The element to compare.
602    ///
603    /// # Example
604    ///
605    /// ```rust
606    /// use burn_tensor::backend::Backend;
607    /// use burn_tensor::{Tensor, Shape};
608    ///
609    /// fn example<B: Backend>() {
610    ///    let device = B::Device::default();
611    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
612    ///    let tensor = tensor.equal_elem(3.0);
613    ///    println!("{tensor}");
614    ///    // [[false, false, true], [false, false, false]]
615    /// }
616    /// ```
617    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
618        Tensor::new(K::equal_elem(self.primitive, other.elem()))
619    }
620
621    /// Applies element wise non-equality comparison and returns a boolean tensor.
622    ///
623    /// # Arguments
624    ///
625    /// * `other` - The element to compare.
626    ///
627    /// # Example
628    ///
629    /// ```rust
630    /// use burn_tensor::backend::Backend;
631    /// use burn_tensor::{Tensor, Shape};
632    ///
633    /// fn example<B: Backend>() {
634    ///    let device = B::Device::default();
635    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
636    ///    let tensor = tensor.not_equal_elem(3.0);
637    ///    println!("{tensor}");
638    ///    // [[true, true, false], [true, true, true]]
639    /// }
640    /// ```
641    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {
642        Tensor::new(K::not_equal_elem(self.primitive, other.elem()))
643    }
644
645    /// Applies element wise greater comparison and returns a boolean tensor.
646    ///
647    /// # Panics
648    ///
649    /// If the two tensors don't have the same shape.
650    ///
651    /// # Example
652    ///
653    /// ```rust
654    /// use burn_tensor::backend::Backend;
655    /// use burn_tensor::{Tensor, Shape};
656    ///
657    /// fn example<B: Backend>() {
658    ///   let device = B::Device::default();
659    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
660    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
661    ///   let tensor = tensor1.greater(tensor2);
662    ///   println!("{tensor}");
663    ///   // [[false, false, false], [true, true, true]]
664    /// }
665    /// ```
666    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {
667        check!(TensorCheck::binary_ops_ew("Greater", &self, &other));
668        Tensor::new(K::greater(self.primitive, other.primitive))
669    }
670
671    /// Applies element wise greater-equal comparison and returns a boolean tensor.
672    ///
673    /// # Panics
674    ///
675    /// If the two tensors don't have the same shape.
676    ///
677    /// # Example
678    ///
679    /// ```rust
680    /// use burn_tensor::backend::Backend;
681    /// use burn_tensor::{Tensor, Shape};
682    ///
683    /// fn example<B: Backend>() {
684    ///    let device = B::Device::default();
685    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
686    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
687    ///    let tensor = tensor1.greater_equal(tensor2);
688    ///    println!("{tensor}");
689    ///    // [[false, false, false], [true, true, true]]
690    /// }
691    /// ```
692    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {
693        check!(TensorCheck::binary_ops_ew("Greater_equal", &self, &other));
694        Tensor::new(K::greater_equal(self.primitive, other.primitive))
695    }
696
697    /// Applies element wise lower comparison and returns a boolean tensor.
698    ///
699    /// # Panics
700    ///
701    /// If the two tensors don't have the same shape.
702    ///
703    /// # Example
704    ///
705    /// ```rust
706    /// use burn_tensor::backend::Backend;
707    /// use burn_tensor::{Tensor, Shape};
708    ///
709    /// fn example<B: Backend>() {
710    ///    let device = B::Device::default();
711    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
712    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
713    ///    let tensor = tensor1.lower(tensor2);
714    ///    println!("{tensor}");
715    ///    // [[true, true, true], [false, false, false]]
716    /// }
717    /// ```
718    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {
719        check!(TensorCheck::binary_ops_ew("Lower", &self, &other));
720        Tensor::new(K::lower(self.primitive, other.primitive))
721    }
722
723    /// Applies element wise lower-equal comparison and returns a boolean tensor.
724    ///
725    /// # Panics
726    ///
727    /// If the two tensors don't have the same shape.
728    ///
729    /// # Example
730    ///
731    /// ```rust
732    /// use burn_tensor::backend::Backend;
733    /// use burn_tensor::{Tensor, Shape};
734    ///
735    /// fn example<B: Backend>() {
736    ///    let device = B::Device::default();
737    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
738    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
739    ///    let tensor = tensor1.lower_equal(tensor2);
740    ///    println!("{tensor}");
741    ///    // [[true, true, true], [false, false, false]]
742    /// }
743    /// ```
744    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {
745        check!(TensorCheck::binary_ops_ew("Lower_equal", &self, &other));
746        Tensor::new(K::lower_equal(self.primitive, other.primitive))
747    }
748
749    /// Applies element wise greater comparison and returns a boolean tensor.
750    ///
751    /// # Arguments
752    ///
753    /// * `other` - The element to compare.
754    ///
755    /// # Example
756    ///
757    /// ```rust
758    /// use burn_tensor::backend::Backend;
759    /// use burn_tensor::{Tensor, Shape};
760    ///
761    /// fn example<B: Backend>() {
762    ///    let device = B::Device::default();
763    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
764    ///    let tensor = tensor.greater_elem(3.0);
765    ///    println!("{tensor}");
766    ///    // [[false, false, true], [true, true, true]]
767    /// }
768    /// ```
769    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
770        Tensor::new(K::greater_elem(self.primitive, other.elem()))
771    }
772
773    /// Applies element wise greater-equal comparison and returns a boolean tensor.
774    ///
775    /// # Arguments
776    ///
777    /// * `other` - The element to compare.
778    ///
779    /// # Example
780    ///
781    /// ```rust
782    /// use burn_tensor::backend::Backend;
783    /// use burn_tensor::{Tensor, Shape};
784    ///
785    /// fn example<B: Backend>() {
786    ///    let device = B::Device::default();
787    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
788    ///    let tensor = tensor.greater_equal_elem(3.0);
789    ///    println!("{tensor}");
790    ///    // [[false, false, true], [true, true, true]]
791    /// }
792    /// ```
793    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
794        Tensor::new(K::greater_equal_elem(self.primitive, other.elem()))
795    }
796
797    /// Applies element wise lower comparison and returns a boolean tensor.
798    ///
799    /// # Arguments
800    ///
801    /// * `other` - The element to compare.
802    ///
803    /// # Example
804    ///
805    /// ```rust
806    /// use burn_tensor::backend::Backend;
807    /// use burn_tensor::{Tensor, Shape};
808    ///
809    /// fn example<B: Backend>() {
810    ///     let device = B::Device::default();
811    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
812    ///     let tensor = tensor.lower_elem(3.0);
813    ///     println!("{tensor}");
814    ///     // [[true, true, false], [false, false, false]]
815    /// }
816    /// ```
817    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
818        Tensor::new(K::lower_elem(self.primitive, other.elem()))
819    }
820
821    /// Applies element wise lower-equal comparison and returns a boolean tensor.
822    ///
823    /// # Arguments
824    ///
825    /// * `other` - The element to compare.
826    ///
827    /// # Example
828    ///
829    /// ```rust
830    /// use burn_tensor::backend::Backend;
831    /// use burn_tensor::{Tensor, Shape};
832    ///
833    /// fn example<B: Backend>() {
834    ///    let device = B::Device::default();
835    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
836    ///    let tensor = tensor.lower_equal_elem(3.0);
837    ///    println!("{tensor}");
838    ///    // [[true, true, true], [false, false, false]]
839    /// }
840    /// ```
841    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
842        Tensor::new(K::lower_equal_elem(self.primitive, other.elem()))
843    }
844
845    /// Update the given tensor with the value tensor where the mask is true.
846    ///
847    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of
848    /// a scalar.
849    ///
850    /// # Example
851    ///
852    /// ```rust
853    /// use burn_tensor::backend::Backend;
854    /// use burn_tensor::{Tensor, Shape, Bool};
855    ///
856    /// fn example<B: Backend>() {
857    ///   let device = B::Device::default();
858    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
859    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
860    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
861    ///   let tensor = tensor.mask_where(mask, value);
862    ///   println!("{tensor}");
863    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]
864    /// }
865    /// ```
866    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {
867        Self::new(K::mask_where(
868            self.primitive,
869            mask.primitive,
870            value.primitive,
871        ))
872    }
873
874    /// Update the given tensor with the value where the mask is true.
875    ///
876    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of
877    /// a tensor.
878    ///
879    /// # Example
880    ///
881    /// ```rust
882    /// use burn_tensor::backend::Backend;
883    /// use burn_tensor::{Tensor, Shape, Bool};
884    ///
885    /// fn example<B: Backend>() {
886    ///   let device = B::Device::default();
887    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
888    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);
889    ///   let tensor = tensor.mask_fill(mask, 3.0);
890    ///   println!("{tensor}");
891    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]
892    /// }
893    /// ```
894    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {
895        Self::new(K::mask_fill(self.primitive, mask.primitive, value.elem()))
896    }
897
898    /// Gather tensor elements corresponding to the given indices from the specified dim.
899    ///
900    /// Example using a 3D tensor:
901    ///
902    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
903    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
904    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
905    ///
906    /// # Notes
907    ///
908    /// The index tensor should have the same shape as the original tensor except for the dim
909    /// specified.
910    ///
911    /// # Warning
912    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
913    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
914    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
915        check!(TensorCheck::gather::<D>(
916            dim,
917            &self.shape(),
918            &indices.shape()
919        ));
920
921        Self::new(K::gather(dim, self.primitive, indices.primitive))
922    }
923
924    /// Assign the gathered elements corresponding to the given indices along the specified dimension
925    /// from the value tensor to the original tensor using sum reduction.
926    ///
927    /// Example using a 3D tensor:
928    ///
929    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
930    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
931    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
932    ///
933    /// # Notes
934    ///
935    /// The index tensor should have the same shape as the original tensor except for the specified
936    /// dimension. The value and index tensors should have the same shape.
937    ///
938    /// Other references to the input tensor will not be modified by this operation.
939    ///
940    /// # Warning
941    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
942    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
943    pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
944        check!(TensorCheck::scatter::<D>(
945            dim,
946            &self.shape(),
947            &indices.shape(),
948            &values.shape()
949        ));
950
951        Self::new(K::scatter(
952            dim,
953            self.primitive,
954            indices.primitive,
955            values.primitive,
956        ))
957    }
958
959    /// Select the tensor elements along the given dimension corresponding to the given indices.
960    ///
961    /// Example using a 3D tensor:
962    ///
963    /// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
964    /// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
965    /// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
966    ///
967    /// # Warning
968    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
969    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
970    ///
971    /// # Example
972    ///
973    /// ```rust
974    /// use burn_tensor::backend::Backend;
975    /// use burn_tensor::{Tensor, Shape, Int};
976    ///
977    /// fn example<B: Backend>() {
978    ///   let device = B::Device::default();
979    ///   let tensor = Tensor::<B, 3>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
980    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);
981    ///   let tensor = tensor.select(0, indices);
982    ///   println!("{tensor}");
983    ///   //  [[1.0, -2.0, 3.0]]
984    /// }
985    /// ```
986    pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
987        check!(TensorCheck::select::<D>(dim));
988        Self::new(K::select(self.primitive, dim, indices))
989    }
990
991    /// Assign the selected elements along the given dimension corresponding to the given indices
992    /// from the value tensor to the original tensor using sum reduction.
993    ///
994    /// Example using a 3D tensor:
995    ///
996    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
997    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
998    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
999    ///
1000    /// # Warning
1001    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.
1002    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.
1003    pub fn select_assign(
1004        self,
1005        dim: usize,
1006        indices: Tensor<B, 1, Int>,
1007        values: Tensor<B, D, K>,
1008    ) -> Self {
1009        check!(TensorCheck::select_assign::<D>(dim));
1010
1011        Self::new(K::select_assign(
1012            self.primitive,
1013            dim,
1014            indices,
1015            values.primitive,
1016        ))
1017    }
1018
1019    /// Applies the argmax function along the given dimension and returns an integer tensor.
1020    ///
1021    /// # Example
1022    ///
1023    /// ```rust
1024    /// use burn_tensor::backend::Backend;
1025    /// use burn_tensor::{Tensor, Shape};
1026    ///
1027    /// fn example<B: Backend>() {
1028    ///     let device = B::Device::default();
1029    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1030    ///     let tensor = tensor.argmax(1);
1031    ///     println!("{:?}", tensor.shape());
1032    ///     // Shape { dims: [2, 1, 3] }
1033    /// }
1034    /// ```
1035    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {
1036        Tensor::new(K::argmax(self.primitive, dim))
1037    }
1038
1039    /// Find the maximum value.
1040    ///
1041    /// # Example
1042    ///
1043    /// ```rust
1044    /// use burn_tensor::backend::Backend;
1045    /// use burn_tensor::{Tensor, Shape};
1046    ///
1047    /// fn example<B: Backend>() {
1048    ///   let device = B::Device::default();
1049    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1050    ///   let tensor = tensor.max();
1051    ///   println!("{tensor}");
1052    ///   // [9.0]
1053    /// }
1054    /// ```
1055    pub fn max(self) -> Tensor<B, 1, K> {
1056        Tensor::new(K::max(self.primitive))
1057    }
1058
1059    /// Find the maximum value along the given dimension.
1060    ///
1061    /// # Example
1062    ///
1063    /// ```rust
1064    /// use burn_tensor::backend::Backend;
1065    /// use burn_tensor::{Tensor, Shape};
1066    ///
1067    /// fn example<B: Backend>() {
1068    ///   let device = B::Device::default();
1069    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1070    ///   let tensor = tensor.max_dim(0);
1071    ///   println!("{tensor}");
1072    ///   // [[5.0, 9.0, 6.0]]
1073    /// }
1074    /// ```
1075    pub fn max_dim(self, dim: usize) -> Tensor<B, D, K> {
1076        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1077
1078        Tensor::new(K::max_dim(self.primitive, dim))
1079    }
1080
1081    /// Find the maximum value along the given dimension.
1082    ///
1083    /// Also returns the indices.
1084    ///
1085    /// # Example
1086    ///
1087    /// ```rust
1088    /// use burn_tensor::backend::Backend;
1089    /// use burn_tensor::{Tensor, Shape};
1090    ///
1091    /// fn example<B: Backend>() {
1092    ///    let device = B::Device::default();
1093    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1094    ///    let (tensor, index) = tensor.max_dim_with_indices(0);
1095    ///    // [[5.0, 9.0, 6.0]]
1096    ///    println!("{tensor}");
1097    ///    // [[1, 1, 1]]
1098    ///    println!("{index}");
1099    /// }
1100    /// ```
1101    pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1102        check!(TensorCheck::aggregate_dim::<D>("Max", dim));
1103
1104        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
1105
1106        let tensor = Tensor::new(tensor);
1107        let index = Tensor::new(index);
1108
1109        (tensor, index)
1110    }
1111
1112    /// Finds the maximum pair wise values with another Tensor
1113    ///
1114    /// # Arguments
1115    ///
1116    /// * `other` - Other tensor to find maximum elements with
1117    ///
1118    /// # Returns
1119    ///
1120    /// A tensor with the same shape as the input tensors containing the maximum value found
1121    /// in the input tensors.
1122    ///
1123    /// # Example
1124    ///
1125    /// ```rust
1126    /// use burn_tensor::backend::Backend;
1127    /// use burn_tensor::{Tensor, Shape};
1128    ///
1129    /// fn example<B: Backend>() {
1130    ///    let device = B::Device::default();
1131    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1132    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1133    ///    let tensor = tensor1.max_pair(tensor2);
1134    ///    println!("{tensor}");
1135    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]
1136    /// }
1137    /// ```
1138    pub fn max_pair(self, other: Self) -> Self {
1139        let mask = self.clone().lower(other.clone());
1140        self.mask_where(mask, other)
1141    }
1142
1143    /// Applies the argmin function along the given dimension and returns an integer tensor.
1144    ///
1145    /// # Example
1146    ///
1147    /// ```rust
1148    /// use burn_tensor::backend::Backend;
1149    /// use burn_tensor::{Tensor, Shape};
1150    ///
1151    /// fn example<B: Backend>() {
1152    ///     let device = Default::default();
1153    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);
1154    ///     let tensor = tensor.argmin(1);
1155    ///     println!("{:?}", tensor.shape());
1156    ///     // Shape { dims: [2, 1, 3] }
1157    /// }
1158    /// ```
1159    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {
1160        Tensor::new(K::argmin(self.primitive, dim))
1161    }
1162
1163    /// Find the minimum value.
1164    ///
1165    /// # Example
1166    ///
1167    /// ```rust
1168    /// use burn_tensor::backend::Backend;
1169    /// use burn_tensor::{Tensor, Shape};
1170    ///
1171    /// fn example<B: Backend>() {
1172    ///    let device = B::Device::default();
1173    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1174    ///    let tensor = tensor.min();
1175    ///    println!("{tensor}");
1176    ///    // [-2.0]
1177    /// }
1178    /// ```
1179    pub fn min(self) -> Tensor<B, 1, K> {
1180        Tensor::new(K::min(self.primitive))
1181    }
1182
1183    /// Find the minimum value along the given dimension.
1184    ///
1185    /// # Example
1186    ///
1187    /// ```rust
1188    /// use burn_tensor::backend::Backend;
1189    /// use burn_tensor::{Tensor, Shape};
1190    ///
1191    /// fn example<B: Backend>() {
1192    ///    let device = B::Device::default();
1193    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1194    ///    let tensor = tensor.min_dim(0);
1195    ///    println!("{tensor}");
1196    ///    // [[1.0, -2.0, 3.0]]
1197    /// }
1198    /// ```
1199    pub fn min_dim(self, dim: usize) -> Tensor<B, D, K> {
1200        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1201        Tensor::new(K::min_dim(self.primitive, dim))
1202    }
1203
1204    /// Find the minimum value along the given dimension.
1205    ///
1206    /// Also returns the indices.
1207    ///
1208    /// # Example
1209    ///
1210    /// ```rust
1211    /// use burn_tensor::backend::Backend;
1212    /// use burn_tensor::{Tensor, Shape};
1213    ///
1214    /// fn example<B: Backend>() {
1215    ///    let device = B::Device::default();
1216    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1217    ///    let (tensor, index) = tensor.min_dim_with_indices(0);
1218    ///    println!("{tensor}");
1219    ///    // [[1.0, -2.0, 3.0]]
1220    ///    println!("{}", index);
1221    ///    // [[1, 0, 0]]
1222    /// }
1223    /// ```
1224    pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1225        check!(TensorCheck::aggregate_dim::<D>("Min", dim));
1226
1227        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
1228
1229        let tensor = Tensor::new(tensor);
1230        let index = Tensor::new(index);
1231
1232        (tensor, index)
1233    }
1234
1235    /// Finds the minimum pair wise values with another Tensor
1236    ///
1237    /// # Arguments
1238    ///
1239    /// * `other` - Other tensor to find minimum elements with
1240    ///
1241    /// # Returns
1242    ///
1243    /// A tensor with the same shape as the input tensors containing the minimum value found
1244    /// between each element of the two source tensors.
1245    ///
1246    /// # Example
1247    ///
1248    /// ```rust
1249    /// use burn_tensor::backend::Backend;
1250    /// use burn_tensor::{Tensor, Shape};
1251    ///
1252    /// fn example<B: Backend>() {
1253    ///    let device = B::Device::default();
1254    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1255    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1256    ///    let tensor = tensor1.min_pair(tensor2);
1257    ///    println!("{tensor}");
1258    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]
1259    /// }
1260    pub fn min_pair(self, other: Self) -> Self {
1261        let mask = other.clone().lower(self.clone());
1262        self.mask_where(mask, other)
1263    }
1264
1265    /// Clamp the tensor between the given min and max values.
1266    ///
1267    /// # Arguments
1268    ///
1269    /// * `min` - The minimum value.
1270    /// * `max` - The maximum value.
1271    ///
1272    /// # Returns
1273    ///
1274    /// A new tensor with the values clamped between the given min and max values.
1275    ///
1276    /// # Example
1277    ///
1278    /// ```rust
1279    /// use burn_tensor::backend::Backend;
1280    /// use burn_tensor::{Int, Tensor};
1281    ///
1282    /// fn example<B: Backend>() {
1283    ///   let device = Default::default();
1284    ///   let tensor = Tensor::<B, 2, Int>::from_ints(
1285    ///    [
1286    ///     [1, 2, 3],
1287    ///     [4, 5, 6],
1288    ///     [7, 8, 9]
1289    ///    ],
1290    ///    &device);
1291    ///    let tensor = tensor.clamp(2, 6);
1292    ///    println!("{tensor}");
1293    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]
1294    /// }
1295    /// ```
1296    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {
1297        Self::new(K::clamp(self.primitive, min.elem(), max.elem()))
1298    }
1299
1300    /// Clamps a tensor under a minimum value.
1301    ///
1302    /// # Arguments
1303    ///
1304    /// * `tensor` - The tensor to clamp.
1305    /// * `min` - The minimum value.
1306    ///
1307    /// # Returns
1308    ///
1309    /// A new tensor with the values clamped under the given min value.
1310    ///
1311    /// # Example
1312    ///
1313    /// ```rust
1314    /// use burn_tensor::backend::Backend;
1315    /// use burn_tensor::{Int, Tensor};
1316    ///
1317    /// fn example<B: Backend>() {
1318    ///    let device = Default::default();
1319    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1320    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1321    ///    &device);
1322    ///    let tensor = tensor.clamp_min(4);
1323    ///    println!("{tensor}");
1324    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]
1325    /// }
1326    /// ```
1327    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {
1328        Self::new(K::clamp_min(self.primitive, min.elem()))
1329    }
1330
1331    /// Clamps a tensor over a maximum value.
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `tensor` - The tensor to clamp.
1336    /// * `max` - The maximum value.
1337    ///
1338    /// # Returns
1339    ///
1340    /// A new tensor with the values clamped over the given max value.
1341    ///
1342    /// # Example
1343    ///
1344    /// ```rust
1345    /// use burn_tensor::backend::Backend;
1346    /// use burn_tensor::{Int, Tensor};
1347    ///
1348    /// fn example<B: Backend>() {
1349    ///    let device = Default::default();
1350    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1351    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1352    ///    &device);
1353    ///    let tensor = tensor.clamp_max(5);
1354    ///    println!("{tensor}");
1355    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]
1356    /// }
1357    /// ```
1358    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {
1359        Self::new(K::clamp_max(self.primitive, max.elem()))
1360    }
1361
1362    /// Apply element wise absolute value operation
1363    ///
1364    /// # Example
1365    ///
1366    /// ```rust
1367    /// use burn_tensor::backend::Backend;
1368    /// use burn_tensor::{Int, Tensor};
1369    ///
1370    /// fn example<B: Backend>() {
1371    ///   let device = Default::default();
1372    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);
1373    ///   let tensor = tensor.abs();
1374    ///   println!("{tensor}");
1375    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1376    /// }
1377    /// ```
1378    pub fn abs(self) -> Self {
1379        Self::new(K::abs(self.primitive))
1380    }
1381
1382    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,
1383    /// the other elements of the result tensor out are set to 0.
1384    ///
1385    /// # Example
1386    /// ```rust
1387    /// use burn_tensor::backend::Backend;
1388    /// use burn_tensor::{Int, Tensor};
1389    ///
1390    /// fn example<B: Backend>() {
1391    ///    let device = Default::default();
1392    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1393    ///        [
1394    ///          [1, 2, 3],
1395    ///          [4, 5, 6],
1396    ///          [7, 8, 9]
1397    ///        ],
1398    ///        &device
1399    ///    );
1400    ///    let tensor = tensor.triu(1);
1401    ///    println!("{tensor}");
1402    ///    // [
1403    ///    //   [0, 2, 3],
1404    ///    //   [0, 0, 6],
1405    ///    //   [0, 0, 0]
1406    ///    // ]
1407    /// }
1408    /// ```
1409    pub fn triu(self, diagonal: i64) -> Self {
1410        check!(TensorCheck::tri::<{ D }>());
1411
1412        // last two dimensions
1413        let shape = &self.shape().dims[D - 2..].to_owned();
1414
1415        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();
1416        self.mask_fill(mask, 0)
1417    }
1418
1419    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,
1420    /// the other elements of the result tensor out are set to 0.
1421    ///
1422    /// # Example
1423    /// ```rust
1424    /// use burn_tensor::backend::Backend;
1425    /// use burn_tensor::{Int, Tensor};
1426    ///
1427    /// fn example<B: Backend>() {
1428    ///    let device = Default::default();
1429    ///    let tensor = Tensor::<B, 2, Int>::from_ints(
1430    ///        [
1431    ///          [1, 2, 3],
1432    ///          [4, 5, 6],
1433    ///          [7, 8, 9]
1434    ///        ],
1435    ///        &device
1436    ///    );
1437    ///
1438    ///    let tensor = tensor.tril(-1);
1439    ///    println!("{tensor}");
1440    ///    // [
1441    ///    //   [0, 0, 0],
1442    ///    //   [4, 0, 0],
1443    ///    //   [7, 8, 0]
1444    ///    // ]
1445    /// }
1446    /// ```
1447    pub fn tril(self, diagonal: i64) -> Self {
1448        check!(TensorCheck::tri::<{ D }>());
1449
1450        // last two dimensions
1451        let shape = &self.shape().dims[D - 2..].to_owned();
1452
1453        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();
1454        self.mask_fill(mask, 0)
1455    }
1456
1457    /// Applies element wise power operation with a float Tensor
1458    ///
1459    /// # Arguments
1460    ///
1461    /// * `other` - The tensor to apply the power operation with.
1462    ///
1463    /// # Example
1464    ///
1465    /// ```rust
1466    /// use burn_tensor::backend::Backend;
1467    /// use burn_tensor::{Tensor, Shape};
1468    ///
1469    /// fn example<B: Backend>() {
1470    ///    let device = B::Device::default();
1471    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1472    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);
1473    ///    let tensor = tensor1.powf(tensor2);
1474    ///    println!("{tensor}");
1475    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]
1476    /// }
1477    /// ```
1478    pub fn powf(self, other: Self) -> Self {
1479        Self::new(K::powf(self.primitive, other.primitive))
1480    }
1481
1482    /// Applies element wise power operation with a float scalar
1483    ///
1484    /// # Arguments
1485    ///
1486    /// * `other` - The scalar to apply the power operation with.
1487    ///
1488    /// # Example
1489    ///
1490    /// ```rust
1491    /// use burn_tensor::backend::Backend;
1492    /// use burn_tensor::{Tensor, Shape};
1493    ///
1494    /// fn example<B: Backend>() {
1495    ///    let device = B::Device::default();
1496    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1497    ///    let tensor = tensor.powf_scalar(2.0);
1498    ///    println!("{tensor}");
1499    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]
1500    /// }
1501    /// ```
1502    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
1503        Self::new(K::powf_scalar::<E>(self.primitive, other))
1504    }
1505
1506    /// Applies element wise power operation with a integer Tensor
1507    ///
1508    /// # Arguments
1509    ///
1510    /// * `other` - The tensor to apply the power operation with.
1511    ///
1512    /// # Example
1513    ///
1514    /// ```rust
1515    /// use burn_tensor::backend::Backend;
1516    /// use burn_tensor::{Tensor, Shape, Int};
1517    ///
1518    /// fn example<B: Backend>() {
1519    ///    let device = B::Device::default();
1520    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1521    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);
1522    ///    let tensor = tensor1.powi(tensor2);
1523    ///    println!("{tensor}");
1524    ///    // [[1, -8, 81], [5, 81, 216]]
1525    /// }
1526    /// ```
1527    pub fn powi(self, other: Self) -> Self {
1528        Self::new(K::powi(self.primitive, other.primitive))
1529    }
1530
1531    /// Applies element wise power operation with a integer scalar
1532    ///
1533    /// # Arguments
1534    ///
1535    /// * `other` - The scalar to apply the power operation with.
1536    ///
1537    /// # Example
1538    ///
1539    /// ```rust
1540    /// use burn_tensor::backend::Backend;
1541    /// use burn_tensor::{Tensor, Shape, Int};
1542    ///
1543    /// fn example<B: Backend>() {
1544    ///    let device = B::Device::default();
1545    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);
1546    ///    let tensor = tensor.powi_scalar(2);
1547    ///    println!("{tensor}");
1548    ///    // [[1, 4, 9], [25, 81, 36]]
1549    /// }
1550    /// ```
1551    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {
1552        Self::new(K::powi_scalar::<E>(self.primitive, other))
1553    }
1554
1555    /// Checks element wise if the tensor is close to another tensor.
1556    ///
1557    /// The tolerance is defined by the following equation:
1558    ///
1559    /// ```text
1560    /// abs(a - b) <= (atol + rtol * abs(b))
1561    ///
1562    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1563    /// and `atol` is the absolute tolerance.
1564    /// ```
1565    ///
1566    /// # Arguments
1567    ///
1568    /// * `other` - The tensor to compare with.
1569    /// * `rtol` - Optional relative tolerance. Default is 1e-5.
1570    /// * `atol` - Optional absolute tolerance. Default is 1e-8.
1571    ///
1572    /// # Returns
1573    ///
1574    /// A boolean tensor with the same shape as the input tensors.
1575    ///
1576    /// # Example
1577    ///
1578    /// ```rust
1579    /// use burn_tensor::backend::Backend;
1580    /// use burn_tensor::{Tensor, Shape};
1581    ///
1582    /// fn example<B: Backend>() {
1583    ///    let device = B::Device::default();
1584    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1585    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1586    ///    let tensor = tensor1.is_close(tensor2, None, None);
1587    ///    println!("{tensor}");
1588    ///    // [[true, true, true], [true, true, true]]
1589    /// }
1590    /// ```
1591    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {
1592        let rtol = rtol.unwrap_or(1e-5);
1593        let atol = atol.unwrap_or(1e-8);
1594
1595        Tensor::new(K::lower_equal(
1596            K::abs(K::sub(self.primitive, other.primitive.clone())),
1597            K::add_scalar(K::mul_scalar(K::abs(other.primitive), rtol), atol),
1598        ))
1599    }
1600
1601    /// Checks if all elements are close to another tensor.
1602    ///
1603    /// The tolerance is defined by the following equation:
1604    ///
1605    /// ```text
1606    ///
1607    /// abs(a - b) <= (atol + rtol * abs(b))
1608    ///
1609    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,
1610    /// and `atol` is the absolute tolerance.
1611    ///
1612    /// ```
1613    ///
1614    /// # Arguments
1615    ///
1616    /// * `other` - The tensor to compare with.
1617    /// * `rtol` - Optional relative tolerance. Default is 1e-5.
1618    /// * `atol` - Optional absolute tolerance. Default is 1e-8.
1619    ///
1620    /// # Returns
1621    ///
1622    /// A boolean scalar.
1623    ///
1624    /// # Remarks
1625    ///
1626    /// # Example
1627    ///
1628    /// ```rust
1629    /// use burn_tensor::backend::Backend;
1630    /// use burn_tensor::{Tensor, Shape};
1631    ///
1632    /// fn example<B: Backend>() {
1633    ///    let device = B::Device::default();
1634    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1635    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
1636    ///    let result = tensor1.all_close(tensor2, None, None);
1637    ///    println!("{}", result);
1638    ///    // true
1639    /// }
1640    /// ```
1641    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {
1642        self.is_close(other, rtol, atol).all().into_scalar()
1643    }
1644
1645    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.
1646    ///
1647    /// # Returns
1648    ///
1649    /// A boolean tensor with the same shape as the input tensor.
1650    ///
1651    /// # Example
1652    ///
1653    /// ```rust
1654    /// use burn_tensor::backend::Backend;
1655    /// use burn_tensor::{Tensor, Shape};
1656    ///
1657    /// fn example<B: Backend>() {
1658    ///   let device = B::Device::default();
1659    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);
1660    ///   let tensor = tensor.bool();
1661    ///   println!("{tensor}");
1662    ///   // [
1663    ///   //   [true, true, true],
1664    ///   //   [false, true, true]
1665    ///   // ]
1666    /// }
1667    pub fn bool(self) -> Tensor<B, D, Bool> {
1668        Tensor::new(K::not_equal_elem(self.primitive, 0.elem()))
1669    }
1670
1671    /// Create a random tensor of the given shape on the given device where each element is
1672    /// sampled from the given distribution.
1673    ///
1674    /// # Arguments
1675    ///
1676    /// * `shape` - The shape of the tensor.
1677    /// * `distribution` - The distribution to sample from.
1678    /// * `device` - The device to create the tensor on.
1679    ///
1680    /// # Returns
1681    ///
1682    /// A new tensor with the given shape and elements sampled from the given distribution.
1683    ///
1684    /// # Example
1685    ///
1686    /// ```rust
1687    /// use burn_tensor::backend::Backend;
1688    /// use burn_tensor::{Tensor, Shape, Distribution};
1689    ///
1690    /// fn example<B: Backend>() {
1691    ///   let device = B::Device::default();
1692    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0
1693    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);
1694    ///   println!("{tensor}");
1695    ///   // [
1696    ///   //   [0.08347523, 0.70498955, 0.60332155],
1697    ///   //   [0.08173251, 0.18028641, 0.97942924]
1698    ///   // ]
1699    /// }
1700    /// ```
1701    pub fn random<S: Into<Shape>>(
1702        shape: S,
1703        distribution: Distribution,
1704        device: &B::Device,
1705    ) -> Self {
1706        Self::new(K::random(shape.into(), distribution, device))
1707    }
1708
1709    /// Sort the elements by value in ascending order along a given dimension.
1710    ///
1711    /// This sort is unstable (i.e., may reorder equal elements).
1712    ///
1713    /// # Arguments
1714    ///
1715    /// * `dim` - The dimension to sort along.
1716    ///
1717    /// # Returns
1718    ///
1719    /// A new tensor with the elements sorted in ascending order along the given dimension.
1720    ///
1721    /// # Example
1722    ///
1723    /// ```rust
1724    /// use burn_tensor::backend::Backend;
1725    /// use burn_tensor::{Tensor, Shape};
1726    ///
1727    /// fn example<B: Backend>() {
1728    ///   let device = B::Device::default();
1729    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1730    ///   let tensor = tensor.sort(0);
1731    ///   println!("{tensor}");
1732    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1733    ///   let tensor = tensor.sort(1);
1734    ///   println!("{tensor}");
1735    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]
1736    /// }
1737    /// ```
1738    pub fn sort(self, dim: usize) -> Tensor<B, D, K> {
1739        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1740        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))
1741    }
1742
1743    /// Sort the elements by value in descending order along a given dimension.
1744    ///
1745    /// This sort is unstable (i.e., may reorder equal elements).
1746    ///
1747    /// # Arguments
1748    ///
1749    /// * `dim` - The dimension to sort along.
1750    ///
1751    /// # Returns
1752    ///
1753    /// A new tensor with the elements sorted in descending order along the given dimension.
1754    ///
1755    /// # Example
1756    ///
1757    /// ```rust
1758    /// use burn_tensor::backend::Backend;
1759    /// use burn_tensor::{Tensor, Shape};
1760    ///
1761    /// fn example<B: Backend>() {
1762    ///    let device = B::Device::default();
1763    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1764    ///    let tensor = tensor.sort_descending(0);
1765    ///    println!("{tensor}");
1766    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1767    ///    let tensor = tensor.sort_descending(1);
1768    ///    println!("{tensor}");
1769    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]
1770    /// }
1771    /// ```
1772    pub fn sort_descending(self, dim: usize) -> Tensor<B, D, K> {
1773        check!(TensorCheck::sort_dim::<D>("Sort", dim));
1774        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))
1775    }
1776
1777    /// Sort the elements by value in ascending order along a given dimension.
1778    /// Also returns the indices.
1779    ///
1780    /// This sort is unstable (i.e., may reorder equal elements).
1781    ///
1782    /// # Arguments
1783    ///
1784    /// * `dim` - The dimension to sort along.
1785    ///
1786    /// # Returns
1787    ///
1788    /// A tuple containing the sorted tensor and the indices tensor.
1789    ///
1790    /// # Example
1791    ///
1792    /// ```rust
1793    /// use burn_tensor::backend::Backend;
1794    /// use burn_tensor::{Tensor, Shape};
1795    ///
1796    /// fn example<B: Backend>() {
1797    ///   let device = B::Device::default();
1798    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1799    ///   let (tensor, indices) = tensor.sort_with_indices(0);
1800    ///   println!("{tensor}");
1801    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]
1802    ///   println!("{}", indices);
1803    ///   // [[1, 0, 0], [0, 1, 1]]
1804    /// }
1805    /// ```
1806    pub fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1807        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1808        let (values, indices) =
1809            K::sort_with_indices(self.primitive, dim, /*descending*/ false);
1810        (Tensor::new(values), Tensor::new(indices))
1811    }
1812
1813    /// Sort the elements by value in descending order along a given dimension.
1814    /// Also returns the indices.
1815    ///
1816    /// This sort is unstable (i.e., may reorder equal elements).
1817    ///
1818    /// # Arguments
1819    ///
1820    /// * `dim` - The dimension to sort along.
1821    ///
1822    /// # Example
1823    ///
1824    /// ```rust
1825    /// use burn_tensor::backend::Backend;
1826    /// use burn_tensor::{Tensor, Shape};
1827    ///
1828    /// fn example<B: Backend>() {
1829    ///    let device = B::Device::default();
1830    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1831    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);
1832    ///    println!("{tensor}");
1833    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1834    ///    println!("{}", indices);
1835    ///    // [[0, 1, 1], [1, 0, 0]]
1836    /// }
1837    /// ```
1838    pub fn sort_descending_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1839        check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
1840        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);
1841        (Tensor::new(values), Tensor::new(indices))
1842    }
1843
1844    /// Returns the indices that sort the elements by value in ascending order along a given dimension.
1845    ///
1846    /// This sort is unstable (i.e., may reorder equal elements).
1847    ///
1848    /// # Arguments
1849    ///
1850    /// * `dim` - The dimension to sort along.
1851    ///
1852    /// # Example
1853    ///
1854    /// ```rust
1855    /// use burn_tensor::backend::Backend;
1856    /// use burn_tensor::{Tensor, Shape};
1857    ///
1858    /// fn example<B: Backend>() {
1859    ///    let device = B::Device::default();
1860    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1861    ///    let tensor = tensor.argsort(0);
1862    ///    println!("{tensor}");
1863    ///    // [[1, 0, 0], [0, 1, 1]]
1864    /// }
1865    /// ```
1866    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
1867        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1868        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))
1869    }
1870
1871    /// Returns the indices that sort the elements by value in descending order along a given dimension.
1872    ///
1873    /// This sort is unstable (i.e., may reorder equal elements).
1874    ///
1875    /// # Arguments
1876    ///
1877    /// * `dim` - The dimension to sort along.
1878    ///
1879    /// # Example
1880    ///
1881    /// ```rust
1882    /// use burn_tensor::backend::Backend;
1883    /// use burn_tensor::{Tensor, Shape};
1884    ///
1885    /// fn example<B: Backend>() {
1886    ///    let device = B::Device::default();
1887    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1888    ///    let tensor = tensor.argsort_descending(0);
1889    ///    println!("{tensor}");
1890    ///    // [[0, 1, 1], [1, 0, 0]]
1891    ///    let tensor = tensor.argsort_descending(1);
1892    ///    println!("{tensor}");
1893    ///    // [[0, 2, 1], [2, 0, 1]]
1894    /// }
1895    /// ```
1896    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
1897        check!(TensorCheck::sort_dim::<D>("Argsort", dim));
1898        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
1899    }
1900
1901    /// Returns the `k` largest elements of the given input tensor along a given dimension.
1902    ///
1903    /// # Arguments
1904    ///
1905    /// * `k` - The number of elements to return.
1906    ///
1907    /// # Returns
1908    ///
1909    /// A new tensor with the `k` largest elements along the given dimension.
1910    ///
1911    /// # Example
1912    ///
1913    /// ```rust
1914    /// use burn_tensor::backend::Backend;
1915    /// use burn_tensor::{Tensor, Shape};
1916    ///
1917    /// fn example<B: Backend>() {
1918    ///   let device = B::Device::default();
1919    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1920    ///   let tensor = tensor.topk(2, 0);
1921    ///   println!("{tensor}");
1922    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1923    ///   let tensor = tensor.topk(1, 1);
1924    ///   println!("{tensor}");   
1925    ///   // [[12.0], [6.0]]
1926    /// }
1927    /// ```
1928    pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
1929        let k_indices = Tensor::arange(0..k as i64, &self.device());
1930        self.sort_descending(dim).select(dim, k_indices)
1931    }
1932
1933    /// Returns the `k` largest elements of the given input tensor along a given dimension.
1934    /// Also returns the indices.
1935    ///
1936    /// # Arguments
1937    ///
1938    /// * `k` - The number of elements to return.
1939    /// * `dim` - The dimension to sort along.
1940    ///
1941    /// # Example
1942    ///
1943    /// ```rust
1944    /// use burn_tensor::backend::Backend;
1945    /// use burn_tensor::{Tensor, Shape};
1946    ///
1947    /// fn example<B: Backend>() {
1948    ///    let device = B::Device::default();
1949    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1950    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);
1951    ///    println!("{tensor}");
1952    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]
1953    ///    println!("{}", indices);
1954    ///    // [[0, 1, 1], [1, 0, 0]]
1955    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);
1956    ///    println!("{tensor}");
1957    ///    // [[12.0], [6.0]]
1958    ///    println!("{indices}");
1959    ///    // [[0], [2]]
1960    /// }
1961    /// ```
1962    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
1963        let k_indices = Tensor::arange(0..k as i64, &self.device());
1964        let (values, indices) = self.sort_descending_with_indices(dim);
1965        (
1966            values.select(dim, k_indices.clone()),
1967            indices.select(dim, k_indices),
1968        )
1969    }
1970
1971    /// Pad the tensor of rank two or higher with the given value on the last two dimensions.
1972    ///
1973    /// # Arguments
1974    ///
1975    /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
1976    /// * `value` - The value to pad the tensor with.
1977    ///
1978    /// # Returns
1979    ///
1980    /// A new tensor with the given padding.
1981    ///
1982    /// # Example
1983    ///
1984    /// ```rust
1985    /// use burn_tensor::backend::Backend;
1986    /// use burn_tensor::{Tensor, Shape};
1987    ///
1988    /// fn example<B: Backend<FloatElem: From<f32>>>() {
1989    ///    let device = B::Device::default();
1990    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
1991    ///    let tensor = tensor.pad((1, 1, 1, 1), 0.0);
1992    ///    println!("{tensor}");
1993    ///    // [
1994    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0],
1995    ///    //   [0.0, 12.0, -2.0, 3.0, 0.0],
1996    ///    //   [0.0, 5.0, 3.0, 6.0, 0.0],
1997    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0]
1998    ///    // ]
1999    /// }
2000    /// ```
2001    pub fn pad<E: ElementConversion>(
2002        self,
2003        padding: (usize, usize, usize, usize),
2004        value: E,
2005    ) -> Tensor<B, D, K> {
2006        let (left, right, top, bottom) = padding;
2007
2008        let mut padded_dims: [usize; D] = self.dims();
2009
2010        // Update the last two dimensions with padding
2011        padded_dims[D - 2] += top + bottom;
2012        padded_dims[D - 1] += left + right;
2013
2014        // Create the ranges for the padded tensor
2015        let ranges: [core::ops::Range<usize>; D] = padded_dims
2016            .iter()
2017            .enumerate()
2018            .map(|(i, &dim)| {
2019                if i == D - 2 {
2020                    top..dim - bottom
2021                } else if i == D - 1 {
2022                    left..dim - right
2023                } else {
2024                    0..dim
2025                }
2026            })
2027            .collect::<Vec<core::ops::Range<usize>>>()
2028            .try_into()
2029            .unwrap();
2030
2031        // Create the padded tensor
2032        let padded_tensor = Tensor::full(padded_dims, value, &self.device());
2033
2034        // Assign the original tensor data to the appropriate slice of the padded tensor
2035        padded_tensor.slice_assign(ranges, self)
2036    }
2037
2038    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
2039    ///
2040    /// # Returns
2041    ///
2042    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
2043    ///
2044    /// # Example
2045    ///
2046    /// ```rust
2047    /// use burn_tensor::backend::Backend;
2048    /// use burn_tensor::{Tensor, Bool, Shape};
2049    ///
2050    /// fn example<B: Backend>() {
2051    ///    let device = B::Device::default();
2052    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);
2053    ///    let tensor = tensor.is_nan();
2054    ///    println!("{tensor}");
2055    ///    // [[false, true, false], [false, false, false]]
2056    /// }
2057    /// ```
2058    pub fn is_nan(&self) -> Tensor<B, D, Bool> {
2059        // Check if the input tensor is NaN by comparing it to itself
2060        // NaN is the only value that is not equal to itself
2061        Tensor::new(K::not_equal(self.primitive.clone(), self.primitive.clone()))
2062    }
2063
2064    /// Checks if the tensor contains any NaN values.
2065    ///
2066    /// # Returns
2067    ///
2068    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
2069    ///
2070    /// # Example
2071    ///
2072    /// ```rust
2073    /// use burn_tensor::backend::Backend;
2074    /// use burn_tensor::{Tensor, Bool, Shape};
2075    ///
2076    /// fn example<B: Backend>() {
2077    ///   let device = B::Device::default();
2078    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);
2079    ///   let tensor = tensor.contains_nan();
2080    ///   println!("{tensor}");
2081    ///   // [true]
2082    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);
2083    ///   let tensor = tensor.contains_nan();
2084    ///   println!("{tensor}");
2085    ///   // [false]
2086    /// }
2087    /// ```
2088    pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
2089        // Summing the tensor will result in NaN if the tensor contains any NaN values
2090        // This is faster than checking each element individually
2091        // because it rolls up the NaN values into a single value
2092        let sum = K::sum(self.primitive.clone());
2093
2094        // Check if the sum is NaN by comparing it to itself
2095        Tensor::new(K::not_equal(sum.clone(), sum))
2096    }
2097}
2098
2099impl<B, K> Tensor<B, 2, K>
2100where
2101    B: Backend,
2102    K: Numeric<B>,
2103    K::Elem: Element,
2104{
2105    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.
2106    ///
2107    /// # Arguments
2108    ///
2109    /// * `size` - The size of the square matrix.
2110    pub fn eye(size: usize, device: &B::Device) -> Self {
2111        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();
2112        let ones = K::ones([1, size].into(), device);
2113        let zeros = K::zeros([size, size].into(), device);
2114        Self::new(K::scatter(0, zeros, indices.primitive, ones))
2115    }
2116}
2117
2118/// Trait that list all operations that can be applied on all numerical tensors.
2119///
2120/// # Warnings
2121///
2122/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
2123pub trait Numeric<B: Backend>: BasicOps<B>
2124where
2125    Self::Elem: Element,
2126{
2127    /// Adds two tensors together.
2128    ///
2129    /// # Arguments
2130    ///
2131    /// * `lhs` - The left hand side tensor.
2132    /// * `rhs` - The right hand side tensor.
2133    ///
2134    /// # Returns
2135    ///
2136    /// The sum of the two tensors.
2137    ///
2138    /// # Remarks
2139    ///
2140    /// This is a low-level function used internally by the library to call different backend functions
2141    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2142    /// or use this function directly.
2143    ///
2144    /// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
2145    /// which is more high-level and designed for public use.
2146    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2147
2148    /// Adds a scalar to a tensor element-wise.
2149    ///
2150    /// # Arguments
2151    ///
2152    /// * `lhs` - The left hand side tensor.
2153    /// * `rhs` - The right hand side scalar.
2154    ///
2155    /// # Returns
2156    ///
2157    /// The sum of the tensor and the scalar.
2158    ///
2159    /// # Remarks
2160    ///
2161    /// This is a low-level function used internally by the library to call different backend functions
2162    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2163    /// or use this function directly.
2164    ///
2165    /// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
2166    /// which is more high-level and designed for public use.
2167    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2168
2169    /// Subtracts two tensors.
2170    ///
2171    /// # Arguments
2172    ///
2173    /// * `lhs` - The left hand side tensor.
2174    /// * `rhs` - The right hand side tensor.
2175    ///
2176    /// # Returns
2177    ///
2178    /// The difference of the two tensors.
2179    ///
2180    /// # Remarks
2181    ///
2182    /// This is a low-level function used internally by the library to call different backend functions
2183    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2184    /// or use this function directly.
2185    ///
2186    /// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
2187    /// which is more high-level and designed for public use.
2188    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2189
2190    /// Subtracts a scalar from a tensor element-wise.
2191    ///
2192    /// # Arguments
2193    ///
2194    /// * `lhs` - The left hand side tensor.
2195    /// * `rhs` - The right hand side scalar.
2196    ///
2197    /// # Returns
2198    ///
2199    /// The difference of the tensor and the scalar.
2200    ///
2201    /// # Remarks
2202    ///
2203    /// This is a low-level function used internally by the library to call different backend functions
2204    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2205    /// or use this function directly.
2206    ///
2207    /// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
2208    /// which is more high-level and designed for public use.
2209    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2210
2211    /// Divides two tensors.
2212    ///
2213    /// # Arguments
2214    ///
2215    /// * `lhs` - The left hand side tensor.
2216    /// * `rhs` - The right hand side tensor.
2217    ///
2218    /// # Returns
2219    ///
2220    /// The quotient of the two tensors.
2221    ///
2222    /// # Remarks
2223    ///
2224    /// This is a low-level function used internally by the library to call different backend functions
2225    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2226    /// or use this function directly.
2227    ///
2228    /// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
2229    /// which is more high-level and designed for public use.
2230    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2231
2232    /// Divides a tensor by a scalar element-wise.
2233    ///
2234    /// # Arguments
2235    ///
2236    /// * `lhs` - The left hand side tensor.
2237    /// * `rhs` - The right hand side scalar.
2238    ///
2239    /// # Returns
2240    ///
2241    /// The quotient of the tensor and the scalar.
2242    ///
2243    /// # Remarks
2244    ///
2245    /// This is a low-level function used internally by the library to call different backend functions
2246    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2247    /// or use this function directly.
2248    ///
2249    /// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
2250    /// which is more high-level and designed for public use.
2251    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2252
2253    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2254    /// less than that of the divisor.
2255    ///
2256    /// # Arguments
2257    ///
2258    /// * `lhs` - The dividend.
2259    /// * `rhs` - The divisor.
2260    ///
2261    /// # Returns
2262    ///
2263    /// The modulo of the input tensor with the divisor.
2264    ///
2265    /// # Remarks
2266    ///
2267    /// This is a low-level function used internally by the library to call different backend functions
2268    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2269    /// or use this function directly.
2270    ///
2271    /// For performing the modulo operation, users should prefer the [Tensor::remainder](Tensor::remainder) function,
2272    /// which is more high-level and designed for public use.
2273    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2274
2275    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is
2276    /// less than that of the divisor.
2277    ///
2278    /// # Arguments
2279    ///
2280    /// * `lhs` - The dividend.
2281    /// * `rhs` - The divisor.
2282    ///
2283    /// # Returns
2284    ///
2285    /// The modulo of the input tensor with the divisor.
2286    ///
2287    /// # Remarks
2288    ///
2289    /// This is a low-level function used internally by the library to call different backend functions
2290    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2291    /// or use this function directly.
2292    ///
2293    /// For performing the modulo operation, users should prefer the [Tensor::remainder_scalar](Tensor::remainder_scalar) function,
2294    /// which is more high-level and designed for public use.
2295    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2296
2297    /// Multiplies two tensors.
2298    ///
2299    /// # Arguments
2300    ///
2301    /// * `lhs` - The left hand side tensor.
2302    /// * `rhs` - The right hand side tensor.
2303    ///
2304    /// # Returns
2305    ///
2306    /// The product of the two tensors.
2307    ///
2308    /// # Remarks
2309    ///
2310    /// This is a low-level function used internally by the library to call different backend functions
2311    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2312    /// or use this function directly.
2313    ///
2314    /// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
2315    /// which is more high-level and designed for public use.
2316    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
2317
2318    /// Multiplies a tensor by a scalar element-wise.
2319    ///
2320    /// # Arguments
2321    ///
2322    /// * `lhs` - The left hand side tensor.
2323    /// * `rhs` - The right hand side scalar.
2324    ///
2325    /// # Returns
2326    ///
2327    /// The product of the tensor and the scalar.
2328    ///
2329    /// # Remarks
2330    ///
2331    /// This is a low-level function used internally by the library to call different backend functions
2332    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2333    /// or use this function directly.
2334    ///
2335    /// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
2336    /// which is more high-level and designed for public use.
2337    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
2338
2339    /// Negates a tensor.
2340    ///
2341    /// # Arguments
2342    ///
2343    /// * `tensor` - The tensor to negate.
2344    ///
2345    /// # Returns
2346    ///
2347    /// The negated tensor.
2348    ///
2349    /// # Remarks
2350    ///
2351    /// This is a low-level function used internally by the library to call different backend functions
2352    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2353    /// or use this function directly.
2354    ///
2355    /// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
2356    /// which is more high-level and designed for public use.
2357    fn neg(tensor: Self::Primitive) -> Self::Primitive;
2358
2359    /// Returns the signs of the elements of a tensor.
2360    ///
2361    /// # Arguments
2362    ///
2363    /// * `tensor` - The tensor.
2364    ///
2365    /// # Returns
2366    ///
2367    /// The signs of the elements of the tensor.
2368    ///
2369    /// # Remarks
2370    ///
2371    /// This is a low-level function used internally by the library to call different backend functions
2372    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2373    /// or use this function directly.
2374    ///
2375    /// For getting the signs of the elements of a tensor, users should prefer the [Tensor::sign](Tensor::sign) function,
2376    /// which is more high-level and designed for public use.
2377    fn sign(tensor: Self::Primitive) -> Self::Primitive;
2378
2379    /// Creates a tensor filled with zeros.
2380    ///
2381    /// # Arguments
2382    ///
2383    /// * `shape` - The shape of the tensor.
2384    /// * `device` - The device on which the tensor will be allocated.
2385    ///
2386    /// # Returns
2387    ///
2388    /// The tensor filled with zeros.
2389    ///
2390    /// # Remarks
2391    ///
2392    /// This is a low-level function used internally by the library to call different backend functions
2393    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2394    /// or use this function directly.
2395    ///
2396    /// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
2397    /// which is more high-level and designed for public use.
2398    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive;
2399
2400    /// Creates a tensor filled with ones.
2401    ///
2402    /// # Arguments
2403    ///
2404    /// * `shape` - The shape of the tensor.
2405    /// * `device` - The device on which the tensor will be allocated.
2406    ///
2407    /// # Returns
2408    ///
2409    /// The tensor filled with ones.
2410    ///
2411    /// # Remarks
2412    ///
2413    /// This is a low-level function used internally by the library to call different backend functions
2414    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2415    /// or use this function directly.
2416    ///
2417    /// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
2418    /// which is more high-level and designed for public use.
2419    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive;
2420
2421    /// Creates a tensor filled with elements equal to the given value.
2422    ///
2423    /// # Arguments
2424    ///
2425    /// * `shape` - The shape of the tensor.
2426    /// * `fill_value` - The value with which to fill the tensor
2427    /// * `device` - The device on which the tensor will be allocated.
2428    ///
2429    /// # Returns
2430    ///
2431    /// The tensor filled with elements equal to the given value
2432    ///
2433    /// # Remarks
2434    ///
2435    /// This is a low-level function used internally by the library to call different backend functions
2436    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2437    /// or use this function directly.
2438    ///
2439    /// For creating a tensor filled with a specific value, users should prefer the [Tensor::full](Tensor::full) function,
2440    /// which is more high-level and designed for public use.
2441    fn full<E: ElementConversion>(
2442        shape: Shape,
2443        fill_value: E,
2444        device: &B::Device,
2445    ) -> Self::Primitive;
2446
2447    /// Sums all the elements of the tensor.
2448    ///
2449    /// # Arguments
2450    ///
2451    /// * `tensor` - The tensor to sum.
2452    ///
2453    /// # Returns
2454    ///
2455    /// The sum of all the elements of the tensor.
2456    ///
2457    /// # Remarks
2458    ///
2459    /// This is a low-level function used internally by the library to call different backend functions
2460    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2461    /// or use this function directly.
2462    ///
2463    /// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
2464    /// which is more high-level and designed for public use.
2465    fn sum(tensor: Self::Primitive) -> Self::Primitive;
2466
2467    /// Sums all the elements of the tensor along a dimension.
2468    ///
2469    /// # Arguments
2470    ///
2471    /// * `tensor` - The tensor to sum.
2472    /// * `dim` - The dimension along which to sum.
2473    ///
2474    /// # Returns
2475    ///
2476    /// The sum of all the elements of the tensor along the specified dimension.
2477    ///
2478    /// # Remarks
2479    ///
2480    /// This is a low-level function used internally by the library to call different backend functions
2481    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2482    /// or use this function directly.
2483    ///
2484    /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
2485    /// which is more high-level and designed for public use.
2486    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2487
2488    /// Computes the product of all the elements of the tensor.
2489    ///
2490    /// # Arguments
2491    ///
2492    /// * `tensor` - The tensor to compute the product of.
2493    ///
2494    /// # Returns
2495    ///
2496    /// The product of all the elements of the tensor.
2497    ///
2498    /// # Remarks
2499    ///
2500    /// This is a low-level function used internally by the library to call different backend functions
2501    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2502    /// or use this function directly.
2503    ///
2504    /// For computing the product of all the elements of a tensor, users should prefer the
2505    /// [Tensor::prod](Tensor::prod) function,
2506    /// which is more high-level and designed for public use.
2507    fn prod(tensor: Self::Primitive) -> Self::Primitive;
2508
2509    /// Computes the product of all the elements of the tensor along a dimension.
2510    ///
2511    /// # Arguments
2512    ///
2513    /// * `tensor` - The tensor to compute the product of.
2514    /// * `dim` - The dimension along which to compute the product.
2515    ///
2516    /// # Returns
2517    ///
2518    /// The product of all the elements of the tensor along the specified dimension.
2519    ///
2520    /// # Remarks
2521    ///
2522    /// This is a low-level function used internally by the library to call different backend functions
2523    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2524    /// or use this function directly.
2525    ///
2526    /// For computing the product of all the elements of a tensor along a dimension, users should
2527    /// prefer the [Tensor::prod_dim](Tensor::prod_dim) function,
2528    /// which is more high-level and designed for public use.
2529    ///
2530    ///
2531    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2532
2533    /// Computes the mean of all the elements of the tensor.
2534    ///
2535    /// # Arguments
2536    ///
2537    /// * `tensor` - The tensor to compute the mean of.
2538    ///
2539    /// # Returns
2540    ///
2541    /// The mean of all the elements of the tensor.
2542    ///
2543    /// # Remarks
2544    ///
2545    /// This is a low-level function used internally by the library to call different backend functions
2546    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2547    /// or use this function directly.
2548    ///
2549    /// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
2550    /// which is more high-level and designed for public use.
2551    fn mean(tensor: Self::Primitive) -> Self::Primitive;
2552
2553    /// Computes the mean of all the elements of the tensor along a dimension.
2554    ///
2555    /// # Arguments
2556    ///
2557    /// * `tensor` - The tensor to compute the mean of.
2558    /// * `dim` - The dimension along which to compute the mean.
2559    ///
2560    /// # Returns
2561    ///
2562    /// The mean of all the elements of the tensor along the specified dimension.
2563    ///
2564    /// # Remarks
2565    ///
2566    /// This is a low-level function used internally by the library to call different backend functions
2567    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2568    /// or use this function directly.
2569    ///
2570    /// For computing the mean of all the elements of a tensor along a dimension, users should prefer
2571    /// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
2572    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
2573
2574    /// Element-wise equality between two tensors.
2575    ///
2576    /// # Arguments
2577    ///
2578    /// * `lhs` - The left hand side tensor.
2579    /// * `rhs` - The right hand side tensor.
2580    ///
2581    /// # Returns
2582    ///
2583    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2584    /// corresponding elements of the input tensors are equal, and false otherwise.
2585    ///
2586    /// # Remarks
2587    ///
2588    /// This is a low-level function used internally by the library to call different backend functions
2589    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2590    /// or use this function directly.
2591    ///
2592    /// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem)
2593    /// function, which is more high-level and designed for public use.
2594    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2595
2596    /// Element-wise non-equality between two tensors.
2597    ///
2598    /// # Arguments
2599    ///
2600    /// * `lhs` - The left hand side tensor.
2601    /// * `rhs` - The right hand side tensor.
2602    ///
2603    /// # Returns
2604    ///
2605    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2606    /// corresponding elements of the input tensors are equal, and false otherwise.
2607    ///
2608    /// # Remarks
2609    ///
2610    /// This is a low-level function used internally by the library to call different backend functions
2611    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2612    /// or use this function directly.
2613    ///
2614    /// For element-wise non-equality between two tensors, users should prefer the [Tensor::not_equal_elem](Tensor::not_equal_elem)
2615    /// function, which is more high-level and designed for public use.
2616    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2617
2618    /// Element-wise greater than comparison between two tensors.
2619    ///
2620    /// # Arguments
2621    ///
2622    /// * `lhs` - The left hand side tensor.
2623    /// * `rhs` - The right hand side tensor.
2624    ///
2625    /// # Returns
2626    ///
2627    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2628    /// corresponding element of the left hand side tensor is greater than the corresponding element
2629    /// of the right hand side tensor, and false otherwise.
2630    ///
2631    /// # Remarks
2632    ///
2633    /// This is a low-level function used internally by the library to call different backend functions
2634    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2635    /// or use this function directly.
2636    ///
2637    /// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
2638    /// which is more high-level and designed for public use.
2639    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2640
2641    /// Element-wise greater than comparison between a tensor and a scalar.
2642    ///
2643    /// # Arguments
2644    ///
2645    /// * `lhs` - The left hand side tensor.
2646    /// * `rhs` - The right hand side scalar.
2647    ///
2648    /// # Returns
2649    ///
2650    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2651    /// corresponding element of the left hand side tensor is greater than the right hand side
2652    /// scalar, and false otherwise.
2653    ///
2654    /// # Remarks
2655    ///
2656    /// This is a low-level function used internally by the library to call different backend functions
2657    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2658    /// or use this function directly.
2659    ///
2660    /// For element-wise greater than comparison between a tensor and a scalar, users should prefer
2661    /// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
2662    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2663
2664    /// Element-wise greater than or equal comparison between two tensors.
2665    ///
2666    /// # Arguments
2667    ///
2668    /// * `lhs` - The left hand side tensor.
2669    /// * `rhs` - The right hand side tensor.
2670    ///
2671    /// # Returns
2672    ///
2673    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2674    /// corresponding element of the left hand side tensor is greater than or equal to the
2675    /// corresponding element of the right hand side tensor, and false otherwise.
2676    ///
2677    /// # Remarks
2678    ///
2679    /// This is a low-level function used internally by the library to call different backend functions
2680    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2681    /// or use this function directly.
2682    ///
2683    /// For element-wise greater than or equal comparison between two tensors, users should prefer
2684    /// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
2685    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2686
2687    /// Element-wise greater than or equal comparison between a tensor and a scalar.
2688    ///
2689    /// # Arguments
2690    ///
2691    /// * `lhs` - The left hand side tensor.
2692    /// * `rhs` - The right hand side scalar.
2693    ///
2694    /// # Returns
2695    ///
2696    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2697    /// corresponding element of the left hand side tensor is greater than or equal to the right
2698    /// hand side scalar, and false otherwise.
2699    ///
2700    /// # Remarks
2701    ///
2702    /// This is a low-level function used internally by the library to call different backend functions
2703    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2704    /// or use this function directly.
2705    ///
2706    /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
2707    /// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
2708    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2709
2710    /// Element-wise less than comparison between two tensors.
2711    ///
2712    /// # Arguments
2713    ///
2714    /// * `lhs` - The left hand side tensor.
2715    /// * `rhs` - The right hand side tensor.
2716    ///
2717    /// # Returns
2718    ///
2719    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2720    /// corresponding element of the left hand side tensor is less than the corresponding element of
2721    /// the right hand side tensor, and false otherwise.
2722    ///
2723    /// # Remarks
2724    ///
2725    /// This is a low-level function used internally by the library to call different backend functions
2726    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2727    /// or use this function directly.
2728    ///
2729    /// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
2730    /// which is more high-level and designed for public use.
2731    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2732
2733    /// Element-wise less than comparison between a tensor and a scalar.
2734    ///
2735    /// # Arguments
2736    ///
2737    /// * `lhs` - The left hand side tensor.
2738    /// * `rhs` - The right hand side scalar.
2739    ///
2740    /// # Returns
2741    ///
2742    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2743    /// corresponding element of the left hand side tensor is less than the right hand side scalar,
2744    /// and false otherwise.
2745    ///
2746    /// # Remarks
2747    ///
2748    /// This is a low-level function used internally by the library to call different backend functions
2749    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2750    /// or use this function directly.
2751    ///
2752    /// For element-wise less than comparison between a tensor and a scalar, users should prefer
2753    /// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
2754    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2755
2756    /// Element-wise less than or equal comparison between two tensors.
2757    ///
2758    /// # Arguments
2759    ///
2760    /// * `lhs` - The left hand side tensor.
2761    /// * `rhs` - The right hand side tensor.
2762    ///
2763    /// # Returns
2764    ///
2765    /// A boolean tensor with the same shape as the input tensors, where each element is true if the
2766    /// corresponding element of the left hand side tensor is less than or equal to the corresponding
2767    /// element of the right hand side tensor, and false otherwise.
2768    ///
2769    /// # Remarks
2770    ///
2771    /// This is a low-level function used internally by the library to call different backend functions
2772    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2773    /// or use this function directly.
2774    ///
2775    /// For element-wise less than or equal comparison between two tensors, users should prefer
2776    /// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
2777    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;
2778
2779    /// Element-wise less than or equal comparison between a tensor and a scalar.
2780    ///
2781    /// # Arguments
2782    ///
2783    /// * `lhs` - The left hand side tensor.
2784    /// * `rhs` - The right hand side scalar.
2785    ///
2786    /// # Returns
2787    ///
2788    /// A boolean tensor with the same shape as the input tensor, where each element is true if the
2789    /// corresponding element of the left hand side tensor is less than or equal to the right hand
2790    /// side scalar, and false otherwise.
2791    ///
2792    /// # Remarks
2793    ///
2794    /// This is a low-level function used internally by the library to call different backend functions
2795    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2796    /// or use this function directly.
2797    ///
2798    /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
2799    /// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
2800    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive;
2801
2802    /// Selects elements from a tensor based on a boolean mask.
2803    ///
2804    /// # Arguments
2805    ///
2806    /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
2807    /// * `mask` - The boolean mask to use for selecting elements.
2808    /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
2809    ///
2810    /// # Returns
2811    ///
2812    /// A tensor with the same shape as the input tensors, where each element is taken from the
2813    /// corresponding element of the left hand side tensor if the corresponding element of the mask
2814    /// is true, and from the corresponding element of the right hand side tensor otherwise.
2815    ///
2816    /// # Remarks
2817    ///
2818    /// This is a low-level function used internally by the library to call different backend functions
2819    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2820    /// or use this function directly.
2821    ///
2822    /// For selecting elements from a tensor based on a boolean mask, users should prefer the
2823    /// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
2824    fn mask_where(
2825        tensor: Self::Primitive,
2826        mask: B::BoolTensorPrimitive,
2827        source: Self::Primitive,
2828    ) -> Self::Primitive;
2829
2830    /// Fills elements of a tensor based on a boolean mask.
2831    ///
2832    /// # Arguments
2833    ///
2834    /// * `tensor` - The tensor where will be overwritten with the value
2835    ///              when the corresponding element of the mask is true.
2836    /// * `mask` - The boolean mask to use for filling elements.
2837    /// * `value` - The value to fill elements with when the corresponding element of the mask is true.
2838    ///
2839    /// # Returns
2840    ///
2841    /// A tensor with the same shape as the input tensors, where each element is taken from the
2842    /// corresponding element unmodified if the corresponding element of the mask is false, and
2843    /// filled with the value otherwise.
2844    ///
2845    /// # Remarks
2846    ///
2847    /// This is a low-level function used internally by the library to call different backend functions
2848    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2849    /// or use this function directly.
2850    ///
2851    /// For filling elements of a tensor based on a boolean mask, users should prefer the
2852    /// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
2853    fn mask_fill(
2854        tensor: Self::Primitive,
2855        mask: B::BoolTensorPrimitive,
2856        value: Self::Elem,
2857    ) -> Self::Primitive;
2858
2859    /// Gathers elements from a tensor along an axis.
2860    ///
2861    /// # Arguments
2862    ///
2863    /// * `dim` - The axis along which to gather elements.
2864    /// * `tensor` - The tensor to gather elements from.
2865    /// * `indices` - The indices of the elements to gather.
2866    ///
2867    /// # Returns
2868    ///
2869    /// A tensor with the same shape as the input tensor, where each element is taken from the
2870    /// corresponding element of the input tensor at the corresponding index along the specified axis.
2871    ///
2872    /// # Remarks
2873    ///
2874    /// This is a low-level function used internally by the library to call different backend functions
2875    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2876    /// or use this function directly.
2877    ///
2878    /// For gathering elements from a tensor along an axis, users should prefer the
2879    /// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
2880    fn gather(
2881        dim: usize,
2882        tensor: Self::Primitive,
2883        indices: B::IntTensorPrimitive,
2884    ) -> Self::Primitive;
2885
2886    /// Scatters elements into a tensor along an axis.
2887    ///
2888    /// # Arguments
2889    ///
2890    /// * `dim` - The axis along which to scatter elements.
2891    /// * `tensor` - The tensor to scatter elements into.
2892    /// * `indices` - The indices of the elements to scatter.
2893    /// * `values` - The values to scatter into the tensor.
2894    ///
2895    /// # Returns
2896    ///
2897    /// A tensor with the same shape as the input tensor, where each element is taken from the
2898    /// corresponding element of the input tensor at the corresponding index along the specified axis,
2899    /// except for the elements at the specified indices, which are taken from the corresponding
2900    /// element of the values tensor.
2901    ///
2902    /// # Remarks
2903    ///
2904    /// This is a low-level function used internally by the library to call different backend functions
2905    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2906    /// or use this function directly.
2907    ///
2908    /// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
2909    /// which is more high-level and designed for public use.
2910    fn scatter(
2911        dim: usize,
2912        tensor: Self::Primitive,
2913        indices: B::IntTensorPrimitive,
2914        values: Self::Primitive,
2915    ) -> Self::Primitive;
2916
2917    /// Select tensor elements along the given dimension corresponding for the given indices.
2918    ///
2919    /// # Arguments
2920    ///
2921    /// * `tensor` - The tensor to select elements from.
2922    /// * `dim` - The axis along which to select elements.
2923    /// * `indices` - The indices of the elements to select.
2924    ///
2925    /// # Returns
2926    ///
2927    /// A tensor with the same shape as the input tensor, where each element is taken from the
2928    /// corresponding element of the input tensor at the corresponding index along the specified axis.
2929    ///
2930    /// # Remarks
2931    ///
2932    /// This is a low-level function used internally by the library to call different backend functions
2933    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2934    /// or use this function directly.
2935    ///
2936    /// For selecting elements from a tensor along an axis, users should prefer the
2937    /// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
2938    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive;
2939
2940    /// Assign the selected elements along the given dimension corresponding to the given indices
2941    /// from the value tensor.
2942    ///
2943    /// # Arguments
2944    ///
2945    /// * `tensor` - The tensor to assign elements to.
2946    /// * `dim` - The axis along which to assign elements.
2947    /// * `indices` - The indices of the elements to assign.
2948    /// * `values` - The values to assign to the tensor.
2949    ///
2950    /// # Returns
2951    ///
2952    /// A tensor with the same shape as the input tensor, where each element is taken from the
2953    /// corresponding element of the input tensor at the corresponding index along the specified axis,
2954    /// except for the elements at the specified indices, which are taken from the corresponding
2955    /// element of the values tensor.
2956    ///
2957    /// # Remarks
2958    ///
2959    /// This is a low-level function used internally by the library to call different backend functions
2960    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2961    /// or use this function directly.
2962    ///
2963    /// For assigning elements to a tensor along an axis, users should prefer the
2964    /// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
2965    fn select_assign(
2966        tensor: Self::Primitive,
2967        dim: usize,
2968        indices: Tensor<B, 1, Int>,
2969        values: Self::Primitive,
2970    ) -> Self::Primitive;
2971
2972    /// Gets the indices of the maximum elements of a tensor along an axis.
2973    ///
2974    /// # Arguments
2975    ///
2976    /// * `dim` - The axis along which to get the indices of the maximum elements.
2977    /// * `tensor` - The tensor to get the indices of the maximum elements from.
2978    ///
2979    /// # Returns
2980    ///
2981    /// A tensor with the same shape as the input tensor, where each element is the index of the
2982    /// maximum element of the input tensor at the corresponding index along the specified axis.
2983    ///
2984    /// # Remarks
2985    ///
2986    /// This is a low-level function used internally by the library to call different backend functions
2987    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
2988    /// or use this function directly.
2989    ///
2990    /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
2991    /// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
2992    fn argmax(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
2993
2994    /// Gets the indices of the minimum elements of a tensor along an axis.
2995    ///
2996    /// # Arguments
2997    ///
2998    /// * `dim` - The axis along which to get the indices of the minimum elements.
2999    /// * `tensor` - The tensor to get the indices of the minimum elements from.
3000    ///
3001    /// # Returns
3002    ///
3003    /// A tensor with the same shape as the input tensor, where each element is the index of the
3004    /// minimum element of the input tensor at the corresponding index along the specified axis.
3005    ///
3006    /// # Remarks
3007    ///
3008    /// This is a low-level function used internally by the library to call different backend functions
3009    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3010    /// or use this function directly.
3011    ///
3012    /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
3013    /// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
3014    fn argmin(tensor: Self::Primitive, dim: usize) -> B::IntTensorPrimitive;
3015
3016    /// Gets the maximum elements of a tensor along an axis.
3017    ///
3018    /// # Arguments
3019    ///
3020    /// * `dim` - The axis along which to get the maximum elements.
3021    ///
3022    /// # Returns
3023    ///
3024    /// A single-element tensor containing the maximum element of the input tensor.
3025    ///
3026    /// # Remarks
3027    ///
3028    /// This is a low-level function used internally by the library to call different backend functions
3029    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3030    /// or use this function directly.
3031    ///
3032    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3033    /// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
3034    fn max(tensor: Self::Primitive) -> Self::Primitive;
3035
3036    /// Gets the maximum elements of a tensor along an axis.
3037    ///
3038    /// # Arguments
3039    ///
3040    /// * `tensor` - The tensor to get the maximum elements from.
3041    /// * `dim` - The axis along which to get the maximum elements.
3042    ///
3043    /// # Returns
3044    ///
3045    /// A tensor with the same shape as the input tensor, where each element is the maximum element
3046    ///
3047    /// # Remarks
3048    ///
3049    /// This is a low-level function used internally by the library to call different backend functions
3050    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3051    /// or use this function directly.
3052    ///
3053    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3054    /// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
3055    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3056
3057    /// Gets the maximum elements of a tensor along an axis.
3058    ///
3059    /// # Arguments
3060    ///
3061    /// * `tensor` - The tensor to get the maximum elements from.
3062    /// * `dim` - The axis along which to get the maximum elements.
3063    ///
3064    /// # Returns
3065    ///
3066    /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
3067    /// as the input tensor, where each element is the index of the maximum element of the input tensor
3068    /// at the corresponding index along the specified axis.
3069    ///
3070    /// # Remarks
3071    ///
3072    /// This is a low-level function used internally by the library to call different backend functions
3073    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3074    /// or use this function directly.
3075    ///
3076    /// For getting the maximum elements of a tensor along an axis, users should prefer the
3077    /// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
3078    fn max_dim_with_indices(
3079        tensor: Self::Primitive,
3080        dim: usize,
3081    ) -> (Self::Primitive, B::IntTensorPrimitive);
3082
3083    /// Gets the minimum elements of a tensor along an axis.
3084    ///
3085    /// # Arguments
3086    ///
3087    /// * `tensor` - The tensor to get the minimum elements from.
3088    ///
3089    /// # Returns
3090    ///
3091    /// A single-element tensor containing the minimum element of the input tensor.
3092    ///
3093    /// # Remarks
3094    ///
3095    /// This is a low-level function used internally by the library to call different backend functions
3096    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3097    /// or use this function directly.
3098    ///
3099    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3100    /// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
3101    fn min(tensor: Self::Primitive) -> Self::Primitive;
3102
3103    /// Gets the minimum elements of a tensor along an axis.
3104    ///
3105    /// # Arguments
3106    ///
3107    /// * `tensor` - The tensor to get the minimum elements from.
3108    /// * `dim` - The axis along which to get the minimum elements.
3109    ///
3110    /// # Returns
3111    ///
3112    /// A tensor with the same shape as the input tensor, where each element is the minimum element
3113    /// of the input tensor at the corresponding index along the specified axis.
3114    ///
3115    /// # Remarks
3116    ///
3117    /// This is a low-level function used internally by the library to call different backend functions
3118    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3119    /// or use this function directly.
3120    ///
3121    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3122    /// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
3123    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;
3124
3125    /// Gets the minimum elements and indices of a tensor along an axis.
3126    ///
3127    /// # Arguments
3128    ///
3129    /// * `tensor` - The tensor to get the minimum elements from.
3130    ///
3131    /// # Returns
3132    ///
3133    /// A tensor with the same shape as the input tensor and corresponding indices, where
3134    /// each element is the minimum element of the input tensor at the corresponding index
3135    /// along the specified axis.
3136    ///
3137    /// # Remarks
3138    ///
3139    /// This is a low-level function used internally by the library to call different backend functions
3140    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3141    /// or use this function directly.
3142    ///
3143    /// For getting the minimum elements of a tensor along an axis, users should prefer the
3144    /// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
3145    fn min_dim_with_indices(
3146        tensor: Self::Primitive,
3147        dim: usize,
3148    ) -> (Self::Primitive, B::IntTensorPrimitive);
3149
3150    /// Clamp the tensor between the given min and max values.
3151    ///
3152    /// # Arguments
3153    ///
3154    /// * `min` - The minimum value.
3155    /// * `max` - The maximum value.
3156    ///
3157    /// # Returns
3158    ///
3159    /// A new tensor with the values clamped between the given min and max values.
3160    ///
3161    /// # Remarks
3162    ///
3163    /// This is a low-level function used internally by the library to call different backend functions
3164    /// with static dispatch. It is not designed for direct usage by users.
3165    ///
3166    /// For clamping a tensor between the given min and max values, users should prefer the
3167    /// [Tensor::clamp](Tensor::clamp) function, which is more high-level and designed for public use.
3168    fn clamp(tensor: Self::Primitive, min: Self::Elem, max: Self::Elem) -> Self::Primitive;
3169
3170    /// Clamps a tensor under a minimum value.
3171    ///
3172    /// # Arguments
3173    ///
3174    /// * `tensor` - The tensor to clamp.
3175    /// * `min` - The minimum value.
3176    ///
3177    /// # Returns
3178    ///
3179    /// A new tensor with the values clamped under the given min value.
3180    ///
3181    /// # Remarks
3182    ///
3183    /// This is a low-level function used internally by the library to call different backend functions
3184    /// with static dispatch. It is not designed for direct usage by users.
3185    ///
3186    /// For clamping a tensor under a minimum value, users should prefer the
3187    /// [Tensor::clamp_min](Tensor::clamp_min) function, which is more high-level and designed for public use.
3188    fn clamp_min(tensor: Self::Primitive, min: Self::Elem) -> Self::Primitive;
3189
3190    /// Clamps a tensor over a maximum value.
3191    ///
3192    /// # Arguments
3193    ///
3194    /// * `tensor` - The tensor to clamp.
3195    /// * `max` - The maximum value.
3196    ///
3197    /// # Returns
3198    ///
3199    /// A new tensor with the values clamped over the given max value.
3200    ///
3201    /// # Remarks
3202    ///
3203    /// This is a low-level function used internally by the library to call different backend functions
3204    /// with static dispatch. It is not designed for direct usage by users.
3205    ///
3206    /// For clamping a tensor over a maximum value, users should prefer the
3207    /// [Tensor::clamp_max](Tensor::clamp_max) function, which is more high-level and designed for public use.
3208    fn clamp_max(tensor: Self::Primitive, max: Self::Elem) -> Self::Primitive;
3209
3210    /// Calculate absolute value on all elements of a tensor
3211    ///
3212    /// # Arguments
3213    ///
3214    /// * `tensor` - The tensor to apply abs to.
3215    ///
3216    /// # Returns
3217    ///
3218    /// A tensor with absolute values.
3219    ///
3220    /// # Remarks
3221    ///
3222    /// This is a low-level function used internally by the library to call different backend functions
3223    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3224    /// or use this function directly.
3225    ///
3226    /// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function,
3227    /// which is more high-level and designed for public use.
3228    fn abs(tensor: Self::Primitive) -> Self::Primitive;
3229
3230    /// Element-wise power of a tensor to a float tensor
3231    ///
3232    /// # Arguments
3233    /// * `tensor` - The tensor to apply power to.
3234    /// * `power` - The power to apply to the tensor.
3235    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3236
3237    /// Element-wise power of a tensor
3238    ///
3239    /// # Arguments
3240    /// * `tensor` - The tensor to apply power to.
3241    /// * `power` - The power to apply to the tensor.
3242    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;
3243
3244    /// Element-wise power of a tensor to a scalar float
3245    ///
3246    /// # Arguments
3247    /// * `tensor` - The tensor to apply power to.
3248    /// * `power` - The power to apply to the tensor.
3249    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3250
3251    /// Element-wise power of a tensor to a scalar int
3252    ///
3253    /// # Arguments
3254    /// * `tensor` - The tensor to apply power to.
3255    /// * `power` - The power to apply to the tensor.
3256    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;
3257
3258    /// Create a random tensor.
3259    ///
3260    /// # Arguments
3261    ///
3262    /// * `shape` - The shape of the output tensor.
3263    /// * `distribution` - The distribution used to sample.
3264    /// * `device` - The device to use.
3265    ///
3266    /// # Returns
3267    ///
3268    /// A new tensor.
3269    ///
3270    /// # Remarks
3271    ///
3272    /// This is a low-level function used internally by the library to call different backend functions
3273    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3274    /// or use this function directly.
3275    ///
3276    /// Users should prefer the [Tensor::random](Tensor::random) function,
3277    /// which is more high-level and designed for public use.
3278    fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;
3279
3280    /// Sort the elements of the input `tensor` by value along a given dimension.
3281    ///
3282    /// This sort is unstable (i.e., may reorder equal elements).
3283    ///
3284    /// # Arguments
3285    ///
3286    /// * `tensor` - The input tensor.
3287    /// * `dim` - The axis along which to sort.
3288    /// * `descending` - The sorting order.
3289    ///
3290    /// # Returns
3291    ///
3292    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.
3293    ///
3294    /// # Remarks
3295    /// This is a low-level function used internally by the library to call different backend functions
3296    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3297    /// or use this function directly.
3298    ///
3299    /// Users should prefer the [Tensor::sort](Tensor::sort) function,
3300    /// which is more high-level and designed for public use.
3301    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;
3302
3303    /// Sort the elements of the input `tensor` by value along a given dimension.
3304    ///
3305    /// This sort is unstable (i.e., may reorder equal elements).
3306    ///
3307    /// # Arguments
3308    ///
3309    /// * `tensor` - The input tensor.
3310    /// * `dim` - The axis along which to sort.
3311    /// * `descending` - The sorting order.
3312    ///
3313    /// # Returns
3314    ///
3315    /// A tensor with the same shape as the input tensor and corresponding indices, where
3316    /// the elements are sorted by value and the indices map back to the original input tensor.
3317    ///
3318    /// # Remarks
3319    /// This is a low-level function used internally by the library to call different backend functions
3320    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3321    /// or use this function directly.
3322    ///
3323    /// For sorting the elements of a tensor, users should prefer the
3324    /// [Tensor::sort_with_indices](Tensor::sort_with_indices) function, which is more high-level
3325    /// and designed for public use.
3326    fn sort_with_indices(
3327        tensor: Self::Primitive,
3328        dim: usize,
3329        descending: bool,
3330    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive);
3331
3332    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.
3333    ///
3334    /// This sort is unstable (i.e., may reorder equal elements).
3335    ///
3336    /// # Arguments
3337    ///
3338    /// * `tensor` - The input tensor.
3339    /// * `dim` - The axis along which to sort.
3340    /// * `descending` - The sorting order.
3341    ///
3342    /// # Returns
3343    ///
3344    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
3345    ///
3346    /// # Remarks
3347    /// This is a low-level function used internally by the library to call different backend functions
3348    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
3349    /// or use this function directly.
3350    ///
3351    /// Users should prefer the [Tensor::argsort](Tensor::argsort) function,
3352    /// which is more high-level and designed for public use.
3353    fn argsort(
3354        tensor: Self::Primitive,
3355        dim: usize,
3356        descending: bool,
3357    ) -> <Int as TensorKind<B>>::Primitive;
3358}
3359
3360impl<B: Backend> Numeric<B> for Int {
3361    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3362        B::int_add(lhs, rhs)
3363    }
3364    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3365        B::int_add_scalar(lhs, rhs.elem())
3366    }
3367    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3368        B::int_sub(lhs, rhs)
3369    }
3370    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3371        B::int_sub_scalar(lhs, rhs.elem())
3372    }
3373    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3374        B::int_div(lhs, rhs)
3375    }
3376    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3377        B::int_div_scalar(lhs, rhs.elem())
3378    }
3379    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3380        B::int_remainder(lhs, rhs)
3381    }
3382    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3383        B::int_remainder_scalar(lhs, rhs.elem())
3384    }
3385    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Int as TensorKind<B>>::Primitive {
3386        B::int_mul(lhs, rhs)
3387    }
3388    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3389        B::int_mul_scalar(lhs, rhs.elem())
3390    }
3391    fn neg(tensor: Self::Primitive) -> Self::Primitive {
3392        B::int_neg(tensor)
3393    }
3394    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3395        B::int_zeros(shape, device)
3396    }
3397    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3398        B::int_ones(shape, device)
3399    }
3400    fn full<E: ElementConversion>(
3401        shape: Shape,
3402        fill_value: E,
3403        device: &B::Device,
3404    ) -> Self::Primitive {
3405        B::int_full(shape, fill_value.elem(), device)
3406    }
3407
3408    fn sum(tensor: Self::Primitive) -> Self::Primitive {
3409        B::int_sum(tensor)
3410    }
3411
3412    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3413        B::int_sum_dim(tensor, dim)
3414    }
3415
3416    fn prod(tensor: Self::Primitive) -> Self::Primitive {
3417        B::int_prod(tensor)
3418    }
3419
3420    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3421        B::int_prod_dim(tensor, dim)
3422    }
3423
3424    fn mean(tensor: Self::Primitive) -> Self::Primitive {
3425        B::int_mean(tensor)
3426    }
3427    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3428        B::int_mean_dim(tensor, dim)
3429    }
3430
3431    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3432        B::int_equal_elem(lhs, rhs)
3433    }
3434    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3435        B::int_not_equal_elem(lhs, rhs)
3436    }
3437    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3438        B::int_greater(lhs, rhs)
3439    }
3440
3441    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3442        B::int_greater_elem(lhs, rhs)
3443    }
3444
3445    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3446        B::int_greater_equal(lhs, rhs)
3447    }
3448
3449    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3450        B::int_greater_equal_elem(lhs, rhs)
3451    }
3452
3453    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3454        B::int_lower(lhs, rhs)
3455    }
3456
3457    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3458        B::int_lower_elem(lhs, rhs)
3459    }
3460
3461    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3462        B::int_lower_equal(lhs, rhs)
3463    }
3464
3465    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3466        B::int_lower_equal_elem(lhs, rhs)
3467    }
3468
3469    fn mask_where(
3470        tensor: Self::Primitive,
3471        mask: B::BoolTensorPrimitive,
3472        source: Self::Primitive,
3473    ) -> Self::Primitive {
3474        B::int_mask_where(tensor, mask, source)
3475    }
3476
3477    fn mask_fill(
3478        tensor: Self::Primitive,
3479        mask: B::BoolTensorPrimitive,
3480        value: Self::Elem,
3481    ) -> Self::Primitive {
3482        B::int_mask_fill(tensor, mask, value)
3483    }
3484
3485    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3486        B::int_select(tensor, dim, indices.primitive)
3487    }
3488
3489    fn select_assign(
3490        tensor: Self::Primitive,
3491        dim: usize,
3492        indices: Tensor<B, 1, Int>,
3493        values: Self::Primitive,
3494    ) -> Self::Primitive {
3495        B::int_select_assign(tensor, dim, indices.primitive, values)
3496    }
3497    fn gather(
3498        dim: usize,
3499        tensor: Self::Primitive,
3500        indices: B::IntTensorPrimitive,
3501    ) -> Self::Primitive {
3502        B::int_gather(dim, tensor, indices)
3503    }
3504
3505    fn scatter(
3506        dim: usize,
3507        tensor: Self::Primitive,
3508        indices: B::IntTensorPrimitive,
3509        values: Self::Primitive,
3510    ) -> Self::Primitive {
3511        B::int_scatter(dim, tensor, indices, values)
3512    }
3513
3514    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3515        B::int_argmax(tensor, dim)
3516    }
3517
3518    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3519        B::int_argmin(tensor, dim)
3520    }
3521
3522    fn max(tensor: Self::Primitive) -> Self::Primitive {
3523        B::int_max(tensor)
3524    }
3525
3526    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3527        B::int_max_dim(tensor, dim)
3528    }
3529
3530    fn max_dim_with_indices(
3531        tensor: Self::Primitive,
3532        dim: usize,
3533    ) -> (Self::Primitive, IntTensor<B>) {
3534        B::int_max_dim_with_indices(tensor, dim)
3535    }
3536
3537    fn min(tensor: Self::Primitive) -> Self::Primitive {
3538        B::int_min(tensor)
3539    }
3540
3541    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3542        B::int_min_dim(tensor, dim)
3543    }
3544
3545    fn min_dim_with_indices(
3546        tensor: Self::Primitive,
3547        dim: usize,
3548    ) -> (Self::Primitive, IntTensor<B>) {
3549        B::int_min_dim_with_indices(tensor, dim)
3550    }
3551
3552    fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
3553        B::int_clamp(tensor, min, max)
3554    }
3555
3556    fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
3557        B::int_clamp_min(tensor, min)
3558    }
3559
3560    fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
3561        B::int_clamp_max(tensor, max)
3562    }
3563
3564    fn abs(tensor: Self::Primitive) -> Self::Primitive {
3565        B::int_abs(tensor)
3566    }
3567
3568    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3569        B::int_powf(lhs, B::int_into_float(rhs))
3570    }
3571
3572    fn powf_scalar<E: ElementConversion>(
3573        lhs: Self::Primitive,
3574        rhs: E,
3575    ) -> <Int as TensorKind<B>>::Primitive {
3576        B::int_powf_scalar(lhs, rhs.elem())
3577    }
3578
3579    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
3580        B::int_powi(lhs, rhs)
3581    }
3582
3583    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3584        B::int_powf_scalar(lhs, rhs.elem())
3585    }
3586
3587    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
3588        B::int_random(shape, distribution, device)
3589    }
3590
3591    fn sign(tensor: Self::Primitive) -> Self::Primitive {
3592        B::int_sign(tensor)
3593    }
3594
3595    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
3596        B::int_sort(tensor, dim, descending)
3597    }
3598
3599    fn sort_with_indices(
3600        tensor: Self::Primitive,
3601        dim: usize,
3602        descending: bool,
3603    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
3604        B::int_sort_with_indices(tensor, dim, descending)
3605    }
3606
3607    fn argsort(
3608        tensor: Self::Primitive,
3609        dim: usize,
3610        descending: bool,
3611    ) -> <Int as TensorKind<B>>::Primitive {
3612        B::int_argsort(tensor, dim, descending)
3613    }
3614}
3615
3616impl<B: Backend> Numeric<B> for Float {
3617    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3618        match (lhs, rhs) {
3619            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3620                TensorPrimitive::Float(B::float_add(lhs, rhs))
3621            }
3622            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3623                TensorPrimitive::QFloat(B::q_add(lhs, rhs))
3624            }
3625            _ => panic!("Primitive type mismatch for lhs and rhs"),
3626        }
3627    }
3628    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3629        match lhs {
3630            TensorPrimitive::Float(lhs) => {
3631                TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem()))
3632            }
3633            TensorPrimitive::QFloat(lhs) => {
3634                TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem()))
3635            }
3636        }
3637    }
3638    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3639        match (lhs, rhs) {
3640            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3641                TensorPrimitive::Float(B::float_sub(lhs, rhs))
3642            }
3643            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3644                TensorPrimitive::QFloat(B::q_sub(lhs, rhs))
3645            }
3646            _ => panic!("Primitive type mismatch for lhs and rhs"),
3647        }
3648    }
3649    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3650        match lhs {
3651            TensorPrimitive::Float(lhs) => {
3652                TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem()))
3653            }
3654            TensorPrimitive::QFloat(lhs) => {
3655                TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem()))
3656            }
3657        }
3658    }
3659    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3660        match (lhs, rhs) {
3661            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3662                TensorPrimitive::Float(B::float_div(lhs, rhs))
3663            }
3664            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3665                TensorPrimitive::QFloat(B::q_div(lhs, rhs))
3666            }
3667            _ => panic!("Primitive type mismatch for lhs and rhs"),
3668        }
3669    }
3670    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3671        match lhs {
3672            TensorPrimitive::Float(lhs) => {
3673                TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem()))
3674            }
3675            TensorPrimitive::QFloat(lhs) => {
3676                TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem()))
3677            }
3678        }
3679    }
3680    fn remainder(
3681        lhs: Self::Primitive,
3682        rhs: Self::Primitive,
3683    ) -> <Float as TensorKind<B>>::Primitive {
3684        match (lhs, rhs) {
3685            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3686                TensorPrimitive::Float(B::float_remainder(lhs, rhs))
3687            }
3688            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3689                TensorPrimitive::QFloat(B::q_remainder(lhs, rhs))
3690            }
3691            _ => panic!("Primitive type mismatch for lhs and rhs"),
3692        }
3693    }
3694    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3695        match lhs {
3696            TensorPrimitive::Float(lhs) => {
3697                TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem()))
3698            }
3699            TensorPrimitive::QFloat(lhs) => {
3700                TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem()))
3701            }
3702        }
3703    }
3704    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> <Float as TensorKind<B>>::Primitive {
3705        match (lhs, rhs) {
3706            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
3707                TensorPrimitive::Float(B::float_mul(lhs, rhs))
3708            }
3709            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
3710                TensorPrimitive::QFloat(B::q_mul(lhs, rhs))
3711            }
3712            _ => panic!("Primitive type mismatch for lhs and rhs"),
3713        }
3714    }
3715    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
3716        match lhs {
3717            TensorPrimitive::Float(lhs) => {
3718                TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem()))
3719            }
3720            TensorPrimitive::QFloat(lhs) => {
3721                TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem()))
3722            }
3723        }
3724    }
3725    fn neg(tensor: Self::Primitive) -> Self::Primitive {
3726        match tensor {
3727            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),
3728            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)),
3729        }
3730    }
3731    fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive {
3732        TensorPrimitive::Float(B::float_zeros(shape, device))
3733    }
3734    fn ones(shape: Shape, device: &B::Device) -> Self::Primitive {
3735        TensorPrimitive::Float(B::float_ones(shape, device))
3736    }
3737
3738    fn full<E: ElementConversion>(
3739        shape: Shape,
3740        fill_value: E,
3741        device: &B::Device,
3742    ) -> Self::Primitive {
3743        TensorPrimitive::Float(B::float_full(shape, fill_value.elem(), device))
3744    }
3745
3746    fn sum(tensor: Self::Primitive) -> Self::Primitive {
3747        match tensor {
3748            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),
3749            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)),
3750        }
3751    }
3752
3753    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3754        match tensor {
3755            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),
3756            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)),
3757        }
3758    }
3759
3760    fn prod(tensor: Self::Primitive) -> Self::Primitive {
3761        match tensor {
3762            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),
3763            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)),
3764        }
3765    }
3766
3767    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3768        match tensor {
3769            TensorPrimitive::Float(tensor) => {
3770                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))
3771            }
3772            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)),
3773        }
3774    }
3775
3776    fn mean(tensor: Self::Primitive) -> Self::Primitive {
3777        match tensor {
3778            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),
3779            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)),
3780        }
3781    }
3782
3783    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3784        match tensor {
3785            TensorPrimitive::Float(tensor) => {
3786                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))
3787            }
3788            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)),
3789        }
3790    }
3791
3792    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3793        B::float_equal_elem(lhs.tensor(), rhs)
3794    }
3795    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3796        B::float_not_equal_elem(lhs.tensor(), rhs)
3797    }
3798    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3799        B::float_greater(lhs.tensor(), rhs.tensor())
3800    }
3801
3802    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3803        B::float_greater_elem(lhs.tensor(), rhs)
3804    }
3805
3806    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3807        B::float_greater_equal(lhs.tensor(), rhs.tensor())
3808    }
3809
3810    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3811        B::float_greater_equal_elem(lhs.tensor(), rhs)
3812    }
3813
3814    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3815        B::float_lower(lhs.tensor(), rhs.tensor())
3816    }
3817
3818    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3819        B::float_lower_elem(lhs.tensor(), rhs)
3820    }
3821
3822    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
3823        B::float_lower_equal(lhs.tensor(), rhs.tensor())
3824    }
3825
3826    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
3827        B::float_lower_equal_elem(lhs.tensor(), rhs)
3828    }
3829
3830    fn mask_where(
3831        tensor: Self::Primitive,
3832        mask: B::BoolTensorPrimitive,
3833        source: Self::Primitive,
3834    ) -> Self::Primitive {
3835        match (tensor, source) {
3836            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => {
3837                TensorPrimitive::Float(B::float_mask_where(tensor, mask, source))
3838            }
3839            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => {
3840                TensorPrimitive::QFloat(B::q_mask_where(tensor, mask, source))
3841            }
3842            _ => panic!("Primitive type mismatch for tensor and source"),
3843        }
3844    }
3845
3846    fn mask_fill(
3847        tensor: Self::Primitive,
3848        mask: B::BoolTensorPrimitive,
3849        value: Self::Elem,
3850    ) -> Self::Primitive {
3851        match tensor {
3852            TensorPrimitive::Float(tensor) => {
3853                TensorPrimitive::Float(B::float_mask_fill(tensor, mask, value))
3854            }
3855            TensorPrimitive::QFloat(tensor) => {
3856                TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask, value))
3857            }
3858        }
3859    }
3860
3861    fn select(tensor: Self::Primitive, dim: usize, indices: Tensor<B, 1, Int>) -> Self::Primitive {
3862        match tensor {
3863            TensorPrimitive::Float(tensor) => {
3864                TensorPrimitive::Float(B::float_select(tensor, dim, indices.primitive))
3865            }
3866            TensorPrimitive::QFloat(tensor) => {
3867                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices.primitive))
3868            }
3869        }
3870    }
3871
3872    fn select_assign(
3873        tensor: Self::Primitive,
3874        dim: usize,
3875        indices: Tensor<B, 1, Int>,
3876        values: Self::Primitive,
3877    ) -> Self::Primitive {
3878        match (tensor, values) {
3879            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
3880                TensorPrimitive::Float(B::float_select_assign(
3881                    tensor,
3882                    dim,
3883                    indices.primitive,
3884                    values,
3885                ))
3886            }
3887            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
3888                TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values))
3889            }
3890            _ => panic!("Primitive type mismatch for tensor and values"),
3891        }
3892    }
3893
3894    fn gather(
3895        dim: usize,
3896        tensor: Self::Primitive,
3897        indices: B::IntTensorPrimitive,
3898    ) -> Self::Primitive {
3899        match tensor {
3900            TensorPrimitive::Float(tensor) => {
3901                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))
3902            }
3903            TensorPrimitive::QFloat(tensor) => {
3904                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))
3905            }
3906        }
3907    }
3908
3909    fn scatter(
3910        dim: usize,
3911        tensor: Self::Primitive,
3912        indices: B::IntTensorPrimitive,
3913        values: Self::Primitive,
3914    ) -> Self::Primitive {
3915        match (tensor, values) {
3916            (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => {
3917                TensorPrimitive::Float(B::float_scatter(dim, tensor, indices, values))
3918            }
3919            (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => {
3920                TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices, values))
3921            }
3922            _ => panic!("Primitive type mismatch for tensor and values"),
3923        }
3924    }
3925
3926    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3927        match tensor {
3928            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),
3929            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),
3930        }
3931    }
3932
3933    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
3934        match tensor {
3935            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),
3936            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),
3937        }
3938    }
3939
3940    fn max(tensor: Self::Primitive) -> Self::Primitive {
3941        match tensor {
3942            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),
3943            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),
3944        }
3945    }
3946
3947    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3948        match tensor {
3949            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),
3950            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),
3951        }
3952    }
3953
3954    fn max_dim_with_indices(
3955        tensor: Self::Primitive,
3956        dim: usize,
3957    ) -> (Self::Primitive, IntTensor<B>) {
3958        match tensor {
3959            TensorPrimitive::Float(tensor) => {
3960                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);
3961                (TensorPrimitive::Float(values), indices)
3962            }
3963            TensorPrimitive::QFloat(tensor) => {
3964                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);
3965                (TensorPrimitive::QFloat(values), indices)
3966            }
3967        }
3968    }
3969
3970    fn min(tensor: Self::Primitive) -> Self::Primitive {
3971        match tensor {
3972            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),
3973            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),
3974        }
3975    }
3976
3977    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
3978        match tensor {
3979            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),
3980            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),
3981        }
3982    }
3983
3984    fn min_dim_with_indices(
3985        tensor: Self::Primitive,
3986        dim: usize,
3987    ) -> (Self::Primitive, IntTensor<B>) {
3988        match tensor {
3989            TensorPrimitive::Float(tensor) => {
3990                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);
3991                (TensorPrimitive::Float(values), indices)
3992            }
3993            TensorPrimitive::QFloat(tensor) => {
3994                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);
3995                (TensorPrimitive::QFloat(values), indices)
3996            }
3997        }
3998    }
3999
4000    fn clamp(tensor: Self::Primitive, min: B::FloatElem, max: B::FloatElem) -> Self::Primitive {
4001        match tensor {
4002            TensorPrimitive::Float(tensor) => {
4003                TensorPrimitive::Float(B::float_clamp(tensor, min, max))
4004            }
4005            TensorPrimitive::QFloat(tensor) => {
4006                TensorPrimitive::QFloat(B::q_clamp(tensor, min, max))
4007            }
4008        }
4009    }
4010
4011    fn clamp_min(tensor: Self::Primitive, min: B::FloatElem) -> Self::Primitive {
4012        match tensor {
4013            TensorPrimitive::Float(tensor) => {
4014                TensorPrimitive::Float(B::float_clamp_min(tensor, min))
4015            }
4016            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)),
4017        }
4018    }
4019
4020    fn clamp_max(tensor: Self::Primitive, max: B::FloatElem) -> Self::Primitive {
4021        match tensor {
4022            TensorPrimitive::Float(tensor) => {
4023                TensorPrimitive::Float(B::float_clamp_max(tensor, max))
4024            }
4025            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)),
4026        }
4027    }
4028
4029    fn abs(tensor: Self::Primitive) -> Self::Primitive {
4030        match tensor {
4031            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),
4032            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),
4033        }
4034    }
4035
4036    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4037        match (lhs, rhs) {
4038            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4039                TensorPrimitive::Float(B::float_powf(lhs, rhs))
4040            }
4041            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4042                TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4043            }
4044            _ => panic!("Primitive type mismatch for lhs and rhs"),
4045        }
4046    }
4047
4048    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4049        match lhs {
4050            TensorPrimitive::Float(lhs) => {
4051                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4052            }
4053            TensorPrimitive::QFloat(lhs) => {
4054                TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4055            }
4056        }
4057    }
4058
4059    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
4060        match (lhs, rhs) {
4061            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
4062                TensorPrimitive::Float(B::float_powf(lhs, rhs))
4063            }
4064            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
4065                TensorPrimitive::QFloat(B::q_powf(lhs, rhs))
4066            }
4067            _ => panic!("Primitive type mismatch for lhs and rhs"),
4068        }
4069    }
4070
4071    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
4072        match lhs {
4073            TensorPrimitive::Float(lhs) => {
4074                TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem()))
4075            }
4076            TensorPrimitive::QFloat(lhs) => {
4077                TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem()))
4078            }
4079        }
4080    }
4081
4082    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
4083        TensorPrimitive::Float(B::float_random(shape, distribution, device))
4084    }
4085
4086    fn sign(tensor: Self::Primitive) -> Self::Primitive {
4087        TensorPrimitive::Float(B::float_sign(tensor.tensor()))
4088    }
4089
4090    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
4091        match tensor {
4092            TensorPrimitive::Float(tensor) => {
4093                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))
4094            }
4095            TensorPrimitive::QFloat(tensor) => {
4096                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))
4097            }
4098        }
4099    }
4100
4101    fn sort_with_indices(
4102        tensor: Self::Primitive,
4103        dim: usize,
4104        descending: bool,
4105    ) -> (Self::Primitive, <Int as TensorKind<B>>::Primitive) {
4106        match tensor {
4107            TensorPrimitive::Float(tensor) => {
4108                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);
4109                (TensorPrimitive::Float(values), indices)
4110            }
4111            TensorPrimitive::QFloat(tensor) => {
4112                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);
4113                (TensorPrimitive::QFloat(values), indices)
4114            }
4115        }
4116    }
4117
4118    fn argsort(
4119        tensor: Self::Primitive,
4120        dim: usize,
4121        descending: bool,
4122    ) -> <Int as TensorKind<B>>::Primitive {
4123        match tensor {
4124            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),
4125            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),
4126        }
4127    }
4128}
4129
4130impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
4131where
4132    B: Backend,
4133    K: Numeric<B>,
4134    K::Elem: Element,
4135{
4136    type Output = Self;
4137
4138    fn add(self, rhs: Tensor<B, D, K>) -> Self {
4139        Self::add(self, rhs)
4140    }
4141}
4142
4143impl<E, const D: usize, B, K> core::ops::Add<E> for Tensor<B, D, K>
4144where
4145    E: ElementConversion,
4146    B: Backend,
4147    K: Numeric<B>,
4148    K::Elem: Element,
4149{
4150    type Output = Self;
4151
4152    fn add(self, other: E) -> Self {
4153        Tensor::add_scalar(self, other)
4154    }
4155}
4156
4157impl<B, const D: usize, K> core::ops::Sub<Tensor<B, D, K>> for Tensor<B, D, K>
4158where
4159    B: Backend,
4160    K: Numeric<B>,
4161    K::Elem: Element,
4162{
4163    type Output = Self;
4164
4165    fn sub(self, rhs: Tensor<B, D, K>) -> Self {
4166        Tensor::sub(self, rhs)
4167    }
4168}
4169
4170impl<E, const D: usize, B, K> core::ops::Sub<E> for Tensor<B, D, K>
4171where
4172    E: ElementConversion,
4173    B: Backend,
4174    K: Numeric<B>,
4175    K::Elem: Element,
4176{
4177    type Output = Self;
4178
4179    fn sub(self, other: E) -> Self {
4180        Tensor::sub_scalar(self, other)
4181    }
4182}
4183
4184impl<B, const D: usize, K> core::ops::Div<Tensor<B, D, K>> for Tensor<B, D, K>
4185where
4186    B: Backend,
4187    K: Numeric<B>,
4188    K::Elem: Element,
4189{
4190    type Output = Self;
4191
4192    fn div(self, rhs: Tensor<B, D, K>) -> Self {
4193        Tensor::div(self, rhs)
4194    }
4195}
4196
4197impl<E, const D: usize, B, K> core::ops::Div<E> for Tensor<B, D, K>
4198where
4199    E: ElementConversion,
4200    B: Backend,
4201    K: Numeric<B>,
4202    K::Elem: Element,
4203{
4204    type Output = Self;
4205
4206    fn div(self, other: E) -> Self {
4207        Tensor::div_scalar(self, other)
4208    }
4209}
4210
4211impl<const D: usize, B, K> core::ops::Rem<Tensor<B, D, K>> for Tensor<B, D, K>
4212where
4213    B: Backend,
4214    K: Numeric<B>,
4215    K::Elem: Element,
4216{
4217    type Output = Self;
4218    fn rem(self, rhs: Tensor<B, D, K>) -> Self::Output {
4219        Tensor::remainder(self, rhs)
4220    }
4221}
4222
4223impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
4224where
4225    E: ElementConversion,
4226    B: Backend,
4227    K: Numeric<B>,
4228    K::Elem: Element,
4229{
4230    type Output = Self;
4231
4232    fn rem(self, other: E) -> Self {
4233        Tensor::remainder_scalar(self, other)
4234    }
4235}
4236
4237impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
4238where
4239    B: Backend,
4240    K: Numeric<B>,
4241    K::Elem: Element,
4242{
4243    type Output = Self;
4244
4245    fn mul(self, rhs: Tensor<B, D, K>) -> Self {
4246        Tensor::mul(self, rhs)
4247    }
4248}
4249
4250impl<E, const D: usize, B, K> core::ops::Mul<E> for Tensor<B, D, K>
4251where
4252    E: ElementConversion,
4253    B: Backend,
4254    K: Numeric<B>,
4255    K::Elem: Element,
4256{
4257    type Output = Self;
4258
4259    fn mul(self, other: E) -> Self {
4260        Tensor::mul_scalar(self, other)
4261    }
4262}
4263
4264impl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>
4265where
4266    B: Backend,
4267    K: Numeric<B>,
4268    K::Elem: Element,
4269{
4270    type Output = Self;
4271
4272    fn neg(self) -> Self {
4273        Tensor::neg(self)
4274    }
4275}